pytorch --反向传播和优化器

news2024/9/25 23:20:26

1. 反向传播

计算当前张量的梯度

Tensor.backward(gradient=None, retain_graph=None, create_graph=False, inputs=None)

计算当前张量相对于图中叶子节点的梯度。

使用反向传播,每个节点的梯度,根据梯度进行参数优化,最后使得损失最小化

代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)



class Tudui(nn.Module):
    def __init__(self):
        super().__init__()
        # 另一种写法
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(in_features=1024, out_features=64),
            nn.Linear(in_features=64, out_features=10)
        )

    def forward(self,x):
        # sequential方式
        x = self.model1(x)
        return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:
    imgs,target = data
    outputs= tudui(imgs)
    result_loss = loss(outputs,target)
    result_loss.backward() # 梯度
    print(result_loss)

2.优化器 (以随机梯度下降算法为例)
将上一步的梯度清零
params ,lr(学习率)
随机梯度下降SGD
torch.optim.SGD(params,
lr=,
momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)

代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)



class Tudui(nn.Module):
    def __init__(self):
        super().__init__()
        # 另一种写法
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(in_features=1024, out_features=64),
            nn.Linear(in_features=64, out_features=10)
        )

    def forward(self,x):
        # sequential方式
        x = self.model1(x)
        return x
loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(),lr=0.01) # params,lr
for epoch in range(5):# 在整个数据集上训练5次
    running_loss = 0
    #对数据进行一轮学习
    for data in dataloader:
        imgs,target = data
        outputs= tudui(imgs)
        result_loss = loss(outputs,target)
        optim.zero_grad() # 将上一步的梯度清零
        result_loss.backward() # 梯度
        optim.step() # 根据梯度修改参数
        # print(result_loss)
        running_loss = running_loss + result_loss
    print(running_loss)

输出
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1483047.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Dynamo幕墙探究系列(一)

一直想写个系列教程,但是没有那么多时间整理资料,这次呢,先弄个小系列吧,还是和之前差不多的幕墙测试,我们分几节课,一步一步深入研究。 今天先开个小头儿,要弄的,就是下面这么个模型…

小程序图形:echarts-weixin 入门使用

去官网下载整个项目: https://github.com/ecomfe/echarts-for-weixin 拷贝ec-canvs文件夹到小程序里面 index.js里面的写法 import * as echarts from "../../components/ec-canvas/echarts" const app getApp(); function initChart(canvas, width, h…

【IDEA】2023版IDEA安装破解教程

