

PyTorch_Practice/rnn_demo.py at master · zhangxiann/PyTorch_Practice · GitHub
source link: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/rnn_demo.py
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

PyTorch_Practice/rnn_demo.py at master · zhangxiann/PyTorch_Practice · GitHubPermalink
234 lines (169 sloc) 6.14 KB
# -*- coding: utf-8 -*- """ RNN 实现人名分类 """ from io import open import glob import unicodedata import string import math import os import time import torch.nn as nn import torch import random import matplotlib.pyplot as plt import torch.utils.data from common_tools import set_seed import enviroments
set_seed(1) # 设置随机种子 BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cpu")
# Read a file and split into lines def readLines(filename): lines = open(filename, encoding='utf-8').read().strip().split('\n') return [unicodeToAscii(line) for line in lines]
def unicodeToAscii(s): return ''.join( c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn' and c in all_letters)
# Find letter index from all_letters, e.g. "a" = 0 def letterToIndex(letter): return all_letters.find(letter)
# Just for demonstration, turn a letter into a <1 x n_letters> Tensor def letterToTensor(letter): tensor = torch.zeros(1, n_letters) tensor[0][letterToIndex(letter)] = 1 return tensor
# Turn a line into a <line_length x 1 x n_letters>, # or an array of one-hot letter vectors def lineToTensor(line): tensor = torch.zeros(len(line), 1, n_letters) for li, letter in enumerate(line): tensor[li][0][letterToIndex(letter)] = 1 return tensor
def categoryFromOutput(output): top_n, top_i = output.topk(1) category_i = top_i[0].item() return all_categories[category_i], category_i
def randomChoice(l): return l[random.randint(0, len(l) - 1)]
def randomTrainingExample(): category = randomChoice(all_categories) # 选类别 line = randomChoice(category_lines[category]) # 选一个样本 category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long) line_tensor = lineToTensor(line) # str to one-hot return category, line, category_tensor, line_tensor
def timeSince(since): now = time.time() s = now - since m = math.floor(s / 60) s -= m * 60 return '%dm %ds' % (m, s)
# Just return an output given a line def evaluate(line_tensor): hidden = rnn.initHidden()
for i in range(line_tensor.size()[0]): output, hidden = rnn(line_tensor[i], hidden)
return output
def predict(input_line, n_predictions=3): print('\n> %s' % input_line) with torch.no_grad(): output = evaluate(lineToTensor(input_line))
# Get top N categories topv, topi = output.topk(n_predictions, 1, True)
for i in range(n_predictions): value = topv[0][i].item() category_index = topi[0][i].item() print('(%.2f) %s' % (value, all_categories[category_index]))
def get_lr(iter, learning_rate): lr_iter = learning_rate if iter < n_iters else learning_rate*0.1 return lr_iter
class RNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(RNN, self).__init__()
self.hidden_size = hidden_size
self.u = nn.Linear(input_size, hidden_size) self.w = nn.Linear(hidden_size, hidden_size) self.v = nn.Linear(hidden_size, output_size)
self.tanh = nn.Tanh() self.softmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, hidden):
u_x = self.u(inputs)
hidden = self.w(hidden) hidden = self.tanh(hidden + u_x)
output = self.softmax(self.v(hidden))
return output, hidden
def initHidden(self): return torch.zeros(1, self.hidden_size)
def train(category_tensor, line_tensor): hidden = rnn.initHidden()
rnn.zero_grad()
line_tensor = line_tensor.to(device) hidden = hidden.to(device) category_tensor = category_tensor.to(device)
for i in range(line_tensor.size()[0]): output, hidden = rnn(line_tensor[i], hidden)
loss = criterion(output, category_tensor) loss.backward()
# Add parameters' gradients to their values, multiplied by learning rate for p in rnn.parameters(): p.data.add_(-learning_rate, p.grad.data)
return output, loss.item()
if __name__ == "__main__": # config path_txt = os.path.join(enviroments.names,"*.txt") all_letters = string.ascii_letters + " .,;'" n_letters = len(all_letters) # 52 + 5 字符总数 print_every = 5000 plot_every = 5000 learning_rate = 0.005 n_iters = 200000
# step 1 data # Build the category_lines dictionary, a list of names per language category_lines = {} all_categories = [] for filename in glob.glob(path_txt): category = os.path.splitext(os.path.basename(filename))[0] all_categories.append(category) lines = readLines(filename) category_lines[category] = lines
n_categories = len(all_categories)
# step 2 model n_hidden = 128 # rnn = RNN(n_letters, n_hidden, n_categories) rnn = RNN(n_letters, n_hidden, n_categories)
rnn.to(device)
# step 3 loss criterion = nn.NLLLoss()
# step 4 optimize by hand
# step 5 iteration current_loss = 0 all_losses = [] start = time.time() for iter in range(1, n_iters + 1): # sample category, line, category_tensor, line_tensor = randomTrainingExample()
# training output, loss = train(category_tensor, line_tensor)
current_loss += loss
# Print iter number, loss, name and guess if iter % print_every == 0: guess, guess_i = categoryFromOutput(output) correct = '✓' if guess == category else '✗ (%s)' % category print('Iter: {:<7} time: {:>8s} loss: {:.4f} name: {:>10s} pred: {:>8s} label: {:>8s}'.format( iter, timeSince(start), loss, line, guess, correct))
# Add current loss avg to list of losses if iter % plot_every == 0: all_losses.append(current_loss / plot_every) current_loss = 0 path_model = os.path.join(BASE_DIR, "rnn_state_dict.pkl") torch.save(rnn.state_dict(), path_model) plt.plot(all_losses) plt.show()
predict('Yue Tingsong') predict('Yue tingsong') predict('yutingsong')
predict('test your name')
Recommend
-
48
RNN is a class of artificial neural network where connections between nodes form a directed graph along a sequence. It...
-
32
image credit —
-
29
字幕组双语原文: 告别RNN,迎接TCN 英语原文:
-
7
README.md 这是我用于学习 Skip-gram 的笔记。 文中会有一些公式,如果 github 出现公式乱码问题,请通过我的博客查看:h...
-
9
[PyTorch 学习笔记] 8.4 手动实现 RNN哈尔滨工程大学 计算机硕士在读 本章代码:
-
12
An attempt at implementing char-rnn with PyTorch Hello! I spent a bunch of time in the last couple of weeks implementing a version of char-rnn...
-
15
PyTorch Android Examples A list of Android demo apps built on the powerful PyTorch Mobile platform. HelloWorld
-
4
tf-rnn/sequence_example.ipynb at master · dennybritz/tf-rnn · GitHub
-
7
Figure 1 GNN Demo Using PyTorck Lightning an...
-
8
【pytorch】RNN 2022年06月14日 Author:Guofei 文章归类: 0x26_torch 文章编号: 262 ...
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK