Using multiple validation sets with keras

I ended up writing my own Callback based on the History callback to solve the problem. I'm not sure if this is the best approach but the following Callback records losses and metrics for the training and validation set like the History callback as well as losses and metrics for additional validation sets passed to the constructor.

class AdditionalValidationSets(Callback):
    def __init__(self, validation_sets, verbose=0, batch_size=None):
        """
        :param validation_sets:
        a list of 3-tuples (validation_data, validation_targets, validation_set_name)
        or 4-tuples (validation_data, validation_targets, sample_weights, validation_set_name)
        :param verbose:
        verbosity mode, 1 or 0
        :param batch_size:
        batch size to be used when evaluating on the additional datasets
        """
        super(AdditionalValidationSets, self).__init__()
        self.validation_sets = validation_sets
        for validation_set in self.validation_sets:
            if len(validation_set) not in [3, 4]:
                raise ValueError()
        self.epoch = []
        self.history = {}
        self.verbose = verbose
        self.batch_size = batch_size

    def on_train_begin(self, logs=None):
        self.epoch = []
        self.history = {}

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        self.epoch.append(epoch)

        # record the same values as History() as well
        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)

        # evaluate on the additional validation sets
        for validation_set in self.validation_sets:
            if len(validation_set) == 3:
                validation_data, validation_targets, validation_set_name = validation_set
                sample_weights = None
            elif len(validation_set) == 4:
                validation_data, validation_targets, sample_weights, validation_set_name = validation_set
            else:
                raise ValueError()

            results = self.model.evaluate(x=validation_data,
                                          y=validation_targets,
                                          verbose=self.verbose,
                                          sample_weight=sample_weights,
                                          batch_size=self.batch_size)

            for metric, result in zip(self.model.metrics_names,results):
                valuename = validation_set_name + '_' + metric
                self.history.setdefault(valuename, []).append(result)

which i am then using like this:

history = AdditionalValidationSets([(validation_data2, validation_targets2, 'val2')])
model.fit(train_data, train_targets,
          epochs=epochs,
          batch_size=batch_size,
          validation_data=(validation_data1, validation_targets1),
          callbacks=[history]
          shuffle=True)

I tested this on TensorFlow 2 and it worked. You can evaluate on as many validation sets as you want at the end of each epoch:

class MyCustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        res_eval_1 = self.model.evaluate(X_test_1, y_test_1, verbose = 0)
        res_eval_2 = self.model.evaluate(X_test_2, y_test_2, verbose = 0)
        print(res_eval_1)
        print(res_eval_2)

And later:

my_val_callback = MyCustomCallback()
# Your model creation code
model.fit(..., callbacks=[my_val_callback])