Merge tensorflow models
source link: https://www.chunyangwen.com/blog/tensorflow/merge-tensorflow-models.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
Sometimes you want to transfer certain weights from multiple models into a single model or just want to merge multiple models. There are at least two ways to do that:
init_from_checkpoint
You can refer to tf.train.init_from_checkpoint.
tf.train.init_from_checkpoint(init_from_checkpoint(
ckpt_dir_or_file, assignment_map
)
ckpt_dir_or_file
- Can be a checkpoint directory or file: /path/to/checkpoint_dir, /path/to/checkpoint_dir/model-1234
- Can be a saved model: /path/to/saved_model/variables/variables (The second variables is the name prefix of
variables.index
)
assignment_map
- key: can be a scope
- value: can be variable name, variable reference
It is very flexible. Under the hood, init_from_checkpoint
modifies the initializer
of a variable. When we run tf.global_variables_initializer()
, the related restore op will be executed.
variable._initializer_op = init_op
restore_op.set_shape(variable.shape)
variable._initial_value = restore_op
If you have user-defined variables such as you create a AwesomeVariable
which behaves like a tensorflow Variable but with a different back-end storage. You can define a similar function by creating the user-defined initializer_op
.
# Create your op using
io_ops.restore_v2
# Replace the initializer op
variable._initializer_op = init_op
Multiple init_from_checkpoint
can be called with different ckpt_dir_or_file
. As a result, a single model’s variables can be initialized from different source checkpoints or saved models.
import os
import tensorflow as tf
os.makedirs("./models/a", exist_ok=True)
os.makedirs("./models/b", exist_ok=True)
with tf.Session(graph=tf.Graph()) as session:
tf.Variable(3, name="a")
saver = tf.train.Saver()
session.run(tf.global_variables_initializer())
saver.save(session, "./models/a/model-a")
with tf.Session(graph=tf.Graph()) as session:
tf.Variable(4, name="b")
saver = tf.train.Saver()
session.run(tf.global_variables_initializer())
saver.save(session, "./models/b/model-b")
with tf.Session(graph=tf.Graph()) as session:
a = tf.Variable(1, name="a")
b = tf.Variable(1, name="b")
tf.train.init_from_checkpoint("./models/a/model-a", {"a": a})
tf.train.init_from_checkpoint("./models/b/model-b", {"b": b})
session.run(tf.global_variables_initializer())
print(session.run(a))
print(session.run(b))
gen_io_ops.merge_v2_checkpoints
When tensorflow loads a model, the only requirement is that: all variables in the model must have a valid value in the checkpoint. So we can first merge checkpoints of different models and then just load once.
import tensorflow as tf
from tensorflow.python.ops import gen_io_ops
src = tf.constant(["./models/a/model-a", "./models/b/model-b"])
target = tf.constant("./models/merged_model")
op = gen_io_ops.merge_v2_checkpoints(src, target, delete_old_dirs=False)
tf.Session().run(op)
with tf.Session(graph=tf.Graph()) as session:
a = tf.Variable(1, name="a")
b = tf.Variable(1, name="b")
saver = tf.train.Saver()
saver.restore(session, "./models/merged_model")
print(session.run(a))
print(session.run(b))
Be careful that delete_old_dirs
will delete the .index
and .data
file no matter it is set to True
or False
.
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK