3

轻量级PyTorch通用训练模板pytorch-accelerated解析:4 -- 其他API

 1 year ago
source link: https://www.qixinbo.info/2022/06/12/pytorch-accelerated_4/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

这一章将研究pytorch-accelerated的其他API,包括追踪Tracking、运行配置Run Config、微调Fine tuning

Tracking

RunHistory抽象基类

RunHistory抽象基类定义了Trainer运行历史的API

class pytorch_accelerated.tracking.RunHistory


(1)获得当前epoch的数值:

def current_epoch(self) -> int


(2)获得指定指标最近的记录值:

def get_latest_metric(self, metric_name)


(3)获得在追踪的所有指标的名字:

def get_metric_names(self) -> Iterable


返回的是一个集合,所以名称都是独一无二的。
(4)获得指定指标的所有值:

def get_metric_values(self, metric_name) -> Iterable


(5)重置:

def reset(self)


重置RunHistory的状态。
(6)记录指定指标的值:

def update_metric(self, metric_name, metric_value)

上面的RunHistory是抽象基类,其定义的方法都没有具体的实现。
InMemoryRunHistory是对RunHistory的一个具体实现。

class InMemoryRunHistory(RunHistory):
"""
An implementation of :class:`RunHistory` which stores all recorded values in memory.
"""

def __init__(self):
self._current_epoch = 1
self._metrics = defaultdict(list)

def get_metric_names(self):
return set(self._metrics.keys())

def get_metric_values(self, metric_name):
return self._metrics[metric_name]

def get_latest_metric(self, metric_name):
if len(self._metrics[metric_name]) > 0:
return self._metrics[metric_name][-1]
else:
raise ValueError(
f"No values have been recorded for the metric {metric_name}"
)

def update_metric(self, metric_name, metric_value):
self._metrics[metric_name].append(metric_value)

@property
def current_epoch(self):
return self._current_epoch

def _increment_epoch(self):
self._current_epoch += 1

def reset(self):
self._current_epoch = 1
self._metrics = defaultdict(list)

Run Config

TrainerRunConfig是一个不可变的数据类,包含训练器Trainer当前状态的数值。

@dataclass(frozen=True)
class TrainerRunConfig:


其属性有:

  • num_epochs:当前训练的迭代次数
  • train_per_device_batch_size:训练时每个设备上的批大小
  • train_dl_kwargs:创建训练集数据加载器的所需参数
  • eval_per_device_batch_size:评估时每个设备上的批大小
  • eval_dl_kwargs:创建验证集数据加载器的所需参数
  • gradient_accumulation_steps:训练时梯度累加的步数
  • gradient_clip_value:模型参数的梯度修剪的阈值
  • train_total_batch_size:训练时总批大小
  • eval_total_batch_size:评估时总批大小
  • num_update_steps_per_epoch:训练时当模型参数更新时的步数
  • max_num_train_steps:训练的总步数,如果指定的话,会覆盖掉num_epochs参数
  • is_local_process_zero:如果当前进程是当前节点上的主进程,则为True;否则为False
  • is_world_process_zero:如果当前进程是横跨所有节点的主进程,则为True,否则为False
  • is_distributed:如果trainer是分布式训练,则为True,否则为False
  • mixed_precision:包含所用的混合精度类型的字符串,否则为no

Fine tuning

ModelFreezer是一个用来冻结和解冻一个模型的不同部分的类,其用来简化迁移学习中微调的操作。

class pytorch_accelerated.finetuning.ModelFreezer(model, freeze_batch_norms=False)


该类使用以下的抽象定义:

  • Layer:是一个深度为1的torch.nn.Module的子类,即这个特定的module不是嵌套的
  • LayerGroup:是模型类的属性,可以是网络层layers或嵌套的modules

举个例子,如下的模型:

from torch import nn

class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.input = nn.Linear(100, 100)
self.block_1 = nn.Sequential(
nn.Linear(100, 100),
nn.BatchNorm1d(100),
nn.Sequential(
nn.Linear(100, 100),
nn.BatchNorm1d(100),
nn.ReLU(),
),
)
self.output = nn.Linear(100, 10)

def forward(self, x):
x = self.input(x)
x = self.block_1(x)
out = self.output(x)
return out


该模型的LayerGroup就是[input, block_1, output],而Layers则是一个有序的、压平的LinearBatchNormReLU模块列表。

主要方法有:
(1)冻结指定索引的LayerGroup

freeze(from_index=0, to_index=-2, set_modules_as_eval=False)


默认情况下,这将冻结所有层组,除了最后一个层组。
参数有:

  • from_index:第一个被冻结的LayerGroup的索引
  • to_index:最后一个被冻结的LayerGroup的索引
  • set_modules_as_eval:若为True,这些冻结的模块也会被置为eval模式。默认是False

(2)返回模型的所有LayerGroups

get_layer_groups() → List[LayerGroup]


(3)返回模型的所有Layer

get_layers() → List[Layer]


(4)返回所有未被冻结的模型参数:

get_trainable_parameters()


这些参数将在训练中被更新。
(5)解冻指定索引的LayerGroup

unfreeze(from_index=-1, to_index=0, set_modules_as_training=True)


默认情况下,这将解冻所有LayerGroups。对于一个LayerGroup,任何已经解冻的参数都会被返回,这样如果需要的话,它们可以被添加到一个优化器中。
参数有:

  • from_index:第一个被解冻的LayerGroup的索引
  • to_index:最后一个被解冻的LayerGroup的索引
  • set_modules_as_training:若为True,这些解冻的模块也会被置为train模式。默认是True

返回值是:包含每一个解冻的layer Group的参数的字典。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK