Deep-Learning Nan loss reasons
If you're training for cross entropy, you want to add a small number like 1e-8 to your output probability.
Because log(0) is negative infinity, when your model trained enough the output distribution will be very skewed, for instance say I'm doing a 4 class output, in the beginning my probability looks like
0.25 0.25 0.25 0.25
but toward the end the probability will probably look like
1.0 0 0 0
And you take a cross entropy of this distribution everything will explode. The fix is to artifitially add a small number to all the terms to prevent this.
There are lots of things I have seen make a model diverge.
Too high of a learning rate. You can often tell if this is the case if the loss begins to increase and then diverges to infinity.
I am not to familiar with the DNNClassifier but I am guessing it uses the categorical cross entropy cost function. This involves taking the log of the prediction which diverges as the prediction approaches zero. That is why people usually add a small epsilon value to the prediction to prevent this divergence. I am guessing the DNNClassifier probably does this or uses the tensorflow opp for it. Probably not the issue.
Other numerical stability issues can exist such as division by zero where adding the epsilon can help. Another less obvious one if the square root who's derivative can diverge if not properly simplified when dealing with finite precision numbers. Yet again I doubt this is the issue in the case of the DNNClassifier.
You may have an issue with the input data. Try calling
assert not np.any(np.isnan(x))
on the input data to make sure you are not introducing the nan. Also make sure all of the target values are valid. Finally, make sure the data is properly normalized. You probably want to have the pixels in the range [-1, 1] and not [0, 255].The labels must be in the domain of the loss function, so if using a logarithmic-based loss function all labels must be non-negative (as noted by evan pu and the comments below).
In my case I got NAN when setting distant integer LABELs. ie:
- Labels [0..100] the training was ok,
- Labels [0..100] plus one additional label 8000, then I got NANs.
So, not use a very distant Label.
EDIT You can see the effect in the following simple code:
from keras.models import Sequential
from keras.layers import Dense, Activation
import numpy as np
X=np.random.random(size=(20,5))
y=np.random.randint(0,high=5, size=(20,1))
model = Sequential([
Dense(10, input_dim=X.shape[1]),
Activation('relu'),
Dense(5),
Activation('softmax')
])
model.compile(optimizer = "Adam", loss = "sparse_categorical_crossentropy", metrics = ["accuracy"] )
print('fit model with labels in range 0..5')
history = model.fit(X, y, epochs= 5 )
X = np.vstack( (X, np.random.random(size=(1,5))))
y = np.vstack( ( y, [[8000]]))
print('fit model with labels in range 0..5 plus 8000')
history = model.fit(X, y, epochs= 5 )
The result shows the NANs after adding the label 8000:
fit model with labels in range 0..5
Epoch 1/5
20/20 [==============================] - 0s 25ms/step - loss: 1.8345 - acc: 0.1500
Epoch 2/5
20/20 [==============================] - 0s 150us/step - loss: 1.8312 - acc: 0.1500
Epoch 3/5
20/20 [==============================] - 0s 151us/step - loss: 1.8273 - acc: 0.1500
Epoch 4/5
20/20 [==============================] - 0s 198us/step - loss: 1.8233 - acc: 0.1500
Epoch 5/5
20/20 [==============================] - 0s 151us/step - loss: 1.8192 - acc: 0.1500
fit model with labels in range 0..5 plus 8000
Epoch 1/5
21/21 [==============================] - 0s 142us/step - loss: nan - acc: 0.1429
Epoch 2/5
21/21 [==============================] - 0s 238us/step - loss: nan - acc: 0.2381
Epoch 3/5
21/21 [==============================] - 0s 191us/step - loss: nan - acc: 0.2381
Epoch 4/5
21/21 [==============================] - 0s 191us/step - loss: nan - acc: 0.2381
Epoch 5/5
21/21 [==============================] - 0s 188us/step - loss: nan - acc: 0.2381
If using integers as targets, makes sure they aren't symmetrical at 0.
I.e., don't use classes -1, 0, 1. Use instead 0, 1, 2.