【深度学习】实验12 使用PyTorch训练模型

news2025/1/7 5:59:50

文章目录

  • 使用PyTorch训练模型
    • 1. 线性回归类
    • 2. 创建数据集
    • 3. 训练模型
    • 4. 测试模型
  • 附:系列文章

使用PyTorch训练模型

PyTorch是一个基于Python的科学计算库,它是一个开源的机器学习框架,由Facebook公司于2016年开源。它提供了构建动态计算图的功能,可以更自然地使用Python语言编写深度神经网络的程序,具有易于使用、灵活、高效等特点,被广泛应用于深度学习任务中。

PyTorch的核心是动态计算图(Dynamic Computational Graph),这意味着计算图是在运行时动态生成的,而不是预先编译好的。这个特点使得PyTorch具有高度的灵活性,可以更加轻松地进行实验和调试。同时,它也有一个静态计算图模块,可以用于生产环境中,提高计算效率。

另外,PyTorch的另一个特点是它的张量计算。张量是PyTorch中的核心数据结构,类似于NumPy中的数组。PyTorch支持GPU加速,可以使用GPU进行张量计算,大大提高了计算效率。同时,它也支持自动求导功能,可以自动计算张量的梯度,使得深度学习的模型训练更加便捷。

PyTorch还提供了丰富的模型库,包括经典的深度学习模型,如卷积神经网络(CNN)、循环神经网络(RNN)和生成对抗网络(GAN),以及各种领域的预训练模型,如自然语言处理(NLP)和计算机视觉(CV),可以快速搭建和训练模型。

PyTorch也具有良好的社区支持。它的文档详细且易于理解,社区提供了大量的示例和教程,可以帮助用户更好地学习和使用PyTorch。同时,PyTorch还有一个活跃的开发团队,定期发布新的版本,修复bug和增加新的特性,保证了PyTorch的稳定性和可用性。

总的来说,PyTorch是一个强大、灵活、易于使用的机器学习框架,具有良好的社区支持和广泛的应用领域,能够满足不同用户的需求。随着人工智能的不断发展,PyTorch的应用将会更加广泛。

1. 线性回归类

import torch
import numpy as np
import matplotlib.pyplot as plt
class LinearRegression(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(1, 1)
        self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
        self.loss_function = torch.nn.MSELoss()
    
    def forward(self, x):
        out = self.linear(x)
        return out
   
    def train(self, data, model_save_path='model.path'):
        x = data["x"]
        y = data["y"]
        for epoch in range(10000):
            prediction = self.forward(x)
            loss = self.loss_function(prediction, y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if epoch % 100 == 0:
                print("epoch:{}, loss is:{}".format(epoch, loss.item()))
        torch.save(self.state_dict(), "linear.pth")
    def test(self, x, model_path="linear.pth"):
        x = data["x"]
        y = data["y"]
        self.load_state_dict(torch.load(model_path))
        prediction = self.forward(x)
        plt.scatter(x.numpy(), y.numpy(), c=x.numpy())
        plt.plot(x.numpy(), prediction.detach().numpy(), color="r")
        plt.show()

该Python代码实现了一个简单的线性回归模型,并进行了训练和测试。

首先,导入了PyTorch、NumPy和Matplotlib.pyplot库。

接下来,定义了一个名为LinearRegression的类,它是一个继承自torch.nn.Module的类,因此可以利用PyTorch的自动求导和优化功能。在该类的初始化方法中,定义了一个torch.nn.Linear对象,它表示一个全连接层,输入大小为1,输出大小为1;并定义了一个torch.optim.SGD对象,它表示随机梯度下降法的优化器,学习率为0.01;以及一个torch.nn.MSELoss对象,它表示均方误差损失函数。

接下来,定义了一个名为forward的方法,它表示前向传递过程,即对输入进行线性变换,得到输出。

然后,定义了一个名为train的方法,它接受一个数据字典和一个模型保存路径作为输入。该方法首先从数据字典中获取输入数据x和输出数据y,然后进行10000次迭代训练。在每次迭代中,先将输入数据x送入模型中得到预测输出prediction,然后计算预测输出和真实输出之间的均方误差损失loss,并进行反向传播和参数优化。每100次迭代打印一次损失值。最后将模型参数保存到指定的文件路径中。

最后,定义了一个名为test的方法,它接受一个输入数据x和一个模型保存路径作为输入。该方法首先从文件中加载训练好的模型参数,然后将输入数据x送入模型中得到预测输出prediction,并将预测输出和真实输出以及输入数据可视化展示出来。

总之,这段代码实现了一个简单的线性回归模型,并可以通过train方法进行训练,通过test方法进行测试和可视化展示。

2. 创建数据集

def create_linear_data(nums_data, if_plot=False):
    x = torch.linspace(0, 1, nums_data)
    x = torch.unsqueeze(x, dim = 1)
    k = 2
    y = k * x + torch.rand(x.size())
    if if_plot:
        plt.scatter(x.numpy(), y.numpy(), c=x.numpy())
        plt.show()
    data = {"x":x, "y":y}
    return data
data = create_linear_data(300, if_plot=True)

1

3. 训练模型

model = LinearRegression()
model.train(data)
   epoch:0, loss is:3.8653182983398438
   epoch:100, loss is:0.31251025199890137
   epoch:200, loss is:0.2438090741634369
   epoch:300, loss is:0.20671892166137695
   epoch:400, loss is:0.17835141718387604
   epoch:500, loss is:0.15658551454544067
   epoch:600, loss is:0.13988454639911652
   epoch:700, loss is:0.12706983089447021
   epoch:800, loss is:0.11723710596561432
   epoch:900, loss is:0.10969242453575134
   epoch:1000, loss is:0.10390334576368332
   epoch:1100, loss is:0.09946136921644211
   epoch:1200, loss is:0.09605306386947632
   epoch:1300, loss is:0.09343785047531128
   epoch:1400, loss is:0.09143117070198059
   epoch:1500, loss is:0.0898914709687233
   epoch:1600, loss is:0.08871004730463028
   epoch:1700, loss is:0.08780352771282196
   epoch:1800, loss is:0.08710794895887375
   epoch:1900, loss is:0.08657423406839371
   epoch:2000, loss is:0.08616471290588379
   epoch:2100, loss is:0.08585048466920853
   epoch:2200, loss is:0.08560937643051147
   epoch:2300, loss is:0.08542437106370926
   epoch:2400, loss is:0.08528240770101547
   epoch:2500, loss is:0.08517350256443024
   epoch:2600, loss is:0.08508992940187454
   epoch:2700, loss is:0.08502580225467682
   epoch:2800, loss is:0.08497659116983414
   epoch:2900, loss is:0.08493883907794952
   epoch:3000, loss is:0.08490986377000809
   epoch:3100, loss is:0.08488764613866806
   epoch:3200, loss is:0.08487057685852051
   epoch:3300, loss is:0.08485749363899231
   epoch:3400, loss is:0.08484745025634766
   epoch:3500, loss is:0.08483975380659103
   epoch:3600, loss is:0.08483383059501648
   epoch:3700, loss is:0.08482930809259415
   epoch:3800, loss is:0.08482582122087479
   epoch:3900, loss is:0.08482315391302109
   epoch:4000, loss is:0.08482109755277634
   epoch:4100, loss is:0.08481952548027039
   epoch:4200, loss is:0.08481831848621368
   epoch:4300, loss is:0.08481740206480026
   epoch:4400, loss is:0.08481667935848236
   epoch:4500, loss is:0.08481614291667938
   epoch:4600, loss is:0.08481571823358536
   epoch:4700, loss is:0.08481539785861969
   epoch:4800, loss is:0.08481515198945999
   epoch:4900, loss is:0.08481497317552567
   epoch:5000, loss is:0.08481481671333313
   epoch:5100, loss is:0.08481471240520477
   epoch:5200, loss is:0.08481462299823761
   epoch:5300, loss is:0.08481455594301224
   epoch:5400, loss is:0.08481451123952866
   epoch:5500, loss is:0.08481448143720627
   epoch:5600, loss is:0.08481443673372269
   epoch:5700, loss is:0.08481442183256149
   epoch:5800, loss is:0.0848143994808197
   epoch:5900, loss is:0.0848143920302391
   epoch:6000, loss is:0.08481437712907791
   epoch:6100, loss is:0.08481436222791672
   epoch:6200, loss is:0.08481435477733612
   epoch:6300, loss is:0.08481435477733612
   epoch:6400, loss is:0.08481435477733612
   epoch:6500, loss is:0.08481435477733612
   epoch:6600, loss is:0.08481435477733612
   epoch:6700, loss is:0.08481435477733612
   epoch:6800, loss is:0.08481434732675552
   epoch:6900, loss is:0.08481435477733612
   epoch:7000, loss is:0.08481433987617493
   epoch:7100, loss is:0.08481435477733612
   epoch:7200, loss is:0.08481433987617493
   epoch:7300, loss is:0.08481433987617493
   epoch:7400, loss is:0.08481434732675552
   epoch:7500, loss is:0.08481434732675552
   epoch:7600, loss is:0.08481434732675552
   epoch:7700, loss is:0.08481434732675552
   epoch:7800, loss is:0.08481434732675552
   epoch:7900, loss is:0.08481434732675552
   epoch:8000, loss is:0.08481434732675552
   epoch:8100, loss is:0.08481434732675552
   epoch:8200, loss is:0.08481434732675552
   epoch:8300, loss is:0.08481434732675552
   epoch:8400, loss is:0.08481434732675552
   epoch:8500, loss is:0.08481434732675552
   epoch:8600, loss is:0.08481434732675552
   epoch:8700, loss is:0.08481434732675552
   epoch:8800, loss is:0.08481434732675552
   epoch:8900, loss is:0.08481434732675552
   epoch:9000, loss is:0.08481434732675552
   epoch:9100, loss is:0.08481434732675552
   epoch:9200, loss is:0.08481434732675552
   epoch:9300, loss is:0.08481434732675552
   epoch:9400, loss is:0.08481434732675552
   epoch:9500, loss is:0.08481434732675552
   epoch:9600, loss is:0.08481434732675552
   epoch:9700, loss is:0.08481434732675552
   epoch:9800, loss is:0.08481434732675552
   epoch:9900, loss is:0.08481434732675552
model.test(data)

4. 测试模型

2

附:系列文章

序号文章目录直达链接
1波士顿房价预测https://want595.blog.csdn.net/article/details/132181950
2鸢尾花数据集分析https://want595.blog.csdn.net/article/details/132182057
3特征处理https://want595.blog.csdn.net/article/details/132182165
4交叉验证https://want595.blog.csdn.net/article/details/132182238
5构造神经网络示例https://want595.blog.csdn.net/article/details/132182341
6使用TensorFlow完成线性回归https://want595.blog.csdn.net/article/details/132182417
7使用TensorFlow完成逻辑回归https://want595.blog.csdn.net/article/details/132182496
8TensorBoard案例https://want595.blog.csdn.net/article/details/132182584
9使用Keras完成线性回归https://want595.blog.csdn.net/article/details/132182723
10使用Keras完成逻辑回归https://want595.blog.csdn.net/article/details/132182795
11使用Keras预训练模型完成猫狗识别https://want595.blog.csdn.net/article/details/132243928
12使用PyTorch训练模型https://want595.blog.csdn.net/article/details/132243989
13使用Dropout抑制过拟合https://want595.blog.csdn.net/article/details/132244111
14使用CNN完成MNIST手写体识别(TensorFlow)https://want595.blog.csdn.net/article/details/132244499
15使用CNN完成MNIST手写体识别(Keras)https://want595.blog.csdn.net/article/details/132244552
16使用CNN完成MNIST手写体识别(PyTorch)https://want595.blog.csdn.net/article/details/132244641
17使用GAN生成手写数字样本https://want595.blog.csdn.net/article/details/132244764
18自然语言处理https://want595.blog.csdn.net/article/details/132276591

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1030891.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Spatial-Temporal Action Localization(七)】论文阅读2022年

文章目录 1. TubeR: Tubelet Transformer for Video Action Detection摘要和结论引言:针对痛点和贡献模型框架TubeR Encoder:TubeR Decoder:Task-Specific Heads: 2. Holistic Interaction Transformer Network for Action Detect…

少儿编程 2023年5月中国电子学会图形化编程等级考试Scratch编程三级真题解析(判断题)

2023年5月scratch编程等级考试三级真题 判断题(共10题,每题2分,共20分) 26、运行下列程序后,变量c的值是6 答案:错 考点分析:考查积木综合使用,重点考查变量积木的使用 最后一步c设为a+b,所以c=1+2=3,答案错误 27、变量a与变量b的初始值都是1,a+b等于2。运行下列…

【2023华为杯B题】DFT类矩阵的整数分解逼近(思路及代码下载)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

ETHERNET IP站转CCLKIE协议网关

产品介绍 JM-EIP-CCLKIE是自主研发的一款 ETHERNET/IP 从站功能的通讯网关。该产品主要功能是实现 CCLINK IEFB 总线和 ETHERNET/IP 网络的数据互通。 本网关连接到 ETHERNET/IP 总线中做为从站使用,连接到 CCLINK IEFB 总线中做为从站使用。 产品参数 技术参数 …

A : DS顺序表--类实现

Description 实现顺序表的用C语言和类实现顺序表 属性包括&#xff1a;数组、实际长度、最大长度&#xff08;设定为1000&#xff09; 操作包括&#xff1a;创建、插入、删除、查找 类定义参考 #include<iostream> using namespace std; #define ok 0 #define error…

Unity实现角色受到攻击后屏幕抖动的效果

文章目录 实现效果摄像机抖动脚本 玩家受伤其他文章 实现效果 首先看一下实现效果。 摄像机 我们要使用屏幕抖动&#xff0c;使用的是CinemachineVirtualCamera这个组件&#xff0c;这个组件需要在包管理器中进行导入。 导入这个组件之后&#xff0c;创建一个Chinemachine-…

学习记忆——宫殿篇——记忆宫殿——记忆桩——单间+客厅+厨房+厕所+书房+院子

文章目录 单间客厅厨房厕所书房院子 单间 水壶 水龙头 香皂 果汁机 电视 门空间 花 红酒 葡萄 不锈钢 白毛沙发 彩色垫子 吉他 皮椅 挂画 风扇 糖抱枕 盒子 花土 水晶腿 衣柜 笔 三环相框 水壶 壁挂 台灯 被 网球拍 足球 抽屉 闹钟 蝴蝶 心 斑马 三轮车 音响 椅子 碗 玩偶 烟灰…

Android 12 源码分析 —— 应用层 六(StatusBar的UI创建和初始化)

Android 12 源码分析 —— 应用层 六&#xff08;StatusBar的UI创建和初始化) 在前面的文章中,我们分别介绍了Layout整体布局,以及StatusBar类的初始化.前者介绍了整体上面的布局,后者介绍了三大窗口的创建的入口处,以及需要做的准备工作.现在我们分别来细化三大窗口的UI创建和…

苹果手机怎么录屏?1分钟轻松搞定

虽然一直使用苹果手机&#xff0c;但是对它的录屏功能还不是很会使用。苹果手机怎么录屏&#xff1f;录屏可以录制声音吗&#xff1f;麻烦大家教教我&#xff01; 苹果手机为用户提供了十分便捷的内置录屏功能&#xff0c;可以让您随时随地录制手机上的内容。但是很多小伙伴在第…

六角形锌饼的尺寸及其允许偏差

声明 本文是学习GB-T 3610-2010 电池锌饼. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了电池锌饼的产品分类、要求、试验方法、检验规则及标志、包装、运输、贮存、质量证明 书和合同(或订货单)等内容。 本标准适用于制造锌-…

如何把.mhd和.raw文件转换为DICOM文件

之前拿到体渲染的人头数据Manix&#xff0c;格式为mhd和raw格式的需要转换为DICOM ResearchGate上的一个帖子帮了大忙&#xff08;链接如下&#xff09;&#xff0c;有人说用ImageJ&#xff0c;有的说用XMedCon。我试了半天也没用ImageJ弄成功&#xff0c;但是XMedCon一下就好…

【有关mysql的实操记录】

一. 导入导出数据 1. 导出mysql的数据库作为备份文件 mysqldump -u 用户名 -p 数据库名 >导出文件路径.sql 回车之后&#xff0c;提示输入密码. 2. 导入mysql之前备份的数据库文件 mysql -u 用户名 -p 数据库名 <导入文件路径.sql 回车之后&#xff0c;提示输入密码 …

总结分析 | 基于phpmyadmin的渗透测试

一、什么是phpmyadmin&#xff1f; phpMyAdmin 是一个以PHP为基础&#xff0c;以Web-Base方式架构在网站主机上的MySQL的数据库管理工具&#xff0c;让管理者可用Web接口管理MySQL数据库。借由此Web接口可以成为一个简易方式输入繁杂SQL语法的较佳途径&#xff0c;尤其要处理大…

CG-78静力水准仪采用压力传感器测量液体的压差

CG-78静力水准仪采用压力传感器测量液体的压差产品概述 静力水准仪是测量两点间或多点间相对高程变化的仪器。由储液器、高精度芯体和特别定制电路模块、保护罩等部件组成。沉降系统由多个同型号传感器组成&#xff0c;储液罐之间由通气管和通液管相连通&#xff0c;基准点置于…

循环神经网络——下篇【深度学习】【PyTorch】【d2l】

文章目录 6、循环神经网络6.7、深度循环神经网络6.7.1、理论部分6.7.2、代码实现 6.8、双向循环神经网络6.8.1、理论部分6.8.2、代码实现 6.9、机器翻译6.9.1、理论部分 6.10、编码器解码器架构6.10.1、理论部分 6、循环神经网络 6.7、深度循环神经网络 6.7.1、理论部分 设计…

瑞慈医疗:H1体检业务同比上涨101.2%,因何领跑医疗健康行业?

悄然间&#xff0c;医疗健康行业碰上历史性变革。水面之上&#xff0c;医院体检医院体检人潮涌动&#xff0c;愈来愈多的医院迈上扩建体检中心的步伐&#xff0c;赛道激增 20%为所有科室之首。水面之下&#xff0c;依靠信息技术使体检数字化、智能化的转型浪潮&#xff0c;也在…

TypeError: res.data.map is not a function微信小程序报错

从数据库查&#xff1a; 调用的是&#xff1a; 访问的springboot后端是这个&#xff1a; 打印出来如下&#xff1a; 看到是json格式的数据 [Users [id3, name刘雨昕, phone18094637788, admintrue, actionsJsonadmin, createAtSat Sep 16 10:11:20 CST 2023, tokentest], User…

小节9:Python之numpy

numpy全称为Numerical Python&#xff0c;是很多数据或科学相关Python包的基础。 1、numpy数组&#xff08;ND array N维数组&#xff09; numpy数组是更适合数据分析的列表。 numpy的数组和Python的内置列表有相似之处&#xff0c;也有不同之处。 相似之处&#xff1a;我们…

面向对象进阶

文章目录 面向对象进阶一.static1.静态变量2.静态方法3.static的注意事项 二.继承1.概述2.特点3.子类可以继承父类中的内容4.继承中成员变量的访问特点5.继承中成员方法的访问特点6.继承中构造方法的访问特点7.this和super使用总结 三.多态1.认识多态2.多态中调用成员的特点3.多…

简单的手机电脑无线传输方案@固定android生成ftp的IP地址(android@windows)

文章目录 abstractwindows浏览android文件环境准备客户端软件无线网络链接步骤其他方法 手机浏览电脑文件公网局域网everythingpython http.server 高级:固定android设备IP准备检查模块是否生效 windows 访问ftp服务器快捷方式命令行方式双击启动方式普通快捷方式映射新的网络位…