PyTorch-优化器以及网络模型的修改

news2024/11/24 21:02:44

目的:优化器可以将神经网络中的参数根据损失函数和反向传播来进行优化,以得到最佳的参数值,让模型预测的更准确。

1. SGD

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
dataloader = DataLoader(dataset, batch_size=1)

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

loss = nn.CrossEntropyLoss()
myModule1 = MyModule()
# 1.定义一个优化器
optim = torch.optim.SGD(myModule1.parameters(), lr=0.01)

for data in dataloader:
    imgs, targets = data
    outputs = myModule1(imgs)
    res_loss = loss(outputs, targets)
    # 2.梯度清零
    optim.zero_grad()
    # 3.反向传播,求出每个参数的梯度
    res_loss.backward()
    # 4.对参数进行调优
    optim.step()

a. 梯度清零

b. 反向传播(此时grad中就有数了)

c. 优化参数(在这里变化有些小,但是有变化)

nn_loss_network.py

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
dataloader = DataLoader(dataset, batch_size=1)

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

loss = nn.CrossEntropyLoss()
myModule1 = MyModule()
# 1.定义一个优化器
optim = torch.optim.SGD(myModule1.parameters(), lr=0.01)

for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = myModule1(imgs)
        res_loss = loss(outputs, targets)
        # 2.梯度清零
        optim.zero_grad()
        # 3.反向传播,求出每个参数的梯度
        res_loss.backward()
        # 4.对参数进行调优
        optim.step()
        # 在这一轮学习中,整体的loss求和
        running_loss = running_loss + res_loss
    print(running_loss)

tensor(18636.9590, grad_fn=<AddBackward0>)
tensor(16148.1221, grad_fn=<AddBackward0>)
tensor(15489.6064, grad_fn=<AddBackward0>)

...

2. VGG

weights的使用:

weights=None 表示随机ImageNet初始化;不为None则为具体给定的权值。

model_pretrained.py

import torchvision

vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights=True)
print(vgg16_true)  # 1000个分类

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Process finished with exit code 0

如果用Cifar10(只有10分类的话)+VGG模型,该怎么去修改呢?

方法一:添加一层

# 方法一: 加一层
# vgg16_true.add_module('add_linear', nn.Linear(1000, 10))
# 若想在classifier里加,则:
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)

   (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
    (add_linear): Linear(in_features=1000, out_features=10, bias=True)
  )
)

方法二:修改

vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

 (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/585048.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Security 笔记

在Spring Security 5.7.0-M2&#xff0c;我们弃用了 WebSecurityConfigurerAdapter &#xff0c;因为我们鼓励用户转向使用基于组件的安全配置。 为了帮助大家熟悉这种新的配置风格&#xff0c;我们编制了一份常见用例表和推荐的新写法。 配置HttpSecurity Configuration pu…

重磅发布!面向装备制造业服务化转型白皮书

《面向装备制造业服务化转型白皮书》 关于白皮书 《面向装备制造业服务化转型白皮书》通过调研160余家装备制造企业的服务化路径及模式&#xff0c;研讨支持企业开展服务型制造的系统化方案&#xff0c;希望为装备制造业服务化转型&#xff0c;探索切实有效的路径以供参考。 …

Web 自动化测试案例——关闭某视频网站弹出广告以及打开登录框输入内容

文章目录 &#x1f4cb;前言&#x1f3af;自动化测试&#x1f9e9;环境的搭建 &#x1f3af;案例介绍&#x1f4dd;最后 &#x1f4cb;前言 人生苦短&#xff0c;我用Python。许久没写博客了&#xff0c;今天又是久违的参与话题的讨论&#xff0c;话题的内容是&#xff1a;如何…

4.文件系统

组成 Linux&#xff1a;一切皆文件 索引节点&#xff08;I-node&#xff09; I-node&#xff08;Index Node&#xff09;&#xff1a;文件系统的内部数据结构&#xff0c;用于管理文件的元数据和数据块。 文件的元数据&#xff1a;包括文件的权限、拥有者、大小、时间戳、索引…

VM增加磁盘并挂载到根目录

1、虚拟机增加磁盘 首先要关闭虚拟机&#xff0c;否则增加按钮不可见。 9 vm添加磁盘完毕。 2、登录虚拟机挂盘 1、lsblk查看硬盘挂载情况&#xff0c;sdb为新挂载的磁盘。 [rootlocalhost ~]# lsblk NAME MAJ:MIN RM SIZE RO TYPE MOUNTPOINT sda …

通过python封装接口采集1688店铺所有商品数据接口,1688店铺所有商品接口,1688API接口

采集1688店铺所有商品数据需要进行以下步骤&#xff1a; 获取店铺ID 要获取店铺ID&#xff0c;您可以通过访问店铺首页来获取&#xff0c;例如&#xff1a;https://1688455341.1688.com/ 店铺ID就是链接中的“1688455341”。 获取店铺所有商品列表页 通过向1688店铺的搜索…

关于【SD-WEBUI】的LoRA模型训练:怎样才算训练好了?

文章目录 &#xff08;零&#xff09;前言&#xff08;一&#xff09;模型(LoRA)训练&#xff08;1.1&#xff09;数据准备&#xff08;1.1.1&#xff09;筛选照片&#xff08;1.1.2&#xff09;预处理照片&#xff08;1.1.3&#xff09;提示词(tags)处理&#xff08;1.1.4&…

部署微信小程序-shopro

部署微信小程序 开始之前 注意不要运行模式下的代码提交小程序审核&#xff0c;第一包体积太大&#xff0c;第二性能太差请下载 小程序开发工具正式小程序无法正常使用&#xff0c;而开发版正常&#xff0c;请确保域名都添加到小程序后台&#xff0c;并且配置好了 IP 白名单&a…

Openai+Deeplearning.AI: ChatGPT Prompt Engineering(五)

想和大家分享一下最近学习的Deeplearning.AI和openai联合打造ChatGPT Prompt Engineering在线课程.以下是我写的关于该课程的前四篇博客&#xff1a; ChatGPT Prompt Engineering(一)ChatGPT Prompt Engineering(二)ChatGPT Prompt Engineering(三)ChatGPT Prompt Engineering…

微星笔记本618大促至高直降5000元,泰坦GP78 HX爆款配置10999拿下

在万众玩家的期待下&#xff0c;微星笔记本618大促如约而至&#xff01;不仅覆盖今年全新13代酷睿HX RTX40系显卡的高能游戏本&#xff0c;还特别在618同步推出新品&#xff1a;泰坦GP78 HX&#xff0c;承袭“泰坦系列”旗舰的满血基因极致性能体验外&#xff0c;更有i9-13980…

自学web前端能找到工作吗?是否有必要参加前端培训?

是的&#xff0c;自学前端可以帮助您找到工作&#xff0c;参加培训是根据个人学习能力和经济实力来自己决定的。前端开发是一个相对容易入门的领域&#xff0c;并且许多人通过自学成功地找到了前端开发的工作。以下是好程序员的一些建议&#xff0c;可以帮助您在自学前端时提高…

头顶“米链代工厂”标签,德尔玛上市之后怎么走?

截至5月29日上午收盘&#xff0c;德尔玛股价当前为14.10、成交量55272手、成交额为7820.32万&#xff0c;总市值65.08亿元&#xff0c;总股本为4.62亿。 曲折的股价走势背后&#xff0c;德尔玛未来的增长潜力成疑。德尔玛表示&#xff0c;此次上市将有助于公司在创新家电市场保…

诚迈科技携智达诚远出席高通汽车技术与合作峰会

5月25日至26日&#xff0c;诚迈科技及旗下的智能汽车操作系统及中间件产品提供商智达诚远作为高通生态伙伴&#xff0c;亮相首届“高通汽车技术与合作峰会”&#xff0c;通过产品展示和主题演讲呈现了基于高通骁龙数字底盘的最新智能座舱技术成果&#xff0c;共同展望智能网联汽…

Java代码命名规范是真优雅呀!代码如诗

Java 命名规范 一、Java总体命名规范 1、项目名全部小写. 2、包名全部小写. 3、类名首字母大写,其余组成词首字母依次大写. 4、变量名,方法名首字母小写,如果名称由多个单词组成,除首字母外的每个单词首字母都要大写. 5、常量名全部大写. 6、所有命名规则必须遵循以下规则 : …

Java - ThreadLocal数据存储和传递方式的演变之路

Java - ThreadLocal数据存储和传递方式的演变之路 前言一. InheritableThreadLocal - 父子线程数据传递1.1 父子线程知识预热和 InheritableThreadLocal 实现原理1.2 InheritableThreadLocal 的诟病 二. TransmittableThreadLocal (TTL) 横空出世2.1 跨线程变量传递测试案例2.2…

代码随想录二刷 day06 | 哈希表之 242.有效的字母异位词 349. 两个数组的交集 202. 快乐数 1. 两数之和

day06 242.有效的字母异位词349. 两个数组的交集202. 快乐数1. 两数之和 哈希表能解决什么问题呢&#xff1f;一般哈希表都是用来快速判断一个元素是否出现集合里。 242.有效的字母异位词 题目链接 解题思路&#xff1a; 题目的意思就是 判断两个字符串是否由相同字母组成。 字…

【Java|基础篇】内部类

文章目录 1.什么是内部类?2.实例内部类3.静态内部类4.局部内部类5.匿名内部类6.结语 1.什么是内部类? 内部类就是在一个类中再定义一个类,内部类也是封装的体现.它可以被声明为 public、protected、private 或默认访问控制符。内部类可以访问外部类的所有成员变量和方法&…

【WebRTC】音视频通信

WebRTC对等体还需要查找并交换本地和远程音频和视频媒体信息&#xff0c;例如分辨率和编解码器功能。 交换媒体配置信息的信令通过使用被称为SDP的会话描述协议格式来交换&#xff0c;被称为提议和应答的元数据块 WebRTC 音视频通信基本流程 一方发起调用 getUserMedia 打开本…

线程池在业务中的实践-美团技术团队分享

原文地址&#xff1a;Java线程池实现原理及其在美团业务中的实践 场景1&#xff1a;快速响应用户请求 描述&#xff1a;用户发起的实时请求&#xff0c;服务追求响应时间。比如说用户要查看一个商品的信息&#xff0c;那么我们需要将商品维度的一系列信息如商品的价格、优惠、…

从小白到大神之路之学习运维第31天

第二阶段基础 时 间&#xff1a;2023年5月29日 参加人&#xff1a;全班人员 内 容&#xff1a; Rsync服务 目录 一、基本信息 二、rsync命令 三、rsyncinotfy实时同步 一、基本信息 &#xff08;一&#xff09;概述 rsync是linux 下一个远程数据同步工具 他可通过…