pytorch学习笔记(十一)

news2025/1/13 8:08:36

优化器学习

把搭建好的模型拿来训练,得到最优的参数。

import torch.optim
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),
                                        download=True)
dataloader = DataLoader(dataset, batch_size=1)
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self, x):
        x = self.model1(x)
        return x
#定义loss
loss = nn.CrossEntropyLoss()
tudui = Tudui()
#一开始时采用比较大的学习速率学习,后面用比较小的学习速率学习
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    #在每一轮学习之前都把loss设置成0
    #在每一轮的学习过程中计算的loss都加上去
    #这个数据是表示,在每一轮的学习的过程中在这一轮的整体的loss的求和,整体误差总和
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        #得到每一个可调参数的梯度
        result_loss.backward()
        optim.step()
        #损失函数没有已知在变化,原因是只有单个循环下,只看了一次数据,这一次看到的数据对你下一次看到的数据预测的影响不大
        # print(result_loss)
        running_loss = running_loss + result_loss
    print(running_loss)

在debug的过程中选择最后三行,观察梯度变化

其中optim.step()会把每一步更新的梯度用于数据的更新

现有模型的使用和修改

参数:root (string) - ImageNet数据集的根目录。

split (string,可选)-数据集分割,支持train或val。

transform(可调用的,可选的)-一个函数/转换,接收PIL图像并返回转换后的版本。例如,变换。RandomCrop

target_transform (callable, optional) -一个函数/transform,接收目标并对其进行变换。

loader -加载给定路径的图像的函数。

这边看看VGG16,因为它的预训练数据集太大了,不好下载,这边采用CIFAR10代替ImageNet的方法。

然后发现他的线性层输出的特征是1000,也是分1000个类,而CIFAR10只有10个类,这需要对网络模型进行修改,两种思路进行修改。

(1)直接修改最后一个线性层(6),将输出特征改为10

(2)加个线性层(7),输入设置为1000,而输出设置为10

模型的保存和模型的加载

官方推荐的保存下来文件比较小

方式2输出的是一个字典形式,要恢复成网络结构,要新建这个模型,然后还要通过字典的形式重建。

另外要注意用方式1(陷阱)保存的时候要在加载的部分引入你定义的结构否则会报错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1411532.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Python】采用OpenCV和Flask来进行网络图像推流的低延迟高刷FPS方法(项目模板)

【Python】采用OpenCV和Flask来进行网络图像推流的低延迟高刷FPS方法(项目模板) gitee项目模板: 网络图像推流项目模板(采用OpenCV和Flask来进行网络图像推流的低延迟高刷FPS方法) 前文: 【最简改进】基于…

深入浅出 diffusion(2):pytorch 实现 diffusion 加噪过程

我在上篇博客深入浅出 diffusion(1):白话 diffusion 原理(无公式)中介绍了 diffusion 的一些基本原理,其中谈到了 diffusion 的加噪过程,本文用pytorch 实现下到底是怎么加噪的。 import torch…

最小覆盖子串(Leetcode76)

例题: 分析: 比如现在有字符串(s),s "ADOBECODEBANC", 给出目标字符串 t "ABC", 题目就是要从原始字符串(s)中找到一个子串(res)可以覆盖目标字符串 t &…

UE使用C++添加FGameplayTag(游戏标签)

首先Ue会有一个UGameplayTagsManager类型的对象 游戏标签管理器(全局中就有一个) 我们直接通过 UGameplayTagsManager::Get()静态函数拿到 全局唯一的游戏标签管理器的实例 返回的是个左值引用 之后通过调用 AddNativeGameplayTag()函数就可添加游戏标签了 就这么简单 第…

Java+Spring Cloud +Vue+UniApp微服务智慧工地云平台源码

目录 智慧工地云平台功能 【劳务工种】所属工种有哪些? 1.管理人员 2.信息采集 3.证件管理 4.考勤管理 5.考勤明细 6.工资管理 7.现场统计 8.WIFI教育 9.课程库管理 10.工种管理 11.分包商管理 12.班组管理 13.项目管理 智慧工地管理平台是以物联网、…

算法题 — 删除排序数组中的重复项

问题:一个有序数组 nums,原地删除重复出现的元素,使每个元素只出现一次,返回删除后数组的新长度。 注:不能使用额外的数组空间,必须在原地修改输入数组并在使用 O(1) 额外空间的条件下完成。 例&#xff…

【考研结束了,不管上不上岸,我建议你先....】

