Commit e51249d7 authored by Paul Shved's avatar Paul Shved
Browse files

Ask the model to optimize offsets more if class conf is good

parent 31063385
......@@ -128,8 +128,12 @@ class Model(tf.keras.Model):
return z
@tf.function
def loss(self, gt_features, y_pred,
def loss(self,
gt_features, y_pred,
debug_print=False,
task_suppression='sum',
alpha=1.0,
alpha_extra=10.0, # For 'limit' suppression.
hard_negative_min_per_batch=4,
hard_negative_mining_ratio=3.):
"""Returns the loss function.
......@@ -148,7 +152,6 @@ class Model(tf.keras.Model):
L_conf = 0
L_loc = 0
alpha = 1.0
smooth_l1 = tf.losses.Huber(delta=tf.cast(1.0, loss_dtype)) # Lol, need to cast 1.0 to float16 or the Mul inside won't work.
N = y_pred.shape[0]
......@@ -276,9 +279,30 @@ class Model(tf.keras.Model):
y_pred_params * gt_pos_params_mask,
y_gt_params * gt_pos_params_mask)
loss = 1.0 / num_matched_boxes * (L_conf + alpha * L_loc)
if debug_print:
tf.print("Conf then Loc loss", output_stream=sys.stderr)
tf.print(L_conf, output_stream=sys.stderr)
tf.print(alpha * L_loc, output_stream=sys.stderr)
if task_suppression == 'sum':
loss = 1.0 / num_matched_boxes * (L_conf + alpha * L_loc)
elif task_suppression == 'limit':
conf_ok = tf.reduce_max(tf.math.softmax(class_probs_pred * gt_pos_cls_mask, axis=1), axis=1)
good_probs = gt_pos_mask & (conf_ok > 0.999)
box_params_focus_mask = expand_mask(good_probs, self.detector.num_box_params)
L_extra = tf.cast(alpha_extra, tf.float16) * smooth_l1(
y_pred_params * box_params_focus_mask,
y_gt_params * box_params_focus_mask)
if debug_print:
tf.print("Extra loss", output_stream=sys.stderr)
tf.print(L_extra, output_stream=sys.stderr)
loss = 1.0 / num_matched_boxes * (L_conf + alpha * L_loc + alpha * L_extra)
else:
raise ValueError("Unknown task_suppression: {}".format(task_suppression))
if True or debug_print:
tf.print("KEY then LOSS for 0", output_stream=sys.stderr)
tf.print(gt_features['key'][0], output_stream=sys.stderr)
tf.print(gt_features['num_matched_boxes'][0], output_stream=sys.stderr)
......
......@@ -696,6 +696,9 @@ class SmileBotTrainer(object):
tape.watch(X)
y_pred = model(X, training=True)
loss = model.loss(gt_features, y_pred,
task_suppression=self.train_kwargs.get('task_suppression', 'sum'),
alpha=self.train_kwargs.get('alpha', 1.0),
alpha_extra=self.train_kwargs.get('alpha_extra', 1.0),
hard_negative_mining_ratio=hard_negative_mining_ratio,
debug_print=self.train_kwargs.get('debug_print', False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment