20241008深度学习动手篇

news2025/1/12 1:45:16

文章目录

    • 1.如何写一个神经网络进行训练?
      • 1.1创建一个子类,搭建你需要的神经网络结构
      • 1.2 加载数据集
      • 1.3 自定义一些指标评估函数
      • 1.4训练
      • 1.5 结果展示
    • 2.参考文献

在这里插入图片描述

1.如何写一个神经网络进行训练?

1.1创建一个子类,搭建你需要的神经网络结构

# @File: 241008LeNet.py
# @Author: chen_song
# @Time: 2024/10/8 上午8:31

import torch
from torch import nn
from d2l import torch as d2l

net = nn.Sequential(
 # 进行卷积操作以后,
 nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
 nn.AvgPool2d(2,stride=2),
 nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
 nn.AvgPool2d(2,stride=2),
 nn.Flatten(),
 nn.Linear(16*5*5,120),nn.Sigmoid(),
 nn.Linear(120,84),nn.Sigmoid(),
 nn.Linear(84,10)
)
print(net)
print("===============================")
X = torch.rand(size=(1,1,28,28),dtype=torch.float32)
Y  = X.copy_(X)
for layer in net:
 X = layer(X)
 print(layer.__class__.__name__,X.shape)

print("============================")
# 输入给定以后,会进行一系列张量乘法计算
A = net(Y)
# print the last result
print(A)

result below:

Sequential( (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1),
padding=(2, 2)) (1): Sigmoid() (2): AvgPool2d(kernel_size=2,
stride=2, padding=0) (3): Conv2d(6, 16, kernel_size=(5, 5),
stride=(1, 1)) (4): Sigmoid() (5): AvgPool2d(kernel_size=2,
stride=2, padding=0) (6): Flatten(start_dim=1, end_dim=-1) (7):
Linear(in_features=400, out_features=120, bias=True) (8): Sigmoid()
(9): Linear(in_features=120, out_features=84, bias=True) (10):
Sigmoid() (11): Linear(in_features=84, out_features=10, bias=True) )
=============================== Conv2d torch.Size([1, 6, 28, 28]) Sigmoid torch.Size([1, 6, 28, 28]) AvgPool2d torch.Size([1, 6, 14,
14]) Conv2d torch.Size([1, 16, 10, 10]) Sigmoid torch.Size([1, 16, 10,
10]) AvgPool2d torch.Size([1, 16, 5, 5]) Flatten torch.Size([1, 400])
Linear torch.Size([1, 120]) Sigmoid torch.Size([1, 120]) Linear
torch.Size([1, 84]) Sigmoid torch.Size([1, 84]) Linear torch.Size([1,
10])
============================ tensor([[-0.2278, -0.5057, -0.6303, 0.1526, -0.1510, -0.1933, -0.3120, -0.7823,
0.4070, -0.0717]], grad_fn=)

1.2 加载数据集

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

打断点调试:
在这里插入图片描述在这里插入图片描述
你会发现:
train_iter和test_iter都是一个torch.utils.dataLoader对象,里面包含几个成员变量,住关键的是dataset对象以及sample对象,仔细研究你就会发现,为啥需要数据加载器了,因为你用神经网络进行训练,数据格式总得对吧,再就是要给个label吧,也就是目标值target吧,所以有余力朋友可以自己设计一个数据加载器…

1.3 自定义一些指标评估函数

def evaluate_accuracy_gpu(net, data_iter, device=None):  # @save
 """使用GPU计算模型在数据集上的精度"""
 if isinstance(net, nn.Module):
  net.eval()  # 设置为评估模式
  if not device:
   device = next(iter(net.parameters())).device
 # 正确预测的数量,总预测的数量
 metric = d2l.Accumulator(2)
 with torch.no_grad():
  for X, y in data_iter:
   if isinstance(X, list):
    # BERT微调所需的(之后将介绍)=== 自然语言处理
    X = [x.to(device) for x in X]
   else:
    X = X.to(device)
   y = y.to(device)
   metric.add(d2l.accuracy(net(X), y), y.numel())
 return metric[0] / metric[1]

注意一下里面net.eval()和net.train()

1.4训练

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    print('training on', device)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和,训练准确率之和,样本数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
d2l.plt.show()

1.5 结果展示

在这里插入图片描述

2.参考文献

