从代码学习深度学习 - LSTM PyTorch版

news2025/4/9 1:52:42

文章目录

  • 前言
  • 一、数据加载与预处理
    • 1.1 代码实现
    • 1.2 功能解析
  • 二、LSTM介绍
    • 2.1 LSTM原理
    • 2.2 模型定义
      • 代码解析
  • 三、训练与预测
    • 3.1 训练逻辑
      • 代码解析
    • 3.2 可视化工具
      • 功能解析
      • 功能结果
  • 总结


前言

深度学习中的循环神经网络(RNN)及其变种长短期记忆网络(LSTM)在处理序列数据(如文本、时间序列等)方面表现出色。本篇博客将通过一个完整的PyTorch实现,带你从零开始学习如何使用LSTM进行文本生成任务。我们将基于H.G. Wells的《时间机器》数据集,逐步展示数据预处理、模型定义、训练与预测的全过程。通过代码和文字的结合,帮助你深入理解LSTM的实现细节及其在自然语言处理中的应用。

本文的代码分为四个主要部分:

  1. 数据加载与预处理(utils_for_data.py
  2. LSTM模型定义(Jupyter Notebook中的模型部分)
  3. 训练与预测逻辑(utils_for_train.py
  4. 可视化工具(utils_for_huitu.py

以下是详细的实现与解析。


一、数据加载与预处理

首先,我们需要加载《时间机器》数据集并进行预处理。以下是utils_for_data.py中的完整代码及其功能说明。

1.1 代码实现

import random
import re
import torch
from collections import Counter

def read_time_machine():
    """将时间机器数据集加载到文本行的列表中"""
    with open('timemachine.txt', 'r') as f:
        lines = f.readlines()
    return [re.sub('[^A-Za-z]+', ' ', line).strip().lower() for line in lines]

def tokenize(lines, token='word'):
    """将文本行拆分为单词或字符词元"""
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        print(f'错误:未知词元类型:{
     token}')

def count_corpus(tokens):
    """统计词元的频率"""
    if not tokens:
        return Counter()
    if isinstance(tokens[0], list):
        flattened_tokens = [token for sublist in tokens for token in sublist]
    else:
        flattened_tokens = tokens
    return Counter(flattened_tokens)

class Vocab:
    """文本词表类,用于管理词元及其索引的映射关系"""
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        self.tokens = tokens if tokens is not None else []
        self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
        counter = self._count_corpus(self.tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
        self.idx_to_token = ['<unk>'] + self.reserved_tokens
        self.token_to_idx = {
   token: idx for idx, token in enumerate(self.idx_to_token)}
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    @staticmethod
    def _count_corpus(tokens):
        if not tokens:
            return Counter()
        if isinstance(tokens[0], list):
            tokens = [token for sublist in tokens for token in sublist]
        return Counter(tokens)

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self[token] for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self):
        return 0

    @property
    def token_freqs(self):
        return self._token_freqs

def load_corpus_time_machine(max_tokens=-1):
    lines = read_time_machine()
    tokens = tokenize(lines, 'char')
    vocab = Vocab(tokens)
    corpus = [vocab[token] for line in tokens for token in line]
    if max_tokens > 0:
        corpus = corpus[:max_tokens]
    return corpus, vocab

def seq_data_iter_random(corpus, batch_size, num_steps):
    offset = random.randint(0, num_steps - 1)
    corpus = corpus[offset:]
    num_subseqs = (len(corpus) - 1) // num_steps
    initial_indices = list(range(0, num_subseqs * num_steps, num_steps))
    random.shuffle(initial_indices)

    def data(pos):
        return corpus[pos:pos + num_steps]

    num_batches = num_subseqs // batch_size
    for i in range(0, batch_size * num_batches, batch_size):
        initial_indices_per_batch = initial_indices[i:i + batch_size]
        X = [data(j) for j in initial_indices_per_batch]
        Y = [data(j + 1) for j in initial_indices_per_batch]
        yield torch.tensor(X), torch.tensor(Y)

def seq_data_iter_sequential(corpus, batch_size, num_steps):
    offset = random.randint(0, num_steps)
    num_tokens = ((len(corpus) - offset - 1) // batch_size) *

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2329560.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据技术发展与应用趋势分析

大数据技术发展与应用趋势分析 文章目录 大数据技术发展与应用趋势分析1. 大数据概述2 大数据技术架构2.1 数据采集层2.2 数据存储层2.3 数据处理层2.4 数据分析层 3 大数据发展趋势3.1 AI驱动的分析与自动化3.2 隐私保护分析技术3.3 混合云架构的普及3.4 数据网格架构3.5 量子…

与Linux操作系统相关的引导和服务

目录 一.Linux操作系统引导过程 1.1引导过程总览 1.2系统初始化进程 1.2.1init进程 1.2.2sysmted 1.3systemd单元类型 二.排除启动类故障 2.1MBR扇区故障 2.1.1故障原因 2.1.2故障现象 2.1.3解决办法 2.1.4模拟修复MBR扇区故障 1)添加新的硬盘 2&#xff09;进行…

STM32单片机入门学习——第16节: [6-4] PWM驱动LED呼吸灯PWM驱动舵机PWM驱动直流电机

写这个文章是用来学习的,记录一下我的学习过程。希望我能一直坚持下去,我只是一个小白,只是想好好学习,我知道这会很难&#xff0c;但我还是想去做&#xff01; 本文写于&#xff1a;2025.04.05 STM32开发板学习——第16节: [6-4] PWM驱动LED呼吸灯&PWM驱动舵机&PWM驱…

基础框架系列分享:一个通用的Excel报表生成管理框架

由于我们系统经常要生成大量的Excel报表&#xff08;Word&#xff0c;PDF报表也有&#xff0c;另行分享&#xff09;&#xff0c;最初始他们的方案是&#xff0c;设计一个表&#xff0c;和Excel完全对应&#xff0c;然后读表&#xff0c;把数据填进去&#xff0c;这显然是非常不…

Ansible(4)—— Playbook

目录 一、Ansible Playbook : 1、Play &#xff1a; 2、Playbook&#xff1a; 二、Ansible Playbook 格式&#xff1a; 1、空格&#xff1a; 2、破折号&#xff08; - &#xff09;&#xff1a; 3、Play 格式&#xff1a; 三、查找用于任务的模块&#xff1a; 1、模块…

自学-C语言-基础-数组、函数、指针、结构体和共同体、文件

这里写自定义目录标题 代码环境&#xff1a;&#xff1f;问题思考&#xff1a;一、数组二、函数三、指针四、结构体和共同体五、文件问题答案&#xff1a; 代码环境&#xff1a; Dev C &#xff1f;问题思考&#xff1a; 把上门的字母与下面相同的字母相连&#xff0c;线不能…

蓝桥云客--团队赛

2.团队赛【算法赛】 - 蓝桥云课 问题描述 蓝桥杯最近推出了一项团队赛模式&#xff0c;要求三人组队参赛&#xff0c;并规定其中一人必须担任队长。队长的资格很简单&#xff1a;其程序设计能力值必须严格大于其他两名队友程序设计能力值的总和。 小蓝、小桥和小杯正在考虑报名…

C-S模式之实现一对一聊天

天天开心&#xff01;&#xff01;&#xff01; 文章目录 一、如何实现一对一聊天&#xff1f;1. 服务器设计2. 客户端设计3. 服务端代码实现4. 客户端代码实现5. 实现说明6.实验结果 二、改进常见的服务器高并发方案1. 多线程/多进程模型2. I/O多路复用3. 异步I/O&#xff08;…

[Deep-ML]Transpose of a Matrix(矩阵的转置)

Transpose of a Matrix&#xff08;矩阵的转置&#xff09; 题目链接&#xff1a; Transpose of a Matrix&#xff08;矩阵的转置&#xff09;https://www.deep-ml.com/problems/2 题目描述&#xff1a; 难度&#xff1a; easy&#xff08;简单&#xff09;。 分类&#…

智慧节能双突破 强力巨彩谷亚VK系列刷新LED屏使用体验

当前全球节能减排趋势明显&#xff0c;LED节能屏作为显示技术的佼佼者&#xff0c;正逐渐成为市场的新宠。强力巨彩谷亚万境VK系列节能智慧屏凭借三重技术保障、四大智能设计以及大师臻彩画质&#xff0c;在实现节能效果的同时&#xff0c;更在智慧显示领域树立新的标杆。   …

html 给文本两端加虚线自适应

效果图&#xff1a; <div class"separator">文本 </div>.separator {width: 40%;border-style: dashed;display: flex;align-items: center;color: #e2e2e2;font-size: 14px;line-height: 20px;border-color: #e2e2e2;border-width: 0; }.separator::bef…

leetcode4.寻找两个正序数组中的中位数

思路源于 LeetCode004-两个有序数组的中位数-最优算法代码讲解 基本思路是将两个数组看成一个数组&#xff0c;然后划分为两个部分&#xff0c;若为奇数左边部分个数多1&#xff0c;若为偶数左边部分等于右边部分个数。i表示数组1划分位置&#xff08;i为4是索引4也表示i的左半…

0101安装matplotlib_numpy_pandas-报错-python

文章目录 1 前言2 报错报错1&#xff1a;ModuleNotFoundError: No module named distutils报错2&#xff1a;ERROR:root:code for hash blake2b was not found.报错3&#xff1a;**ModuleNotFoundError: No module named _tkinter**报错4&#xff1a;UserWarning: Glyph 39044 …

OSCP - Proving Grounds- SoSimple

主要知识点 wordpress 插件RCE漏洞sudo -l shell劫持 具体步骤 依旧是nmap 起手&#xff0c;只发现了22和80端口&#xff0c;但80端口只能看到一张图 Nmap scan report for 192.168.214.78 Host is up (0.46s latency). Not shown: 65533 closed tcp ports (reset) PORT …

C语言求3到100之间的素数

一、代码展示 二、运行结果 三、感悟思考 注意: 这个题思路他是一个试除法的一个思路 先进入一个for循环 遍历3到100之间的数字 第二个for循环则是 判断他不是素数 那么就直接退出 这里用break 是素数就打印出来 在第一个for循环内 第二个for循环外

【2025】物联网发展趋势介绍

目录 物联网四层架构感知识别层网络构建层管理服务层——**边缘存储**边缘计算关键技术&#xff1a;综合应用层——信息应用 物联网四层架构 综合应用层&#xff1a;信息应用 利用获取的信息和知识&#xff0c;支持各类应用系统的运转 管理服务层&#xff1a;信息处理 对数据进…

如何查看 MySQL 的磁盘空间使用情况:从表级到数据库级的分析

在日常数据库管理中&#xff0c;了解每张表和每个数据库占用了多少磁盘空间是非常关键的。这不仅有助于我们监控数据增长&#xff0c;还能为性能优化提供依据。 Google Gemini中国版调用Google Gemini API&#xff0c;中国大陆优化&#xff0c;完全免费&#xff01;https://ge…

汇编学习之《移位指令》

这章节学习前需要回顾之前的标志寄存器的内容&#xff1a; 汇编学习之《标志寄存器》 算数移位指令 SAL (Shift Arithmetic Left)算数移位指令 : 左移一次&#xff0c;最低位用0补位&#xff0c;最高位放入EFL标志寄存器的CF位&#xff08;进位标志&#xff09; OllyDbg查看…

Nature Communications上交、西湖大学、复旦大学研发面向机器人多模式运动的去电子化刚弹耦合高频自振荡驱动单元

近年来&#xff0c;轻型仿生机器人因其卓越的运动灵活性与环境适应性受到国际机器人领域的广泛关注。然而&#xff0c;现有气动驱动器普遍受限于低模量粘弹性材料的回弹滞后效应与能量耗散特性&#xff0c;加之其"非刚即柔"的二元结构设计范式&#xff0c;难以同时满…

对备忘录模式的理解

对备忘录模式的理解 一、场景1、题目【[来源](https://kamacoder.com/problempage.php?pid1095)】1.1 题目描述1.2 输入描述1.3 输出描述1.4 输入示例1.5 输出示例 2、理解需求 二、不采用备忘录设计模式1、代码2、问题3、错误的备忘录模式 三、采用备忘录设计模式1、代码1.1 …