文章目录
- 前言
- 一、数据加载与预处理
-
- 1.1 代码实现
- 1.2 功能解析
- 二、LSTM介绍
-
- 2.1 LSTM原理
- 2.2 模型定义
-
- 代码解析
- 三、训练与预测
-
- 3.1 训练逻辑
-
- 代码解析
- 3.2 可视化工具
-
- 功能解析
- 功能结果
- 总结
前言
深度学习中的循环神经网络(RNN)及其变种长短期记忆网络(LSTM)在处理序列数据(如文本、时间序列等)方面表现出色。本篇博客将通过一个完整的PyTorch实现,带你从零开始学习如何使用LSTM进行文本生成任务。我们将基于H.G. Wells的《时间机器》数据集,逐步展示数据预处理、模型定义、训练与预测的全过程。通过代码和文字的结合,帮助你深入理解LSTM的实现细节及其在自然语言处理中的应用。
本文的代码分为四个主要部分:
- 数据加载与预处理(
utils_for_data.py
) - LSTM模型定义(Jupyter Notebook中的模型部分)
- 训练与预测逻辑(
utils_for_train.py
) - 可视化工具(
utils_for_huitu.py
)
以下是详细的实现与解析。
一、数据加载与预处理
首先,我们需要加载《时间机器》数据集并进行预处理。以下是utils_for_data.py
中的完整代码及其功能说明。
1.1 代码实现
import random
import re
import torch
from collections import Counter
def read_time_machine():
"""将时间机器数据集加载到文本行的列表中"""
with open('timemachine.txt', 'r') as f:
lines = f.readlines()
return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]
def tokenize(lines, token='word'):
"""将文本行拆分为单词或字符词元"""
if token == 'word':
return [line.split() for line in lines]
elif token == 'char':
return [list(line) for line in lines]
else:
print(f'错误:未知词元类型:{
token}')
def count_corpus(tokens):
"""统计词元的频率"""
if not tokens:
return Counter()
if isinstance(tokens[0], list):
flattened_tokens = [token for sublist in tokens for token in sublist]
else:
flattened_tokens = tokens
return Counter(flattened_tokens)
class Vocab:
"""文本词表类,用于管理词元及其索引的映射关系"""
def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
self.tokens = tokens if tokens is not None else []
self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
counter = self._count_corpus(self.tokens)
self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
self.idx_to_token = ['<unk>'] + self.reserved_tokens
self.token_to_idx = {
token: idx for idx, token in enumerate(self.idx_to_token)}
for token, freq in self._token_freqs:
if freq < min_freq:
break
if token not in self.token_to_idx:
self.idx_to_token.append(token)
self.token_to_idx[token] = len(self.idx_to_token) - 1
@staticmethod
def _count_corpus(tokens):
if not tokens:
return Counter()
if isinstance(tokens[0], list):
tokens = [token for sublist in tokens for token in sublist]
return Counter(tokens)
def __len__(self):
return len(self.idx_to_token)
def __getitem__(self, tokens):
if not isinstance(tokens, (list, tuple)):
return self.token_to_idx.get(tokens, self.unk)
return [self[token] for token in tokens]
def to_tokens(self, indices):
if not isinstance(indices, (list, tuple)):
return self.idx_to_token[indices]
return [self.idx_to_token[index] for index in indices]
@property
def unk(self):
return 0
@property
def token_freqs(self):
return self._token_freqs
def load_corpus_time_machine(max_tokens=-1):
lines = read_time_machine()
tokens = tokenize(lines, 'char')
vocab = Vocab(tokens)
corpus = [vocab[token] for line in tokens for token in line]
if max_tokens > 0:
corpus = corpus[:max_tokens]
return corpus, vocab
def seq_data_iter_random(corpus, batch_size, num_steps):
offset = random.randint(0, num_steps - 1)
corpus = corpus[offset:]
num_subseqs = (len(corpus) - 1) // num_steps
initial_indices = list(range(0, num_subseqs * num_steps, num_steps))
random.shuffle(initial_indices)
def data(pos):
return corpus[pos:pos + num_steps]
num_batches = num_subseqs // batch_size
for i in range(0, batch_size * num_batches, batch_size):
initial_indices_per_batch = initial_indices[i:i + batch_size]
X = [data(j) for j in initial_indices_per_batch]
Y = [data(j + 1) for j in initial_indices_per_batch]
yield torch.tensor(X), torch.tensor(Y)
def seq_data_iter_sequential(corpus, batch_size, num_steps):
offset = random.randint(0, num_steps)
num_tokens = ((len(corpus) - offset - 1) // batch_size) *