Tensorflow 2.0 Custom loss function with multiple inputs
This problem can be easily solved using custom training in TF2. You need only compute your two-component loss function within a GradientTape
context and then call an optimizer with the produced gradients. For example, you could create a function custom_loss
which computes both losses given the arguments to each:
def custom_loss(model, loss1_args, loss2_args):
# model: tf.model.Keras
# loss1_args: arguments to loss_1, as tuple.
# loss2_args: arguments to loss_2, as tuple.
with tf.GradientTape() as tape:
l1_value = loss_1(*loss1_args)
l2_value = loss_2(*loss2_args)
loss_value = [l1_value, l2_value]
return loss_value, tape.gradient(loss_value, model.trainable_variables)
# In training loop:
loss_values, grads = custom_loss(model, loss1_args, loss2_args)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
In this way, each loss function can take an arbitrary number of eager tensors, regardless of whether they are inputs or outputs to the model. The sets of arguments to each loss function need not be disjoint as shown in this example.
To expand on Jon's answer. In case you want to still have the benefits of a Keras Model you can expand the model class and write your own custom train_step:
from tensorflow.python.keras.engine import data_adapter
# custom loss function that takes two outputs of the model
# as input parameters which would otherwise not be possible
def custom_loss(gt, x, y):
return tf.reduce_mean(x) + tf.reduce_mean(y)
class CustomModel(keras.Model):
def compile(self, optimizer, my_loss):
super().compile(optimizer)
self.my_loss = my_loss
def train_step(self, data):
data = data_adapter.expand_1d(data)
input_data, gt, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
with tf.GradientTape() as tape:
y_pred = self(input_data, training=True)
loss_value = self.my_loss(gt, y_pred[0], y_pred[1])
grads = tape.gradient(loss_value, self.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
return {"loss_value": loss_value}
...
model = CustomModel(inputs=input_tensor0, outputs=[x, y])
model.compile(optimizer=tf.keras.optimizers.Adam(), my_loss=custom_loss)