Model saving error when using Apex
source link: https://donghao.org/2022/06/03/model-saving-error-when-using-apex/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
Model saving error when using Apex
Apex is a tool to enable mixed-precision training that comes from Nvidia.
import apex.amp as amp net, optimizer = amp.initialize(net, optimizer, opt_level="O2") # forward outputs = net(inputs) loss = criterion(outputs, targets) optimizer.zero_grad() # float16 backward with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() ... torch.save(net, "model.pth")
import apex.amp as amp
net, optimizer = amp.initialize(net, optimizer, opt_level="O2")
# forward
outputs = net(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
# float16 backward
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
...
torch.save(net, "model.pth")
After I changed my code to use Apex, it reported an error when saving the model by using torch.save(net, "model.pth")
AttributeError: Can't pickle local object '_initialize.<locals>.patch_forward.<locals>.new_fwd'
AttributeError: Can't pickle local object '_initialize.<locals>.patch_forward.<locals>.new_fwd'
Someone has already noticed this problem but it seems no one wants to solve it: link. The only solution for this comes from a Chinese blog: link. It recommends just saving model parameters:
torch.save(net.state_dict(), "model.pth")
torch.save(net.state_dict(), "model.pth")
Related Posts
- An error about multiprocessing of Python
Our python program reported errors when running a new dataset: [77 rows x 4 columns]]'.…
- Trace memory error of CUDA program
The program which used CUDA for computing in GPU reported error about memory: terminate called…
- Debug CUDA error for PyTorch
After I changed my dataset for my code, the training failed: /tmp/pip-req-build-_tx3iysr/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:310: operator(): block: [0,0,0],…
June 3, 2022 - 0:16
RobinDong
machine learning
PyTorch
Leave a comment
Leave a Reply Cancel reply
Your email address will not be published. Required fields are marked *
Comment *
Name *
Email *
Website
Save my name, email, and website in this browser for the next time I comment.
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK