5

Pytorch入门教程15-Pytorch中模型的保存和加载

 2 years ago
source link: https://mathpretty.com/12577.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

摘要当我们完成了模型的训练之后, 我们会希望将其保存下来, 之后可以进行使用. 或是在训练过程中, 我们需要定时对模型进行保存. 所以这一篇, 我们会介绍Pytorch中模型的加载和保存.

这一篇介绍Pytorch中模型的加载和保存. 关于模型的保存和加载, 主要分为下面的几种方法:

  • 整个模型的保存和加载;
  • 模型的参数的保存和加载;
  • 模型的上下文保存, 除了保存模型的权重外, 还会保持此时的学习率, 优化器的参数, 方便恢复.

同时会在最后介绍CPU和GPU情况下的一些特殊情况.

Pytorch中变量的保存

在讲模型的保存之前, 我们首先来简单说明一下Pytorch中tensor的保存.

基础的tensor保存

首先是最基础的对tensor的保存, 可以直接使用torch.save进行保存, torch.load来进行读取. 下面看一个例子.

  1. x = torch.arange(4)
  2. # 进行保存
  3. torch.save(x, 'x-file')
  4. # 进行读取
  5. x2 = torch.load("x-file")
  6. # 比较是否一样
  7. x == x2
  8. tensor([True, True, True, True])

使用List的方式保存

除此之后, 我们还可以对tensor按照List的方式进行保存和读取.

  1. x = torch.arange(4)
  2. y = torch.zeros(4)
  3. torch.save([x, y],'x-files')
  4. x2, y2 = torch.load('x-files')
  5. (x2, y2)
  6. (tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

按照字典类型存储

最后, 我们可以按照dict的方式进行存储, 这个就很类似下面模型参数的存储了.

  1. x = torch.arange(4)
  2. y = torch.zeros(4)
  3. mydict = {'x': x, 'y': y}
  4. torch.save(mydict, 'mydict')
  5. mydict2 = torch.load('mydict')
  6. mydict2
  7. {'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

模型的保存与加载

整个模型的保存与加载

我们可以将整个模型进保存, 使用torch.save即可.

  1. # model是模型, file是要保存的文件名
  2. torch.save(model, file)

接着我们可以进行模型的加载(load), 导入之前我们需要保证网络定义的类是存在的.

  1. # 注意load之前, 我们需要先定义模型
  2. loaded_model = torch.load(file)

我们可以比较一下导入前后两个模型的参数是否有改变.

  1. print("保存前:")
  2. for param in model.parameters():
  3.     print(param)
  4. print("=====================================")
  5. print("保存后:")
  6. for param in loaded_model.parameters():
  7.     print(param)

模型参数的保存和加载

有的时候, 保存整个模型会显得比较麻烦, 存储的文件会比较大. 所以在实际使用的时候, 我们会通常在训练的过程中, 只保存模型的参数.

我们可以使用model.state_dict()将模型参数转为字典对象, 于是模型的参数保持可以使用下面的方式来进行保存.

  1. torch.save(model.state_dict(), file)

在加载参数的时候, 我们需要提前初始化模型, 接着传入保存的参数.

  1. dicts = torch.load('params.pkl')
  2. model_object.load_state_dict(dicts)

模型的上下文保存

上面我们介绍了模型的参数的保存. 但是有的时候, 我们还需要保存此时的学习率, 保存优化器的系数等. 这样一旦模型的训练终止, 我们可以从终止的地方继续开始训练我们的模型.

例如, 现在不仅有模型, 还有优化器. 那么我们保存的时候就会将模型参数, 优化器参数, 迭代次数都封装进入一个字典类型的数据.

  1. checkpoint = {
  2.     "epoch": 90,
  3.     "model_state": model.state_dict(),
  4.     "optim_state": optimizer.state_dict()

接着还是使用torch.save来保存上面字典类型的数据.

  1. FILE = "checkpoint.pth"
  2. torch.save(checkpoint, FILE)

关于这些系数的加载, 我们也是使用torch.load来进行参数的加载. 下面看一个简单的例子.

  1. checkpoint = torch.load(FILE)
  2. # 加载的文件是一个字典,根据key值,将其加载到模型、优化器、迭次次数中
  3. model.load_state_dict(checkpoint['model_state'])
  4. optimizer.load_state_dict(checkpoint['optim_state'])
  5. epoch = checkpoint['epoch']

GPU与CPU

由于GPU和CPU的训练模型方式不同, 因此保存下来的模型也存在不同. 为此, 面对不同环境下训练出来的模型, 我们的加载方式也存在细微的差别.

保存模型在GPU, 运行在CPU

此时在load_state_dict的时候, 需要指定map_location.

  1. device = torch.device('cpu')
  2. model.load_state_dict(torch.load(PATH, map_location=device))

保存模型在CPU, 运行在GPU

同样, 如果运行在GPU上面, 我们也是要通过map_location来进行指定的. 例如下面这个简单的例子.

  1. model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
  2. model.to(device)

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK