Neural Network学习笔记3

news2024/11/18 5:56:59

损失函数和反向传播网络

在进行损失函数计算后,再进行.backward()反向传播。

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10("dataset_transform",train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=1)


class Zrf(nn.Module):
    def __init__(self):
        super(Zrf, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

# 分类问题适合用交叉熵
loss = nn.CrossEntropyLoss()
zrf = Zrf()
for data in dataloader:
    imgs, targets = data
    outputs = zrf(imgs)
    # print(outputs)
    # print(targets)
    result_loss = loss(outputs, targets)
    result_loss.backward()

优化器

  以Adadelta为例,torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)

params: 模型的参数,让优化器知道我们的模型长什么样子。

lr: Learning rate, 学习率

其他的参数可以采用默认,并且优化算法不同,参数也会有很大不同。

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10("dataset_transform",train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=1)


class Zrf(nn.Module):
    def __init__(self):
        super(Zrf, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

# 分类问题适合用交叉熵
loss = nn.CrossEntropyLoss()
zrf = Zrf()
# 设置优化器
# SGD 随机梯度下降
# lr的设不可以太大也不可以太小,一般情况下我们采用训练开始时lr大,之后的训练中lr小的方式
optim = torch.optim.SGD(zrf.parameters(), lr=0.01)
for epoch in range(20) :
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = zrf(imgs)
        result_loss = loss(outputs, targets)
        # 在进行反向传播来计算梯度时,要先将梯度置为0,防止之前计算出来的梯度的影响
        optim.zero_grad()
        result_loss.backward()
        # 根据梯度对卷积核参数进行调优
        optim.step()
        running_loss = running_loss + result_loss
    print(running_loss)

现有网络模型的使用及修改

torchvision.modles中的VGG为例。VGG常用VGG16和VGG19。

weights: 可选,要使用的预训练权重。默认情况下,不使用预先训练的权重。

progress: true时,会展示一个进度条

此外,pytorch在下载模型时会把模型下载到C盘,下面语句可以修改下载位置:

os.environ['TORCH_HOME'] = '/path/to/torch_home'

import torchvision
from torch import nn
from torchvision.models import VGG16_Weights
import os


# train_data = torchvision.datasets.ImageNet(root="data_image_net", split="train", download=True,
#                                            transform=torchvision.transforms.ToTensor())

# 最新版本默认是没有预训练的,需要使用预训练设置weights='DEFAULT'

os.environ['TORCH_HOME'] = '/path/to/torch_home'

vgg16_noPre = torchvision.models.vgg16()
vgg16_pre = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)
print(vgg16_pre)

# 微调网络模型
train_data = torchvision.datasets.CIFAR10("dataset_transform",train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 添加一层
# 在vgg整体层面上加
vgg16_pre.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_pre)
# 只在某一部分加(classifier部分)
vgg16_pre.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_pre)

# 修改
vgg16_noPre.classifier[6] = nn.Linear(4096, 10)
print(vgg16_noPre)

网络模型的保存与读取

保存方法演示:model_save.py

import torch
import torchvision
from torch import nn

# 使用未经过训练的,初始化的参数
vgg16 = torchvision.models.vgg16()
# 保存方式1
# 这样不仅保存了网络模型的结构,也保存了网络模型的参数
torch.save(vgg16, "vgg16_method1.pth")

# 保存方式2
# 不保存网络的结构,只是把网络的参数保存成数据字典,也就是保存了网络的状态
# (官方推荐!!!)占用空间更小
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# 陷阱1
class Zrf(nn.Module):
    def __init__(self):
        super(Zrf, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

zrf = Zrf()
torch.save(zrf, "zrf_method1.pth")

读取方法演示:model_load.py

import torch
import torchvision
from torch import nn

# 方式1 ----> 对应保存方式1,来加载模型
model = torch.load("vgg16_method1.pth")
# print(model)

# 方式2 ---> 对应保存方式2
vgg16 = torchvision.models.vgg16() # 新建网络模型结构
model = torch.load("vgg16_method2.pth")
# print(model)
vgg16.load_state_dict(model) # 加载网络模型的状态
print(vgg16)

# 陷阱1
# 在使用自己创建的网络时,注意要有网络的这个类在本文件(程序可以访问到网络),只是不需要zrf = Zrf()再创建网络了
# 方法1
# class Zrf(nn.Module):
#     def __init__(self):
#         super(Zrf, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
#
#     def forward(self, x):
#         x = self.conv1(x)
#         return x
# 方法2
from model_save import *

model = torch.load("zrf_method1.pth")
print(model)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/611509.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

外贸人注意!这件事不能再对客户承诺了!

你还在配合客户低开发票吗? 本文目录: 什么是低开发票? 低开发票有什么风险? 哪些国家客户喜欢低开发票? 哪些国家低开发票会被抓? 很多人认为客户索要低开发票偷税漏税是人之常情。为了加强合作关系&a…

2015 年一月联考逻辑真题

2015 年一月联考逻辑真题 真题(2015-26) 26.晴朗的夜晚我们可以看到满天星斗,其中有些是自身发光的恒星,有些是自身不发光但可以反射附近恒星光的行星。恒星尽管遥远,但是有些可以被现有的光学望远镜“看到”。和恒星不…

YOLOv8训练参数详解

全部参数表 首先罗列一下官网提供的全部参数。 1. model ✰✰✰✰✰ model: 模型文件的路径。这个参数指定了所使用的模型文件的位置,例如 yolov8n.pt 或 yolov8n.yaml。 选择.pt和.yaml的区别 若我们选择 yolov8n.pt这种.pt类型的文件,其实里面是包…

从0到1实现IOC

一、什么是 IOC 我们先来看看spring官网对IOC的定义: IoC is also known as dependency injection (DI). It is a process whereby objects define their dependencies, that is, the other objects they work with, only through constructor arguments, argumen…

阿里工程师手打的MySQL学习笔记,轻松拿捏MySQL

我们都知道阿里经历过几次重大的技术变革,其中就包括放弃Oracle和Hadoop,全面拥抱MySQL。 讲道理其实靠OracleHadoop也能撑一撑,为啥偏得变。当然肯定不是因为阿里爸爸没钱,而是即便再花个几千万还是不能彻底解决问题&#xff0c…

压力测试遭遇大量TIME_WITE之后(这样解决)

前言:http协议是互联网中最常使用的应用层协议,它的绝大多数实现是基于TCP协议的。 目录 一 问题描述 二 问题跟踪 三 跟进分析 四 解决方法 一、问题描述 某天,在对一个提供http接口的后台服务进行压力测试过程中,我们设定了…

科班出身又如何?这类人连外包都不要...

在软件测试这个领域,多数人对于外包公司是有戴有色眼镜看待的,外包测试员往往会处于一个比较尴尬的局面。主要是由于雇主公司比较核心或者底层的东西是不会让外包人员作的。外包人员一般做的都是“边角料”。而这些活往往对于技术要求不高,所…

python接口自动化 —— 接口测试工具介绍(详解)

简介 “工欲善其事必先利其器”,通过前边几篇文章的介绍,大家大致对接口有了进一步的认识。那么接下来让我们看看接口测试的工具有哪些。 目前,市场上有很多支持接口测试的工具。利用工具进行接口测试,能够提供测试效率。例如&…

运维小白必学篇之基础篇第十集:系统启动流程实验

系统启动流程实验 实验作业: 1、现在有一台服务器因为长时间不使用,管理员密码已经丢失,现在想要启动该服务器,如何操作 第一步:开启系统,在GRUB界面按E进行编辑 在linux16行中centos/swap后添加 rd.break参…

【AI】InsCode AI 创作助手 --使用心得

CSDN AI写作助手上线了!InsCode AI 创作助手不仅能够帮助用户高效创作文章,而且能够作为对话式AI回答你想知道的问题。成倍提高生产力!欢迎大家使用新功能后分享自己的使用心得与建议! 文章目录 一、你平时会使用这类AI工具吗&am…

navicat与SQLyog的区别

在之前的学习中由于先学的SQL Server,后来才学的MySQL,导致我刚学习的时候冥冥之中感觉到那有点不对劲,但是又说不出来。通过进行深入的学习解除到了Navicat Premium和SQLyog这两个工具,才让我明白了MySQL与之前学习的内容是有所出…

影响代理ip纯净度的原因及目标网站如何识别代理ip

网络上代理ip很多,但真正可以为我们所用的大部分都是付费ip,那为什么免费ip不能为我们所用呢?下面我们就纯净度和目标网站是如何识别代理ip来分析一下。 一、纯净度 ip纯净度是什么意思呢?简单一点开始就是指使用这个ip的人少&…

教你 5 分钟快速部署开源网关

最近在研究开源网关,找了一圈,发现这个叫 Apinto 的开源网关符合我的需求,下面我将演示如何部署这样一个开源网关。 Apinto功能架构图 开始部署 部署资源 设备推荐配置设备数量部署对象4核8G,250G磁盘空间,2.5GHz1控制…

SQL-基础

SQL-小基础 1 SQL简介 英文:Structured Query Language,简称 SQL结构化查询语言,一门操作关系型数据库的编程语言定义操作所有关系型数据库的统一标准对于同一个需求,每一种数据库操作的方式可能会存在一些不一样的地方&#xff…

RPC接口测试技术-Tcp 协议的接口测试

【摘要】 首先明确 Tcp 的概念,针对 Tcp 协议进行接口测试,是指基于 Tcp 协议的上层协议比如 Http ,串口,网口, Socket 等。这些协议与 Http 测试方法类似(具体查看接口自动化测试章节)&#xf…

Ingress Controller高可用部署

Ingress-controller 高可用解说 Ingress Controller 是集群流量的接入层,对它做高可用非常重要,可以基于 keepalive 实现 nginx-ingress-controller 高可用,具体实现如下: Ingress-controller 根据 Deployment nodeSeletorpod 反…

揭秘广告投放的9大关键环节,了解真相让你成为广告投放高手!

正式开始本章的内容之前,先来简单复习一下上一章的主要内容: 核心要点1:广告投放的意义主要有三点:传播品牌、宣传产品、促成转化; 核心要点2:广告投放的主要流程有这样 9 个阶段: 本章我们以…

OpenCV(图像处理)-基于Python-图像的基本变换-平移-翻转-仿射变换-透视变换

1. 概述2. 接口介绍resize()flip()rotate()仿射变换warpAffine()getRotationMatrix2D()-变换矩阵1getAffineTransform()-变换矩阵2 透视变换warpPerspective()getPerspectiveTransform() 1. 概述 为了方便开发人员的操作,OpenCV还提供了一些图像变换的API&#xff…

Qt 去除标题栏不同方法不同平台差异探究

Qt 版本:Qt 6.5.0 Windows 11 当窗体为QWidget时 setWindowFlags(Qt::FramelessWindowHint);// 窗口不能缩放setWindowFlags(Qt::CustomizeWindowHint);// 窗口支持缩放,且窗体四角为圆角CustomizeWindowHintFramelessWindowHint 当窗体为QMainWindow时…

什么是第三方付费模式?用“尤伯罗斯模式”让你的商品由别人买单

什么是第三方付费模式?用“尤伯罗斯模式”让你的商品由别人买单 微三云营销策划胡总监给大家介绍一下,什么是第三方付费模式? 当同质化产品日趋严重的时候,改变客户接受产品价值及服务的模式创新就是商业模式的创新,以…