3

PyTorch 进行 MNIST 图片分类

 11 months ago
source link: https://xujinzh.github.io/2023/05/31/pytorch-mnist-classify/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

PyTorch 进行 MNIST 图片分类

发表于2023-05-31|更新于2023-05-31|technologypython
字数总计:1.4k|阅读时长:8分钟|阅读量:4

PyTorch 图像分类,数据集采用内置的 MNIST.

加载数据集

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import visdom
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms

torch.manual_seed(33)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
device(type='cuda', index=0)
train_ds = torchvision.datasets.MNIST(
"/workspace/disk1/datasets/",
train=True,
transform=transforms.ToTensor(),
download=True,
)

test_ds = torchvision.datasets.MNIST(
"/workspace/disk1/datasets/",
train=False,
transform=transforms.ToTensor(),
download=True,
)
batch_size = 64
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size)
imgs, labels = next(iter(train_dl))
imgs.shape
torch.Size([64, 1, 28, 28])
labels.shape
torch.Size([64])
plt.figure(figsize=(batch_size, 1))
for i, img in enumerate(imgs):
img_np = img.numpy().squeeze()
plt.subplot(1, batch_size, i + 1)
plt.imshow(img_np, cmap="gray")
plt.axis("off")
labels.data
tensor([3, 0, 5, 2, 3, 4, 5, 9, 1, 7, 4, 7, 8, 4, 2, 1, 7, 9, 8, 3, 4, 9, 7, 5,
        0, 2, 4, 2, 5, 7, 6, 4, 2, 8, 8, 5, 6, 0, 6, 4, 9, 5, 9, 9, 9, 4, 9, 8,
        8, 6, 9, 3, 2, 2, 2, 5, 0, 4, 9, 3, 0, 8, 3, 2])
class MLPModel(nn.Module):
def __init__(self):
super(MLPModel, self).__init__()
self.linear1 = nn.Linear(28 * 28, 128)
self.linear2 = nn.Linear(128, 64)
self.linear3 = nn.Linear(64, 10)

def forward(self, inputs):
x = inputs.view(-1, 1 * 28 * 28)
x = torch.relu(self.linear1(x))
x = torch.relu(self.linear2(x))
logits = self.linear3(x)
return logits
model = MLPModel().to(device)
model
MLPModel(
  (linear1): Linear(in_features=784, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=64, bias=True)
  (linear3): Linear(in_features=64, out_features=10, bias=True)
)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
optimizer
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 0
)
def train(dl, model, loss_fn, optimizer):
size = len(dl.dataset)
num_batches = len(dl)

train_loss, correct = 0, 0

for x, y in dl:
x, y = x.to(device), y.to(device)
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

with torch.no_grad():
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()

correct /= size
train_loss /= num_batches
return correct, train_loss
def test(dl, model, loss_fn):
size = len(dl.dataset)
num_batches = len(dl)

test_loss, correct = 0, 0

with torch.no_grad():
for x, y in dl:
x, y = x.to(device), y.to(device)
pred = model(x)
loss = loss_fn(pred, y)
test_loss += loss.item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()

correct /= size
test_loss /= num_batches
return correct, test_loss
# 使用 pytorch 可视化模块 visdom 可视化训练损失变化情况
viz = visdom.Visdom(
server="http://localhost",
port=8097,
base_url="/visdom",
username="jinzhongxu",
password="123123",
)
win = "mnist"
opts = dict(
title="MNIST",
xlabel="epoch",
ylabel="loss and acc",
markers=True,
legend=["train_loss", "train_acc", "test_loss", "test_acc"],
)
viz.line(
[[0.0, 0.0, 0.0, 0.0]],
[0.0],
win=win,
opts=opts,
)
Setting up a new session...





'mnist'
epochs = 50

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
epoch_acc, epoch_loss = train(
dl=train_dl, model=model, loss_fn=loss_fn, optimizer=optimizer
)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
epoch_test_acc, epoch_test_loss = test(dl=test_dl, model=model, loss_fn=loss_fn)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
print(
f"epoch={epoch:2d}, train_loss={epoch_loss:.5f}, train_acc={epoch_acc:.5f}, test_loss={epoch_test_loss:.5f}, test_acc={epoch_test_acc:.5f}"
)
viz.line(
[[epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc]],
[epoch],
win=win,
update="append",
)

