Limit neural network output to subset of trained classes

First of all, I will loosely go through available options you have listed and add some viable alternatives with the pros and cons. It's kinda hard to structure this answer but I hope you'll get what I'm trying to put out:

1. Multiply restricted before sending it through softmax.

Obviously may give higher chance to the zeroed-out entries as you have written, at seems like a false approach at the beginning.

Alternative: replace impossible values with smallest logit value. This one is similar to softmax(output[1:]), though the network will be even more uncertain about the results. Example pytorch implementation:

import torch

logits = torch.Tensor([5.39413513, 3.81445419, 3.75369546, 1.02716988, 0.39189373])
minimum, _ = torch.min(logits, dim=0)
logits[0] = minimum

which yields:

tensor([0.0158, 0.4836, 0.4551, 0.0298, 0.0158])


  • Citing you: "In the original output the softmax gives .70 that the answer is [1,0,0,0,0] but if that's an invalid answer and thus removed the redistribution how assigns the 4 remaining options with under 50% probability which could easily be ignored as too low to use."

Yes, and you would be in the right when doing that. Even more so, the actual probabilities for this class are actually far lower, around 14% (tensor([0.7045, 0.1452, 0.1366, 0.0089, 0.0047])). By manually changing the output you are essentially destroying the properties this NN has learned (and it's output distribution) rendering some part of your computations pointless. This points to another problem stated in the bounty this time:

2. NN are known to be overconfident for classification problems

I can imagine this being solved in multiple ways:

2.1 Ensemble

Create multiple neural networks and ensemble them by summing logits taking argmax at the end (or softmax and then `argmax). Hypothetical situation with 3 different models with different predictions:

import torch

predicted_logits_1 = torch.Tensor([5.39413513, 3.81419, 3.7546, 1.02716988, 0.39189373])
predicted_logits_2 = torch.Tensor([3.357895, 4.0165, 4.569546, 0.02716988, -0.189373])
predicted_logits_3 = torch.Tensor([2.989513, 5.814459, 3.55369546, 3.06988, -5.89473])

combined_logits = predicted_logits_1 + predicted_logits_2 + predicted_logits_3

This would gives us the following probabilities after softmax:

[0.11291057 0.7576356 0.1293983 0.00005554 0.]

(notice the first class is now the most probable)

You can use bootstrap aggregating and other ensembling techniques to improve predictions. This approach makes the classifying decision surface smoother and fixes mutual errors between classifiers (given their predictions vary quite a lot). It would take many posts to describe in any greater detail (or separate question with specific problem would be needed), here or here are some which might get you started.

Still I would not mix this approach with manual selection of outputs.

2.2 Transform the problem into binary

This approach might yield better inference time and maybe even better training time if you can distribute it over multiple GPUs.

Basically, each class of yours can either be present (1) or absent (0). In principle you could train N neural networks for N classes, each outputting a single unbounded number (logit). This single number tells whether the network thinks this example should be classified as it's class or not.

If you are sure certain class won't be the outcome for sure you do not run network responsible for this class detection. After obtaining predictions from all the networks (or subset of networks), you choose the highest value (or highest probability if you use sigmoid activation, though it would be computationally wasteful).

Additional benefit would be simplicity of said networks (easier training and fine-tuning) and easy switch-like behavior if needed.


If I were you I would go with the approach outlined in 2.2 as you could save yourself some inference time easily and would allow you to "choose outputs" in a sensible manner.

If this approach is not enough, you may consider N ensembles of networks, so a mix of 2.2 and 2.1, some bootstrap or other ensembling techniques. This should improve your accuracy as well.