【pytorch框架】对模型知识的基本了解

news2025/1/23 15:04:47

文章目录

    • TensorBoard的使用
      • 1、TensorBoard启动:
      • 2、使用TensorBoard查看一张图片
      • 3、transforms的使用
    • pytorch框架基础知识
      • 1 nn.module的使用
      • 2 nn.conv2d的使用
      • 3、池化(MaxPool2d)
      • 4 非线性激活
      • 5 线性层
      • 6 Sequential的使用
      • 7 损失函数与反向传播
      • 8 优化器
      • 9 对现有网络的使用和修改
      • 10 网络模型的保存与读取

TensorBoard的使用

1、TensorBoard启动:

在Terminal终端命令中输入:

tensorboard --logdir=logs #logs为创建的文件名

2、使用TensorBoard查看一张图片

writer=SummaryWriter("../logs")
image_path=r'F:\image\1.jpg'
img_PIL=Image.open(image_path)
image_array=np.array(img_PIL)
writer.add_image('test',image_array,1,dataformats='HWC')
writer.close()

3、transforms的使用

作用:使PIL Image 或者np ——》tensor

imgae_path=r'F:\image\1.jpg'
img=Image.open(img_path)
tensor_trans=transsforms.ToTensor()  #相当于创建一个工具
tensor_img=tensor_trans(img) #img转化成tensor模式

同理,ToPILIMage是为了tensor 或者 ndarray =》Image

pytorch框架基础知识

1 nn.module的使用

目的:给所有网络提供基本骨架
在这里插入图片描述

from torch import nn
class aiy(nn.Module):

    def  __init__(self):
        super().__init__()
        
    def forward(sel,input):
        output=input+1
        return output

aiy=aiy()
# x=torch.tensor(1.0)
x=1
output=aiy(x)
print(output)
'''
2
'''

2 nn.conv2d的使用

参数代码解释如下:
在这里插入图片描述
示例:输入一个5x5的矩阵,和一个3x3的卷积核做卷积操作

import torch
import torch.nn.functional as F
input=torch.tensor([[1,2,0,3,1],
                    [0,1,2,3,1],
                    [1,2,1,0,0],
                    [5,2,3,1,1],
                    [2,1,0,1,1]])
input=input.reshape([1,1,5,5])

kears=torch.tensor([[1,2,1],
                    [0,1,0],
                    [2,1,0]])
kears=kears.reshape([1,1,3,3])

output=F.conv2d(input,kears,stride=1)

print(output)
print(output.shape)
'''
tensor([[[[10, 12, 12], 
          [18, 16, 16],
          [13,  9,  3]]]])
torch.Size([1, 1, 3, 3])
'''

若是输入的卷积核的数量有两个,则得到的output也是两个
在这里插入图片描述
示例:借用CIFAR10数据集,用自定义的网络模型做一次卷积操作,然后用tensorboard查看卷积之后的结果。
这里需要注意的是,经过卷积得到的大小是[64,6,30,30],而图片的通道一般都是3通道的,6通道的图片不知道怎么显示,需要使用reshpae重新改变矩阵的大小。

output=output.reshape([-1,3,30,30]) #-1自动计算剩余的值,后面[3,30,30]改成指定大小

示例代码:

import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import Conv2d
from torch.utils.tensorboard import SummaryWriter

#数据准备
dataset=torchvision.datasets.CIFAR10("./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)

dataloader=DataLoader(dataset,batch_size=64)

#自定义网络模型
class aiy(nn.Module):
    def __init__(self):
        super(aiy, self).__init__()
        #卷积运算
        self.conv1=Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)

    def forward(self,x):
        x=self.conv1(x)
        return x

aiy=aiy()
# print(aiy)
step=0
writer=SummaryWriter("../log")

for data in dataloader:
    img,targets=data
    output=aiy(img)
    # print(img.shape)
    #torch.Size([64, 3, 32, 32]
    # print(output.shape)
    #torch.Size([64, 6, 30, 30])
    #因为图片的通道是3,需要改变矩阵的大小
    # output=output.reshape([-1,3,30,30])
    writer.add_images("input",img,step)

    output=torch.reshape(output,(-1,3,30,30))
    writer.add_images("output", output, step)
    # print(output.shape)
    step=step+1


print(step)
writer.close()

在这里插入图片描述

