Loading two models from Saver in the same Tensorflow session

Solving this problem took a long time so I'm posting my likely imperfect solution in case anyone else needs it.

To diagnose the problem I manually looped through each of the variables and assigned them one by one. Then I noticed that after assigning the variable the name would change. This is described here: TensorFlow checkpoint save and read

Based on the advice in that post I ran each of the models in their own graphs. It also means that I had to run each graph in its own session. This meant handling the session management differently.

First I created two graphs

model_graph = tf.Graph()
with model_graph.as_default():
    model = Model(args)

adv_graph = tf.Graph()
with adv_graph.as_default():
    adversary = Adversary(adv_args)

Then two sessions

adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)

Then I initialised the variables in each session and restored each graph separately

with sess.as_default():
    with model_graph.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)

with adv_sess.as_default():
    with adv_graph.as_default():
        tf.global_variables_initializer().run()
        adv_saver = tf.train.Saver(tf.global_variables())
        adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
        adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)

From here whenever each session was needed I would wrap any tf functions in that session with with sess.as_default():. At the end I manually close the sessions

sess.close()
adv_sess.close()

The answer marked as correct does not tell us how to load two different models into one session explicitly, here is my answer:

  1. create two different name scopes for the models you want to load.

  2. initialize two savers which are going to load parameters for variables in the two different networks.

  3. load from the corresponding checkpoint files.

with tf.Session() as sess:
    with tf.name_scope("net1"):
      net1 = Net1()
    with tf.name_scope("net2"):
      net2 = Net2()

    net1_varlist = {v.op.name.lstrip("net1/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
    net1_saver = tf.train.Saver(var_list=net1_varlist)

    net2_varlist = {v.op.name.lstrip("net2/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
    net2_saver = tf.train.Saver(var_list=net2_varlist)

    net1_saver.restore(sess, "net1.ckpt")
    net2_saver.restore(sess, "net2.ckpt")