2

PyTorch 逻辑回归

 1 year ago
source link: https://xujinzh.github.io/2023/05/30/pytorch-logistic-regression/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

PyTorch 逻辑回归

发表于2023-05-30|更新于2023-05-30|technologypython
字数总计:2.2k|阅读时长:13分钟|阅读量:7

PyTorch 逻辑回归,数据集:UCI Iris Data Set

导入依赖包

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import visdom
torch.__version__, torch.manual_seed(33)
('1.12.1+cu102', <torch._C.Generator at 0x7f753ca60e90>)

加载数据集

path = "/workspace/disk1/datasets/scalar/iris.data.csv"
data = pd.read_csv(path)
data.head(3)
5.1 3.5 1.4 0.2 Iris-setosa
0 4.9 3.0 1.4 0.2 Iris-setosa
1 4.7 3.2 1.3 0.2 Iris-setosa
2 4.6 3.1 1.5 0.2 Iris-setosa
len(data)
set(data["Iris-setosa"].values)
{'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'}
data_setosa = data[data["Iris-setosa"] == "Iris-setosa"]
data_versicolor = data[data["Iris-setosa"] == "Iris-versicolor"]
data_virginica = data[data["Iris-setosa"] == "Iris-virginica"]
len(data_setosa), len(data_versicolor), len(data_virginica)
(49, 50, 50)
data_using = data[data["Iris-setosa"] != "Iris-setosa"]
X = data_using.iloc[:, :-1]
Y = data_using.iloc[:, -1].replace("Iris-versicolor", 0).replace("Iris-virginica", 1)
Y.unique()
array([0, 1])
X = torch.from_numpy(X.values).type(torch.float32)
X.shape
torch.Size([100, 4])
Y = torch.from_numpy(Y.values.reshape(-1, 1)).type(torch.float32)
Y.shape
torch.Size([100, 1])
train_set_ratio = 0.8
train_set_num = int(train_set_ratio * X.shape[0])
X_train = X[:train_set_num]
Y_train = Y[:train_set_num]
X_test = X[train_set_num:]
Y_test = Y[train_set_num:]
model = nn.Sequential(
nn.Linear(in_features=X.shape[1], out_features=Y.shape[1]),
nn.Sigmoid(),
)
model
Sequential(
  (0): Linear(in_features=4, out_features=1, bias=True)
  (1): Sigmoid()
)
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batch = 20
num_of_batch = X_train.shape[0] // batch
epochs = 20000
# 使用 pytorch 可视化模块 visdom 可视化训练损失变化情况
viz = visdom.Visdom(
server="http://localhost",
port=8097,
base_url="/visdom",
username="jinzhongxu",
password="123123",
)
win = "logistic regression loss"
opts = dict(
title="train_losses",
xlabel="epoch",
ylabel="loss",
markers=True,
legend=["loss", "acc"],
)
viz.line(
[
[
0.0,
0.0,
]
],
[0.0],
win=win,
opts=opts,
)
Setting up a new session...





'logistic regression loss'
for epoch in range(epochs):
losses = []
for n in range(num_of_batch):
start = n * batch
end = (n + 1) * batch
x = X_train[start:end]
y = Y_train[start:end]
y_pred = model(x)
loss = loss_fn(y_pred, y)
losses.append(loss.data)
optimizer.zero_grad()
loss.backward()
optimizer.step()

