

Average weights of two Pytorch models
source link: https://donghao.org/2022/07/14/average-weights-of-two-pytorch-models/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Average weights of two Pytorch models
After reading this paper, I begin to do an experiment about it. Referencing this snippet, I wrote my code:
net1 = model_builder.build_model() net2 = model_builder.build_model() output = model_builder.build_model() net1.load_state_dict(torch.load(args.model1, map_location="cpu")) net2.load_state_dict(torch.load(args.model2, map_location="cpu")) # Average sd1 = net1.named_parameters() sd2 = net2.named_parameters() sdo = dict(sd2) for name, param in sd1: sdo[name].data.copy_(0.5*param.data + 0.5*sdo[name].data) output.load_state_dict(sdo) torch.save(output, args.output) # here is a test output.load_state_dict(torch.load(args.output))
net1 = model_builder.build_model()
net2 = model_builder.build_model()
output = model_builder.build_model()
net1.load_state_dict(torch.load(args.model1, map_location="cpu"))
net2.load_state_dict(torch.load(args.model2, map_location="cpu"))
# Average
sd1 = net1.named_parameters()
sd2 = net2.named_parameters()
sdo = dict(sd2)
for name, param in sd1:
sdo[name].data.copy_(0.5*param.data + 0.5*sdo[name].data)
output.load_state_dict(sdo)
torch.save(output, args.output)
# here is a test
output.load_state_dict(torch.load(args.output))
But after generating the average-weights new model, the PyTorch failed to load it:
Traceback (most recent call last): File "average_models.py", line 43, in <module> output.load_state_dict(torch.load(args.output)) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1534, in load_state_dict state_dict = state_dict.copy() File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1186, in __getattr__ raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'RegNet' object has no attribute 'copy'
Traceback (most recent call last):
File "average_models.py", line 43, in <module>
output.load_state_dict(torch.load(args.output))
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1534, in load_state_dict
state_dict = state_dict.copy()
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1186, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'RegNet' object has no attribute 'copy'
The reason for failure is quite simple: we only need to save the state_dict of the model instead of all information (since I am using FP16 format ). Therefore the correct code should be:
net1 = model_builder.build_model() net2 = model_builder.build_model() net1.load_state_dict(torch.load(args.model1, map_location="cpu")) net2.load_state_dict(torch.load(args.model2, map_location="cpu")) # Average sd1 = net1.named_parameters() sd2 = net2.named_parameters() sdo = dict(sd2) for name, param in sd1: sdo[name].data.copy_(0.5*param.data + 0.5*sdo[name].data) torch.save(sdo, args.output)
net1 = model_builder.build_model()
net2 = model_builder.build_model()
net1.load_state_dict(torch.load(args.model1, map_location="cpu"))
net2.load_state_dict(torch.load(args.model2, map_location="cpu"))
# Average
sd1 = net1.named_parameters()
sd2 = net2.named_parameters()
sdo = dict(sd2)
for name, param in sd1:
sdo[name].data.copy_(0.5*param.data + 0.5*sdo[name].data)
torch.save(sdo, args.output)
BTW, the averaging of my models doesn’t rise accuracy as the paper suggests in my experiment.
Related Posts
- Tips about Numpy and PyTorch
1. Type convertion in Numpy Here is my code: import numpy as np a =…
- Some tips about PyTorch and Python
1. '()' may mean tuple or nothing. len(("birds")) # the inner '()' means nothing len(("birds",))…
- Using PyTorch on ClearLinux docker image
I am using Nvidia's official docker image of PyTorch for my model training for quite…
July 14, 2022 - 23:40
RobinDong
machine learning
PyTorch
Leave a comment
Leave a Reply Cancel reply
Your email address will not be published. Required fields are marked *
Comment *
Name *
Email *
Website
Save my name, email, and website in this browser for the next time I comment.
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK