深度学习-pytorch_lightning框架使用实例

news2024/11/15 9:05:59

下面是我写过的一个pytorch_lightning项目的代码框架。关键代码已经省略。

模型构建

import pytorch_lightning as pl
from pytorch_lightning.plugins.io import TorchCheckpointIO as tcio
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from pathlib import Path
import random

class MyModel(pl.LightningModule):
	# 参数省略了
    def __init__(self, ... learning_rate=2e-5):
        super(MyModel, self).__init__()
		...
        encoder_layer = nn.TransformerEncoderLayer()
        self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)
        ...
        self.apply(self.init_weights)

    def init_weights(self, layer):
        if isinstance(layer, (nn.Linear, nn.Embedding)):
            if isinstance(layer.weight, torch.Tensor):
                layer.weight.data.normal_(mean=0.0, std=self.initializer_range)
        elif isinstance(layer, nn.LayerNorm):
            layer._epsilon = 1e-12 # 一个小的常数,用于避免除以零误差。默认值为 1e-5
	
	# args是自定义的参数,这里略去
    def forward(self, args):
        # 层构建也略去
        return F.relu(dense_outputs).squeeze(dim=1).exp()

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        predicts = self(**inputs)
        loss = F.l1_loss(predicts, targets, reduction="mean")
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        predicts = self(**inputs)
        loss = F.l1_loss(predicts, targets, reduction="mean")
        self.log("val_loss", loss)
        for i in range(targets.size(0)):
            predict = predicts[i].item()
            target = targets[i].item()
            print(f"{target} => {predict} ")
        print(f"loss = {loss.item()}")

    def test_step(self, batch, batch_idx):
        inputs, targets = batch
        predicts = self(**inputs)
        loss = F.l1_loss(predicts, targets, reduction="mean")
        for i in range(targets.size(0)):
            predict = predicts[i].item()
            target = targets[i].item()
            print(f"{target} => {predict} ")
        print(f"loss = {loss.item()}")

    #只要在training_step返回了loss,就会自动反向传播,调用lr_scheduler
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": ReduceLROnPlateau(optimizer, mode="min", factor=0.1, patience=10,
                                               eps=1e-9, verbose=True),
                "interval": "epoch",
                "monitor": "val_loss",
            },
        }

数据导入

from torch.utils.data import Dataset, DataLoader

def process_data(input_path: Path) -> List[List]:

class DnaDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)


def collate_fn(batch):
    xxxx, days = zip(*batch)
    inputs = dict()
    inputs['xxxx'] = torch.tensor(base_ids, dtype=torch.long)
    ...
    # 构造一个字典
    targets = torch.tensor(days).float()
    return inputs, targets

-- use in main --

    train_data = process_data(work_path / 'train_data.txt')
    val_data = process_data(work_path / 'val_data.txt')

    train_dataset = DnaDataset(train_data)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              collate_fn=collate_fn, drop_last=True, num_workers=64)

    val_dataset = DnaDataset(val_data)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            collate_fn=collate_fn,drop_last=True, num_workers=64)

训练

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor

checkpoint_callback = ModelCheckpoint(
    filename="{epoch:04d}-{val_loss:.6f}",
    monitor="val_loss", # 监控的变量
    verbose=True, # 冗余模式
    save_last=True, # 覆盖保存
    save_top_k=10, # 保存k个最好的模型
    mode="min", # val_loss越小越好
)
# 监控validation_step中的变量,如果不能更优则提前停止训练
early_stop_callback = EarlyStopping(
    patience=10, # 持续几个epoch没有提升,则停止训练
    min_delta=0.01, # 0.01 最小的该变量,即当监控的量的绝对值变量小于该值,则认为没有新提升
    monitor='val_loss',
    mode='min',
    verbose=True)

lr_monitor = LearningRateMonitor(logging_interval='step')

lr = 2e-04
model = MyModel(hidden_size=HIDDEN_SIZE, learning_rate=lr)
ckpt_file = "weights_last.ckpt"
ckpt_path = work_path / ckpt_file
trainer = pl.Trainer(gpus=2, accelerator='dp',
                     # resume_from_checkpoint=str(ckpt_path),
                     gradient_clip_val=2.0,
                     accumulate_grad_batches=64,
                     sync_batchnorm=True,
                     #min_epochs=40,
                     #max_epochs=2000,
                     callbacks=[checkpoint_callback, early_stop_callback, lr_monitor])

trainer.fit(model, train_loader, val_loader)

其中,在训练阶段,test_step是不会被调用的。只有训练得差不多了,我们才手动调用。

测试

import pytorch_lightning as pl
from pytorch_lightning.plugins.io import TorchCheckpointIO as tcio
import torch.multiprocessing
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import random
from model_training import MyModel, DnaDataset, collate_fn, process_data

