【深度学习】神经网络实战分类与回归任务

news2025/1/24 8:14:45

第一步 读取数据

①导入torch

import torch

②使用魔法命令,使它使得生成的图形直接嵌入到 Notebook 的单元格输出中,而不是弹出新的窗口来显示图形

%matplotlib inline

③读取文件

from pathlib import Path
import requests

DATA_PATH=Path("data")
PATH = DATA_PATH/"mnist"
PATH.mkdir(parents=True,exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH/FILENAME).exists():
    content = requests.get(URL+FILENAME).content
    (PATH/FILENAME).open("wb").write(content)

④使用 gzippickle 模块加载一个压缩的 pickle 文件 (mnist.pkl.gz)

(PATH / FILENAME).as_posix():将 Path 对象转换为 POSIX 路径字符串,适用于跨平台环境。

import pickle
import gzip

with gzip.open((PATH/FILENAME).as_posix(),"rb") as f:
    ((x_train,y_train),(x_valid,y_valid),_) = pickle.load(f,encoding="latin-1")

第二步 主体部分

①自定义神经网络模型

import torch.nn.functional as F
from torch import nn

class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784,128)
        self.hidden2 = nn.Linear(128,256)
        self.out = nn.Linear(256,10)
        
    def forward(self,x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = self.out(x)
        return x

②定义获取数据的方法

shuffle代表洗牌

def get_data(train_ds,valid_ds,bs):
    return (
        DataLoader(train_ds,batch_size=bs,shuffle=True),
        DataLoader(valid_ds,batch_size=bs*2)
    )

③定义获取模型的方法

torch.optim 是 PyTorch 中用于定义各种优化算法的模块

from torch import optim

def get_model():
    model = Mnist_NN()
    return model,optim.Adam(model.parameters(),lr=0.001)

④定义损失函数

注1:调用model(xb)时会自动进行前向计算(forward pass)

这是因为PyTorch的nn.Module类(即所有神经网络模型的基类)内部实现了对__call__方法的重载。当通过实例化一个继承自nn.Module的类来创建对象时,并调用该对象(如果model(xb)),实际上是调用了这个对象的__call__方法。而__cacll__方法负责调用forward方法

注2:F.entropy是PyTorch中用于计算交叉熵损失的函数,位于torch.nn.functional模块中

它结合了log_softmax和nll_loss(负对数似然损失),使得在分类任务中可以直接使用,而无需显示地应用log_softmax

F提供了很多用于构建神经网络的方法,包括激活函数、损失函数、卷积操作、池化操作等

注3

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    代码段解析

这一段用于优化模型。它包含了反向传播(计算梯度)和参数更新的过程,是模型训练的核心步骤。

loss.backward()

        反向传播:调用backward()方法会根据损失函数对模型参数进行自动求导,计算每个参数的梯度,这些梯度将被存储在对应的参数张量的.grad属性中

反向传播是基于链式法则自动计算所有参数相对于损失的偏导数的过程

这一步骤对于更新模型参数至关重要,因为它提供了调整参数所需的方向信息

opt.step()

        参数更新:调用step()方法会使用之前计算的梯度来更新模型参数。具体的更新规则取决于所使用的优化算法(如SGD、Adam等),并且可能涉及到学习率、动量等超参数

opt.zero_grad()

        清除梯度:调用zero_grad()方法会将所有参数的梯度重置为零。这是必要的,因为PyTorch默认会累积梯度,而不是每次前向传播后自动清除它们。

如果不重置梯度,旧的梯度将会与新的梯度相加,导致不正确的梯度值,进而影响参数更新的效果。

通常在每次迭代结束时调用此方法以确保下一次前向传播时梯度是从零开始计算的。

loss_func = F.cross_entropy

def loss_batch(model,loss_func,xb,yb,opt=None):
    loss = loss_func(model(xb),yb)
    
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
        
    return loss.item(),len(xb)

⑤定义训练函数

import numpy as np

def fit(steps,model,loss_func,opt,train_dl,valid_dl):
    for step in range(steps):
        model.train()
        for xb,yb in train_dl:
            loss_batch(model,loss_func,xb,yb,opt)
            
        model.eval()
        with torch.no_grad():
            losses,nums = zip(
                *[loss_batch(model,loss_func,xb,yb) for xb,yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses,nums))/np.sum(nums)
        print("当前step:"+str(step)+",验证集损失:"+str(val_loss))

第三步 运行

①使用 Python 的内置 map() 函数,结合 PyTorch 的 torch.tensor 方法,将 x_train, y_train, x_valid, 和 y_valid 转换为 PyTorch 张量。这一步骤是数据预处理的一部分,确保所有数据都以张量的形式存储,从而可以直接用于 PyTorch 模型的训练和评估。

②加载数据集和数据

③加载模型

④训练,评估

x_train,y_train,x_valid,y_valid = map(torch.tensor,(x_train,y_train,x_valid,y_valid))

train_ds = TensorDataset(x_train,y_train)
valid_ds = TensorDataset(x_valid,y_valid)
bs=64

train_dl,valid_dl = get_data(train_ds,valid_ds,bs)
model,opt = get_model()
fit(25,model,loss_func,opt,train_dl,valid_dl)

运行结果:

测试训练精度:

correct = 0
total = 0
for xb,yb in valid_dl:
    outputs = model(xb)
    _,predicted = torch.max(outputs.data,1)
    total += yb.size(0)
    correct += (predicted==yb).sum().item()
    
print("准确率为: %d %%" % (100*correct/total))


至此,该实战完成!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2281306.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

翻译:How do I reset my FPGA?

文章目录 背景翻译:How do I reset my FPGA?1、Understanding the flip-flop reset behavior2、Reset methodology3、Use appropriate resets to maximize utilization4、Many options5、About the author 背景 在写博客《复位信号的同步与释放(同步复…

Linux调试器-gdb的使用简介

1、背景 程序的发布方式有两种,debug模式(给程序员用的)和release模式(给用户用的)Linux gcc/g出来的二进制程序,默认是release模式要使用gdb调试,必须在源代码生成二进制程序的时候,加上 -g 选项 注:debug模式产生的…

通过 Visual Studio Code 启动 IPython

在Visual Studio Code 中,你可以使用内置的终端来启动 ipython,当然首先要安装好ipython。 安装ipython的方法是在cmd里面输入以下命令安装: pip install ipython 启动ipython的步骤如下: 打开 VSCode 终端: 在 VSCo…

019:什么是 Resnet50 神经网络

本文为合集收录,欢迎查看合集/专栏链接进行全部合集的系统学习。 合集完整版请查看这里。 在上一节中,使用了一个简单的神经网络进行识别数字。 这个网络结构非常简单,一是因为层数少,二是因为结构是顺序的,没有其他…

微信小程序获取位置服务

wx.getLocation({type: gcj02,success(res) {wx.log(定位成功);},fail(err) {wx.log(定位失败, err);wx.showModal({content: 请打开手机和小程序中的定位服务,success: (modRes) > {if (modRes.confirm) {wx.openSetting({success(setRes) {if (setRes.authSetting[scope.u…

煤矿场景下拖链检测数据集VOC+YOLO格式21407张1类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):21407 标注数量(xml文件个数):21407 标注数量(txt文件个数):2140…

Charles 4.6.7 浏览器网络调试指南:HTTPS抓包(三)

概述 在现代互联网应用中,网络请求和响应是服务交互的核心。对于开发者和测试人员来说,能够准确捕获并分析这些请求,是保证系统稳定性和性能的关键。Charles作为一个强大的网络调试工具,不仅可以捕获普通的HTTP请求,还…

第五天 Labview数据记录(5.1 INI配置文件读写)

5.1 INI配置文件读写 INI配置文件是一种简单的文本文件,通常用于存储软件的配置信息。它具有以下作用: 存储软件配置参数方便软件的维护和更新提高软件的灵活性和可扩展性便于用户修改和共享配置 5.1.1 前面板 1)新建项目SaveData_Exampl…

1905电影网中国地区电影数据分析(一) - 数据采集、清洗与存储

文章目录 前言一、数据采集步骤及python库使用版本1. python库使用版本2. 数据采集步骤 二、数据采集网页分析1. 分析采集的字段和URL1.1 分析要爬取的数据字段1.2 分析每部电影的URL1.2 分析每页的URL 2. 字段元素标签定位 三、数据采集代码实现1. 爬取1905电影网分类信息2. 爬…

qml Dialog详解

1、概述 Dialog是QML(Qt Modeling Language)中用于显示对话框的组件,它提供了一个模态窗口,通常用于与用户进行重要交互,如确认操作、输入信息或显示警告等。Dialog组件具有灵活的布局和样式选项,可以轻松…

开关电源设计(1)--电感和伏秒平衡

电感(Inductor)是电子电路中用于存储磁场能量的被动元件,其核心特性是阻碍电流的变化。当电流通过导线时,周围会产生磁场,电感是衡量导线(或线圈)存储磁场能量能力的物理量。 先认识几个公式 …

Blazo-Blazor Web App项目结构

让我们还是从创建项目开始,来一起了解下Blazor Web App的项目情况 创建项目 呈现方式 这里我们可以看到需要选择项目的呈现方式,有以上四种呈现方式 ● WebAssembly ● Server ● Auto(Server and WebAssembly) ● None 纯静态界面静态SSR呈现方式 WebAs…

数据表中的数据查询

文章目录 一、概述二、简单查询1.列出表中所有字段2.“*”符号表示所有字段3.查询指定字段数据4.DISTINCT查询 三、IN查询四、BETWEEN ADN查询1.符合范围的数据记录查询2.不符合范围的数据记录查询 五、LIKE模糊查询六、对查询结果排序七、简单分组查询1.统计数量2.统计计算平均…

System slimming and Quicker action

今天介绍2款提升工作效率的软件,一款用于系统瘦身,当你的各个盘快满的时候,你又不知道该删除哪些文件的时候,就可以用这个插件,进行系统瘦身;另外一款是可以快捷做很多操作以节省时间,比如有很多…

2025年华为云一键快速部署饥荒联机服务器教程

饥荒是一款动作冒险类求生游戏,自行部署专属游戏联机服务器,可以确保游戏的流畅性和稳定性,获得更好的游戏体验。为了方便玩家搭建专属游戏联机服务器,华为云推出了云游戏专场,无需专业技术,新手小白也能一…

OSCP - Proving Grounds - Quackerjack

主要知识点 端口转发 具体步骤 执行nmap扫描,开了好多端口,我先试验80和8081,看起来8081比较有趣 Nmap scan report for 192.168.51.57 Host is up (0.0011s latency). Not shown: 65527 filtered tcp ports (no-response) PORT STATE SERVICE …

Go 切片:用法和本质

要想更好的了解一个知识点,实战是最好的经历。 题目 我这里放一道题目: package mainimport "fmt"func SliceRise(s []int) {s append(s, 0)for i : range s {s[i]}fmt.Println(s) }func SlicePrint() {s1 : []int{1, 2}s2 : s1s2 append…

零售业革命:改变行业的顶级物联网用例

mpro5 产品负责人Ruby Whipp表示,技术进步持续重塑零售业,其中物联网(IoT)正引领这一变革潮流。 研究表明,零售商们正在采用物联网解决方案,以提升运营效率并改善顾客体验。这些技术能够监控运营的各个方面…

安卓动态设置Unity图形API

命令行方式 Unity图像api设置为自动,安卓动态设置Vulkan、OpenGLES Unity设置 安卓设置 创建自定义活动并将其设置为应用程序入口点。 在自定义活动中,覆盖字符串UnityPlayerActivity。updateunitycommandlineararguments (String cmdLine)方法。 在该方法中,将cmdLine…

Go学习:iota枚举

iota注意事项: iota:常量自动生成器,每隔一行,自动累加iota给常量赋值使用iota 遇到 const,重置为 0可以只写一个iotaiota如果是同一行,值都一样 简单代码: package mainimport "fmt&qu…