Keras: weighted binary crossentropy

You can use the sklearn module to automatically calculate the weights for each class like this:

# Import
import numpy as np
from sklearn.utils import class_weight

# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))

# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
                                            np.unique(y_train),
                                            y_train)

# Add the class weights to the training                                         
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].


Normally, the minority class will have a higher class weight. It'll be better to use one_weight=0.89, zero_weight=0.11 (btw, you can use class_weight={0: 0.11, 1: 0.89}, as suggested in the comment).

Under class imbalance, your model is seeing much more zeros than ones. It will also learn to predict more zeros than ones because the training loss can be minimized by doing so. That's also why you're seeing an accuracy close to the proportion 0.11. If you take an average over model predictions, it should be very close to zero.

The purpose of using class weights is to change the loss function so that the training loss cannot be minimized by the "easy solution" (i.e., predicting zeros), and that's why it'll be better to use a higher weight for ones.

Note that the best weights are not necessarily 0.89 and 0.11. Sometimes you might have to try something like taking logarithms or square roots (or any weights satisfying one_weight > zero_weight) to make it work.


I think using class weight in model.fit is not correct. {0:0.11, 1:0.89}, 0 here is the index, not the 0 class. Keras Documentation: https://keras.io/models/sequential/ class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.


Using class_weights in model.fit is slightly different: it actually updates samples rather than calculating weighted loss.

I also found that class_weights, as well as sample_weights, are ignored in TF 2.0.0 when x is sent into model.fit as TFDataset, or generator. It's fixed though in TF 2.1.0+ I believe.

Here is my weighted binary cross entropy function for multi-hot encoded labels.

import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
# weighted loss functions


def weighted_binary_cross_entropy(weights: dict, from_logits: bool = False):
    '''
    Return a function for calculating weighted binary cross entropy
    It should be used for multi-hot encoded labels

    # Example
    y_true = tf.convert_to_tensor([1, 0, 0, 0, 0, 0], dtype=tf.int64)
    y_pred = tf.convert_to_tensor([0.6, 0.1, 0.1, 0.9, 0.1, 0.], dtype=tf.float32)
    weights = {
        0: 1.,
        1: 2.
    }
    # with weights
    loss_fn = get_loss_for_multilabels(weights=weights, from_logits=False)
    loss = loss_fn(y_true, y_pred)
    print(loss)
    # tf.Tensor(0.6067193, shape=(), dtype=float32)

    # without weights
    loss_fn = get_loss_for_multilabels()
    loss = loss_fn(y_true, y_pred)
    print(loss)
    # tf.Tensor(0.52158177, shape=(), dtype=float32)

    # Another example
    y_true = tf.convert_to_tensor([[0., 1.], [0., 0.]], dtype=tf.float32)
    y_pred = tf.convert_to_tensor([[0.6, 0.4], [0.4, 0.6]], dtype=tf.float32)
    weights = {
        0: 1.,
        1: 2.
    }
    # with weights
    loss_fn = get_loss_for_multilabels(weights=weights, from_logits=False)
    loss = loss_fn(y_true, y_pred)
    print(loss)
    # tf.Tensor(1.0439969, shape=(), dtype=float32)

    # without weights
    loss_fn = get_loss_for_multilabels()
    loss = loss_fn(y_true, y_pred)
    print(loss)
    # tf.Tensor(0.81492424, shape=(), dtype=float32)

    @param weights A dict setting weights for 0 and 1 label. e.g.
        {
            0: 1.
            1: 8.
        }
        For this case, we want to emphasise those true (1) label, 
        because we have many false (0) label. e.g. 
            [
                [0 1 0 0 0 0 0 0 0 1]
                [0 0 0 0 1 0 0 0 0 0]
                [0 0 0 0 1 0 0 0 0 0]
            ]



    @param from_logits If False, we apply sigmoid to each logit
    @return A function to calcualte (weighted) binary cross entropy
    '''
    assert 0 in weights
    assert 1 in weights

    def weighted_cross_entropy_fn(y_true, y_pred):
        tf_y_true = tf.cast(y_true, dtype=y_pred.dtype)
        tf_y_pred = tf.cast(y_pred, dtype=y_pred.dtype)

        weights_v = tf.where(tf.equal(tf_y_true, 1), weights[1], weights[0])
        ce = K.binary_crossentropy(tf_y_true, tf_y_pred, from_logits=from_logits)
        loss = K.mean(tf.multiply(ce, weights_v))
        return loss

    return weighted_cross_entropy_fn