model = MyModel(hidden_size=HIDDEN_SIZE, learning_rate=lr)
ckpt_path = "xxxx.ckpt"
trainer = pl.Trainer(resume_from_checkpoint=ckpt_path)
# tc = tcio()
# ckpt_dict = tc.load_checkpoint(path=ckpt_path)
# ckpt_dict是一个字典,里面有各种权重数据。如果想看可以看。

test_data = process_data(work_path / 'test_data.txt')
test_dataset = DnaDataset(test_data)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True,
                              collate_fn=collate_fn, drop_last=True, num_workers=64)

trainer.test(model, dataloaders=test_dataloader)

转化为onnx格式

这里的代码我跑通了,但是打出来的模型跑测试贼慢。我怀疑是模型太复杂了🚬还没想好解决办法,之后解决了再更新。

model = MyModel(hidden_size=HIDDEN_SIZE, learning_rate=lr)
ckpt_path = "last.ckpt"
trainer = pl.Trainer(resume_from_checkpoint=ckpt_path)
tc = tcio()
ckpt_dict = tc.load_checkpoint(path=ckpt_path)
model.load_state_dict(ckpt_dict["state_dict"])
model.eval()

val_data = process_data(work_path / 'val_data.txt')
val_dataset = DnaDataset(val_data)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                        collate_fn=collate_fn,
                        drop_last=True, num_workers=64)
                        
# 随便生成1个数据就好。torch.onnx.export实际进行了一次推理
for d in val_loader:
    a, b = d
    break
import torch.onnx
torch.onnx.export(model, (a["x1"], a["x2"], a["x3"]),
                  "last.onnx", # 打包后的文件名
                  # 如果有多个输入,可以通过input_names指明输入变量的名称
                  input_names=["x1", "x2", "x3"],
                  )

查看模型结构

使用 https://netron.app/,我下载了win10版,打开一个onnx格式的文件,可以看到这样的:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/638108.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

师生相逢,潇潇洒洒出品

师生相逢2023年6月10日潇潇洒洒出品骑行肩并肩 相望曾相识①遥忆多年前 青春勃发时豪情梦摘星 清纯玉壶冰感叹岁月老 友情弥久新寄情山水间 天涯不负卿①骑行路上,偶遇30年前的学生,现在是骑行群的骑友,共同的兴趣爱好使我们30年后再相逢&…

【服务器数据恢复】热备盘激活失败导致raid5瘫痪的数据恢复案例

服务器数据恢复环境: 一台EMC存储中数块磁盘组建了一组raid5磁盘阵列,阵列中有2块热备盘;上层采用ZFS文件系统,划分了一个lun,供sun小机使用。 服务器故障&检测: 存储在正常运行中突然崩溃无法使用&…

STM32单片机(五)第一节:EXTI外部中断

❤️ 专栏简介:本专栏记录了从零学习单片机的过程,其中包括51单片机和STM32单片机两部分;建议先学习51单片机,其是STM32等高级单片机的基础;这样再学习STM32时才能融会贯通。 ☀️ 专栏适用人群 :适用于想要…

AST反混淆js代码——猿人学竞赛第二题

猿人学JS比赛第二题解混淆 分析代码解混淆代码结果 前一段时间参加猿人学js比赛,今天把第二题的还原做一下笔记。 分析代码 首先,我们需要对混淆js代码进行分析,确定还原的思路,才能书写解混淆代码进行还原。代码是静态的&#x…

KYOCERA Programming Contest 2023(AtCoder Beginner Contest 305)(A、B、C、D)[施工中]

文章目录 A - Water Station(模拟)B - ABCDEFG(模拟)C - Snuke the Cookie Picker(模拟、暴力)D - Sleep Log(二分,前缀) A - Water Station(模拟) 题意:在[0,100]所有 x % 5 0的地方设置一个水站&#x…

由源码生成Python可调用的dll

1. 不带参数的函数与调用 blog.csdn.net/qq 40833391/article/details/128000638python编程(python调用dll程序)_python 调用dll_嵌入式-老费的博客-CSDN博客思路很简单,只需要在Visual Studio中设置输出类型即可 1.1. 创建项目 首先创建一…

【HashMap集合】存储学生对象并遍历

HashMap集合存储学生对象并遍历 1.键是String,值是Student 需求:创建一个HashMap集合,键是学号(String),值是学生对象(Student)。存储三个键值对元素,并遍历 思路: 定义学生类 创建HashMap集合对象 创建…

uni.navigateBack()返回上一页携带参数的实现

