6

PyTorch_Practice

 3 years ago
source link: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/finetune_resnet18.py
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
master

PyTorch_Practice/lesson7/finetune_resnet18.py / Jump to

Go to file

0 contributors

209 lines (160 sloc) 6.45 KB

# -*- coding: utf-8 -*- """ 模型finetune方法 """ import os import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from matplotlib import pyplot as plt from lesson2.rmb_classification.tools.my_dataset import AntsDataset from common_tools import set_seed import torchvision.models as models import enviroments BASEDIR = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("use device :{}".format(device))

set_seed(1) # 设置随机种子 label_name = {"ants": 0, "bees": 1}

# 参数设置 MAX_EPOCH = 25 BATCH_SIZE = 16 LR = 0.001 log_interval = 10 val_interval = 1 classes = 2 start_epoch = -1 lr_decay_step = 7

# ============================ step 1/5 数据 ============================ data_dir = enviroments.hymenoptera_data_dir train_dir = os.path.join(data_dir, "train") valid_dir = os.path.join(data_dir, "val")

norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ])

valid_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ])

# 构建MyDataset实例 train_data = AntsDataset(data_dir=train_dir, transform=train_transform) valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

# 1/3 构建模型 resnet18_ft = models.resnet18()

# 2/3 加载参数 # flag = 0 flag = 1 if flag: path_pretrained_model = enviroments.resnet18_path state_dict_load = torch.load(path_pretrained_model) resnet18_ft.load_state_dict(state_dict_load)

# 法1 : 冻结卷积层 flag_m1 = 0 # flag_m1 = 1 if flag_m1: for param in resnet18_ft.parameters(): param.requires_grad = False # print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))

# 3/3 替换fc层 # 首先拿到 fc 层的输入个数 num_ftrs = resnet18_ft.fc.in_features # 然后构造新的 fc 层替换原来的 fc 层 resnet18_ft.fc = nn.Linear(num_ftrs, classes)

resnet18_ft.to(device) # ============================ step 3/5 损失函数 ============================ criterion = nn.CrossEntropyLoss() # 选择损失函数

# ============================ step 4/5 优化器 ============================ # 法2 : conv 小学习率 flag = 0 # flag = 1 if flag: # 首先获取全连接层参数的地址 fc_params_id = list(map(id, resnet18_ft.fc.parameters())) # 返回的是parameters的 内存地址 # 然后使用 filter 过滤不属于全连接层的参数,也就是保留卷积层的参数 base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters()) # 设置优化器的分组学习率,传入一个 list,包含 2 个元素,每个元素是字典。对应 2 个参数组 optimizer = optim.SGD([{'params': base_params, 'lr': LR * 0.1}, {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)

else: optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9) # 选择优化器

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1) # 设置学习率下降策略

# ============================ step 5/5 训练 ============================ train_curve = list() valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

loss_mean = 0. correct = 0. total = 0.

resnet18_ft.train() for i, data in enumerate(train_loader):

# forward inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = resnet18_ft(inputs)

# backward optimizer.zero_grad() loss = criterion(outputs, labels) loss.backward()

# update weights optimizer.step()

# 统计分类情况 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).squeeze().cpu().sum().numpy()

# 打印训练信息 loss_mean += loss.item() train_curve.append(loss.item()) if (i+1) % log_interval == 0: loss_mean = loss_mean / log_interval print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total)) loss_mean = 0.

# if flag_m1: # print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...]))

scheduler.step() # 更新学习率

# validate the model if (epoch+1) % val_interval == 0:

correct_val = 0. total_val = 0. loss_val = 0. resnet18_ft.eval() with torch.no_grad(): for j, data in enumerate(valid_loader): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device)

outputs = resnet18_ft(inputs) loss = criterion(outputs, labels)

_, predicted = torch.max(outputs.data, 1) total_val += labels.size(0) correct_val += (predicted == labels).squeeze().cpu().sum().numpy()

loss_val += loss.item()

loss_val_mean = loss_val/len(valid_loader) valid_curve.append(loss_val_mean) print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val)) resnet18_ft.train()

train_x = range(len(train_curve)) train_y = train_curve

train_iters = len(train_loader) valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations valid_y = valid_curve

plt.plot(train_x, train_y, label='Train') plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right') plt.ylabel('loss value') plt.xlabel('Iteration') plt.show()


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK