优化器的使用

news2025/1/17 21:57:34

代码示例:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 加载数据集转化为Tensor数据类型
dataset = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor()
                                       , download=True)
# 使用dataloader加载数据集
dataloader = DataLoader(dataset, batch_size=1)


class Kun(nn.Module):
    def __init__(self):
        super(Kun, self).__init__()
        self.model1 = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Flatten(),  # 将数据进行展平 64*4*4 =1024
            Linear(in_features=1024, out_features=64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
kun = Kun()

# 设置优化器
optim = torch.optim.SGD(kun.parameters(), lr=0.01)
# 相当于一轮学习
    for data in dataloader:
        imgs, target = data
        outputs = kun(imgs)
        result = loss(outputs, target)

        optim.zero_grad()  # 将所有参数梯度调整为0
        result.backward()  # 调用损失函数的反向传播求出每个梯度
        optim.step()  # 循环调优

增加训练次数:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 加载数据集转化为Tensor数据类型
dataset = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor()
                                       , download=True)
# 使用dataloader加载数据集
dataloader = DataLoader(dataset, batch_size=1)


class Kun(nn.Module):
    def __init__(self):
        super(Kun, self).__init__()
        self.model1 = Sequential(
            Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Flatten(),  # 将数据进行展平 64*4*4 =1024
            Linear(in_features=1024, out_features=64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
kun = Kun()

# 设置优化器
optim = torch.optim.SGD(kun.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0      # 记录每轮学习损失的总和
    # 相当于一轮学习
    for data in dataloader:
        imgs, target = data
        outputs = kun(imgs)
        result = loss(outputs, target)

        optim.zero_grad()  # 将所有参数梯度调整为0
        result.backward()  # 调用损失函数的反向传播求出每个梯度
        optim.step()  # 循环调优
        running_loss += result
    print(running_loss)

结果示例:每轮的损失参数不断减小

在这里插入图片描述

造成损失参数不降反升,是lr设置过大

调整lr=0.001

optim = torch.optim.SGD(kun.parameters(), lr=0.001)

结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1016117.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

腾讯mini项目-【指标监控服务重构】2023-08-22

今日已办 50字项目价值和重难点 项目价值 通过将指标监控组件接入项目,对比包括其配套工具在功能、性能上的差异、优劣,给出监控服务瘦身的建议 top3难点 减少监控服务资源成本,考虑性能优化如何证明我们在监控服务差异、优劣方面的断言…

ubuntu 22.04运行opencv4的c++程序遇到的问题

摘要:本文介绍一下在ubuntu系统中,运行一个最简单的opencv4程序都出问题的解决方法,并对其基本原理作简单阐述。解决问题的方法有很多,本文只提供其中一种。 opencv版本是4.2.0,ubuntu版本是20.04 查询opencv版本的指…

Aztec.nr:Aztec的隐私智能合约框架——用Noir扩展智能合约功能

1. 引言 前序博客有: Aztec的隐私抽象:在尊重EVM合约开发习惯的情况下实现智能合约隐私 Aztec.nr,为: 面向Aztec应用的,新的,强大的智能合约框架使得开发者可直观管理私有状态基于Noir构建,…

写一篇nginx配置指南

nginx.conf配置 找到Nginx的安装目录下的nginx.conf文件,该文件负责Nginx的基础功能配置。 配置文件概述 Nginx的主配置文件(conf/nginx.conf)按以下结构组织: 配置块功能描述全局块与Nginx运行相关的全局设置events块与网络连接有关的设置http块代理…

AIGC专栏6——通过阿里云与AutoDL快速拉起Stable Diffusion和EasyPhoto

AIGC专栏6——通过阿里云与AutoDL快速拉起Stable Diffusion和EasyPhoto 学习前言Aliyun DSW快速拉起(新用户有三个月免费时间)1、拉起DSW2、运行Notebook3、一些小bug AutoDL快速拉起1、拉起AutoDL2、运行Notebook 学习前言 快速拉起AIGC服务 对 用户体…

CAN Driver

CAN Driver 前言:CAN驱动针对的是微控制器内部的CAN控制器,它可以实现以下功能: 对CAN控制器进行初始化; 发送和接收报文; 对报文的数据和功能进行通知(对接收报文的指示、对发送报文的确认&#xff09…

基于SSM+Vue的人力资源管理系统

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用Vue技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

交叉编译工具链-Ubuntu 安装说明

交叉编译工具链-Ubuntu 安装说明 【实验目的】 了解交叉编译工具链的安装方法与使用方法 【实验环境】 1、 ubuntu 14.04 发行版 【注意事项】 1、实验步骤中以“$”开头的命令表示在 ubuntu 环境下执行 【实验步骤】 1、安装交叉编译工具链 在 ubuntu 下打开一个终端并进入到家…

yamot:一款功能强大的基于Web的服务器安全监控工具

关于yamot yamot是一款功能强大的基于Web的服务器安全监控工具,专为只有少量服务器的小型环境构建。yamot只会占用非常少的资源,并且几乎可以在任何设备上运行。该工具适用于Linux或BSD,当前版本暂不支持Windows平台。 比如说,广…

elasticsearch11-实战搜索和分页

个人名片: 博主:酒徒ᝰ. 个人简介:沉醉在酒中,借着一股酒劲,去拼搏一个未来。 本篇励志:三人行,必有我师焉。 本项目基于B站黑马程序员Java《SpringCloud微服务技术栈》,SpringCloud…

Prometheus+Grafana可视化监控【ElasticSearch状态】

文章目录 一、安装Docker二、安装ElasticSearch(Docker容器方式)三、安装Prometheus四、安装Grafana五、Pronetheus和Grafana相关联六、安装elasticsearch_exporter七、Grafana添加ElasticSearch监控模板 一、安装Docker 注意:我这里使用之前写好脚本进行安装Docke…

Linux学习之平均负载的概念和查看方法

先理解一下平均负载的含义: 平均负载是指单位时间内,系统处于可运行状态和不可中断状态的进程数,也可以看成平均活跃进程数。 可运行状态的进程: 正在使用CPU或者正在等待CPU处理的进程,ps 命令看到的,处于…

AO天鹰优化算法|含源码(元启发式算法)

-----------------------往期目录------------------ 1、灰狼优化算法 文章目录 天鹰优化器一、第一种搜索方法二、第二种搜素方法三、第三种搜素方法四、第四种搜索方法 代码实现 天鹰优化器 Aquila Optimizer(AO),灵感来自Aquila在捕捉猎物…

Mysql001:(库和表)操作SQL语句

目录: 》SQL通用规则说明 SQL分类: 》DDL(数据定义:用于操作数据库、表、字段) 》DML(数据编辑:用于对表中的数据进行增删改) 》DQL(数据查询:用于对表中的数…

开源日报 0825 | 简化开发过程,提升Swift应用性能的扩展工具库

OpenZeppelin/openzeppelin-contracts Stars: 22.8k License: MIT OpenZeppelin Contracts 是一个用于安全智能合约开发的库。它建立在社区验证过的代码基础上,具有以下主要功能: 实现了 ERC20 和 ERC721 等标准。灵活的基于角色的权限控制方案。可重…

数据结构-leetcode-环形链表

解题图解: 代码如下: bool hasCycle(struct ListNode *head) {struct ListNode * fasthead;//在这里fast是快指针//head作为low指针//因为这个题不需要做修改也只需返回true或false//就少开辟一个空间while(fast!NULL&&fast->next!NULL){hea…

Css实现右上角飘带效果

效果图&#xff1a; 源码&#xff1a; <!DOCTYPE html> <html><head><meta charset"utf-8"><title></title><style type"text/css">*{margin: 0 auto;padding: 0;}.wrap {/* 设置宽高 */width: 350px;height: …

多分类中混淆矩阵的TP,TN,FN,FP计算

关于混淆矩阵&#xff0c;各位可以在这里了解&#xff1a;混淆矩阵细致理解_夏天是冰红茶的博客-CSDN博客 上一篇中我们了解了混淆矩阵&#xff0c;并且进行了类定义&#xff0c;那么在这一节中我们将要对其进行扩展&#xff0c;在多分类中&#xff0c;如何去计算TP&#xff0…

1_图神经网络GNN基础知识学习

文章目录 安装PyTorch Geometric安装工具包 在KarateClub数据集上使用图卷积网络 (GCN) 进行节点分类两个画图函数Graph Neural Networks数据集&#xff1a;Zacharys karate club network.PyTorch Geometric数据集介绍 edge_index使用networkx可视化展示 Graph Neural Networks…

cenos自动启动tomcat

首先创建一个脚本 关闭tomcat 等待2分钟 启动tomcat 并且把日志输出在 /usr/local/tomcat/tomcatchognqi.log #!/bin/bashexport JAVA_HOME/usr/local/jdk/jdk1.8.0_211 export JRE_HOME$JAVA_HOME/jre# 日志文件路径和文件名 LOG_FILE"/usr/local/tomcat/tomcatchognqi.…