[1]王辉,张帆,刘晓凤,等.基于DarkNet-53和YOLOv3的水果图像识别[J].东北师大学报(自然科学版),2020,52(04):60-65.DOI:10.16163/j.cnki.22-1123/n.2020.04.010.
[2]王治国,曹爽,管海燕,等.基于改进SSD的城市地下排水管道缺陷识别算法[J].测绘工程,2024,33(05):7-13.DOI:10.19349/j.cnki.issn1006-7949.2024.05.002.
[3]杨继雯.基于深度学习的监控视频中人员异常行为识别技术[D].西安工业大学,2024.DOI:10.27391/d.cnki.gxagu.2024.000829.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2199753.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RTSP RTP RTCP SDP基础知识

理论 流(Streaming ) 是近年在 Internet 上出现的新概念,其定义非常广泛,主要是指通过网络传输多媒体数据的技术总称。 流式传输分为两种 顺序流式传输 (Progressive Streaming) 实时流式传输 (Real time Streaming) ​​​​​…

李强总理签署国务院令 公布《网络数据安全管理条例》

中华人民共和国国务院令 第790号 《网络数据安全管理条例》已经2024年8月30日国务院第40次常务会议通过,现予公布,自2025年1月1日起施行。 总理 李强 2024年9月24日 网络数据安全管理条例 第一章 总则 第一条 为了规范网络数据处理活动,保…

SpringBoot日常:redission的接入使用和源码解析

文章目录 一、简介二、集成redissionpom文件redission 配置文件application.yml文件启动类 三、JAVA 操作案例字符串操作哈希操作列表操作集合操作有序集合操作布隆过滤器操作分布式锁操作 四、源码解析 一、简介 Redisson 是一个在 Redis 的基础上实现的 Java 驻内存数据网格…

基于java+springboot的旅游信息网站、旅游景区门票管理系统设计与实现

该系统是基于javaspringboot开发的旅游景区门票管理系统。是给师弟开发的大四实习作品。学习过程中,遇到问题可以咨询github作者。 演示地址 前台地址: http://travel.gitapp.cn 后台地址: http://travel.gitapp.cn/admin 后台管理帐号&am…

开发一个ftp上传客户端

文章目录 需求分析Tkinter基本用法多窗口切换FTP上传 程序打包源码 需求 项目中有个小功能模块 ,需要win下实现ftp上传功能,编写一个DEMO测试 要求 界面简单选择本地文件 上传ftp服务器显示进度条显示状态上传完成后显示URL分享地址 分析 Tkinter Tkint…

【读书笔记·VLSI电路设计方法解密】问题6:超大规模集成电路(VLSI)设计实现的主要方法是什么

现代芯片设计实践的主要方法包括: 定制设计现场可编程门阵列 (FPGA)基于标准单元的设计 (ASIC)平台/结构化ASIC在定制设计方法中,每个晶体管都是手动设计和布局的。这种方法的主要优势在于电路可以高度优化以提高速度、减少面积或降低功耗。然而,由于涉及大量手工工作,这种…

什么是物联网nb水表?

物联网NB水表是一种利用NB-IoT(窄带物联网)技术实现远程数据传输的智能水表。这种水表不仅能够精确计量用户的用水量,还能通过无线通信技术实现数据的远程传输和管理。下面我们来详细介绍物联网NB水表的主要特点和功能。 一、基本概念 -定义:物联网NB水…

nVisual集成项目交付模式升级方案

集 成 项 目 的 普 遍 现 状 1 集成项目的普遍现状 设计、工程和运维各部门使用不同的软件工具,缺乏有效的协同,工程数据无法有效积累并转化为运维数据; 传统项目验收交付模式已经无法满足用户的需求,需要项目交付后协助用…

分治算法(6)_归并排序_交易逆序对的总数

个人主页:C忠实粉丝 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 C忠实粉丝 原创 分治算法(6)_归并排序_交易逆序对的总数 收录于专栏【经典算法练习】 本专栏旨在分享学习算法的一点学习笔记,欢迎大家在评论区交流讨论&#x1f48…

怎么去掉图片上的文字不留痕迹?学会这5种P图方法轻松解决

图片编辑已成为我们日常生活和工作中不可或缺的一部分。但有时候,图片上的一些文字却成了我们分享或使用的障碍。如何无痕去除图片上的文字呢?今天,我将为大家介绍5种高效工具,让你轻松P图,一起来学习下吧。 工具一&am…

