PyTorch构建神经网络

news2025/1/12 12:09:33

【图书推荐】《PyTorch深度学习与企业级项目实战》-CSDN博客

《PyTorch深度学习与企业级项目实战(人工智能技术丛书)》(宋立桓,宋立林)【摘要 书评 试读】- 京东图书 (jd.com)

训练一个神经网络通常需要提供大量的数据,我们称之为数据集。数据集一般被分为三类,即训练集(Training Set)、验证集(Validation Set)和测试集(Test Set)。

一个Epoch就等于使用训练集中的全部样本训练一次的过程。所谓训练一次,指的是进行一次正向传播(Forward Pass)和反向传播(Backward Pass),如图5-2所示。 

图5-2

当一个Epoch的样本(也就是训练集)数量太过庞大的时候,进行一次训练可能会消耗过多的时间,并且每次训练都使用训练集的全部数据是不必要的。因此,我们需要把整个训练集分成多个小块,也就是分成多个Batch来进行训练。一个Epoch由一个或多个Batch构成,Batch为训练集的一部分,每次训练的过程只使用一部分数据,即一个Batch。我们称训练一个Batch的过程为一个Iteration。

PyTorch提供了torch.nn库,它是用于构建神经网络的工具库。torch.nn库依赖于autograd库来定义和计算梯度。nn.Module包含神经网络的层以及返回输出的forward(input)方法。

以下是一个简单的神经网络的构建示例:

#######pytorch-nn-demo1.py###############
import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # 输入图像channel:1,输出channel:6,5×5卷积核
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)

        # 全连接层
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 使用2×2窗口进行最大池化
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # 如果窗口是方的,只需要指定一个维度
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)

        x = x.view(-1, self.num_flat_features(x))

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # 获取除batch维度外的其他维度
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

net = Net()
print(net)

结果输出如下:

Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

以上就是一个简单的神经网络的构建方法。首先定义了一个Net类,这个类继承自nn.Module。然后在__init__方法中定义了网络的结构,在forward方法中定义了数据的流向。在网络的构建过程中,我们可以使用任何张量操作。需要注意的是,backward函数(用于计算梯度)会被autograd自动创建和实现。你只需要在nn.Module的子类中定义forward函数。PyTorch的一个重要功能就是autograd,它是为方便用户使用,而专门开发的一套自动求导引擎,能够根据输入和前向传播过程自动构建计算图,并执行反向传播。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2081470.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

特殊类设计(5个)与类型转换

引子:在生活中我们经常有不同类的需求,因此我们有了特殊类的设计(有很多种模式等)。由于类型需求不同我们有了类型转换。今天我们就来略讲略讲一下这方面的知识。 特殊类设计(5个) 注意:关键字…

uniapp发布包app.json文件配置及发包上传注意事项

一、分包(提示主包大小不小于1.5M) 我的分包代码 二、未开启js压缩 操作:【必须】在工具「详情」-「本地设置」中开启「上传代码时自动压缩脚本文件」的设置 三、未开启组件懒注入(按需注入) 只需将代码"lazyCodeLoading&qu…

深度剖析:黑神化悟空的防御机制,为何成为破解难题?

深度剖析:黑神化悟空的防御机制,为何成为破解难题? 黑神话悟空还没发售就被破解? #国产游戏 #悟空 #西游记 推荐阅读: 全红婵抖音魅力无限,粉丝数量历史性突破1000万大关! 奥运激情日&#…

【日记】今天实在太累了(436 字)

正文 今天的工作强度跟之前完全不是一个级别。能不能不要给我找这么多事做,我只想摸鱼摆烂。以后到下一个单位就说自己啥都不会好了,省得一天天全来找我。 忙碌程度上升了一个数量级,一天结束之后完全不想说话。 好想睡觉。 昨晚尝试完成年度…

利用ADB命令截屏,并发送至指定邮箱

需求分析: 为满足对Android设备远程监控、故障排查或自动化报告生成等应用场景需求,本指南将指导你如何利用ADB(Android Debug Bridge)工具实现Android设备的截屏功能,并介绍如何将截屏结果通过远程通知的形式发送至指定邮箱。 关键词:ADB、截屏、远程通知、发送邮箱 准备…

程序设计—气象数据共享平台设计与实现 项目源码30172

摘 要 当前,气象数据的及时获取和共享对于许多行业和个人具有重要意义。然而,存在着数据获取不便、共享不畅、数据可视化展示不足等问题。为了解决这些问题,本研究旨在设计和开发一个基于C语言的气象数据共享平台,结合React框架实…

天津国芯SP下载工具 加个防呆 避免选了OTA升级的固件(后缀带有SIG.BIN)

V2.1 20240828 天津国芯SP下载工具 加个防呆 避免选了OTA升级的固件(后缀带有SIG.BIN) 兆讯的芯片1902首次下载必须先下载key,再下载加密固件。 天津国芯没有这个限制,固件是明文的。 自测使用的版本信息: 本地最新…

count格式的数据转换(count to FPKM,count to TPM) 【GEO数据库】

在正式分析之前,对于数据的处理是至关重要的,这种重要性是体现在很多方面,其中有一点是要求分析者采用正确的数据类型。 对于芯片数据,原始数据进行log2处理之后可以进行很多常见的分析,比如差异分析、热图、箱线图、…

linux下一切皆文件,如何理解?

linux下一切皆文件,不管你有没有学过linux,都应该听过这句话,就像java的一切皆对象一样。 今天就来看看它的真面目。 你记住了,只要一个竞争退出它的PCB要被释放文件名,客服表也要被释放。那么,指向这个文件…

基于大数据的电信诈骗行为可视化系统含预测研究【lightGBM,XGBoost,随机森林】

文章目录 有需要本项目的代码或文档以及全部资源,或者部署调试可以私信博主项目介绍 电信诈骗预测与分析系统项目概述系统架构详细功能描述1. 数据预处理2. 数据可视化与分析3. 机器学习预测4. 系统集成与用户界面 技术亮点应用价值未来展望lightGBMXGBoost随机森林…

猫头虎分享:什么是信创体系?

猫头虎分享:什么是信创体系? 猫头虎技术团队:深入解析信创体系 引言:为什么信创体系是未来发展的关键? 大家好,我是猫头虎,今天我们来聊一聊科技领域的热议话题——信创体系。随着国内外信息技术产业的迅…

分布式云扩展 AI 边缘算力,助力用户智能化创新

近期,AI 创新圈再次发布重磅产品更新。OpenAI 全新旗舰版多模态模型 GPT-4o 横空出世,其打通文本、图像、视频的富媒体理解能力以及敏捷的智能化对话,将 AI 助手的人性化表达效果,提升至更高水平。 ​ 从技术源头来看&#xff0c…

栈OJ题——有效的括号

文章目录 一、题目链接二、解题思路三、解题代码 一、题目链接 有效的括号 题目描述:给定一个只包括 ‘(’,‘)’,‘{’,‘}’,‘[’,‘]’ 的字符串 s ,判断字符串是否有效。括号匹配。 二、…

《大模型应用开发极简入门》学习成为善用 AI 的人!看完懂得90%的大模型!{含pdf版电子书}

📖《大模型应用开发极简入门:基于GPT-4与ChatGPT》 真心建议学习大模型的朋友都去看看这本书,作为一本应用开发入门书,在豆瓣评分好评不断,其中知识点有不少值得深入研究的领域,适合小白初学者阅读学习的&…

【Google Maps JavaScript API】详解地图本地化(Localizing the Map)

文章目录 一、地图本地化概述1. 什么是地图本地化?2. 为什么需要地图本地化? 二、如何实现地图本地化?1. 准备工作2. 编写 HTML 文件3. 初始化地图 三、详细代码解析1. HTML 部分2. JavaScript 部分 四、如何在本地运行示例代码?五…

Spring Boot如何压缩Json并写入redis?

1.为什么需要压缩json? 由于业务需要,存入redis中的缓存数据过大,占用了10G的内存,内存作为重要资源,需要优化一下大对象缓存,采用gzip压缩存储,可以将 redis 的 kv 对大小缩小大约 7-8 倍&…

Jmeter录制脚本(不推荐,因为有大量冗余)

1、以百度举例 2、选择“Requests Filtering”,在“包含模式”中填入“.(baidu\.com).”用以过滤非http://baidu.com的请求; 同时在“排除模式”中填入“(?i).*\.(bmp|css|js|gif|ico|jpe?g|png|swf|woff|woff2|htm|html).”用以过滤js、图片、html等…

postman请求设置

postman请求设置 1、请求参数,只能是none、for-data、x-www...、raw等中的一个,不能多个。2、请求头类型3、案例4、测压 1、请求参数,只能是none、for-data、x-www…、raw等中的一个,不能多个。 2、请求头类型 根据请求头&#x…

用Python分析定性变量之间的相关性_对应分析模板

对应分析是一种多元统计分析方法,主要用于分析定性变量构成的列联表,揭示变量之间的关系。它通过将列联表中的数据转换为点的形式,在低维空间中表示出来,从而实现数据的可视化。这种方法特别适用于有多个类别的定性变量分析&#…

如何将开发工具设置成滚动鼠标改变字体大小

就在刚刚与温州那边技术开会,温州那边技术提出:字体太小,代码看不清,需要将字体放大。然后让我将IDE设置成按住键盘的Ctrl滚动鼠标,可以放大字体大小。。。顿时间的小小尴尬。下面我来记录一下究竟是怎么操作的&#x…