Pytorch:Correct way to use custom weight maps in unet architecture
The weighting portion looks like just simply weighted cross entropy which is performed like this for the number of classes (2 in the example below).
weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)
EDIT:
Have you seen this implementation from Patrick Black?
# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10
# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)
# Calculate log probabilities
logp = F.log_softmax(logits)
# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))
# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)
# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)
# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()