3、池化(MaxPool2d)

目的:降采样,大幅减少网络的参数量,同时保留图像数据的特征。
需要注意的是: 池化不改变通道数

池化参数
在这里插入图片描述
数组演示示例:
在这里插入图片描述

input=torch.tensor([[1,2,0,3,1],
                    [0,1,2,3,1],
                    [1,2,1,0,0],
                    [5,2,3,1,1],
                    [2,1,0,1,1]],dtype=float)
input=torch.reshape(input,(-1,1,5,5))

output=aiy(input)
print(output.shape)
'''
ceil_mode=True:
tensor([[[[2., 3.],
          [5., 1.]]]], dtype=torch.float64)

ceil_mode=False:
tensor([[[[2.]]]], dtype=torch.float64)
'''

示例:同样,借用CIFAR10数据集,用自定义的网络模型做一次池化操作,然后用tensorboard查看卷积之后的结果。

# -*- coding: utf-8 -*-
# Auter:我菜就爱学

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d

#带入数组
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

class aiy(nn.Module):
    def __init__(self):
        super(aiy, self).__init__()
        self.maxpool1=MaxPool2d(kernel_size=3,ceil_mode=False)

    def forward(self,input):
        output=self.maxpool1(input)
        return output

aiy=aiy()

#将池化层用数据集测试
dataset=torchvision.datasets.CIFAR10('./data',train=False,transform=torchvision.transforms.ToTensor(),download=True)

dataloader=DataLoader(dataset,batch_size=64)

step=0
writer=SummaryWriter("../logmaxpool")

for data in dataloader:
    img,target=data
    writer.add_images("input",img,step)
    output=aiy(img)
    writer.add_images("output",output,step)
    step=step+1

有点像打马赛克
在这里插入图片描述

4 非线性激活

作用:提高泛化能力,引入非线性特征
在这里插入图片描述

ReLu(input,inplace=True)  
=>表示原input替换input
out=ReLu(input,inplace=False)
=>表示原input被out替换

5 线性层

在这里插入图片描述

6 Sequential的使用

作用:可以简化自己搭建的网络模型
示例:参考CIFAR10的网络模型结构,创建一个网络。

在这里插入图片描述

# -*- coding: utf-8 -*-
# Auter:我菜就爱学

import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.tensorboard import SummaryWriter