a页面跳转b页面,再b页面点击确定返回到a页面刷新列表 //b页面confirm(){let pages getCurrentPages();//当前页面let prevPage pages[pages.length - 2];//上一个页面prevPage.setData({//直接给上一个页面赋值isRefresh: true})uni.navigateBack();} 返回到a页面…

适用于 Linux 的 Windows 子系统wsl文档

参考链接:https://learn.microsoft.com/zh-cn/windows/wsl/ 鸟哥的Linux私房菜:http://cn.linux.vbird.org/ http://cn.linux.vbird.org/linux_basic/linux_basic.php http://cn.linux.vbird.org/linux_server/ 目录 安装列出可用的 Linux 发行版列出已…

【Protobuf】Map类型的使用

文章目录 2.4 map 类型一、 基本认识二、map相关函数4.3 contact2.4的改写 2.4 map 类型 本系列文章将通过对通讯录项目的不断完善,带大家由浅入深的学习Protobuf的使用。这是Contacts的2.4版本,在这篇文章中将带大家学习Protobuf的 map 语法&#xff0c…

STM32F407软件模拟I2C实现MPU6050通讯(CUBEIDE)

STM32F407软件模拟I2C实现MPU6050通讯(CUBEIDE) 文章目录 STM32F407软件模拟I2C实现MPU6050通讯(CUBEIDE)模拟I2C读写的实现mpu6050_iic.cmpu6050_iic.h代码分析 复位,读取温度,角度等函数封装mpu6050.cmpu…

HTTPS 原理浅析及其在 Android 中的使用

1.HTTP协议的不足 HTTP1.x在传输数据时,所有传输的内容都是明文,客户端和服务器端都无法验证对方的身份,存在的问题如下: 通信使用明文(不加密),内容可能会被窃听;不验证通信方的身份,有可能遭…

Hbase-- 03

4.原理加强 4.1数据存储 4.1.1行式存储 传统的行式数据库将一个个完整的数据行存储在数据页中 4.1.2列式存储 列式数据库是将同一个数据列的各个值存放在一起 传统行式数据库的特性如下: ①数据是按行存储的。 ②没有索引的查询使用大量I/O。比如一般的数据库表…

vue3 element-plus后台管理系统实现登录与记住密码功能

一、效果 二、代码部分 1、勾选记住密码布局代码 2、判断是否勾选,勾选则保存账号密码,否则不保存账号密码,由于是demo,故并没有做加密,如果是生成最好是对密码做加密处理。 3、页面挂载的时候需要背叛的是否保存密码,…

JDK8 ConcurrentHashMap 怎么放弃 Lock 使用 synchronized 了

synchronized 之前一直都是重量级锁,但是 JDK6 中官方是对他进行过升级,引入了偏向锁,轻量级锁,重量级锁,现在采用的是锁升级的方式去做的。针对synchronized 获取锁的方式,JVM 使用了锁升级的优化方式&…

十行代码,就能真正让你理解DMA(CPU的秘书)

下面的代码是单片机串口发送数据的程序. char a0xAA;//定义变量a,值为0xAA; TXREG a;//把数据由内存转移到串口外设;那我们定义的变量a的值存储在哪里了呢?可以看下单片机的逻辑框图。 变量其实都是存在一个叫SRAM的存储器中,它…

Playwright 和 Selenium 的区别是什么?

前言 最近有不少同学问到 Playwright 和 Selenium 的区别是什么?有同学可能之前学过 selenium 了,再学一个 playwright 感觉有些多余,可能之前有项目已经是 selenium 写的了,换成 playwright 需要时间成本,并且可能有…

【支付系统】核心支付流程

支付在产品中常见的用处为购买和充值.这两种功能操作大相径庭,其中购买相对充值多了很多步骤,它需要锁商品或者库存,还需要超时未支付取消订单等操作.在这篇文章中主要探讨支付部分,属于购买和充值公共部分. 下面是绘制的简易支付时序图 以上时序图并非完整,其实核心步骤就是, …

商城购买会员打折满减优惠券商品

文章目录 前言一、代码结构二、UML图三、代码实现3.1.domain3.2.enums3.3.strategy3.4.service3.5.config 四、单元测试五、模式应用六、问题及优化思路6.1.问题6.2.优化 总结 前言 使用策略模式、工厂方法模式、单例模式实现一些购买策略,需求:商城商品…

服装库存管理系统 Mybatis+Layui+MVC+JSP【完整功能介绍+实现详情+源码】

完整源码资料 地址直达:http://t.csdn.cn/RWsGw 前言 这是大二时候写的第一个Java项目,框架基本上都没有用到、而且用到的技术很老很老。只简单使用了一个Mybatis简化数据库的操作。前端框架用的还是Layui,贼难用。闲的无聊,对这…