2
PyTorch 逻辑回归
source link: https://xujinzh.github.io/2023/05/30/pytorch-logistic-regression/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
PyTorch 逻辑回归
发表于2023-05-30|更新于2023-05-30|technologypython
字数总计:2.2k|阅读时长:13分钟|阅读量:7
PyTorch 逻辑回归,数据集:UCI Iris Data Set
导入依赖包
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import visdom
torch.__version__, torch.manual_seed(33)
('1.12.1+cu102', <torch._C.Generator at 0x7f753ca60e90>)
加载数据集
path = "/workspace/disk1/datasets/scalar/iris.data.csv"
data = pd.read_csv(path)
data.head(3)
5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa | |
---|---|---|---|---|---|
0 | 4.9 | 3.0 | 1.4 | 0.2 | Iris-setosa |
1 | 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
2 | 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
len(data)
set(data["Iris-setosa"].values)
{'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'}
data_setosa = data[data["Iris-setosa"] == "Iris-setosa"]
data_versicolor = data[data["Iris-setosa"] == "Iris-versicolor"]
data_virginica = data[data["Iris-setosa"] == "Iris-virginica"]
len(data_setosa), len(data_versicolor), len(data_virginica)
(49, 50, 50)
data_using = data[data["Iris-setosa"] != "Iris-setosa"]
X = data_using.iloc[:, :-1]
Y = data_using.iloc[:, -1].replace("Iris-versicolor", 0).replace("Iris-virginica", 1)
Y.unique()
array([0, 1])
X = torch.from_numpy(X.values).type(torch.float32)
X.shape
torch.Size([100, 4])
Y = torch.from_numpy(Y.values.reshape(-1, 1)).type(torch.float32)
Y.shape
torch.Size([100, 1])
train_set_ratio = 0.8
train_set_num = int(train_set_ratio * X.shape[0])
X_train = X[:train_set_num]
Y_train = Y[:train_set_num]
X_test = X[train_set_num:]
Y_test = Y[train_set_num:]
model = nn.Sequential(
nn.Linear(in_features=X.shape[1], out_features=Y.shape[1]),
nn.Sigmoid(),
)
model
Sequential(
(0): Linear(in_features=4, out_features=1, bias=True)
(1): Sigmoid()
)
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batch = 20
num_of_batch = X_train.shape[0] // batch
epochs = 20000
# 使用 pytorch 可视化模块 visdom 可视化训练损失变化情况
viz = visdom.Visdom(
server="http://localhost",
port=8097,
base_url="/visdom",
username="jinzhongxu",
password="123123",
)
win = "logistic regression loss"
opts = dict(
title="train_losses",
xlabel="epoch",
ylabel="loss",
markers=True,
legend=["loss", "acc"],
)
viz.line(
[
[
0.0,
0.0,
]
],
[0.0],
win=win,
opts=opts,
)
Setting up a new session...
'logistic regression loss'
for epoch in range(epochs):
losses = []
for n in range(num_of_batch):
start = n * batch
end = (n + 1) * batch
x = X_train[start:end]
y = Y_train[start:end]
y_pred = model(x)
loss = loss_fn(y_pred, y)
losses.append(loss.data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0:
acc = (
(model(X_test).data.numpy() > 0.5).astype(np.int8) == Y_test.numpy()
).mean()
print(
f"epoch: {str(epoch).rjust(len(str(epochs)), '0')}, loss:{np.mean(losses):.5f}, acc: {acc:.3f}"
)
viz.line(
[
[
np.mean(losses),
acc,
]
],
[epoch],
win=win,
update="append",
)
epoch: 00000, loss:0.81945, acc: 0.000
epoch: 00100, loss:0.72729, acc: 0.000
epoch: 00200, loss:0.71364, acc: 0.000
epoch: 00300, loss:0.70016, acc: 0.000
epoch: 00400, loss:0.68695, acc: 0.000
epoch: 00500, loss:0.67402, acc: 0.000
epoch: 00600, loss:0.66138, acc: 0.000
epoch: 00700, loss:0.64902, acc: 0.000
epoch: 00800, loss:0.63694, acc: 0.000
epoch: 00900, loss:0.62514, acc: 0.000
epoch: 01000, loss:0.61361, acc: 0.000
epoch: 01100, loss:0.60233, acc: 0.000
epoch: 01200, loss:0.59132, acc: 0.100
epoch: 01300, loss:0.58056, acc: 0.250
epoch: 01400, loss:0.57005, acc: 0.350
epoch: 01500, loss:0.55979, acc: 0.400
epoch: 01600, loss:0.54976, acc: 0.450
epoch: 01700, loss:0.53998, acc: 0.450
epoch: 01800, loss:0.53042, acc: 0.500
epoch: 01900, loss:0.52110, acc: 0.600
epoch: 02000, loss:0.51199, acc: 0.650
epoch: 02100, loss:0.50311, acc: 0.650
epoch: 02200, loss:0.49444, acc: 0.700
epoch: 02300, loss:0.48598, acc: 0.700
epoch: 02400, loss:0.47773, acc: 0.700
epoch: 02500, loss:0.46968, acc: 0.700
epoch: 02600, loss:0.46183, acc: 0.700
epoch: 02700, loss:0.45417, acc: 0.700
epoch: 02800, loss:0.44670, acc: 0.700
epoch: 02900, loss:0.43942, acc: 0.700
epoch: 03000, loss:0.43232, acc: 0.700
epoch: 03100, loss:0.42539, acc: 0.700
epoch: 03200, loss:0.41863, acc: 0.700
epoch: 03300, loss:0.41204, acc: 0.700
epoch: 03400, loss:0.40561, acc: 0.750
epoch: 03500, loss:0.39935, acc: 0.800
epoch: 03600, loss:0.39324, acc: 0.800
epoch: 03700, loss:0.38728, acc: 0.800
epoch: 03800, loss:0.38146, acc: 0.800
epoch: 03900, loss:0.37579, acc: 0.800
epoch: 04000, loss:0.37026, acc: 0.800
epoch: 04100, loss:0.36487, acc: 0.800
epoch: 04200, loss:0.35961, acc: 0.800
epoch: 04300, loss:0.35447, acc: 0.800
epoch: 04400, loss:0.34947, acc: 0.800
epoch: 04500, loss:0.34458, acc: 0.800
epoch: 04600, loss:0.33982, acc: 0.800
epoch: 04700, loss:0.33516, acc: 0.800
epoch: 04800, loss:0.33063, acc: 0.800
epoch: 04900, loss:0.32620, acc: 0.850
epoch: 05000, loss:0.32187, acc: 0.850
epoch: 05100, loss:0.31765, acc: 0.850
epoch: 05200, loss:0.31353, acc: 0.850
epoch: 05300, loss:0.30951, acc: 0.850
epoch: 05400, loss:0.30559, acc: 0.850
epoch: 05500, loss:0.30175, acc: 0.850
epoch: 05600, loss:0.29801, acc: 0.850
epoch: 05700, loss:0.29435, acc: 0.850
epoch: 05800, loss:0.29078, acc: 0.850
epoch: 05900, loss:0.28729, acc: 0.850
epoch: 06000, loss:0.28388, acc: 0.850
epoch: 06100, loss:0.28055, acc: 0.850
epoch: 06200, loss:0.27730, acc: 0.850
epoch: 06300, loss:0.27412, acc: 0.850
epoch: 06400, loss:0.27101, acc: 0.850
epoch: 06500, loss:0.26797, acc: 0.850
epoch: 06600, loss:0.26500, acc: 0.850
epoch: 06700, loss:0.26209, acc: 0.850
epoch: 06800, loss:0.25925, acc: 0.850
epoch: 06900, loss:0.25647, acc: 0.850
epoch: 07000, loss:0.25376, acc: 0.850
epoch: 07100, loss:0.25110, acc: 0.850
epoch: 07200, loss:0.24850, acc: 0.850
epoch: 07300, loss:0.24596, acc: 0.850
epoch: 07400, loss:0.24347, acc: 0.900
epoch: 07500, loss:0.24103, acc: 0.900
epoch: 07600, loss:0.23865, acc: 0.900
epoch: 07700, loss:0.23631, acc: 0.900
epoch: 07800, loss:0.23403, acc: 0.900
epoch: 07900, loss:0.23179, acc: 0.900
epoch: 08000, loss:0.22960, acc: 0.900
epoch: 08100, loss:0.22746, acc: 0.950
epoch: 08200, loss:0.22535, acc: 0.950
epoch: 08300, loss:0.22330, acc: 0.950
epoch: 08400, loss:0.22128, acc: 0.950
epoch: 08500, loss:0.21930, acc: 0.950
epoch: 08600, loss:0.21737, acc: 0.950
epoch: 08700, loss:0.21547, acc: 0.950
epoch: 08800, loss:0.21361, acc: 0.950
epoch: 08900, loss:0.21179, acc: 0.950
epoch: 09000, loss:0.21000, acc: 0.950
epoch: 09100, loss:0.20825, acc: 0.950
epoch: 09200, loss:0.20653, acc: 0.950
epoch: 09300, loss:0.20484, acc: 0.950
epoch: 09400, loss:0.20319, acc: 0.950
epoch: 09500, loss:0.20157, acc: 0.950
epoch: 09600, loss:0.19998, acc: 0.950
epoch: 09700, loss:0.19842, acc: 0.950
epoch: 09800, loss:0.19689, acc: 0.950
epoch: 09900, loss:0.19538, acc: 0.950
epoch: 10000, loss:0.19391, acc: 0.950
epoch: 10100, loss:0.19246, acc: 0.950
epoch: 10200, loss:0.19104, acc: 0.950
epoch: 10300, loss:0.18964, acc: 0.950
epoch: 10400, loss:0.18827, acc: 0.950
epoch: 10500, loss:0.18693, acc: 0.950
epoch: 10600, loss:0.18561, acc: 0.950
epoch: 10700, loss:0.18431, acc: 0.950
epoch: 10800, loss:0.18303, acc: 0.950
epoch: 10900, loss:0.18178, acc: 0.950
epoch: 11000, loss:0.18055, acc: 0.950
epoch: 11100, loss:0.17934, acc: 0.950
epoch: 11200, loss:0.17815, acc: 0.950
epoch: 11300, loss:0.17698, acc: 0.950
epoch: 11400, loss:0.17583, acc: 0.950
epoch: 11500, loss:0.17471, acc: 0.950
epoch: 11600, loss:0.17360, acc: 0.950
epoch: 11700, loss:0.17250, acc: 0.950
epoch: 11800, loss:0.17143, acc: 0.950
epoch: 11900, loss:0.17038, acc: 0.950
epoch: 12000, loss:0.16934, acc: 0.950
epoch: 12100, loss:0.16832, acc: 0.950
epoch: 12200, loss:0.16731, acc: 0.950
epoch: 12300, loss:0.16632, acc: 0.950
epoch: 12400, loss:0.16535, acc: 0.950
epoch: 12500, loss:0.16440, acc: 0.950
epoch: 12600, loss:0.16345, acc: 0.950
epoch: 12700, loss:0.16253, acc: 0.950
epoch: 12800, loss:0.16162, acc: 0.950
epoch: 12900, loss:0.16072, acc: 0.950
epoch: 13000, loss:0.15983, acc: 0.950
epoch: 13100, loss:0.15896, acc: 0.950
epoch: 13200, loss:0.15811, acc: 0.950
epoch: 13300, loss:0.15726, acc: 0.950
epoch: 13400, loss:0.15643, acc: 0.950
epoch: 13500, loss:0.15561, acc: 0.950
epoch: 13600, loss:0.15481, acc: 0.950
epoch: 13700, loss:0.15401, acc: 0.950
epoch: 13800, loss:0.15323, acc: 0.950
epoch: 13900, loss:0.15246, acc: 0.950
epoch: 14000, loss:0.15170, acc: 0.950
epoch: 14100, loss:0.15095, acc: 0.950
epoch: 14200, loss:0.15022, acc: 0.950
epoch: 14300, loss:0.14949, acc: 0.950
epoch: 14400, loss:0.14877, acc: 0.950
epoch: 14500, loss:0.14807, acc: 0.950
epoch: 14600, loss:0.14737, acc: 0.950
epoch: 14700, loss:0.14669, acc: 0.950
epoch: 14800, loss:0.14601, acc: 0.950
epoch: 14900, loss:0.14534, acc: 0.950
epoch: 15000, loss:0.14468, acc: 0.950
epoch: 15100, loss:0.14404, acc: 0.950
epoch: 15200, loss:0.14340, acc: 0.950
epoch: 15300, loss:0.14276, acc: 0.950
epoch: 15400, loss:0.14214, acc: 0.950
epoch: 15500, loss:0.14153, acc: 0.950
epoch: 15600, loss:0.14092, acc: 0.950
epoch: 15700, loss:0.14032, acc: 0.950
epoch: 15800, loss:0.13973, acc: 0.950
epoch: 15900, loss:0.13915, acc: 0.950
epoch: 16000, loss:0.13858, acc: 0.950
epoch: 16100, loss:0.13801, acc: 0.950
epoch: 16200, loss:0.13745, acc: 0.950
epoch: 16300, loss:0.13690, acc: 0.950
epoch: 16400, loss:0.13635, acc: 0.950
epoch: 16500, loss:0.13581, acc: 0.950
epoch: 16600, loss:0.13528, acc: 0.950
epoch: 16700, loss:0.13476, acc: 0.950
epoch: 16800, loss:0.13424, acc: 0.950
epoch: 16900, loss:0.13373, acc: 0.950
epoch: 17000, loss:0.13322, acc: 0.950
epoch: 17100, loss:0.13272, acc: 0.950
epoch: 17200, loss:0.13223, acc: 0.950
epoch: 17300, loss:0.13174, acc: 0.950
epoch: 17400, loss:0.13126, acc: 0.950
epoch: 17500, loss:0.13079, acc: 0.950
epoch: 17600, loss:0.13032, acc: 0.950
epoch: 17700, loss:0.12985, acc: 0.950
epoch: 17800, loss:0.12940, acc: 0.950
epoch: 17900, loss:0.12894, acc: 0.950
epoch: 18000, loss:0.12850, acc: 0.950
epoch: 18100, loss:0.12805, acc: 0.950
epoch: 18200, loss:0.12762, acc: 0.950
epoch: 18300, loss:0.12719, acc: 0.950
epoch: 18400, loss:0.12676, acc: 0.950
epoch: 18500, loss:0.12634, acc: 0.950
epoch: 18600, loss:0.12592, acc: 0.950
epoch: 18700, loss:0.12551, acc: 0.950
epoch: 18800, loss:0.12510, acc: 0.950
epoch: 18900, loss:0.12470, acc: 0.950
epoch: 19000, loss:0.12430, acc: 0.950
epoch: 19100, loss:0.12390, acc: 0.950
epoch: 19200, loss:0.12351, acc: 0.950
epoch: 19300, loss:0.12313, acc: 0.950
epoch: 19400, loss:0.12275, acc: 0.950
epoch: 19500, loss:0.12237, acc: 0.950
epoch: 19600, loss:0.12200, acc: 0.950
epoch: 19700, loss:0.12163, acc: 0.950
epoch: 19800, loss:0.12126, acc: 0.950
epoch: 19900, loss:0.12090, acc: 0.950
model.state_dict()
OrderedDict([('0.weight', tensor([[-1.4480, -2.8220, 2.5209, 5.8656]])),
('0.bias', tensor([-5.3296]))])
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK