深度学习快速入门----Pytorch 系列2

news2025/1/16 15:03:41

注:参考B站‘小土堆’视频教程

视频链接:【PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

上一篇:深度学习快速入门----Pytorch 1


文章目录

    • 八、神经网络--非线性激活
    • 九、神经网络--线性层及其他层介绍
    • 十、神经网络--全连接层Sequential
    • 十一、损失函数与反向传播
    • 十二、优化器
    • 十三、现有网络模型的使用及修改
    • 十四、网络模型的保存与读取

八、神经网络–非线性激活

1、ReLU
在这里插入图片描述

在这里插入图片描述
2、Sigmoid
在这里插入图片描述

在这里插入图片描述
使用sigmoid函数:

import torch
import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

input = torch.tensor([[1, -0.5],
                      [-1, 3]])

input = torch.reshape(input, (-1, 1, 2, 2))
print(input.shape)

dataset = torchvision.datasets.CIFAR10("dataset", train=False, download=True,
                                       transform=torchvision.transforms.ToTensor())

dataloader = DataLoader(dataset, batch_size=64)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.relu1 = ReLU()
        self.sigmoid1 = Sigmoid()

    def forward(self, input):
        output = self.sigmoid1(input)
        return output

tudui = Tudui()

writer = SummaryWriter("logs_relu")
step = 0
for data in dataloader:
    imgs, targets = data
    writer.add_images("input", imgs, global_step=step)
    output = tudui(imgs)
    writer.add_images("output", output, step)
    step += 1

writer.close()

运行结果:
在这里插入图片描述
在这里插入图片描述

九、神经网络–线性层及其他层介绍

import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=64)

# torch.Size([64,3,32,32]) --> torch.Size([1,1,1,196608]) --> torch.Size([1,1,1,10])

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.linear1 = Linear(196608, 10)

    def forward(self, input):
        output = self.linear1(input)
        return output

tudui = Tudui()

for data in dataloader:
    imgs, targets = data
    print(imgs.shape)
    # output = torch.reshape(imgs,(1,1,1,-1))
    # 将图片展平
    output = torch.flatten(imgs)
    print(output.shape)
    output = tudui(output)
    print(output.shape)

运行结果:

torch.Size([64,3,32,32]) --> torch.Size([1,1,1,196608]) --> torch.Size([1,1,1,10])

在这里插入图片描述

十、神经网络–全连接层Sequential

Pytorch官方文档—Conv2d
在这里插入图片描述
在这里插入图片描述

***cifar10 model structure***

在这里插入图片描述

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            # in_channels=3, out_channels=32, kernel_size=5, padding需要根据公式计算
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            # 展平
            Flatten(),
            # 全连接
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

tudui = Tudui()
print(tudui)
input = torch.ones((64, 3, 32, 32))
output = tudui(input)
print(output.shape)

writer = SummaryWriter("logs_seq")
writer.add_graph(tudui, input)
writer.close()

运行结果:

在这里插入图片描述
可视化结果:
在这里插入图片描述

十一、损失函数与反向传播

L1Loss & MSELoss & CrossEntropyLoss

import torch
from torch.nn import L1Loss
from torch import nn

inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)

inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))

loss = L1Loss(reduction='sum')
result = loss(inputs, targets)

# 平方差
loss_mse = nn.MSELoss()
result_mse = loss_mse(inputs, targets)

print(result)
print(result_mse)


# 计算交叉熵
x = torch.tensor([0.1, 0.2, 0.3])
y = torch.tensor([1])
x = torch.reshape(x, (1, 3))
loss_cross = nn.CrossEntropyLoss()
result_cross = loss_cross(x, y)
print(result_cross)

运行结果:

在这里插入图片描述

损失函数作用: 1、计算实际输出与目标之间的差距 2、为我们更新输出提供一定的依据(反向传播)

import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    # print(outputs)
    # print(targets)
    result_loss = loss(outputs, targets)
    print(result_loss)

outputs 与 targets 输出:

在这里插入图片描述
result_loss:

在这里插入图片描述

十二、优化器

import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)

dataloader = DataLoader(dataset, batch_size=1)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


loss = nn.CrossEntropyLoss()
tudui = Tudui()
# 优化器
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)

for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        # 将梯度清0
        optim.zero_grad()
        result_loss.backward()
        # 对网络进行调优
        optim.step()
        running_loss = running_loss + result_loss
    print(running_loss)

在这里插入图片描述

十三、现有网络模型的使用及修改

VGG16输出有1000个类别

在这里插入图片描述

VGG网络用ImageNet数据集来训练,但是该数据集太大;改成用cifar10数据集来进行,于是需要改动VGG网络结构

import torchvision

# ImageNet数据集太大
# train_data = torchvision.datasets.ImageNet("../data_image_net", split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor())
from torch import nn
# pretrained=False表示使用初始化的参数,没有经过数据集训练
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

# print(vgg16_true)

train_data = torchvision.datasets.CIFAR10('dataset', train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)

# 方法1、修改网络模型
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
# print(vgg16_true)

# 方法2、直接改为输出10类
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

方法1、修改网络模型

在这里插入图片描述
方法2、直接改为输出10类

在这里插入图片描述

十四、网络模型的保存与读取

# model_save.py
import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

# 保存方式2,模型参数(官方推荐)将vgg16的状态保存为字典形式
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# 陷阱
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

tudui = Tudui()
torch.save(tudui, "tudui_method1.pth")
# model_load.py
import torch
from model_save import *
# 方式1-》保存方式1,加载模型
import torchvision
from torch import nn

model = torch.load("vgg16_method1.pth")
# print(model)

# 方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# model = torch.load("vgg16_method2.pth")
# print(model)

# 陷阱1
# class Tudui(nn.Module):
#     def __init__(self):
#         super(Tudui, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
#
#     def forward(self, x):
#         x = self.conv1(x)
#         return x

# 采用方法1,需要使程序能够访问到自定义的模型,不然会报错
model = torch.load('tudui_method1.pth')
print(model)

保存方式1:模型结构+模型参数

在这里插入图片描述

保存方式2:模型参数(官方推荐)将vgg16的状态保存为字典形式

在这里插入图片描述

自定义的网络模型:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/44836.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

作为IT行业过来人,我有3个重要建议给后辈程序员!

见字如面,我是军哥!作为一名 40 岁的 IT 老兵,我在年轻时踩了不少坑,我总结了其中最重要的 3 个并一次性分享给你,文章不长,你一定要看完哈~1、重视基础还不够,还要注重技术广度和深…

第2-4-8章 规则引擎Drools实战(1)-个人所得税计算器

文章目录9. Drools实战9.1 个人所得税计算器9.1.1 名词解释9.1.2 计算规则9.1.2.1 新税制主要有哪些变化?9.1.2.2 资较高人员本次个税较少,可能到年底扣税增加?9.1.2.3 关于年度汇算清缴9.1.2.4 个人所得税预扣率表(居民个人工资、…

科教导刊杂志科教导刊杂志社科教导刊编辑部2022年第27期目录

前沿视角《科教导刊》投稿:cn7kantougao163.com 新时代研究生教育质量评价指标体系的框架构建 李军伟;赵永克;杨丹; 1-3 基于现代学徒制的“多主体、双标准、五维度”人才培养质量评价体系构建 汪帆;刘严; 4-6 高教论坛 新工科背景下地方性院校第二课堂…

【云原生】Docker容器服务更新与发现之consul

内容预知 1.consul的相关知识 1.1 什么是注册与发现 1.2 什么是consul 1.3 zookeeper和consul的区别 2. consul 部署 2.1 部署consul服务器 2.2 registrator服务器 3.consul-template 的引入 3.1 consul-template的作用 3.2 consul-template的具体部署运用 &…

微信开发者工具C盘占用大的问题

将User Data 下的文件迁移到其他盘,比如 D盘,E盘,F盘 步骤如下: 1.找到微信开发者工具C盘所在的缓存目录,一般为 C:\Users\ 你的用户名\AppData\Local\微信开发者工具\User Data 将里面的内容全部剪切到其它盘符&…

从鹅厂实例出发!分析Go Channel底层原理

本文是基于Go1.18.1源码的学习笔记。Channel的底层源码从Go1.14到现在的Go1.19之间几乎没有变化,这也是Go最早引入的组件之一,体现了Go并发思想:Do not communicate by sharing memory; instead, share memory by communicating.不要通过共享…

Playwright 简明入门教程:录制自动化测试用例,结合 Docker 使用

本篇文章聊聊如何使用 Playwright 进行测试用例的录制生成,以及如何在Docker 容器运行测试用例,或许是网上最简单的入门教程。 写在前面 Playwright 是微软出品的 Web 自动化测试工具和框架,和 Google Puppeteer 有着千丝万缕的关系。前一阵…

[附源码]计算机毕业设计springboot儿童早教课程管理系统论文2022

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

如何治理 Electron 版本淘宝直播应用崩溃?

经过几个月的努力,基于Electron框架开发的新版淘宝直播推流软件终于上线了。随之而来的就是线上用户反馈的各种问题,其中最影响用户体验的当属应用崩溃问题了。当应用程序出现未 catch 的异常时就会发生崩溃,本文介绍了客户端应用崩溃的处理流…

javaSE - Arrays - 数组的定义与使用

一、数组基本用法 1.1、什么是数组 数组本质上就是让我们能 “批量” 创建相同类型的变量 也可以说是存储一组相同数据类型的数据的集合 如: 如果需要表示两个数据, 那么直接创建两个变量即可 int a; int b 如果需要表示五个数据, 那么可以创建五个变量 int a1; int a2; int …

uni-app入门:自定义组件实现父子组件参数传递

1.属性绑定:父组件传递参数到子组件 2.事件绑定:子组件传递参数到父组件 3.获取组件对象实例:父组件获取子组件实例对象进行参数传递 1.属性绑定:父组件传递参数到子组件首先交代一下基本的项目信息:主页面为index.wxml,创建test子组件,文件目录:component/test/test…

惠柏新材创业板IPO过会:上半年营收9.3亿 拟募资3.4亿

雷递网 雷建平 11月28日惠柏新材料科技(上海)股份有限公司(简称:“惠柏新材”)日前IPO过会,准备在深交所创业板上市。惠柏新材计划募资3.42亿元,其中,1.8亿元用于上海帝福3.7万吨纤维…

xss-labs/level9

这一关界面感觉跟上一关很像 所以我们注入上一关的为编码的答案 javascript:alert(xss) 没能弹窗 查看源代码 他说我输入的链接不合法 我压根没有输入链接 我觉得后台应该是做了一个条件的判断 应该是要有链接才会在第二处输出点回显我们的输入 根据上面的猜测 我们构造如下…

网络的根基

hi 大家好,上个周末带小伙伴,一起复习了一遍网络协议,对网络协议的核心知识进行梳理,希望大家早日掌握这些核心知识,打造自己坚实的基础,为自己目标慢慢积累,等到自己春天的到来。详细点击查看…

设计模式学习笔记

文章目录23种设计模式学习笔记1.创建型模式1 单例模式2 工厂模式3 抽象工厂模式4 建造者模式5 原型模式2.结构型模式6 代理模式7 适配器模式8 桥接模式9 装饰模式10 外观模式11 组合模式12 享元模式3.行为型模式13 策略模式14 观察者模式15 责任链模式16 模板模式17 状态模式18…

Maven程序 tomcat插件安装与web工程启动

第一步:在mvnrepository库中找到tomcat插件 1.打开mvnrepository官网,搜索“tomcat maven”向下滑动找到“org.apache.tomcat.maven ”点进去 2.在这里点第一个“Apache Tomcat Maven Plugin :: Tomcat 7.x” 3.在这里选择2.1版本相对来说比较稳定 4.复…

jsp393学生宿舍管理系统mysql

两个权限 管理员和 学生 1. 学生信息管理 添加学生信息(学生号,姓名 院系 班级入学日期 )修改学生信息 学生退宿舍(可以删除指定的学生也可以成批删除) 2. 宿舍信息管理 宿舍的基本信息(公寓数 宿舍…

DSP-数字滤波器的结构

目录 基本结构块: 例: 一些特殊结构: 无延时回路问题: 规范和非规范结构: 等效结构: FIR滤波器的基本结构 : 直接型: 级联型: 多相型: 线性相位FIR结构: 基本IIR滤波器结…

模拟可执行的四旋翼模型——在未知环境下运动规划应用研究(Matlab代码实现)

1 概述 无人机现在利用最佳搜索策略,使用PRISM模型检查器生成,以寻找目标。本文设计并编写了一种对抗性模式搜索算法来比较性能。 四旋翼无人机由于具有可悬停,可垂直起降,在设计速度范围内向任意方向飞行的运动特点,以及结构简单,构造容易,成本低廉等…

解析异常SAXParseExceptionis如何处理

1.问题背景 今天一位同事找我寻求帮助,售后向他反馈的问题不知道如何排查,他尝试分析服务器端日志文件, 但是日志文件中并没有报错信息,查询源码时候发现,报错信息被try...catch处理 2.排查过程 顺便提一句&#xff…