class Aiy(nn.Module):
    def __init__(self):
        super(Aiy, self).__init__()
        self.model=Sequential(
            Conv2d(3,32,5,padding=2,stride=1),
            MaxPool2d(kernel_size=2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )

    def forward(self,input):
        output=self.model(input)
        return output

aiy=Aiy()

# print(aiy)

input=torch.ones((64,3,32,32))

output=aiy(input)

print(output.shape)

使用tensorboard中的命令可以查看网络模型结构

writer=SummaryWriter('../logmodel')

writer.add_graph(aiy,input)

writer.close()

在这里插入图片描述

7 损失函数与反向传播

作用:

  • 计算处实际输出与目标之间的差距
  • 更新输出提供一定的依据

通过小土堆举的示例可以很好的理解损失函数
在这里插入图片描述

说明:假设一张试卷满分是100分,其中选择30,填空20,解答50.第一次我们得到的结果是:选择10,填空10,解答20.第一次损失值是60.
然后通过不断的训练,让选择提高到20,填空提高20,解答提高到40,这个时候与满分差距20,损失值也就越来越小。

# -*- coding: utf-8 -*-
# Auter:我菜就爱学
import torch
from torch.nn import L1Loss

input=torch.tensor([1,2,3],dtype=torch.float32)
input=torch.reshape(input,(1,1,1,3))
target=torch.tensor([1,2,5],dtype=torch.float32)
target=torch.reshape(target,(1,1,1,3))

#设置一个损失函数
loss=L1Loss(reduction='sum')
output=loss(input,target)
print(output)
'''
tensor(2.)
'''

8 优化器

优化器参数解释:
在这里插入图片描述

for epoch in range (20):
    sum_loss=0.0
    for data in dataloader:
        imgs,targets=data
        output=aiy(imgs)
        result_loss=loss(output,targets)
        optim.zero_grad()
        result_loss.backward() #反向传播,更新对应的梯度
        optim.step()  #调整更新的参数
        sum_loss=sum_loss+result_loss
    print(sum_loss)

下面是对优化器中的交叉熵的解释:
在这里插入图片描述

9 对现有网络的使用和修改

  • 下载现有网络,并使用数据集更新好的参数
vgg16_True=torchvision models vgg16(pretrained=True)

一般下载好的模型保存路径:==C:\user.cache\torch\hub\checkpoints

  • 在已有的网络模型中新添自己需要的层
vgg16_True.classifier.add_module("7",nn.Linear(1000,10))

10 网络模型的保存与读取

方法一:直接把模型和参数保存下来
注意: 有 一个陷阱,自定义的模型在下载的时候运行会报错,得需要复制下载原模型。只能导入专门经典的模型

#保存
torch.save(vgg16_true,"vgg16_method1.pth")

#下载
model=torch.load("vgg16_method1.pth")

方法二:保存模型的参数,一般使用这个。内存比较小,节省空间;以字典的形式保存。

#保存
torch.save(vgg16_true.state_dict(),"vgg16_method1.pth")

#下载
vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_false.load_state_dict(torch.long("wgg166_method2.pth"))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/346747.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flink X Hologres 构建企业级 Streaming Warehouse

摘要:本文整理自阿里云资深技术专家,阿里云 Hologres 负责人姜伟华(果贝),在 FFA 实时湖仓专场的分享。本篇内容主要分为四个部分:实时数仓分层的技术需求阿里云一站式实时数仓 Hologres 介绍Flink x Holog…

30个题型+代码(冲刺2023蓝桥杯)

愿意的可以跟我一起刷,每个类型做1~5题 ,4月前还可以回来系统复习 2月13日 ~ 3月28日,一共32天 一个月时间,0基础省三 --> 省二;基础好点的,省二 --> 省一 目录 🌼前言 &#x1f33c…

1.1配置单区域OSPF

实验1:配置单区域OSPF[1] 1.实验目的 实现单区域OSPF的配置描述OSPF在多路访问网络中邻居关系建立的过程2.实验拓扑 单区域的OSPF实验拓扑如图1-2所示。 图1-2 配置单区域OSPF 3.实验步骤 IP地址的配置[2] R1的配置

Framebuffer驱动程序框架

Framebuffer驱动程序框架 文章目录Framebuffer驱动程序框架一、 怎么编写字符设备驱动程序二、 Framebuffer驱动程序框架三、 怎么编写Framebuffer驱动程序致谢一、 怎么编写字符设备驱动程序 驱动主设备号构造file_operations结构体,填充open/read/write等成员函数…

获取本机的IP地址,看似简单的获取,实则蕴含非常多的操作

这篇文章讲述了PowerJob获取本地IP离奇曲折的经过,以及开放了诸多的可配置参数,打开了我新世界的大窗户。求个关注,求个点赞,求一个评论。 获取地址的操作,本来不应该作为什么重点,但是因为一点小小的意外&…

再创荣誉 | Softing工业荣获CAIMRS 2023 数字化创新奖

在刚刚结束的中国工控-第二十一届“自动化及数字化”年度评选(CAIMRS 2023)中,Softing凭借edgeAggregator产品荣获“数字化创新奖”! 经层层筛选,Softing edgeAggregator边缘聚合服务器从中脱颖而出,摘得C…

隐马尔科夫模型基础

一、定义是一种生成模型,是隐藏的马尔科夫链随机生成不可观测的状态序列,再由各个状态生成观测序列的过程二、符号含义其中:Q是所有可能的状态集合V是所有的可能的观测集合N是可能的状态数,M是可能的观测数其中:I是长度…

想搞钱,先培养商业思维!

昨天谈借助 ChatGPT 挣点房贷钱的时候,看评论区大家留言的时候,发现很多人不知道这个东西可以赚钱,或者说知道这个东西且也做了功课,无奈太忙最后也没搞到钱。可以看到,大家的问题,归根于自己有没有商业思维…

小白系列Vite-Vue3-TypeScript:010-封装svg

上一篇我们介绍了ViteVue3TypeScript项目中mockjs的安装和配置i。本篇我们来介绍封装SVG图标组件。svg特征Preloading所有图标都是在项目运行时生成的,只需要操作一次dom即可。高性能内置缓存,仅在文件被修改时才会重新生成。安装插件vite-plugin-svg-ic…

QHash源码解读

QT版本 v5.12.10 元素 // 重点说明QHashData的函数,QHashData是QHash的基础 struct QHashData {struct Node {Node *next;uint h;};Node *fakeNext; // 永为nullNode **buckets; // Node *数组QtPrivate::RefCount ref;int size; // node个数int nodeSize; /…

koa2结合MySQL实现简单的考试系统

项目需求:1. 数据库采用mysql实现, 后台服务Koa2框架, 通过postman调试所有接口2. 接口功能&#xff1a;(1)实现对科目表的增删改查(2)实现对试题表的增删改(3)实现对试题表的查询操作&#xff0c;要求&#xff1a;<1>显示科目名称和类型名称 <2> 可以按照科目名称…

第一章 - 对数据库和SQL的简单了解

第一章 - 对数据库和SQL的简单了解1 了解数据库&#xff1a;2 什么是数据库&#xff1a;3 什么是SQL&#xff1a;4 SQL的优点&#xff1a;5 数据库的一些常用术语&#xff1a;6 什么是MySQL&#xff1a;1 了解数据库&#xff1a; 其实你一直都在使用数据库&#xff0c;只是你并…

【观察】从消费级SSD AM6A1,看忆联的优势与胜势

毫无疑问&#xff0c;目前SSD&#xff08;固态硬盘&#xff09;已取代HDD&#xff08;机械硬盘&#xff09;成为电脑中常见的存储设备&#xff0c;特别是在技术创新的持续推动下&#xff0c;如今SSD的速度和效率都在不断地提高&#xff0c;从SATA2 3GB发展到SATA3 6GB&#xff…

四、常用样式讲解二

文章目录一、常用样式讲解二1.1 元素隐藏1.2 二级菜单1.3 相对定位和绝对定位1.4 定位的特殊情况1.5 表格1.6 表格的css属性1.7 表格中新增的标签一、常用样式讲解二 1.1 元素隐藏 如何让一个元素隐藏 1、不定义颜色 占用空间 2、display: none 不占用空间 3、visibility: hi…

在Linux和Windows上安装Nacos-2.1.1

记录&#xff1a;377场景&#xff1a;在CentOS 7.9操作系统安装Nacos-2.1.1。在Windows操作系统上安装Nacos-2.1.1。Nacos&#xff1a;Nacos: Dynamic Naming and Configuration Service。Nacos提供动态配置服务、服务发现及管理、动态DNS服务功能。版本&#xff1a;JDK 1.8 Na…

dva + antd 报错

学习 dva 》 按照dva指南学习、安装 dva-cli、引入antd的报错问题解决 1、在执行命令 npm install antd babel-plugin-import --save时报错 报错类似“A complete log of this run can be fund in : … " 解决&#xff1a;换成cnpm 或者 yarn 进行安装 举例在安装history的…

Java常见问题总结三

一、ArrayList 和 LinkedList的区别 1. 底层数据结构不同。ArrayList底层是基于数组实现的&#xff0c;LinkedList底层是基于链表文现的 2. 由于底层数缺结构不同&#xff0c;他们所适电的场景也不同&#xff0c;Araylist史适合随机查战&#xff0c;LinkedList史适合期余和添…

自动化测试工程师的发展前景怎么样?

根据各大网络招聘平台的数据显示&#xff0c;越来越多的企业在招聘测试工程师的时候&#xff0c;都开始重视自动化测试这一重要技能。早在四年前&#xff0c;自动化测试的人才需求和薪资待遇就开始一路上涨。如果你问&#xff1a;自动化测试工程师的发展前景怎么样&#xff1f;…

基于redis实现分布式锁

前言 我们的系统都是分布式部署的&#xff0c;日常开发中&#xff0c;秒杀下单、抢购商品等等业务场景&#xff0c;为了防⽌库存超卖&#xff0c;都需要用到分布式锁。 分布式锁其实就是&#xff0c;控制分布式系统不同进程共同访问共享资源的一种锁的实现。如果不同的系统或…

ubuntu重启、关机命令

// // // //之前用linux系统&#xff0c; 一键解决也是可以的&#xff0c;反正我每次用命令&#xff08;泪目…&#xff09;&#xff0c;中间崩了好几次&#xff0c;换回win&#xff0c;此篇也做记录 // // // 重启命令 以下所有命令在root根目录下输入&#xff08;普通用户&…