stable baseline export tensorflow code example
Example: stable baseline export tensorflow
# Saving model. model here is a stable_baselines model
with model.graph.as_default():
tf.saved_model.simple_save(model.sess, 'tensorflow_model', inputs={"obs": model.act_model.obs_ph},
outputs={"action": model.act_model._policy_proba})
# Loading model into TF1
global_session = tf.Session()
with global_session.as_default():
model_loaded = tf.saved_model.load_v2('tensorflow_model')
model_loaded = model_loaded.signatures['serving_default']
# important step
init = tf.global_variables_initializer()
global_session.run(init)
# Calling the model and converting the result to original format
computation_graph = model_loaded(tf.convert_to_tensor(grid))['action']
with global_session.as_default():
result = computation_graph.eval()
col = result.argmax()