print("done!")
epoch= 0, train_loss=0.87553, train_acc=0.78492, test_loss=0.37957, test_acc=0.89600
epoch= 1, train_loss=0.34057, train_acc=0.90542, test_loss=0.29167, test_acc=0.91640
epoch= 2, train_loss=0.28640, train_acc=0.91793, test_loss=0.26019, test_acc=0.92450
epoch= 3, train_loss=0.25685, train_acc=0.92650, test_loss=0.24007, test_acc=0.92930
epoch= 4, train_loss=0.23372, train_acc=0.93333, test_loss=0.21881, test_acc=0.93560
epoch= 5, train_loss=0.21370, train_acc=0.93912, test_loss=0.20176, test_acc=0.93990
epoch= 6, train_loss=0.19669, train_acc=0.94297, test_loss=0.18557, test_acc=0.94430
epoch= 7, train_loss=0.18122, train_acc=0.94835, test_loss=0.17462, test_acc=0.94620
epoch= 8, train_loss=0.16796, train_acc=0.95127, test_loss=0.16615, test_acc=0.94940
epoch= 9, train_loss=0.15605, train_acc=0.95473, test_loss=0.15155, test_acc=0.95460
epoch=10, train_loss=0.14516, train_acc=0.95775, test_loss=0.14506, test_acc=0.95670
epoch=11, train_loss=0.13557, train_acc=0.96103, test_loss=0.13445, test_acc=0.95970
epoch=12, train_loss=0.12738, train_acc=0.96342, test_loss=0.13094, test_acc=0.96010
epoch=13, train_loss=0.11912, train_acc=0.96610, test_loss=0.12227, test_acc=0.96210
epoch=14, train_loss=0.11219, train_acc=0.96753, test_loss=0.11731, test_acc=0.96430
epoch=15, train_loss=0.10571, train_acc=0.96947, test_loss=0.11181, test_acc=0.96500
epoch=16, train_loss=0.09996, train_acc=0.97147, test_loss=0.10745, test_acc=0.96670
epoch=17, train_loss=0.09438, train_acc=0.97308, test_loss=0.10555, test_acc=0.96800
epoch=18, train_loss=0.08965, train_acc=0.97438, test_loss=0.10191, test_acc=0.96840
epoch=19, train_loss=0.08477, train_acc=0.97557, test_loss=0.09853, test_acc=0.96930
epoch=20, train_loss=0.08065, train_acc=0.97690, test_loss=0.09546, test_acc=0.96970
epoch=21, train_loss=0.07642, train_acc=0.97827, test_loss=0.09460, test_acc=0.97060
epoch=22, train_loss=0.07243, train_acc=0.97918, test_loss=0.09040, test_acc=0.97170
epoch=23, train_loss=0.06898, train_acc=0.98013, test_loss=0.08840, test_acc=0.97270
epoch=24, train_loss=0.06559, train_acc=0.98123, test_loss=0.08831, test_acc=0.97240
epoch=25, train_loss=0.06238, train_acc=0.98242, test_loss=0.08451, test_acc=0.97450
epoch=26, train_loss=0.05947, train_acc=0.98308, test_loss=0.08525, test_acc=0.97340
epoch=27, train_loss=0.05665, train_acc=0.98370, test_loss=0.08331, test_acc=0.97420
epoch=28, train_loss=0.05389, train_acc=0.98510, test_loss=0.08325, test_acc=0.97480
epoch=29, train_loss=0.05153, train_acc=0.98535, test_loss=0.08162, test_acc=0.97450
epoch=30, train_loss=0.04908, train_acc=0.98628, test_loss=0.07992, test_acc=0.97540
epoch=31, train_loss=0.04710, train_acc=0.98658, test_loss=0.07741, test_acc=0.97600
epoch=32, train_loss=0.04476, train_acc=0.98773, test_loss=0.07945, test_acc=0.97460
epoch=33, train_loss=0.04273, train_acc=0.98813, test_loss=0.07803, test_acc=0.97500
epoch=34, train_loss=0.04049, train_acc=0.98873, test_loss=0.07625, test_acc=0.97520
epoch=35, train_loss=0.03883, train_acc=0.98968, test_loss=0.07546, test_acc=0.97660
epoch=36, train_loss=0.03686, train_acc=0.99037, test_loss=0.07731, test_acc=0.97510
epoch=37, train_loss=0.03529, train_acc=0.99060, test_loss=0.07601, test_acc=0.97570
epoch=38, train_loss=0.03339, train_acc=0.99118, test_loss=0.07800, test_acc=0.97490
epoch=39, train_loss=0.03212, train_acc=0.99150, test_loss=0.07530, test_acc=0.97650
epoch=40, train_loss=0.03038, train_acc=0.99222, test_loss=0.07336, test_acc=0.97610
epoch=41, train_loss=0.02889, train_acc=0.99262, test_loss=0.07662, test_acc=0.97680
epoch=42, train_loss=0.02742, train_acc=0.99350, test_loss=0.07404, test_acc=0.97700
epoch=43, train_loss=0.02625, train_acc=0.99347, test_loss=0.07493, test_acc=0.97660
epoch=44, train_loss=0.02480, train_acc=0.99420, test_loss=0.07400, test_acc=0.97710
epoch=45, train_loss=0.02360, train_acc=0.99417, test_loss=0.07704, test_acc=0.97700
epoch=46, train_loss=0.02243, train_acc=0.99490, test_loss=0.07595, test_acc=0.97790
epoch=47, train_loss=0.02117, train_acc=0.99525, test_loss=0.07470, test_acc=0.97700
epoch=48, train_loss=0.02004, train_acc=0.99557, test_loss=0.07563, test_acc=0.97650
epoch=49, train_loss=0.01895, train_acc=0.99598, test_loss=0.07576, test_acc=0.97750
done!

损失和测试准确率曲线:

png

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK