优化器与现有网络模型的修改

news2024/9/22 15:46:47

一、优化器

optimizer = optim.SGD(model.parameters(), lr=0.01(学习速率), momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

一般,学习率的设置,先从大的设置,逐渐变小。

神经网络可以参见上篇文章,接上篇文章的神经网络做模型优化:

sun = SUN() #调用神经网络
# 设置优化器
# 随机梯度下降
optim = torch.optim.SGD(sun.parameters(), lr=0.01)
# 循环学习20次
for epoch in range(20):
    # 整体误差的总和
    running_loss = 0.0
    # 只循环了一次,学习了一次,在外面再次加一个循环
    for data in dataloader:
        imgs, targets = data
        outputs = sun(imgs)
        result_loss = loss(outputs, targets)
        # 将模型中每一个可以调节参数对应的梯度调为零
        optim.zero_grad()
        # 得到可以调节参数的梯度
        result_loss.backward()
        optim.step()
        # 总体误差的总和
        running_loss = running_loss + result_loss
    print(running_loss)

输出结果为:

由输出结果可见,模型优化器,使得误差总和在不断的变小。

二、模型的使用与修改

import torchvision

下载的数据集存放位置:

 在下载ImageNet数据集时,会出现报错,数据集有143G,已经不支持下载。设置语句:

vgg16_false = torchvision.models.vgg16(pretrained=False)

当为False时,只是加载网络模型(也就是像之前的网络模型那样,只是加载模型,含有卷积,池化等,其中的参数都是默认的)

vgg16_true = torchvision.models.vgg16(pretrained=True)

当为True时,不仅加载模型,还要加载对应的参数。

VGG16将数据集分成1000个类。

print(vgg16_true)

输出结果:

之前使用的数据集CIFAR10数据集输出的为10个分类,于是,我们也可以根据现有的网络进行修改:

import torchvision

# train_data = torchvision.datasets.ImageNet("../data_ImageNet", split='train', download= True,
#                                            transform=torchvision.transforms.ToTensor())
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

print(vgg16_true)

# vgg16_true.add_module('add_linear', nn.Linear(1000,10))
# print(vgg16_true) 加在末尾

# 加在开头
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000,10))
print(vgg16_true)

print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

二、网络模型的保存与读取

出现的问题:

AttributeError: Can't get attribute 'SUN' on <module '__main__' from 'D:/test pytorch/learningplan1/models_load.py'>

解决的办法是需要将网络模型放在读取的上面,不需要引用,但是需要将其放置在其上方

当然,引用别人的模型时,就非常的不便,所以使用:

from models_save import *

就能够解决此问题。

三、模型的保存

import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained = False)
# 保存方式1(结构与参数均保存)
torch.save(vgg16, 'vgg16_method1.pth')

# 保存方式2(参数保存为字典形式)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")


# 陷阱
class SUN(nn.Module):
    def __init__(self):
        super(SUN, self).__init__()
        self.convv = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.convv(x)
        return x


sun = SUN()
torch.save(sun, "sun_method1.pth")

四、模型的使用

import torch

# 方式1 保存后加载
# model = torch.load("vgg16_method1.pth")
# print(model)

# 方式2 保存后加载
import torchvision
from torch import nn
from models_save import *

vgg16 = torchvision.models.vgg16(pretrained = False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model = torch.load("vgg16_method2.pth")
# print(vgg16)

# 加载网络模型
# class SUN(nn.Module):
#     def __init__(self):
#         super(SUN, self).__init__()
#         self.convv = nn.Conv2d(3, 64, kernel_size=3)
#
#     def forward(self, x):
#         x = self.convv(x)
#         return x

model = torch.load("sun_method1.pth")
print(model)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2107959.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据库】MySQL-基础篇-函数

专栏文章索引&#xff1a;数据库 有问题可私聊&#xff1a;QQ&#xff1a;3375119339 目录 一、简介 二、字符串函数 三、数值函数 四、日期函数 五、流程函数 一、简介 函数 是指一段可以直接被另一段程序调用的程序或代码。 也就意味着&#xff0c;这一段程序或代码在 M…

【2024国赛C题】高教杯全国大学生数学建模国赛建模过程+完整代码论文全解全析

完整内容在文章末尾阅读全文获取&#xff01; 问题 1是针对不同情况下&#xff0c;该乡村未来几年农作物的最优种植方案的研究。 为解决这个数学建模问题&#xff0c;我们需要构建一个优化模型&#xff0c;考虑到各种限制条件和目标函数。以下是解决问题的步骤&#xff1a; 问…

有源低通/高通滤波器(一阶滤波器+Sallen-Key滤波器+高下降率滤波器)+有源带通滤波器(级联+多重反馈+状态可变)

2024-9-5&#xff0c;星期四&#xff0c;20:40&#xff0c;天气&#xff1a;晴&#xff0c;心情&#xff1a;晴。明天终于又要放假啦&#xff01;继续学习。、 今天继续学习第九章&#xff0c;主要学习内容为&#xff1a;有源低通/高通滤波器(一阶滤波器Sallen-Key滤波器高下降…

极速体验媲美GPT4V的国产开源视觉大模型CogVLM2(赠书)

大家好&#xff0c;我是每天分享AI应用的萤火君&#xff01; 文末赠书 CogVLM2是一款视觉语言模型&#xff08;Visual Language Model&#xff09;&#xff0c;由智谱AI和清华KEG潜心打磨。这款模型是CogVLM的升级版本&#xff0c;支持高达 1344 * 1344 的图像分辨率&#xf…

matter的Commissioning(入网过程)整体流程、加密方式、通信信息结构

在Matter协议中&#xff0c;**控制器负责将新设备加入网络&#xff08;commissioning&#xff09;**的整个流程&#xff0c;这一过程包括设备的发现、验证、授权、加入Fabric&#xff0c;以及最终建立数据通信的步骤。配网完成后的数据通信过程同样遵循严格的加密方式&#xff…

冠军品质!凯伦股份又一产品荣获省级制造业单项冠军

近日&#xff0c;唐山凯伦新材料科技有限公司获得河北省工业和信息化厅颁发的“河北省制造业单项冠军”证书&#xff0c;公司生产的“抗流挂聚氨酯防水涂料”获得该项省级荣誉。 据了解&#xff0c;省级制造业单项冠军代表着河北省细分行业最高的发展水平、最强的市场实力&…

HarmonyOS开发实战( Beta5版)Stack组件实现滚动吸顶效果实现案例

介绍 本示例介绍运用Stack组件以构建多层次堆叠的视觉效果。通过绑定Scroll组件的onScroll滚动事件回调函数&#xff0c;精准捕获滚动动作的发生。当滚动时&#xff0c;实时地调节组件的透明度、高度等属性&#xff0c;从而成功实现了嵌套滚动效果、透明度动态变化以及平滑的组…

kubesphere缩短node notready后pod驱逐时长(pod-eviction-timeout无效)

本文在测试k8s高可用时会关闭某个node节点&#xff0c;然后看某些pod节点是否主动漂移到其他node节点&#xff0c;测试确实可以&#xff0c;但是时长为5分钟&#xff0c;这个时间长度项目上是不能接受的&#xff0c;比如尝试缩短这个时长&#xff0c;搜到更多的配置都是pod-evi…

Docker打包镜像

Docker打包镜像 前置工作 1.虚拟机中配置好docker环境&#xff0c;并导入nginx&#xff0c;mysql&#xff0c;jdk的镜像 2.下载docker for windows 用idea打包镜像和创建容器需要这个东西支持 下载安装包后执行&#xff0c;无脑回车即可 3.idea中配置docker连接 完成配置后&…

One-Shot Imitation Learning

发表时间&#xff1a;NIPS2017 论文链接&#xff1a;https://readpaper.com/pdf-annotate/note?pdfId4557560538297540609&noteId2424799047081637376 作者单位&#xff1a;Berkeley AI Research Lab, Work done while at OpenAI Yan Duan† , Marcin Andrychowicz ‡ ,…

上门家政系统小程序开发产品类目分析

在当今数字化时代&#xff0c;上门家政服务系统作为连接用户与家政服务供应商的重要桥梁&#xff0c;正逐步渗透到人们的日常生活中&#xff0c;为繁忙的现代人提供了极大的便利。作为一名程序员&#xff0c;我将从产品类目、技术实现及市场影响等角度&#xff0c;对上门家政系…

AI写的论文查重率高吗?分享6款实测AI论文生成免费网站

在当今学术研究和论文写作领域&#xff0c;AI技术的迅猛发展为研究人员提供了极大的便利。特别是AI论文自动生成助手&#xff0c;它们不仅能够提高写作效率&#xff0c;还能帮助生成高质量的论文内容。以下是六款经过实测且免费的AI论文生成网站推荐&#xff1a; 一、千笔-AIP…

linux离线安装nacos

1、打开 Nacos-GitHub &#xff0c;点击 Release 可以看到 Nacos 的各版本跟新信息和安装包之类的 点击下载nacos-server-2.4.1.tar.gz&#xff0c;在linux创建nacos文件夹&#xff0c;把下载好的文件上传到nacos文件夹&#xff0c;并通过命令解压:tar -zxvf nacos-server-2.4.…

CUDA统一内存:简化GPU编程的内存管理

CUDA统一内存&#xff1a;简化GPU编程的内存管理 在现代GPU编程中&#xff0c;内存管理一直是开发者面临的一个重要挑战。特别是在使用NVIDIA CUDA进行高性能计算时&#xff0c;如何在CPU和GPU之间高效地传输数据、以及如何管理这些数据的生命周期&#xff0c;都是影响程序性能…

ABAP 调试宏DEFINE

文章目录 调试过程完整程序 调试过程 完整程序 REPORT Z_TEST_DEFINE.TYPES: BEGIN OF GTY_DATA,NAME TYPE STRING,AGE TYPE I,END OF GTY_DATA. DATA: GS_DATA TYPE GTY_DATA,GT_DATA TYPE TABLE OF GTY_DATA. DEFINE D_TEST.GS_DATA-NAME &1.GS_DATA-AGE &2.APPE…

Linux基础网络编程-Socket通信

本文使用C语言&#xff0c;在Centos实现Socket两种通信类型(TCP和UDP) 文章目录 一、安装gcc二、使用TCP协议&#xff0c;实现Socket(SOCKE_STREAM)流式通信1. 编写TCP_server.c函数和参数解释 2.编写TCP_client.c函数和参数解释 3. 编译并运行上述两个文件3.1 编译3.2 运行(启…

TVS汽车级 二极管SZESD9B5.0ST5G你了解多少?专为汽车电子系统设计的瞬态电压抑制二极管

SZESD9B5.0ST5G功能特性分析&#xff1a; SZESD9B5.0ST5G用于保护电压敏感型ESD组件。优异的关断能力&#xff0c;低泄漏&#xff0c;快速响应时间为以下设计提供一流的 ESD 保护。由于体积小&#xff0c;适合在手机、MP3播放器、数码相机和许多其他便携式设备板空间非常宝贵的…

2024高教社杯数学建模竞赛解题思路

高教社杯数学建模竞赛解题思路&#xff1a;独家出版&#xff0c;思路解析模型代码结果可视化。 A题思路及程序链接&#xff1a;https://mbd.pub/o/bread/ZpqblJZs B题思路及程序链接&#xff1a;https://mbd.pub/o/bread/ZpqblJZx D题思路及程序链接&#xff1a;https://mbd.pu…

制造业中工艺路线(工序)与产线(工作中心)关系

一.工艺路线与生产线是数字孪生中的虚实关系&#xff1a; 1.工艺路线为虚&#xff0c;生产线体为实&#xff1b; 2.工艺路线指导生产线的生产组织&#xff0c;生产线承载工艺路线的能力&#xff0c;把虚拟的生产信息流变成真实的产流。 二.工艺路线与生产线是数字孪生中互为“…

nginx中如何设置gzip

前言 Nginx通过配置gzip压缩可以提升网站整体速度 Nginx的gzip功能是用于压缩HTTP响应内容的功能。当启用gzip时&#xff0c;在发送给客户端之前&#xff0c;Nginx会将响应内容压缩以减小其大小。这样可以减少数据传输的带宽消耗和响应时间&#xff0c;提高网站的性能和速度。…