7

Model saving error when using Apex

 1 year ago
source link: https://donghao.org/2022/06/03/model-saving-error-when-using-apex/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Model saving error when using Apex

Apex is a tool to enable mixed-precision training that comes from Nvidia.

import apex.amp as amp

net, optimizer = amp.initialize(net, optimizer, opt_level="O2")

# forward
outputs = net(inputs)

loss = criterion(outputs, targets)

optimizer.zero_grad()

# float16 backward
with amp.scale_loss(loss, optimizer) as scaled_loss:
  scaled_loss.backward()
  
optimizer.step()

...

torch.save(net, "model.pth")
Python
import apex.amp as amp
net, optimizer = amp.initialize(net, optimizer, opt_level="O2")
# forward
outputs = net(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
# float16 backward
with amp.scale_loss(loss, optimizer) as scaled_loss:
  scaled_loss.backward()
optimizer.step()
...
torch.save(net, "model.pth")

After I changed my code to use Apex, it reported an error when saving the model by using torch.save(net, "model.pth")

AttributeError: Can't pickle local object '_initialize.<locals>.patch_forward.<locals>.new_fwd'
Python
AttributeError: Can't pickle local object '_initialize.<locals>.patch_forward.<locals>.new_fwd'

Someone has already noticed this problem but it seems no one wants to solve it: link. The only solution for this comes from a Chinese blog: link. It recommends just saving model parameters:

torch.save(net.state_dict(), "model.pth")
Python
torch.save(net.state_dict(), "model.pth")

Related Posts

June 3, 2022 - 0:16 RobinDong machine learning
PyTorch
Leave a comment

Leave a Reply Cancel reply

Your email address will not be published. Required fields are marked *

Comment *

Name *

Email *

Website

Save my name, email, and website in this browser for the next time I comment.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK