PyTorch之完整的神经网络模型训练

news2025/1/16 3:35:06

简单的示例:

在PyTorch中,可以使用nn.Module类来定义神经网络模型。以下是一个示例的神经网络模型定义的代码:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        
        # 定义神经网络的层和参数
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        
        return x

在上面的示例中,定义了一个名为MyModel的神经网络模型,继承自nn.Module类。在__init__方法中,我们定义了模型的层和参数。具体来说:

  • 代码定义了一个卷积层,输入通道数为1,输出通道数为32,卷积核大小为3x3,步长为1,填充为1。
  • 定义了一个ReLU激活函数,用于在卷积层之后引入非线性性质。
  • 定义了一个最大池化层,池化核大小为2x2,步长为2。
  • 定义了一个全连接层,输入大小为32x14x14(经过卷积和池化后的特征图大小),输出大小为128。
  • 定义了另一个全连接层,输入大小为128,输出大小为10。
  • 定义了一个softmax函数,用于将模型的输出转换为概率分布。

forward方法中,定义了模型的前向传播过程。具体来说:

  • x = self.conv1(x): 将输入张量传递给卷积层进行卷积操作。
  • x = self.relu(x): 将卷积层的输出通过ReLU激活函数进行非线性变换。
  • x = self.maxpool(x): 将ReLU激活后的特征图进行最大池化操作。
  • x = x.view(x.size(0), -1): 将池化后的特征图展平为一维,以适应全连接层的输入要求。
  • x = self.fc1(x): 将展平后的特征向量传递给第一个全连接层。
  • x = self.relu(x): 将第一个全连接层的输出通过ReLU激活函数进行非线性变换。
  • x = self.fc2(x): 将第一个全连接层的输出传递给第二个全连接层。
  • x = self.softmax(x): 将第二个全连接层的输出通过softmax函数进行归一化,得到每个类别的概率分布。

这个示例展示了一个简单的卷积神经网络模型,适用于处理单通道的图像数据,并输出10个类别的分类结果。可以根据自己的需求和数据特点来定义和修改神经网络模型。

接下来将用于实际的数据集进行训练:

以下是基于CIFAR10数据集的神经网络训练模型:

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from nn_mode import *

#准备数据集
train_data=torchvision.datasets.CIFAR10(root='../chap4_Dataset_transforms/dataset',train=True,transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.CIFAR10(root='../chap4_Dataset_transforms/dataset',train=False,transform=torchvision.transforms.ToTensor())
#输出数据集的长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print(train_data_size)
print(test_data_size)
#加载数据集
train_loader=DataLoader(dataset=train_data,batch_size=64)
test_loader=DataLoader(dataset=test_data,batch_size=64)
#创建神经网络
sjnet=Sjnet()

#损失函数
loss_fn=nn.CrossEntropyLoss()
#优化器
learn_lr=0.01#便于修改
YHQ=torch.optim.SGD(sjnet.parameters(),lr=learn_lr)

#设置训练网络的参数
train_step=0#训练次数
test_step=0#测试次数
epoch=10#训练轮数

writer=SummaryWriter('wanzheng_logs')

for i in range(epoch):
    print("第{}轮训练".format(i+1))
    #开始训练
    for data in train_loader:
        imgs,targets=data
        outputs=sjnet(imgs)
        loss=loss_fn(outputs,targets)
        #优化器
        YHQ.zero_grad()  # 将神经网络的梯度置零,以准备进行反向传播
        loss.backward()  # 执行反向传播,计算神经网络中各个参数的梯度
        YHQ.step()  # 调用优化器的step()方法,根据计算得到的梯度更新神经网络的参数,完成一次参数更新
        train_step =train_step+1
        if train_step%100==0:
            print('训练次数为:{},loss为:{}'.format(train_step,loss))
            writer.add_scalar('train_loss',loss,train_step)

    #开始测试
    total_loss=0
    with torch.no_grad():#上下文管理器,用于指示在接下来的代码块中不计算梯度。
        for data in test_loader:
            imgs,targets=data
            outputs = sjnet(imgs)
            loss = loss_fn(outputs, targets)#使用损失函数 loss_fn 计算预测输出与目标之间的损失。
            total_loss=total_loss+loss#将当前样本的损失加到总损失上,用于累积所有样本的损失。
    print('整体测试集上的loss:{}'.format(total_loss))
    writer.add_scalar('test_loss', total_loss, test_step)
    test_step = test_step+1

    torch.save(sjnet,'sjnet_{}.pth'.format(i))
    print("模型已保存!")

writer.close()

 其神经网络训练以及测试时的损失值使用TensorBoard进行展示,如图所示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1513692.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

sqllab第九关通关笔记

知识点: 时间盲注:利用休眠时间进行判断是否注入成功利用bp时注意把timeout时间修改一下 首先判断注入类型 构造id1/0 返回正常信息,应该是字符型注入 构造id1 返回正常信息,欸,这就怪了 构造id1 正常显示内容&am…

Lombok简介、使用、工作原理、优缺点

学习目标: 目标 Lombok简介、使用、工作原理、优缺点 学习内容: 内容 定义 Lombok 是一个 Java 库,它提供了很多有用的注解来简化 Java 代码的编写。 作用: 使用 Lombok 可以用一些简洁的注解来自动生成大量常见的 Java 代码&a…

EI会议----2024年控制技术、自动化与绿色城市规划国际学术会议(TAUGCP 2024)

EI会议----2024年控制技术、自动化与绿色城市规划国际学术会议(TAUGCP 2024) 2024 International Conference on Control Technology, Automation and Green Urban Planning(TAUGCP 2024) 大会主题:(主题包括但不限于这些,更多详细主题请咨询苏老师&am…

缓存的使用

文章目录 1.为什么要有缓存?2.缓存使用场景3.缓存分类4.缓存使用模式5.淘汰策略6.缓存的崩溃与修复7.缓存最佳实践参考文献 1.为什么要有缓存? 数据访问具有局部性,符合二八定律:80% 的数据访问集中在 20% 的数据上,这…

Nuxt3: useFetch使用过程常见一种报错

一、问题描述 先看一段代码&#xff1a; <script setup> const fetchData async () > {const { data, error } await useFetch(https://api.publicapis.org/entries);const { data: data2, error: error2 } await useFetch(https://api.publicapis.org/entries);…

展会邀约 | 加速科技将携重磅产品亮相SEMICON China 2024

SEMICON China 2024将于3月20日-3月22日在上海新国际博览中心隆重举行。展会期间&#xff0c;加速科技将携重磅产品高性能数模混合信号测试机ST2500EX、LCD Driver测试机Flex10K-L、高密度数模混合信号测试系统ST2500E、高性能数模混合信号测试系统ST2500A亮相此次行业盛会&…

[Java、Android面试]_02_HashMap的原理

本人今年参加了很多面试&#xff0c;也有幸拿到了一些大厂的offer&#xff0c;整理了众多面试资料&#xff0c;后续还会分享众多面试资料&#xff0c;感兴趣的朋友可收藏关注。由于时间有限&#xff0c;只能每天整理一点&#xff0c;分享一点儿&#xff01; 现分享如下&#xf…

vue3/vue2若依框架对比,点击新增编辑跳转到新页面(新增编辑共用代码)

vue2若依框架&#xff1a; router里面定义好&#xff0c;编辑里面添加一个id {path: /filmManagement,component: Layout,hidden: true,redirect: noredirect,children: [{path: editFilmDetail,component: () > import(/views/filmManagement/editFilmDetail),name: editFi…

【Mac】鼠标控制\移动\调整窗口大小BBT|边缘触发调整音量\切换桌面

一直在 win 习惯了通过鼠标的侧键来控制窗口的位置、大小&#xff0c;现在找到心的解决方案了&#xff0c;通过 BBT 设置侧键按下\抬起几颗。 以下解决方案的截图&#xff0c;其中还包括了其他操作优化方案&#xff1b; 滚轮配合 cmd 键调节页面大小&#xff1b;配合 option 键…

探索C++中的动态数组:实现自己的Vector容器

&#x1f389;个人名片&#xff1a; &#x1f43c;作者简介&#xff1a;一名乐于分享在学习道路上收获的大二在校生 &#x1f648;个人主页&#x1f389;&#xff1a;GOTXX &#x1f43c;个人WeChat&#xff1a;ILXOXVJE &#x1f43c;本文由GOTXX原创&#xff0c;首发CSDN&…

算法思想总结:双指针算法

一、移动零 . - 力扣&#xff08;LeetCode&#xff09; 移动零 该题重要信息&#xff1a;1、保持非0元素的相对位置。2、原地对数组进行操作 思路&#xff1a;双指针算法 class Solution { public:void moveZeroes(vector<int>& nums){int nnums.size();for(int cur…

Elasticsearch:在本地使用 Gemma LLM 对私人数据进行问答

在本笔记本中&#xff0c;我们的目标是利用 Google 的 Gemma 模型开发 RAG 系统。 我们将使用 Elastic 的 ELSER 模型生成向量并将其存储在 Elasticsearch 中。 此外&#xff0c;我们将探索语义检索技术&#xff0c;并将最热门的搜索结果作为 Gemma 模型的上下文窗口呈现。 此外…

人工智能迷惑行为大赏!

目录 人工智能迷惑行为大赏 一&#xff1a;人工智能的“幽默”瞬间 1. 图像识别出现AI的极限 2. 小批量梯度下降优化器 3. 智能聊天机器人的冰雹问题 4. 大语言模型-3经典语录 二&#xff1a;技术原理探究 1. 深度学习 2. 机器学习 3. 自然语言处理 4. 计算机视觉 三…

java八股文 笔记(持续更新中~)

1 Redis 2Mysql 3JVM 4java基础底层 5 spring 6 微服务 7.......(持续更新) One:Redis篇 1.穿透 2&#xff1a;击穿 3&#xff1a;雪崩 3 33 4:双写一致 5.持久化 2 JVM: 2&#xff1a; 3&#xff1a; 4&#xff1a; 5&#xff1a; 6&#xff1a; 7&#xff…

学生时期学习资源同步-1 第一学期结业考试题1

原创作者&#xff1a;田超凡&#xff08;程序员田宝宝&#xff09; 版权所有&#xff0c;引用请注明原作者&#xff0c;严禁复制转载

深度学习--离线数据增强

最近做项目遇见数据集背景非常单一&#xff0c;为了增加模型的返回能里&#xff0c;只能自己做一些数据增强来增加背景的多样性。代码如下&#xff1a; import numpy as np import cv2def create_mask(box, height, width):"""创建一个全零的掩码图像&#xff…

Prompt进阶2:LangGPT(构建高性能Prompt策略和技巧)--最佳实践指南

Prompt进阶2:LangGPT(构建高性能Prompt策略和技巧)–最佳实践指南 0.前言 左图右图 prompt 基本是一样的&#xff0c;差别只在提示工程这个词是否用中英文表达。我们看到&#xff0c;一词之差&#xff0c;回答质量天壤之别。为了获得理想的模型结果&#xff0c;我们需要调整设…

uniapp开发DAPP钱包应用(二) Vue + Java

上一节我们讲了如何通过vue uniapp还有web3以及需要准备的相关组件&#xff0c;来搭建了DAPP开发的环境。 这一节&#xff0c;我们来说说如何用代码来实现DAPP相关接口。 1. ethers实现类 导入组件 import { ethers , providers , utils } from "ethers"; impor…

跟着GPT学设计模式之桥接模式

说明 桥接模式&#xff0c;也叫作桥梁模式&#xff0c;英文是 Bridge Design Pattern。在 GoF 的《设计模式》一书中&#xff0c;桥接模式是这么定义的&#xff1a;“Decouple an abstraction from its implementation so that the two can vary independently。”翻译成中文就…

我真是服了!你们刚开始学习的时候也是造火箭吗?能不能有一个简单的纯纯纯html模板给我学学,真的看不懂好嘛!

做一个个人博客第一步该怎么做&#xff1f; 好多零基础的同学们不知道怎么迈出第一步。 那么&#xff0c;就找一个现成的模板学一学呗&#xff0c;毕竟我们是高贵的Ctrl c v 工程师。 但是这样也有个问题&#xff0c;那就是&#xff0c;那些模板都&#xff0c;太&#xff01;…