stable baseline export tensorflow code example

Example: stable baseline export tensorflow

# Saving model. model here is a stable_baselines model
with model.graph.as_default():
    tf.saved_model.simple_save(model.sess, 'tensorflow_model', inputs={"obs": model.act_model.obs_ph},
                                   outputs={"action": model.act_model._policy_proba})

# Loading model into TF1
global_session = tf.Session()

with global_session.as_default():
    model_loaded = tf.saved_model.load_v2('tensorflow_model')
    model_loaded = model_loaded.signatures['serving_default']

# important step
init = tf.global_variables_initializer()
global_session.run(init)

# Calling the model and converting the result to original format

computation_graph = model_loaded(tf.convert_to_tensor(grid))['action']
with global_session.as_default():
    result = computation_graph.eval()
col = result.argmax()

Tags:

Misc Example