

轻量级PyTorch通用训练模板pytorch-accelerated解析:4 -- 其他API
source link: https://www.qixinbo.info/2022/06/12/pytorch-accelerated_4/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

这一章将研究pytorch-accelerated
的其他API
,包括追踪Tracking
、运行配置Run Config
、微调Fine tuning
。
Tracking
RunHistory抽象基类
RunHistory
抽象基类定义了Trainer
运行历史的API
:
class pytorch_accelerated.tracking.RunHistory
(1)获得当前epoch
的数值:
def current_epoch(self) -> int
(2)获得指定指标最近的记录值:
def get_latest_metric(self, metric_name)
(3)获得在追踪的所有指标的名字:
def get_metric_names(self) -> Iterable
返回的是一个集合,所以名称都是独一无二的。
(4)获得指定指标的所有值:
def get_metric_values(self, metric_name) -> Iterable
(5)重置:
def reset(self)
重置RunHistory
的状态。
(6)记录指定指标的值:
def update_metric(self, metric_name, metric_value)
上面的RunHistory
是抽象基类,其定义的方法都没有具体的实现。InMemoryRunHistory
是对RunHistory
的一个具体实现。
class InMemoryRunHistory(RunHistory):
"""
An implementation of :class:`RunHistory` which stores all recorded values in memory.
"""
def __init__(self):
self._current_epoch = 1
self._metrics = defaultdict(list)
def get_metric_names(self):
return set(self._metrics.keys())
def get_metric_values(self, metric_name):
return self._metrics[metric_name]
def get_latest_metric(self, metric_name):
if len(self._metrics[metric_name]) > 0:
return self._metrics[metric_name][-1]
else:
raise ValueError(
f"No values have been recorded for the metric {metric_name}"
)
def update_metric(self, metric_name, metric_value):
self._metrics[metric_name].append(metric_value)
@property
def current_epoch(self):
return self._current_epoch
def _increment_epoch(self):
self._current_epoch += 1
def reset(self):
self._current_epoch = 1
self._metrics = defaultdict(list)
Run Config
TrainerRunConfig
是一个不可变的数据类,包含训练器Trainer
当前状态的数值。
@dataclass(frozen=True)
class TrainerRunConfig:
其属性有:
num_epochs
:当前训练的迭代次数train_per_device_batch_size
:训练时每个设备上的批大小train_dl_kwargs
:创建训练集数据加载器的所需参数eval_per_device_batch_size
:评估时每个设备上的批大小eval_dl_kwargs
:创建验证集数据加载器的所需参数gradient_accumulation_steps
:训练时梯度累加的步数gradient_clip_value
:模型参数的梯度修剪的阈值train_total_batch_size
:训练时总批大小eval_total_batch_size
:评估时总批大小num_update_steps_per_epoch
:训练时当模型参数更新时的步数max_num_train_steps
:训练的总步数,如果指定的话,会覆盖掉num_epochs
参数is_local_process_zero
:如果当前进程是当前节点上的主进程,则为True
;否则为False
is_world_process_zero
:如果当前进程是横跨所有节点的主进程,则为True
,否则为False
is_distributed
:如果trainer
是分布式训练,则为True
,否则为False
mixed_precision
:包含所用的混合精度类型的字符串,否则为no
Fine tuning
ModelFreezer
是一个用来冻结和解冻一个模型的不同部分的类,其用来简化迁移学习中微调的操作。
class pytorch_accelerated.finetuning.ModelFreezer(model, freeze_batch_norms=False)
该类使用以下的抽象定义:
Layer
:是一个深度为1的torch.nn.Module
的子类,即这个特定的module
不是嵌套的LayerGroup
:是模型类的属性,可以是网络层layers
或嵌套的modules
举个例子,如下的模型:
from torch import nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.input = nn.Linear(100, 100)
self.block_1 = nn.Sequential(
nn.Linear(100, 100),
nn.BatchNorm1d(100),
nn.Sequential(
nn.Linear(100, 100),
nn.BatchNorm1d(100),
nn.ReLU(),
),
)
self.output = nn.Linear(100, 10)
def forward(self, x):
x = self.input(x)
x = self.block_1(x)
out = self.output(x)
return out
该模型的LayerGroup
就是[input, block_1, output]
,而Layers
则是一个有序的、压平的Linear
、BatchNorm
、ReLU
模块列表。
主要方法有:
(1)冻结指定索引的LayerGroup
:
freeze(from_index=0, to_index=-2, set_modules_as_eval=False)
默认情况下,这将冻结所有层组,除了最后一个层组。
参数有:
from_index
:第一个被冻结的LayerGroup
的索引to_index
:最后一个被冻结的LayerGroup
的索引set_modules_as_eval
:若为True
,这些冻结的模块也会被置为eval
模式。默认是False
(2)返回模型的所有LayerGroups
:
get_layer_groups() → List[LayerGroup]
(3)返回模型的所有Layer
:
get_layers() → List[Layer]
(4)返回所有未被冻结的模型参数:
get_trainable_parameters()
这些参数将在训练中被更新。
(5)解冻指定索引的LayerGroup
:
unfreeze(from_index=-1, to_index=0, set_modules_as_training=True)
默认情况下,这将解冻所有LayerGroups
。对于一个LayerGroup
,任何已经解冻的参数都会被返回,这样如果需要的话,它们可以被添加到一个优化器中。
参数有:
from_index
:第一个被解冻的LayerGroup
的索引to_index
:最后一个被解冻的LayerGroup
的索引set_modules_as_training
:若为True
,这些解冻的模块也会被置为train
模式。默认是True
。
返回值是:包含每一个解冻的layer Group
的参数的字典。
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK