0基础学习PyTorch——最小Demo

news2024/9/23 5:35:07

大纲

  • 环境准备
    • 安装依赖
  • 训练和推理
    • 训练
      • 生成数据
      • 加载数据
        • TensorDataset
        • DataLoader
      • 定义神经网络
      • 定义损失函数和优化器
      • 训练模型
    • 推理
  • 参考代码

PyTorch以其简洁直观的API、动态计算图和强大的社区支持,在学术界和工业界都享有极高的声誉,成为许多深度学习爱好者和专业人士的首选工具。本系列将更多从工程实践角度探索PyTorch的使用,而不是算法公式的讨论。

环境准备

使用《管理Python虚拟环境的脚本》中的脚本初始化一个虚拟环境。

source env.sh init

在这里插入图片描述

然后进入虚拟环境

source env.sh enter

在这里插入图片描述

安装依赖

source env.sh install pyyaml

在这里插入图片描述

source env.sh install torch

这个过程比较漫长,需要下载一个多G文件。
在这里插入图片描述

source env.sh install numpy

在这里插入图片描述

训练和推理

训练就是模型训练。我们可以认为知道系统的输入和输出(目标),猜测系统中的算法。在这里插入图片描述

推理则是使用模型,计算出对应的输出。

在这里插入图片描述
举一个例子,也是我们后面代码的例子。假如我们使用f(x)=2x+1计算一批随机数(输入)得到一批计算结果(目标),然后我们将这些数据交给一个模型训练器,可以得到一个模型。这个模型的计算结果(输出)应该非常近似于f(x)=2x+1。

训练

生成数据

数据生成不是必须的,因为我们从其他地方获取数据。为了让这个例子没有太多依赖项,我们就自己生成数据。
input_data 是100个随机数数组,它是模型训练的“输入”数据;target_data是对input_data使用f(x)=2x+1得到的一个数组,它是模型训练的“目标”数据。即我们需要模型要将“输入”尽量转换成接近的“目标”。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 生成一些随机数据
torch.manual_seed(0)
input_data = torch.randn(100, 1)  # 100个样本,每个样本有2个特征
target_data = 2 * input_data + 1  # 简单的线性关系

加载数据

# 创建数据加载器
dataset = TensorDataset(input_data, target_data)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
TensorDataset

TensorDataset 是一个简单的数据集封装类,用于将多个张量(tensors)包装在一起。它的主要作用是将特征和标签数据对齐,以便于后续的数据加载和处理。

主要功能:

  • 将多个张量封装成一个数据集。
  • 使得数据集可以通过索引访问。
DataLoader

DataLoader 是一个数据加载器类,用于将数据集分批次加载到模型中进行训练或评估。它的主要作用是提供一个迭代器,能够高效地加载数据,并支持多线程并行加载。

主要功能:

  • 将数据集分批次加载。
  • 支持多线程并行加载数据。
  • 支持数据的随机打乱(shuffle)。
  • 提供一个迭代器,方便在训练循环中使用。

定义神经网络

定义神经网络是深度学习模型开发的核心步骤之一。一个良好定义的神经网络可以有效地学习和泛化数据,从而在各种任务中取得优异的表现。
本文不过度讨论神经网络,只是抛砖引玉,让大家知道结构长什么样子。

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.linear = nn.Linear(1, 1)  # 输入和输出都是1维

    def forward(self, x):
        return self.linear(x)

定义损失函数和优化器

# 实例化模型、损失函数和优化器
model = SimpleNet()
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器

损失函数用于衡量模型预测值与真实值之间的差异。它是模型优化的目标,模型训练的目的是最小化损失函数的值

优化器用于更新模型参数,以最小化损失函数。

训练模型

对于有限的数据,我们可以通过增加训练次数来优化模型。所以下面代码,我们对一个数据集进行了20次训练。

num_epochs = 20
for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        optimizer.zero_grad()  # 清空梯度
        outputs = model(inputs)  # 前向传播
        loss = criterion(outputs, targets)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

后面会单独开一篇文章来将训练过程。

推理

我们再生成一个输入数组,并计算期望的值。

# 生成一组测试输入数据
test_input = torch.randn(10, 1)  # 生成10个随机样本,每个样本1个特征
test_output = 2 * test_input + 1  # 期望的输出

然后用模型算出结果,并进行比较

# 推理
model.eval()
with torch.no_grad():
    output = model(test_input)
    for i in range(len(test_input)):
        print(f'''Test Input: {test_input[i].item()}, 
    Test Output: {output[i].item()}, 
    Actual Output: {test_output[i].item()}, 
    Diff: {output[i].item() - test_output[i].item()}, 
    Loss: {abs(output[i].item() - test_output[i].item()) / abs(test_output[i].item())* 100:.2f}%\n''')

在这里插入图片描述
在这里插入图片描述
我们看到,模型最后推理出的结果和我们的期望值误差在2%以内。

参考代码

https://github.com/f304646673/deeplearning/tree/main/mvp

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2156711.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

yum 集中式安装 LNMP

目录 安装 nginx 安装 mysql 安装 php 配置lnmp 配置 nginx 支持 PHP 解析 安装 nginx 修改yum源 将原本的yum源备份 vim /etc/yum.repos.d/nginx.repo [nginx-stable] namenginx stable repo baseurlhttp://nginx.org/packages/centos/7/$basearch/ gpgcheck0 enable…

黎巴嫩BP机爆炸事件启示录:我国应加快供应链安全立法

据报道,当地时间9月17日下午,黎巴嫩首都贝鲁特以及黎巴嫩东南部和东北部多地都发生了BP机爆炸事件。当时的统计数据显示,爆炸造成9人死亡,约2800人受伤。9月18日,死亡人数上升到11人,受伤人数超过4000。 目…