*** 考研结束,一定要做这几件事! 又一年考研季的落幕,经历了漫长考研岁月的学子们,终于迎来了期盼已久的解脱。参加考研的同学们必须都顺利上岸。 然而对于技术类专业的考生而言,新的征程与机遇才刚刚启航。 此时此刻…

专业144总分410+华南理工大学811信号与系统考研经验华工电子信息与通信

今年专业811信号与系统144(二战,感谢信息通信Jenny老师专业课对我的巨大提高,第一年自己复习只考了90,主要栽专业课和数学)总分410含泪(二战的同学都知道苦,成功来之不易)考上华南理…

InterSystem IRIS BS BP BO配置

应用:根据请求的BS,通过BP,到BO的处理,集成平台BO获取数据并推送给指定第三方 操作步骤: 一、事前准备: 创建交互服务前提前将SQL网关创建和连接好。需记录网关连接名称,配置在BS设置的DSN处…

HMI-Board以太网数据监视器(二)MQTT和LVGL

E ∫ d E ∫ k d q r 2 k L ∫ d q r 2 E \int dE \int \frac{kdq}{r^2} \frac{k}{L} \int \frac{dq}{r^2} E∫dE∫r2kdq​Lk​∫r2dq​ E Q 2 π ϵ L 2 E \frac{Q}{2\pi\epsilon L^2} E2πϵL2Q​ Γ ( n ) ( n − 1 ) ! ∀ n ∈ N \Gamma(n) (n-1)!\quad\forall n…

实习日志5

活字格图片上传功能(批量) 这个报错真的恶心,又看不了他服务器源码,接口文档又是错的 活字格V9获取图片失败bug,报错404-CSDN博客 代码BUG记录: 问题:上传多个文件的base64编码被最后一个文…

eclipse启动Java服务及注意事项

1、导入项目 选择file——》import…——》Generate——》Exiting Projects into Workspace——》选择要导入的项目 2、添加tomcat 1)点击Serves——》No servers are available. Click this link to create a new server… 2)点击“Add…” 3&…

【Servlet】如何编写第一个Servlet程序

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【Servlet】 本专栏旨在分享学习Servlet的一点学习心得,欢迎大家在评论区交流讨论💌 Servlet是Java编写的服务器端…

STM32第一节——初识STM32

1 硬件介绍 1.1 硬件平台 配套硬件:以野火的STM32 F1霸道开发板为平台,若用的是别的开发板,可自己进行移植。 1.2 什么是STM32 STM32是由意法半导体(STMicroelectronics)公司推出的一系列32位的ARM Cortex-M微控制…

CMIP6数据驱动WRF和WRF-Chem模式、WRF-Chem的未来大气污染变化模拟

目录 一、CMIP6数据及运行平台建设 二、CMIP6数据驱动WRF和WRF-Chem模式 三、WRF-Chem的未来情景模拟 更多应用 对模式比较计划的全球气候预估数据进行动力降尺度,结合预估的未来气候变化,运用区域气候模式和气候-化学耦合模式,实现对未来…

【数据结构】数据结构初识

前言: 数据结构是计算存储,组织数据的方式。数据结构是指相互间存在一种或多种特定关系的数据元素的集合。通常情况下,精心选择的数据结构可以带来更高的运行或者存储效率。数据结构往往同高效的检索算法和索引技术有关。 Data Structure Vi…

day2 C++

封装一个矩形类(Rect)&#xff0c;拥有私有属性:宽度(width)、高度(height)&#xff0c; 定义公有成员函数: 初始化函数:void init(int w, int h) 更改宽度的函数:set_w(int w) 更改高度的函数:set_h(int h) 输出该矩形的周长和面积函数:void show() #include <iostre…

笔记--写代码好习惯

原文&#xff1a;写代码有这16个好习惯&#xff0c;可以减少80%非业务的bug

uniapp vuecli项目融合[小记]:将多个项目融合,打包成一个小程序/App,拆分多个H5应用

前言&#xff1a; 目前两个uniapp vuecli开发的项目【A、B】&#xff0c;新规划的项目C&#xff1a;需要融合项目B 80%的功能模块&#xff0c;同时也需要涵盖项目A的所有功能模块。 应用需求&#xff1a; 1、新项目C【小程序】可支持切换到应用A/C界面【内部通过初始化、路由跳…

对于小微企业而言,数字化转型的重要性是什么?

数字化转型对于小微企业至关重要&#xff0c;原因如下&#xff1a; 1.效率和生产力&#xff1a;采用数字工具和技术可以简化业务流程、自动化重复性任务并提高整体效率。这使得小型企业能够用更少的资源完成更多的工作&#xff0c;最终提高生产力。 2.节省成本&#xff1a;数…