Calculate recall for each class after each epoch in Tensorflow 2
We can use classification_report
of sklearn and keras Callback
to achieve this.
Working code sample (with comments)
import tensorflow as tf
import keras
from tensorflow.python.keras.layers import Dense, Input
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.callbacks import Callback
from sklearn.metrics import recall_score, classification_report
from sklearn.datasets import make_classification
import numpy as np
import matplotlib.pyplot as plt
# Model -- Binary classifier
binary_model = Sequential()
binary_model.add(Dense(16, input_shape=(2,), activation='relu'))
binary_model.add(Dense(8, activation='relu'))
binary_model.add(Dense(1, activation='sigmoid'))
binary_model.compile('adam', loss='binary_crossentropy')
# Model -- Multiclass classifier
multiclass_model = Sequential()
multiclass_model.add(Dense(16, input_shape=(2,), activation='relu'))
multiclass_model.add(Dense(8, activation='relu'))
multiclass_model.add(Dense(3, activation='softmax'))
multiclass_model.compile('adam', loss='categorical_crossentropy')
# callback to find metrics at epoch end
class Metrics(Callback):
def __init__(self, x, y):
self.x = x
self.y = y if (y.ndim == 1 or y.shape[1] == 1) else np.argmax(y, axis=1)
self.reports = []
def on_epoch_end(self, epoch, logs={}):
y_hat = np.asarray(self.model.predict(self.x))
y_hat = np.where(y_hat > 0.5, 1, 0) if (y_hat.ndim == 1 or y_hat.shape[1] == 1) else np.argmax(y_hat, axis=1)
report = classification_report(self.y,y_hat,output_dict=True)
self.reports.append(report)
return
# Utility method
def get(self, metrics, of_class):
return [report[str(of_class)][metrics] for report in self.reports]
# Generate some train data (2 class) and train
x, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
random_state=1, n_clusters_per_class=1)
metrics_binary = Metrics(x,y)
binary_model.fit(x, y, epochs=30, callbacks=[metrics_binary])
# Generate some train data (3 class) and train
x, y = make_classification(n_features=2, n_redundant=0, n_informative=2,
random_state=1, n_clusters_per_class=1, n_classes=3)
y = keras.utils.to_categorical(y,3)
metrics_multiclass = Metrics(x,y)
multiclass_model.fit(x, y, epochs=30, callbacks=[metrics_multiclass])
# Plotting
plt.close('all')
plt.plot(metrics_binary.get('recall',0), label='Class 0 recall')
plt.plot(metrics_binary.get('recall',1), label='Class 1 recall')
plt.plot(metrics_binary.get('precision',0), label='Class 0 precision')
plt.plot(metrics_binary.get('precision',1), label='Class 1 precision')
plt.plot(metrics_binary.get('f1-score',0), label='Class 0 f1-score')
plt.plot(metrics_binary.get('f1-score',1), label='Class 1 f1-score')
plt.legend(loc='lower right')
plt.show()
plt.close('all')
for m in ['recall', 'precision', 'f1-score']:
for c in [0,1,2]:
plt.plot(metrics_multiclass.get(m,c), label='Class {0} {1}'.format(c,m))
plt.legend(loc='lower right')
plt.show()
Output
Advantages:
classification_report
provides lots of metrics- Can calculate metrics on validation data on train data by passing the same to
Metrics
constructor.
There are multiple ways to do this but using a callback
seems the best and most kerasy way of doing it. One side-note before I show you how:
I am also not clear on if I can use Keras metrics (as they are calculated at the end of each batch and then averaged) or if I need to use Keras callbacks (which can run at the end of each epoch).
This is not true. Keras' callbacks can use the following methods:
- on_epoch_begin: called at the beginning of every epoch.
- on_epoch_end: called at the end of every epoch.
- on_batch_begin: called at the beginning of every batch.
- on_batch_end: called at the end of every batch.
- on_train_begin: called at the beginning of model training.
- on_train_end: called at the end of model training.
This is true regardless of whether you are using keras
or tf.keras
.
Below you can find my implementation of a custom callback.
class RecallHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.recall = {}
def on_epoch_end(self, epoch, logs={}):
# Compute and store recall for each class here.
self.recall[...] = 42
history = RecallHistory()
model.fit(..., callbacks=[history])
print(history.recall)
In TF2, tf.keras.metrics.Recall
gained a class_id
member that enables to do just that. Example using FashionMNIST:
import tensorflow as tf
(x_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train[..., None].astype('float32') / 255
y_train = tf.keras.utils.to_categorical(y_train)
input_shape = x_train.shape[1:]
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=input_shape),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=256, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(units=10, activation='softmax')])
model.compile(loss='categorical_crossentropy', optimizer='Adam',
metrics=[tf.keras.metrics.Recall(class_id=i) for i in range(10)])
model.fit(x_train, y_train, batch_size=128, epochs=50)
In TF 1.13, tf.keras.metric.Recall
does not have this class_id
argument, but it can be added by subclassing (something that, somewhat suprisingly, seems impossible in the alpha release of TF2).
class Recall(tf.keras.metrics.Recall):
def __init__(self, *, class_id, **kwargs):
super().__init__(**kwargs)
self.class_id= class_id
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = y_true[:, self.class_id]
y_pred = tf.cast(tf.equal(
tf.math.argmax(y_pred, axis=-1), self.class_id), dtype=tf.float32)
return super().update_state(y_true, y_pred, sample_weight)