if epoch % 100 == 0:
acc = (
(model(X_test).data.numpy() > 0.5).astype(np.int8) == Y_test.numpy()
).mean()
print(
f"epoch: {str(epoch).rjust(len(str(epochs)), '0')}, loss:{np.mean(losses):.5f}, acc: {acc:.3f}"
)
viz.line(
[
[
np.mean(losses),
acc,
]
],
[epoch],
win=win,
update="append",
)
epoch: 00000, loss:0.81945, acc: 0.000
epoch: 00100, loss:0.72729, acc: 0.000
epoch: 00200, loss:0.71364, acc: 0.000
epoch: 00300, loss:0.70016, acc: 0.000
epoch: 00400, loss:0.68695, acc: 0.000
epoch: 00500, loss:0.67402, acc: 0.000
epoch: 00600, loss:0.66138, acc: 0.000
epoch: 00700, loss:0.64902, acc: 0.000
epoch: 00800, loss:0.63694, acc: 0.000
epoch: 00900, loss:0.62514, acc: 0.000
epoch: 01000, loss:0.61361, acc: 0.000
epoch: 01100, loss:0.60233, acc: 0.000
epoch: 01200, loss:0.59132, acc: 0.100
epoch: 01300, loss:0.58056, acc: 0.250
epoch: 01400, loss:0.57005, acc: 0.350
epoch: 01500, loss:0.55979, acc: 0.400
epoch: 01600, loss:0.54976, acc: 0.450
epoch: 01700, loss:0.53998, acc: 0.450
epoch: 01800, loss:0.53042, acc: 0.500
epoch: 01900, loss:0.52110, acc: 0.600
epoch: 02000, loss:0.51199, acc: 0.650
epoch: 02100, loss:0.50311, acc: 0.650
epoch: 02200, loss:0.49444, acc: 0.700
epoch: 02300, loss:0.48598, acc: 0.700
epoch: 02400, loss:0.47773, acc: 0.700
epoch: 02500, loss:0.46968, acc: 0.700
epoch: 02600, loss:0.46183, acc: 0.700
epoch: 02700, loss:0.45417, acc: 0.700
epoch: 02800, loss:0.44670, acc: 0.700
epoch: 02900, loss:0.43942, acc: 0.700
epoch: 03000, loss:0.43232, acc: 0.700
epoch: 03100, loss:0.42539, acc: 0.700
epoch: 03200, loss:0.41863, acc: 0.700
epoch: 03300, loss:0.41204, acc: 0.700
epoch: 03400, loss:0.40561, acc: 0.750
epoch: 03500, loss:0.39935, acc: 0.800
epoch: 03600, loss:0.39324, acc: 0.800
epoch: 03700, loss:0.38728, acc: 0.800
epoch: 03800, loss:0.38146, acc: 0.800
epoch: 03900, loss:0.37579, acc: 0.800
epoch: 04000, loss:0.37026, acc: 0.800
epoch: 04100, loss:0.36487, acc: 0.800
epoch: 04200, loss:0.35961, acc: 0.800
epoch: 04300, loss:0.35447, acc: 0.800
epoch: 04400, loss:0.34947, acc: 0.800
epoch: 04500, loss:0.34458, acc: 0.800
epoch: 04600, loss:0.33982, acc: 0.800
epoch: 04700, loss:0.33516, acc: 0.800
epoch: 04800, loss:0.33063, acc: 0.800
epoch: 04900, loss:0.32620, acc: 0.850
epoch: 05000, loss:0.32187, acc: 0.850
epoch: 05100, loss:0.31765, acc: 0.850
epoch: 05200, loss:0.31353, acc: 0.850
epoch: 05300, loss:0.30951, acc: 0.850
epoch: 05400, loss:0.30559, acc: 0.850
epoch: 05500, loss:0.30175, acc: 0.850
epoch: 05600, loss:0.29801, acc: 0.850
epoch: 05700, loss:0.29435, acc: 0.850
epoch: 05800, loss:0.29078, acc: 0.850
epoch: 05900, loss:0.28729, acc: 0.850
epoch: 06000, loss:0.28388, acc: 0.850
epoch: 06100, loss:0.28055, acc: 0.850
epoch: 06200, loss:0.27730, acc: 0.850
epoch: 06300, loss:0.27412, acc: 0.850
epoch: 06400, loss:0.27101, acc: 0.850
epoch: 06500, loss:0.26797, acc: 0.850
epoch: 06600, loss:0.26500, acc: 0.850
epoch: 06700, loss:0.26209, acc: 0.850
epoch: 06800, loss:0.25925, acc: 0.850
epoch: 06900, loss:0.25647, acc: 0.850
epoch: 07000, loss:0.25376, acc: 0.850
epoch: 07100, loss:0.25110, acc: 0.850
epoch: 07200, loss:0.24850, acc: 0.850
epoch: 07300, loss:0.24596, acc: 0.850
epoch: 07400, loss:0.24347, acc: 0.900
epoch: 07500, loss:0.24103, acc: 0.900
epoch: 07600, loss:0.23865, acc: 0.900
epoch: 07700, loss:0.23631, acc: 0.900
epoch: 07800, loss:0.23403, acc: 0.900
epoch: 07900, loss:0.23179, acc: 0.900
epoch: 08000, loss:0.22960, acc: 0.900
epoch: 08100, loss:0.22746, acc: 0.950
epoch: 08200, loss:0.22535, acc: 0.950
epoch: 08300, loss:0.22330, acc: 0.950
epoch: 08400, loss:0.22128, acc: 0.950
epoch: 08500, loss:0.21930, acc: 0.950
epoch: 08600, loss:0.21737, acc: 0.950
epoch: 08700, loss:0.21547, acc: 0.950
epoch: 08800, loss:0.21361, acc: 0.950
epoch: 08900, loss:0.21179, acc: 0.950
epoch: 09000, loss:0.21000, acc: 0.950
epoch: 09100, loss:0.20825, acc: 0.950
epoch: 09200, loss:0.20653, acc: 0.950
epoch: 09300, loss:0.20484, acc: 0.950
epoch: 09400, loss:0.20319, acc: 0.950
epoch: 09500, loss:0.20157, acc: 0.950
epoch: 09600, loss:0.19998, acc: 0.950
epoch: 09700, loss:0.19842, acc: 0.950
epoch: 09800, loss:0.19689, acc: 0.950
epoch: 09900, loss:0.19538, acc: 0.950
epoch: 10000, loss:0.19391, acc: 0.950
epoch: 10100, loss:0.19246, acc: 0.950
epoch: 10200, loss:0.19104, acc: 0.950
epoch: 10300, loss:0.18964, acc: 0.950
epoch: 10400, loss:0.18827, acc: 0.950
epoch: 10500, loss:0.18693, acc: 0.950
epoch: 10600, loss:0.18561, acc: 0.950
epoch: 10700, loss:0.18431, acc: 0.950
epoch: 10800, loss:0.18303, acc: 0.950
epoch: 10900, loss:0.18178, acc: 0.950
epoch: 11000, loss:0.18055, acc: 0.950
epoch: 11100, loss:0.17934, acc: 0.950
epoch: 11200, loss:0.17815, acc: 0.950
epoch: 11300, loss:0.17698, acc: 0.950
epoch: 11400, loss:0.17583, acc: 0.950
epoch: 11500, loss:0.17471, acc: 0.950
epoch: 11600, loss:0.17360, acc: 0.950
epoch: 11700, loss:0.17250, acc: 0.950
epoch: 11800, loss:0.17143, acc: 0.950
epoch: 11900, loss:0.17038, acc: 0.950
epoch: 12000, loss:0.16934, acc: 0.950
epoch: 12100, loss:0.16832, acc: 0.950
epoch: 12200, loss:0.16731, acc: 0.950
epoch: 12300, loss:0.16632, acc: 0.950
epoch: 12400, loss:0.16535, acc: 0.950
epoch: 12500, loss:0.16440, acc: 0.950
epoch: 12600, loss:0.16345, acc: 0.950
epoch: 12700, loss:0.16253, acc: 0.950
epoch: 12800, loss:0.16162, acc: 0.950
epoch: 12900, loss:0.16072, acc: 0.950
epoch: 13000, loss:0.15983, acc: 0.950
epoch: 13100, loss:0.15896, acc: 0.950
epoch: 13200, loss:0.15811, acc: 0.950
epoch: 13300, loss:0.15726, acc: 0.950
epoch: 13400, loss:0.15643, acc: 0.950
epoch: 13500, loss:0.15561, acc: 0.950
epoch: 13600, loss:0.15481, acc: 0.950
epoch: 13700, loss:0.15401, acc: 0.950
epoch: 13800, loss:0.15323, acc: 0.950
epoch: 13900, loss:0.15246, acc: 0.950
epoch: 14000, loss:0.15170, acc: 0.950
epoch: 14100, loss:0.15095, acc: 0.950
epoch: 14200, loss:0.15022, acc: 0.950
epoch: 14300, loss:0.14949, acc: 0.950
epoch: 14400, loss:0.14877, acc: 0.950
epoch: 14500, loss:0.14807, acc: 0.950
epoch: 14600, loss:0.14737, acc: 0.950
epoch: 14700, loss:0.14669, acc: 0.950
epoch: 14800, loss:0.14601, acc: 0.950
epoch: 14900, loss:0.14534, acc: 0.950
epoch: 15000, loss:0.14468, acc: 0.950
epoch: 15100, loss:0.14404, acc: 0.950
epoch: 15200, loss:0.14340, acc: 0.950
epoch: 15300, loss:0.14276, acc: 0.950
epoch: 15400, loss:0.14214, acc: 0.950
epoch: 15500, loss:0.14153, acc: 0.950
epoch: 15600, loss:0.14092, acc: 0.950
epoch: 15700, loss:0.14032, acc: 0.950
epoch: 15800, loss:0.13973, acc: 0.950
epoch: 15900, loss:0.13915, acc: 0.950
epoch: 16000, loss:0.13858, acc: 0.950
epoch: 16100, loss:0.13801, acc: 0.950
epoch: 16200, loss:0.13745, acc: 0.950
epoch: 16300, loss:0.13690, acc: 0.950
epoch: 16400, loss:0.13635, acc: 0.950
epoch: 16500, loss:0.13581, acc: 0.950
epoch: 16600, loss:0.13528, acc: 0.950
epoch: 16700, loss:0.13476, acc: 0.950
epoch: 16800, loss:0.13424, acc: 0.950
epoch: 16900, loss:0.13373, acc: 0.950
epoch: 17000, loss:0.13322, acc: 0.950
epoch: 17100, loss:0.13272, acc: 0.950
epoch: 17200, loss:0.13223, acc: 0.950
epoch: 17300, loss:0.13174, acc: 0.950
epoch: 17400, loss:0.13126, acc: 0.950
epoch: 17500, loss:0.13079, acc: 0.950
epoch: 17600, loss:0.13032, acc: 0.950
epoch: 17700, loss:0.12985, acc: 0.950
epoch: 17800, loss:0.12940, acc: 0.950
epoch: 17900, loss:0.12894, acc: 0.950
epoch: 18000, loss:0.12850, acc: 0.950
epoch: 18100, loss:0.12805, acc: 0.950
epoch: 18200, loss:0.12762, acc: 0.950
epoch: 18300, loss:0.12719, acc: 0.950
epoch: 18400, loss:0.12676, acc: 0.950
epoch: 18500, loss:0.12634, acc: 0.950
epoch: 18600, loss:0.12592, acc: 0.950
epoch: 18700, loss:0.12551, acc: 0.950
epoch: 18800, loss:0.12510, acc: 0.950
epoch: 18900, loss:0.12470, acc: 0.950
epoch: 19000, loss:0.12430, acc: 0.950
epoch: 19100, loss:0.12390, acc: 0.950
epoch: 19200, loss:0.12351, acc: 0.950
epoch: 19300, loss:0.12313, acc: 0.950
epoch: 19400, loss:0.12275, acc: 0.950
epoch: 19500, loss:0.12237, acc: 0.950
epoch: 19600, loss:0.12200, acc: 0.950
epoch: 19700, loss:0.12163, acc: 0.950
epoch: 19800, loss:0.12126, acc: 0.950
epoch: 19900, loss:0.12090, acc: 0.950
model.state_dict()
OrderedDict([('0.weight', tensor([[-1.4480, -2.8220,  2.5209,  5.8656]])),
             ('0.bias', tensor([-5.3296]))])

训练损失和测试集预测准确率曲线:

png

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK