Calculate recall for each class after each epoch in Tensorflow 2

We can use classification_report of sklearn and keras Callback to achieve this.

Working code sample (with comments)

import tensorflow as tf
import keras
from tensorflow.python.keras.layers import Dense, Input
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.callbacks import Callback
from sklearn.metrics import recall_score, classification_report
from sklearn.datasets import make_classification
import numpy as np
import matplotlib.pyplot as plt

# Model -- Binary classifier
binary_model = Sequential()
binary_model.add(Dense(16, input_shape=(2,), activation='relu'))
binary_model.add(Dense(8, activation='relu'))
binary_model.add(Dense(1, activation='sigmoid'))
binary_model.compile('adam', loss='binary_crossentropy')

# Model -- Multiclass classifier
multiclass_model = Sequential()
multiclass_model.add(Dense(16, input_shape=(2,), activation='relu'))
multiclass_model.add(Dense(8, activation='relu'))
multiclass_model.add(Dense(3, activation='softmax'))
multiclass_model.compile('adam', loss='categorical_crossentropy')

# callback to find metrics at epoch end
class Metrics(Callback):
    def __init__(self, x, y):
        self.x = x
        self.y = y if (y.ndim == 1 or y.shape[1] == 1) else np.argmax(y, axis=1)
        self.reports = []

    def on_epoch_end(self, epoch, logs={}):
        y_hat = np.asarray(self.model.predict(self.x))
        y_hat = np.where(y_hat > 0.5, 1, 0) if (y_hat.ndim == 1 or y_hat.shape[1] == 1)  else np.argmax(y_hat, axis=1)
        report = classification_report(self.y,y_hat,output_dict=True)
        self.reports.append(report)
        return
   
    # Utility method
    def get(self, metrics, of_class):
        return [report[str(of_class)][metrics] for report in self.reports]
    
# Generate some train data (2 class) and train
x, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
                           random_state=1, n_clusters_per_class=1)
metrics_binary = Metrics(x,y)
binary_model.fit(x, y, epochs=30, callbacks=[metrics_binary])

# Generate some train data (3 class) and train
x, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
                           random_state=1, n_clusters_per_class=1, n_classes=3)
y = keras.utils.to_categorical(y,3)
metrics_multiclass = Metrics(x,y)
multiclass_model.fit(x, y, epochs=30, callbacks=[metrics_multiclass])

# Plotting 
plt.close('all')
plt.plot(metrics_binary.get('recall',0), label='Class 0 recall') 
plt.plot(metrics_binary.get('recall',1), label='Class 1 recall') 

plt.plot(metrics_binary.get('precision',0), label='Class 0 precision') 
plt.plot(metrics_binary.get('precision',1), label='Class 1 precision') 

plt.plot(metrics_binary.get('f1-score',0), label='Class 0 f1-score') 
plt.plot(metrics_binary.get('f1-score',1), label='Class 1 f1-score') 
plt.legend(loc='lower right')
plt.show()

plt.close('all')
for m in ['recall', 'precision', 'f1-score']:
    for c in [0,1,2]:
        plt.plot(metrics_multiclass.get(m,c), label='Class {0} {1}'.format(c,m))
        
plt.legend(loc='lower right')
plt.show()

Output

enter image description here

enter image description here

Advantages:

  • classification_report provides lots of metrics
  • Can calculate metrics on validation data on train data by passing the same to Metrics constructor.

There are multiple ways to do this but using a callback seems the best and most kerasy way of doing it. One side-note before I show you how:

I am also not clear on if I can use Keras metrics (as they are calculated at the end of each batch and then averaged) or if I need to use Keras callbacks (which can run at the end of each epoch).

This is not true. Keras' callbacks can use the following methods:

  • on_epoch_begin: called at the beginning of every epoch.
  • on_epoch_end: called at the end of every epoch.
  • on_batch_begin: called at the beginning of every batch.
  • on_batch_end: called at the end of every batch.
  • on_train_begin: called at the beginning of model training.
  • on_train_end: called at the end of model training.

This is true regardless of whether you are using keras or tf.keras.

Below you can find my implementation of a custom callback.

class RecallHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.recall = {}

    def on_epoch_end(self, epoch, logs={}):
        # Compute and store recall for each class here.
        self.recall[...] = 42

history = RecallHistory()
model.fit(..., callbacks=[history])

print(history.recall)

In TF2, tf.keras.metrics.Recall gained a class_id member that enables to do just that. Example using FashionMNIST:

import tensorflow as tf

(x_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train[..., None].astype('float32') / 255
y_train = tf.keras.utils.to_categorical(y_train)

input_shape = x_train.shape[1:]
model = tf.keras.Sequential([
  tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=input_shape),
  tf.keras.layers.MaxPool2D(pool_size=2),
  tf.keras.layers.Dropout(0.3),

  tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'),
  tf.keras.layers.MaxPool2D(pool_size=2),
  tf.keras.layers.Dropout(0.3),

  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(units=256, activation='relu'),
  tf.keras.layers.Dropout(0.5),
  tf.keras.layers.Dense(units=10, activation='softmax')])

model.compile(loss='categorical_crossentropy', optimizer='Adam',
  metrics=[tf.keras.metrics.Recall(class_id=i) for i in range(10)])
model.fit(x_train, y_train, batch_size=128, epochs=50)

In TF 1.13, tf.keras.metric.Recall does not have this class_id argument, but it can be added by subclassing (something that, somewhat suprisingly, seems impossible in the alpha release of TF2).

class Recall(tf.keras.metrics.Recall):

  def __init__(self, *, class_id, **kwargs):
    super().__init__(**kwargs)
    self.class_id= class_id

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = y_true[:, self.class_id]
    y_pred = tf.cast(tf.equal(
      tf.math.argmax(y_pred, axis=-1), self.class_id), dtype=tf.float32)
    return super().update_state(y_true, y_pred, sample_weight)