10

Average weights of two Pytorch models

 3 years ago
source link: https://donghao.org/2022/07/14/average-weights-of-two-pytorch-models/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
neoserver,ios ssh client

Average weights of two Pytorch models

After reading this paper, I begin to do an experiment about it. Referencing this snippet, I wrote my code:

    net1 = model_builder.build_model()
    net2 = model_builder.build_model()
    output = model_builder.build_model()
    net1.load_state_dict(torch.load(args.model1, map_location="cpu"))
    net2.load_state_dict(torch.load(args.model2, map_location="cpu"))
    
    # Average
    sd1 = net1.named_parameters()
    sd2 = net2.named_parameters()
    sdo = dict(sd2)
    for name, param in sd1:
        sdo[name].data.copy_(0.5*param.data + 0.5*sdo[name].data)

    output.load_state_dict(sdo)
    torch.save(output, args.output)
    
    # here is a test
    output.load_state_dict(torch.load(args.output))
Python
    net1 = model_builder.build_model()
    net2 = model_builder.build_model()
    output = model_builder.build_model()
    net1.load_state_dict(torch.load(args.model1, map_location="cpu"))
    net2.load_state_dict(torch.load(args.model2, map_location="cpu"))
    # Average
    sd1 = net1.named_parameters()
    sd2 = net2.named_parameters()
    sdo = dict(sd2)
    for name, param in sd1:
        sdo[name].data.copy_(0.5*param.data + 0.5*sdo[name].data)
    output.load_state_dict(sdo)
    torch.save(output, args.output)
    # here is a test
    output.load_state_dict(torch.load(args.output))

But after generating the average-weights new model, the PyTorch failed to load it:

Traceback (most recent call last):
  File "average_models.py", line 43, in <module>
    output.load_state_dict(torch.load(args.output))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1534, in load_state_dict
    state_dict = state_dict.copy()
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1186, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'RegNet' object has no attribute 'copy'
Shell
Traceback (most recent call last):
  File "average_models.py", line 43, in <module>
    output.load_state_dict(torch.load(args.output))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1534, in load_state_dict
    state_dict = state_dict.copy()
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1186, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'RegNet' object has no attribute 'copy'

The reason for failure is quite simple: we only need to save the state_dict of the model instead of all information (since I am using FP16 format ). Therefore the correct code should be:

    net1 = model_builder.build_model()
    net2 = model_builder.build_model()
    net1.load_state_dict(torch.load(args.model1, map_location="cpu"))
    net2.load_state_dict(torch.load(args.model2, map_location="cpu"))

    # Average 
    sd1 = net1.named_parameters()
    sd2 = net2.named_parameters()
    sdo = dict(sd2) 
    for name, param in sd1:
        sdo[name].data.copy_(0.5*param.data + 0.5*sdo[name].data)

    torch.save(sdo, args.output)
Python
    net1 = model_builder.build_model()
    net2 = model_builder.build_model()
    net1.load_state_dict(torch.load(args.model1, map_location="cpu"))
    net2.load_state_dict(torch.load(args.model2, map_location="cpu"))
    # Average 
    sd1 = net1.named_parameters()
    sd2 = net2.named_parameters()
    sdo = dict(sd2) 
    for name, param in sd1:
        sdo[name].data.copy_(0.5*param.data + 0.5*sdo[name].data)
    torch.save(sdo, args.output)

BTW, the averaging of my models doesn’t rise accuracy as the paper suggests in my experiment.

Related Posts

July 14, 2022 - 23:40 RobinDong machine learning
PyTorch
Leave a comment

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

Comment *

Name *

Email *

Website

Save my name, email, and website in this browser for the next time I comment.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK