Debugging nans in the backward pass
Maybe you could add Print ops to suspect ops print values, something like this
print_ops = []
for op in ops:
print_ops.append(tf.Print(op, [op],
message='%s :' % op.name, summarize=10))
print_op = tf.group(*print_ops)
sess.run([train_op, print_op])
To add to all ops, you could do a loop along the lines of add_check_numerics_ops
.
Debugging NaNs can be tricky, especially if you have a large network. tf.add_check_numerics_ops()
adds ops to the graph that assert that each floating point tensor in the graph does not contain any NaN values, but does not run these checks by default. Instead it returns an op that you can run periodically, or on every step, as follows:
train_op = ...
check_op = tf.add_check_numerics_ops()
sess = tf.Session()
sess.run([train_op, check_op]) # Runs training and checks for NaNs