Calling a Keras model on a TensorFlow tensor but keep weights
Finally found the answer. There are two problems in the example from the question.
1:
The first and most obvious was that I called the tf.global_variables_intializer()
function which will re-initialize all variables in the session. Instead I should have called the tf.variables_initializer(var_list)
where var_list
is a list of variables to initialize.
2:
The second problem was that Keras did not use the same session as the native Tensorflow objects. This meant that to be able to run the tensorflow object model2(v)
with my session sess
it needed to be reinitialized. Again Keras as a simplified interface to tensorflow: Tutorial was able to help
We should start by creating a TensorFlow session and registering it with Keras. This means that Keras will use the session we registered to initialize all variables that it creates internally.
import tensorflow as tf
sess = tf.Session()
from keras import backend as K
K.set_session(sess)
If we apply these changes to the example provided in my question we get the following code that does exactly what is expected from it.
from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense
sess = tf.Session()
# Register session with Keras
K.set_session(sess)
model = Sequential()
model.add(Dense(1, input_dim=1, kernel_initializer='ones',
bias_initializer='zeros'))
model.compile(loss='binary_crossentropy', optimizer='adam',
metrics=['accuracy'])
model.save_weights('test')
model2 = Sequential()
model2.add(Dense(1, input_dim=1, kernel_initializer='zeros',
bias_initializer='zeros'))
model2.compile(loss='binary_crossentropy', optimizer='adam',
metrics=['accuracy'])
model2.load_weights('test')
v = tf.Variable([[1, ], ], dtype=tf.float32)
node = model2(v)
init = tf.variables_initializer([v, ])
sess.run(init)
print(sess.run(node), model2.predict(np.array([[1, ], ])))
# prints: (array([[ 1.]], dtype=float32), array([[ 1.]], dtype=float32))
Conclusion:
The lesson is that when mixing Tensorflow and Keras, make sure everything uses the same session.