2023版IDEA安装破解教程 第一步:IDEA的卸载 这里以Windows11系统为例,首先我们打开控制面板,点击程序,找到自己的IDEA,双击卸载。(或者可以直接找到idea所在文件位置,直接delete文件夹&#x…

化肥工业5G智能制造工厂数字孪生可视化平台,推进化肥行业数字化转型

化肥工业5G智能制造工厂数字孪生可视化平台,推进化肥行业数字化转型。随着科技的不断发展,数字化转型已经成为各行各业发展的必然趋势。在化肥工业领域,5G智能制造工厂数字孪生可视化平台的应用正在逐渐普及,为行业数字化转型提供…

微信小程序云开发教程——墨刀原型工具入门(页面交互+交互案例教程)

引言 作为一个小白,小北要怎么在短时间内快速学会微信小程序原型设计? “时间紧,任务重”,这意味着学习时必须把握微信小程序原型设计中的重点、难点,而非面面俱到。 要在短时间内理解、掌握一个工具的使用&#xf…

AI时代编程新宠!如何让孩子成为未来的编程大师?

文章目录 一、了解编程的基础概念二、选择适合的编程工具三、激发孩子的兴趣四、注重基础能力的培养五、提供实践机会六、鼓励孩子与他人合作七、持续支持与鼓励《信息学奥赛一本通关》本书定位内容简介作者简介目录 随着科技的迅猛发展,编程已经从一种专业技能转变…

C++三级专项 Hermite多项式

用递归的方法求Hermite多项式的值。 对给定的x和正整数n,求多项式的值,并保留两位小数。 输入 给定的n和正整数x. 输出 多项式的值。 输入样例 1 2 输出样例 4.00 解析:按照题目给出的要求:递归来解决就行,上…

『京墨』1.7.0 发布,开源的诗文(名句)、歇后语、成语、绕口令、节日等的阅读 APP

1.7.0 更新日志 优化 UI 显示;优化数据同步,尤其是诗文同步;【诗文名句】【成语】【歇后语】模块添加收藏功能;添加“滑动翻页”功能。 介绍 『京墨』开源的古诗词文(名句)、歇后语、成语、绕口令、节日…

C++_map与set

目录 一、set 1、set的用法 2、multiset 二、map 1、map的用法 2、map的operator[] 3、multimap 结语 前言: C中的map和set容器属于关联式容器,与序列式容器不同的地方在于(序列式容器即vector、list,其底层是由线性数据…

牛客禁用题:求阶乘

思路&#xff1a;在新类中使用全局变量进行运算&#xff0c;在主类中定义新类数组&#xff0c;通过构造函数的调用次数返回阶乘 #include <type_traits> class add{public:static int count;static int tmp;add(){countcounttmp;tmp;} }; int add::count0; int add::t…

Flink:动态表 / 时态表 / 版本表 / 普通表 概念区别澄清

博主历时三年精心创作的《大数据平台架构与原型实现&#xff1a;数据中台建设实战》一书现已由知名IT图书品牌电子工业出版社博文视点出版发行&#xff0c;点击《重磅推荐&#xff1a;建大数据平台太难了&#xff01;给我发个工程原型吧&#xff01;》了解图书详情&#xff0c;…

在golang中使用protoc

【Golang】proto生成go的相关文件 推荐个人主页&#xff1a;席万里的个人空间 文章目录 【Golang】proto生成go的相关文件1、查看proto的版本号2、安装protoc-gen-go和protoc-gen-go-grpc3、生成protobuff以及grpc的文件 1、查看proto的版本号 protoc --version2、安装protoc-…

Java 打包 SpringBoot 项目报错

Java 打包 SpringBoot 项目报错 问题重现 Please refer to xxxx for the individual test results. Please refer to dump files (if any exist) [date].dump, [date]-jvmRun[N].dump and [date].dumpstream. 解决问题 在 pom.xml 的 <properties> 中添加项目代码 <s…

AIGC下一步:如何用AI再度重构或优化媒体处理?

让媒资中“沉默的大多数”再次焕发光彩。 邹娟&#xff5c;演讲者 编者按 AIGC时代下&#xff0c;媒体内容生产领域随着AI的出现也涌现出更多的变化与挑战。面对AI的巨大冲击&#xff0c;如何优化或重构媒体内容生产技术架构&#xff1f;在多样的应用场景中媒体内容生产技术又…

【半监督医学图像分割 2021 IEEE】DU-GAN

【半监督医学图像分割 2021 IEEE】DU-GAN 论文题目&#xff1a;DU-GAN: Generative Adversarial Networks with Dual-Domain U-Net Based Discriminators for Low-Dose CT Denoising 中文题目&#xff1a;基于双域U-Net鉴别器的生成对抗网络用于低剂量CT去噪 论文链接&#xff…

云时代【5】—— LXC 与 容器

云时代【5】—— LXC 与 容器 三、LXC&#xff08;一&#xff09;基本介绍&#xff08;二&#xff09;相关 Linux 指令实战&#xff1a;使用 LXC 操作容器 四、Docker&#xff08;一&#xff09;删除、安装、配置&#xff08;二&#xff09;镜像仓库1. 分类2. 相关指令&#xf…

C---输入5个字符串,找出最长字符串并输出

从键盘输入5个字符串&#xff0c;找出最长的字符串并输出该字符串 例&#xff1a; 输入&#xff1a;123 1234 werere1234 12 123 输出&#xff1a;werere1234 #include <stdio.h> #include <string.h>int main() {char strings[5][80]; // 存储5个字符串&#x…

界面控件DevExpress .NET MAUI v23.2新版亮点 - 拥有全新的彩色主题

DevExpress拥有.NET开发需要的所有平台控件&#xff0c;包含600多个UI控件、报表平台、DevExpress Dashboard eXpressApp 框架、适用于 Visual Studio的CodeRush等一系列辅助工具。屡获大奖的软件开发平台DevExpress 今年第一个重要版本v23.1正式发布&#xff0c;该版本拥有众多…

记录SSM项目集成Spring Security 4.X版本 之 加密验证和记住我功能

目录 前言 一、用户登录密码加密认证 二、记住我功能 前言 本次笔记的记录是接SSM项目集成Spring Security 4.X版本 之 加入DWZ,J-UI框架实现登录和主页菜单显示-CSDN博客https://blog.csdn.net/u011529483/article/details/136255768?spm1001.2014.3001.5502 文章之后补…

腾讯云幻兽帕鲁服务器使用Linux和Windows操作系统,对用户的技术要求有何不同?

腾讯云幻兽帕鲁服务器使用Linux和Windows操作系统对用户的技术要求有何不同&#xff1f; 首先&#xff0c;从操作界面的角度来看&#xff0c;Windows操作系统相对简单易操作&#xff0c;适合那些偏好使用图形化界面操作的用户。而Linux操作系统则需要通过命令行完成&#xff0…