1
PyTorch 多层感知机
source link: https://xujinzh.github.io/2023/05/30/pytorch-mlp/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
PyTorch 多层感知机
发表于2023-05-30|更新于2023-05-30|technologypython
字数总计:1.5k|阅读时长:9分钟|阅读量:1
PyTorch 多层感知机,数据集来自:nivedithabhandary/HR-Analytics
数据预处理
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
torch.manual_seed(33)
<torch._C.Generator at 0x7fabe84e8e90>
path = "/workspace/disk1/datasets/scalar/HR_comma_sep.csv"
data = pd.read_csv(path)
data.head(3)
satisfaction_level | last_evaluation | number_project | average_montly_hours | time_spend_company | Work_accident | left | promotion_last_5years | sales | salary | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0.38 | 0.53 | 2 | 157 | 3 | 0 | 1 | 0 | sales | low |
1 | 0.80 | 0.86 | 5 | 262 | 6 | 0 | 1 | 0 | sales | medium |
2 | 0.11 | 0.88 | 7 | 272 | 4 | 0 | 1 | 0 | sales | medium |
data = data.rename(columns={"sales": "part"})
# data.head(3)
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14999 entries, 0 to 14998
Data columns (total 10 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 satisfaction_level 14999 non-null float64
1 last_evaluation 14999 non-null float64
2 number_project 14999 non-null int64
3 average_montly_hours 14999 non-null int64
4 time_spend_company 14999 non-null int64
5 Work_accident 14999 non-null int64
6 left 14999 non-null int64
7 promotion_last_5years 14999 non-null int64
8 part 14999 non-null object
9 salary 14999 non-null object
dtypes: float64(2), int64(6), object(2)
memory usage: 1.1+ MB
data.part.unique()
array(['sales', 'accounting', 'hr', 'technical', 'support', 'management',
'IT', 'product_mng', 'marketing', 'RandD'], dtype=object)
data.salary.unique()
array(['low', 'medium', 'high'], dtype=object)
data.groupby(["salary", "part"]).size()
salary part
high IT 83
RandD 51
accounting 74
hr 45
management 225
marketing 80
product_mng 68
sales 269
support 141
technical 201
low IT 609
RandD 364
accounting 358
hr 335
management 180
marketing 402
product_mng 451
sales 2099
support 1146
technical 1372
medium IT 535
RandD 372
accounting 335
hr 359
management 225
marketing 376
product_mng 383
sales 1772
support 942
technical 1147
dtype: int64
data = data.join(pd.get_dummies(data.salary))
del data["salary"]
data = data.join(pd.get_dummies(data.part))
del data["part"]
data.head(3)
satisfaction_level | last_evaluation | number_project | average_montly_hours | time_spend_company | Work_accident | left | promotion_last_5years | high | low | ... | IT | RandD | accounting | hr | management | marketing | product_mng | sales | support | technical | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.38 | 0.53 | 2 | 157 | 3 | 0 | 1 | 0 | 0 | 1 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
1 | 0.80 | 0.86 | 5 | 262 | 6 | 0 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
2 | 0.11 | 0.88 | 7 | 272 | 4 | 0 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
3 rows × 21 columns
data.columns
Index(['satisfaction_level', 'last_evaluation', 'number_project',
'average_montly_hours', 'time_spend_company', 'Work_accident', 'left',
'promotion_last_5years', 'high', 'low', 'medium', 'IT', 'RandD',
'accounting', 'hr', 'management', 'marketing', 'product_mng', 'sales',
'support', 'technical'],
dtype='object')
data.shape
(14999, 21)
data.left.unique()
array([1, 0])
# 数据不均衡
data.left.value_counts()
0 11428
1 3571
Name: left, dtype: int64
data.left.value_counts() / len(data)
0 0.761917
1 0.238083
Name: left, dtype: float64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Y_data = data.left.values.reshape(-1, 1)
X_data = data[[c for c in data.columns if c != "left"]].values
# 训练集和验证集划分
X_train, X_test, Y_train, Y_test = train_test_split(
X_data, Y_data, train_size=0.8, random_state=33
)
X_train_ = torch.from_numpy(X_train).type(torch.float).to(device)
Y_train_ = torch.from_numpy(Y_train).type(torch.float).to(device)
X_test_ = torch.from_numpy(X_test).type(torch.float).to(device)
Y_test_ = torch.from_numpy(Y_test).type(torch.float).to(device)
HRdataset = TensorDataset(X_train_, Y_train_)
import torch.nn.functional as F
import visdom
from torch import nn, optim
from torch.utils.data import DataLoader
class MLPModel(nn.Module):
def __init__(self):
super(MLPModel, self).__init__()
self.linear_1 = nn.Linear(20, 64)
self.linear_2 = nn.Linear(64, 64)
self.linear_3 = nn.Linear(64, 1)
def forward(self, inputs):
outputs = F.relu(self.linear_1(inputs))
outputs = F.relu(self.linear_2(outputs))
outputs = torch.sigmoid(self.linear_3(outputs))
return outputs
model = MLPModel()
model.to(device)
model
MLPModel(
(linear_1): Linear(in_features=20, out_features=64, bias=True)
(linear_2): Linear(in_features=64, out_features=64, bias=True)
(linear_3): Linear(in_features=64, out_features=1, bias=True)
)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
loss_fn = nn.BCELoss()
# 使用 pytorch 可视化模块 visdom 可视化训练损失变化情况
viz = visdom.Visdom(
server="http://localhost",
port=8097,
base_url="/visdom",
username="jinzhongxu",
password="123123",
)
win = "Multilayer Perceptron loss"
opts = dict(
title="train_losses",
xlabel="epoch",
ylabel="loss",
markers=True,
legend=[
"loss",
],
)
viz.line(
[
[
0.0,
]
],
[0.0],
win=win,
opts=opts,
)
Setting up a new session...
'Multilayer Perceptron loss'
batch = 64
epochs = 1000
HRdataloader = DataLoader(HRdataset, batch_size=batch, shuffle=True)
for epoch in range(epochs):
for x, y in HRdataloader:
y_pred = model(x)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 == 0:
with torch.no_grad():
val_loss = loss_fn(model(X_test_), Y_test_).data.item()
print(f"epoch: {epoch}, loss:{val_loss}")
viz.line(
[
[
val_loss,
]
],
[epoch],
win=win,
update="append",
)
epoch: 0, loss:0.5702827572822571
epoch: 10, loss:0.5380138754844666
epoch: 20, loss:0.48325881361961365
epoch: 30, loss:0.43280646204948425
epoch: 40, loss:0.40077832341194153
epoch: 50, loss:0.3730575442314148
epoch: 60, loss:0.3515182137489319
epoch: 70, loss:0.333321213722229
epoch: 80, loss:0.3213912844657898
epoch: 90, loss:0.3078565299510956
epoch: 100, loss:0.3066919445991516
epoch: 110, loss:0.28822797536849976
epoch: 120, loss:0.2822719216346741
epoch: 130, loss:0.2698499262332916
epoch: 140, loss:0.2673708200454712
epoch: 150, loss:0.2626533508300781
epoch: 160, loss:0.25525325536727905
epoch: 170, loss:0.25129374861717224
epoch: 180, loss:0.2505483329296112
epoch: 190, loss:0.23713809251785278
epoch: 200, loss:0.24216558039188385
epoch: 210, loss:0.22742848098278046
epoch: 220, loss:0.2271704226732254
epoch: 230, loss:0.2252519577741623
epoch: 240, loss:0.2193305790424347
epoch: 250, loss:0.21628950536251068
epoch: 260, loss:0.2177349030971527
epoch: 270, loss:0.2127998024225235
epoch: 280, loss:0.21056963503360748
epoch: 290, loss:0.2106262445449829
epoch: 300, loss:0.2119056135416031
epoch: 310, loss:0.21744157373905182
epoch: 320, loss:0.2134646326303482
epoch: 330, loss:0.20614351332187653
epoch: 340, loss:0.21047110855579376
epoch: 350, loss:0.20926110446453094
epoch: 360, loss:0.21647755801677704
epoch: 370, loss:0.21862877905368805
epoch: 380, loss:0.20591312646865845
epoch: 390, loss:0.2010645568370819
epoch: 400, loss:0.2050144225358963
epoch: 410, loss:0.1988362967967987
epoch: 420, loss:0.2004001885652542
epoch: 430, loss:0.1973736584186554
epoch: 440, loss:0.22496183216571808
epoch: 450, loss:0.19971159100532532
epoch: 460, loss:0.19692614674568176
epoch: 470, loss:0.19598537683486938
epoch: 480, loss:0.1948203295469284
epoch: 490, loss:0.19280420243740082
epoch: 500, loss:0.19459198415279388
epoch: 510, loss:0.19615624845027924
epoch: 520, loss:0.19460928440093994
epoch: 530, loss:0.19126252830028534
epoch: 540, loss:0.1938866376876831
epoch: 550, loss:0.19190427660942078
epoch: 560, loss:0.18747270107269287
epoch: 570, loss:0.18757325410842896
epoch: 580, loss:0.1988757699728012
epoch: 590, loss:0.19326147437095642
epoch: 600, loss:0.20158174633979797
epoch: 610, loss:0.18407107889652252
epoch: 620, loss:0.1833522766828537
epoch: 630, loss:0.18116755783557892
epoch: 640, loss:0.1835246980190277
epoch: 650, loss:0.18026278913021088
epoch: 660, loss:0.1771889179944992
epoch: 670, loss:0.1763589233160019
epoch: 680, loss:0.17852289974689484
epoch: 690, loss:0.17640434205532074
epoch: 700, loss:0.17527474462985992
epoch: 710, loss:0.18816079199314117
epoch: 720, loss:0.17597466707229614
epoch: 730, loss:0.17061764001846313
epoch: 740, loss:0.1716785877943039
epoch: 750, loss:0.16623708605766296
epoch: 760, loss:0.16821976006031036
epoch: 770, loss:0.163313090801239
epoch: 780, loss:0.1642996072769165
epoch: 790, loss:0.1681404858827591
epoch: 800, loss:0.16040602326393127
epoch: 810, loss:0.1622639298439026
epoch: 820, loss:0.16376478970050812
epoch: 830, loss:0.15491259098052979
epoch: 840, loss:0.15919749438762665
epoch: 850, loss:0.15590602159500122
epoch: 860, loss:0.15480470657348633
epoch: 870, loss:0.1528528481721878
epoch: 880, loss:0.1514061689376831
epoch: 890, loss:0.15681587159633636
epoch: 900, loss:0.1545121818780899
epoch: 910, loss:0.14895710349082947
epoch: 920, loss:0.14539872109889984
epoch: 930, loss:0.1600637137889862
epoch: 940, loss:0.1435800939798355
epoch: 950, loss:0.14283594489097595
epoch: 960, loss:0.14213287830352783
epoch: 970, loss:0.1424899399280548
epoch: 980, loss:0.14058111608028412
epoch: 990, loss:0.1395450234413147
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK