How to set Tensorflow dynamic_rnn, zero_state without a fixed batch_size?
You can specify the batch_size
as a placeholder, not a constant. Just make sure to feed the relevant number in feed_dict
, which will be different for training and for testing
Importantly, specify []
as dimensions for the placeholder, because you might get errors if you specify None
, as is customary elsewhere. So something like this should work:
batch_size = tf.placeholder(tf.int32, [], name='batch_size')
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in,
initial_state=init_state, time_major=False)
# rest of your code
out = sess.run(outputs, feed_dict={batch_size:100})
out = sess.run(outputs, feed_dict={batch_size:10})
Obviously make sure that the batch parameter matches the shape of your inputs, which dynamic_rnn
will interpret as [batch_size, seq_len, features]
or [seq_len, batch_size, features]
if time_major
is set to True