14年数据结构

第一题 解析: 求时间复杂度就是看程序执行了多少次。 假设最外层执行了k次,我们看终止条件是kn,则: 有, 内层是一个j1到jn的循环,显然执行了n次。 总的时间复杂度是内层外层 答案选C。 第二题 解析: 一步一…

车辆行人转向意图状态检测系统源码分享

车辆行人转向意图状态检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of …

【Python】Maya:为人类打造的 Python 日期时间库

不知道少了什么,总感觉没有以前快乐。 在编程中处理日期和时间总是一个挑战,尤其是当涉及到时间和时区的转换时。Maya 是一个由 Kenneth Reitz 开发的 Python 库,旨在简化日期时间的处理,使其对人类开发者更加友好。本文将介绍 M…

如何在jupyter notebook中使用虚拟环境

一:在cmd中打开已经创建好的虚拟环境 二:安装ipykernel conda install ipykernel 三:安装牛逼conda conda install -c conda-forge nb_conda 四:运行jupyter notebook,选择虚拟环境

linux强制关闭再启动后zookeeper无法启动

1、若开启了zkserver就先关闭zkserver 查看zkserver是否启动 sh zkServer.sh status关闭zkServer sh zkServer.sh stop2、更改conf/zoo.cfg 将这里的启动端口改为2183 3、启动zkServer sh zkServer.sh start4、以2183端口启动zkCli zkCli.sh -server 127.0.0.1:2183这样启…

传知代码-基于多尺度动态卷积的图像分类

代码以及视频讲解 本文所涉及所有资源均在传知代码平台可获取 概述 在计算机视觉领域,图像分类是非常重要的任务之一。近年来,深度学习的兴起极大提升了图像分类的精度和效率。本文将介绍一种基于动态卷积网络(Dynamic Convolutional Netw…

机器人机构、制造

简单整理一下,在学习了一些运动学和动力学之类的东西,简单的整合了一些常用的机械结构和图片。 1.电机: 市面上的电机有:直流电机,交流电机,舵机,步进电机,电缸,无刷电…

【无人机设计与控制】 基于matlab的蚁群算法优化无人机uav巡检

摘要 本文使用蚁群算法(ACO)优化无人机(UAV)巡检路径。无人机巡检任务要求高效覆盖特定区域,以最小化能源消耗和时间。本研究提出的算法通过仿生蚁群算法优化巡检路径,在全局搜索和局部搜索中平衡探索与开…

【软件工程】成本效益分析

一、成本分析目的 二、成本估算方法 三、成本效益分析方法 课堂小结 例题 选择题

深度之眼(三十)——pytorch(一)--深入浅出pytorch(附安装流程)

文章目录 一、前言一、pytoch二、六个部分三、如何学习四、学习路径(重要)五、安装pytorch5.1 坑15.2 坑2 一、前言 我看了下目录 第一章和第二章都是本科学的数字图像处理。 也就是这一专栏:数字图像实验。 所以就不准备学习前两章了,直接…

一文详解大语言模型Transformer结构

目录 1. 什么是Transformer 2. Transformer结构 2.1 总体结构 2.2 Encoder层结构 2.3 Decoder层结构 2.4 动态流程图 3. Transformer为什么需要进行Multi-head Attention 4. Transformer相比于RNN/LSTM,有什么优势?为什么? 5. 为什么说Transf…

MySQL --数据类型

文章目录 1.数据类型分类2.数值类型2.1 tinyint类型2.2 bit类型2.3小数类型2.31float2.32decimal 3.字符串类型3.1 char3.2varchar3.3 char和varchar比较 4.日期和时间类型5.enum和set 1.数据类型分类 2.数值类型 2.1 tinyint类型 数值越界测试: create table tt1…

C++ Qt 之 QPushButton 好看的样式效果实践

文章目录 1.前序2.效果演示3.代码如下 1.前序 启发于 edge 更新 web 页面,觉得人家做的体验挺好 决定在Qt实现,方便以后使用 2.效果演示 特性介绍: 默认蓝色鼠标移入 渐变色,鼠标变为小手鼠标移出 恢复蓝色,鼠标恢…

计算机毕业设计之:基于uni-app的校园活动信息共享系统设计与实现(三端开发,安卓前端+网站前端+网站后端)

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…

006——队列

队列: 一种受限的线性表(线性逻辑结构),只允许在一段进行添加操作,在另一端只允许进行删除操作,中间位置不可操作,入队的一端被称为队尾,出队的一端被称为队头,在而我们…

作业报告┭┮﹏┭┮(Android反调试)

一:Android反调试 主要是用来防止IDA进行附加的,主要的方法思路就是,判断自身是否有父进程,判断是否端口被监听,然后通过调用so文件中的线程进行监视,这个线程开启一般JNI_OnLoad中进行开启的。但是这个是…

Java语言程序设计基础篇_编程练习题**18.31 (替换单词)

目录 题目:**18.31 (替换单词) 习题思路 代码示例 运行结果 替换前 替换后 题目:**18.31 (替换单词) 编写一个程序,递归地用一个新单词替换某个目录下的所有文件中出现的某个单词。从命令行如下传递参数: java Exercise18…

C++标准库双向链表 list 中的insert函数实现。

CPrimer中文版(第五版): //运行时错误:迭代器表示要拷贝的范围,不能指向与目的位置相同的容器 slist.insert(slist.begin(),slist.begin(),slist.end()); 如果我们传递给insert一对迭代器,它们不能…