ESP32利用WebServer进行设备配置

目标需求 利用esp32的WebServer功能&#xff0c;展示一个网页&#xff0c;对里面的参数进行配置&#xff0c;并以json文本格式保存到flash里面。 1、定义HTML const char index_html[] PROGMEM R"rawliteral( <!DOCTYPE html> <html lang"en"> …

前沿论文 M5Product 组会 PPT

对比学习&#xff08;Contrast learning&#xff09;&#xff1a;对比学习是一种自监督学习方法&#xff0c;用于在没有标签的情况下&#xff0c;通过让模型学习哪些数据点相似或不同来学习数据集的一般特征。假设一个试图理解世界的新生婴儿。在家里&#xff0c;假设有两只猫和…

PPT在线画SWOT分析图!这2个在线软件堪称办公必备!

swot分析ppt怎么做&#xff1f; swot分析是一个非常常用的战略分析框架&#xff0c;经常会在ppt中使用。想在ppt中绘制swot分析图&#xff0c;使用自带的形状工具可以制作出来&#xff0c;但绘制效率不够高&#xff0c;在需要大批量制作的场景下&#xff0c;会让人非常心累………

【WebGis开发 - Cesium】三维可视化项目教程---初始化场景

系列文章目录 【WebGis开发 - Cesium】三维可视化项目教程—视点管理 目录 系列文章目录引言一、Cesium引入项目1.1 下载资源1.2 项目引入Cesium 二、初始化地球2.1 创建基础文件2.1.1 创建Cesium工具方法文件2.1.2 创建主页面 2.2 看下效果 三、总结 引言 本教程主要是围绕Ce…

现场直击!2023望繁信科技产品发布会精彩回顾

2023望繁信科技产品发布会圆满结束。 感谢200余名企业代表、合作伙伴、媒体到场参会&#xff0c;感谢3万多位关注望繁信科技和流程挖掘的朋友在线观看直播。 在会上&#xff0c;我们正式分享了望繁信科技多年深耕流程挖掘领域的思考、积累和部署&#xff0c;发布了过去一年在…

Pyppeteer:如何在 Python 中使用 Puppeteer 和 Browserless?

Python 中的 Pyppeteer 是什么&#xff1f; Pyppeteer 是流行的 Node.js 库 Puppeteer 的 Python 移植版本&#xff0c;用于以编程方式控制无头 Chrome 或 Chromium 浏览器。 本质上&#xff0c;Pyppeteer 允许 Python 开发人员在 Web 浏览器中自动执行任务&#xff0c;例如抓…

webm格式怎么转换成mp4?值得给你推荐的几种简单方法

webm格式怎么转换成mp4&#xff1f;MP4支持多种音频和视频编解码器&#xff0c;如H.264和AAC&#xff0c;用户可以根据需要调整视频和音频质量&#xff0c;以满足不同需求。同时&#xff0c;许多视频编辑软件广泛支持MP4格式&#xff0c;使得剪辑、合成和特效处理变得更加便捷。…

人工智能、人机交互和机器人国际学术会议

第三届人工智能、人机交互和机器人国际学术会议 &#xff08;AIHCIR 2024&#xff09;组委会热忱地邀请您参与本届大会。本届大会旨在聚集领先的科学家、研究人员和学者&#xff0c;共同交流和分享在人工智能、人机交互和机器人各个方面的经验和研究成果&#xff0c;为研究人员…

【C++】模板(初识):函数模板、类模板

本篇主要介绍C中的模板初阶的一些知识。模板分为函数模板和类模板&#xff0c;我们一个一个来看。 1.函数模板 1.1函数模板概念 函数模板代表了一个函数家族&#xff0c;该函数模板与类型无关&#xff0c;在使用时被参数化&#xff0c;根据实际的参数类型产生函数特定版本。…

LSTM时间序列模型实战——预测上证指数走势

LSTM时间序列模型实战——预测上证指数走势 关于作者 作者&#xff1a;小白熊 作者简介&#xff1a;精通python、matlab、c#语言&#xff0c;擅长机器学习&#xff0c;深度学习&#xff0c;机器视觉&#xff0c;目标检测&#xff0c;图像分类&#xff0c;姿态识别&#xff0c;…