数据分析-深度学习 Pytorch Day7

news2025/1/12 20:07:53

图像识别:CIFAR10图形识别

1.CIFAR10数据集共有60000张彩色图像,这些图像式32*32*3,分为10个类,每个类6000张

2.这里面有50000张用于训练,构成5个训练批,每一批10000张图;另外10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。

3.一个训练批中的各类图像并不一定数量相同,总的来看训练集,每一类都有5000张图片

# 导入包
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
 # 设置transforms
transform = transforms.Compose([
    transforms.ToTensor(), # numpy -> Tensor
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))  # 归一化 ,范围[-1,1]
])
​​# 下载训练数据集
# 训练集
trainset = datasets.CIFAR10(root='./CIFAR10',train=True,download=True,transform=transform)
# 测试集
testset = datasets.CIFAR10(root='./CIFAR10',train=False,download=True,transform=transform)

出现如下图结果数据集下载成功

​​# 批量获取数据
from torch.utils.data.dataloader import DataLoader

BATCH_SIZE = 32

train_loader = DataLoader(trainset,batch_size=BATCH_SIZE,shuffle=True,num_workers=8,pin_memory=True)

test_loader = DataLoader(testset,batch_size=BATCH_SIZE,shuffle=True,num_workers=8,pin_memory=True)

注意:其中BATCH_SIZE = 32 中的32 可以根据自己电脑配置来定,配置高可以定128 低可以定16

# 可视化显示
import matplotlib.pyplot as plt
import numpy as np

# 十个类别
classes = ('plane','car','bird','cat','deer',
          'dog','frog','horse','ship','truck')

def imshow(img):
    img = img / 2 + 0.5 # 逆正则化
    np_img = img.numpy()  # tensor --> numpy
    plt.imshow(np.transpose(np_img,(1,2,0)))  # 改变通道顺序
    plt.show()
    
# 随机获取一批数据
imgs,labs = next(iter(train_loader))


print(imgs.shape)
print(labs.shape)
    
#调用方法
imshow(torchvision.utils.make_grid(imgs))

# 输出这批图片对应的标签
print(' '.join('%5s' % classes[labs[i]] for i in range(BATCH_SIZE)))    

结果如下:

其中

torch.Size([32, 3, 32, 32])
torch.Size([32])

中32代表32张图片,3代表3个通道,32代表像素

​​# 定义网络模型
import torch.nn as nn
import torch.nn.functional as F

'''
知识点:
1.特征图尺寸的计算公式为:[(原图片尺寸 = 卷积核尺寸) / 步长] + 1
'''
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        # 卷积层1.输入是32*32*3,计算(32-5)/ 1 + 1 = 28,那么通过conv1输出的结果是28*28*6
        self.conv1 = nn.Conv2d(3,6,5)  # imput:3 output:6, kernel:5
        # 池化层, 输入时28*28*6, 窗口2*2,计算28 / 2 = 14,那么通过max_poll层输出的结果是14*14*6
        self.pool = nn.MaxPool2d(2,2) # kernel:2 stride:2
        # 卷积层2, 输入是14*14*6,计算(14-5)/ 1 + 1 = 10,那么通过conv2输出的结果是10*10*16
        self.conv2 = nn.Conv2d(6,16,5) # imput:6 output:16, kernel:5
        # 全连接层1
        self.fc1 = nn.Linear(16*5*5, 120)  # input:16*5*5,output:120
        # 全连接层2
        self.fc2 = nn.Linear(120, 84)  # input:120,output:84
        # 全连接层3
        self.fc3 = nn.Linear(84, 10)  # input:84,output:10
        
    def forward(self,x):
        # 卷积1
        '''
        32x32x3 --> 28x28x6 -->14x14x6
        '''
        x = self.pool(F.relu(self.conv1(x)))
        # 卷积2
        '''
        14x14x6 --> 10x10x16 --> 5x5x16
        '''
        x = self.pool(F.relu(self.conv2(x)))
        # 改变shape
        x = x.view(-1,16*5*5)
        # 全连接层1
        x = F.relu(self.fc1(x))
        # 全连接层2
        x = F.relu(self.fc2(x))
        # 全连接层3
        
        x = self.fc3(x)
        return x 

注意:__init__这一块下划线要注意,按理说只要将模型定义到__init__()里就ok了,但是大家容易少打一个下划线会报错,将下划线_改为__即可解决问题。

# 创建模型
net = Net().to('cuda')​​

电脑有GPU的话这一步是部署到CUDA上运行调用GPU, 这一步容易出现下图问题:

这时候多运行几次,代码是没有问题的,应为JUPYTER是在网页上运行,需要时间反应,多运行几次

如果出现以下问题:

注意Linear中的L要大写

# 定义优化器和损失函数
import torch.optim as optim

criterion = nn.CrossEntropyLoss()  # 交叉式损失函数

optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)  # 优化器

# 定义函数
EPOCHS = 200

for epoch in range(EPOCHS):
    
    train_loss = 0.0
    for i, (datas,labels) in enumerate(train_loader):
        datas,labels = datas.to('cuda'),labels.to('cuda')
        # 梯度置零
        optimizer.zero_grad()
        # 训练
        outputs = net(datas)
        # 计算损失
        loss = criterion(outputs,labels)
        # 反向传播
        loss.backward()
        # 参数更新
        optimizer.step()
        # 累计损失
        train_loss += loss.item()
    # 打印信息
    print(epoch+1, i+1,train_loss/len(train_loader.dataset))

循环次数可以自己设置,这里设置为200轮,for循环读取训练集

输出结果如下:

​​# 测试
correct = 0
total = 0
# flag=True
with torch.no_grad():
    for i , (datas,labels) in enumerate(test_loader):
        # 输出
        outputs = model(datas) # outputs.data,shape --> torch.Size([128,10])
        _, predicted = torch.max(outputs.data, dim=1)  # 第一个是值得张量,第二个是序号张量
        # 累计数据值
        total += labels.size(0)  # labels.size() --> torch.Size([128]) , labels.size(0) --> 128
        # 比较有多少个预测正确
        correct += (predicted == labels).sum()  # 相同为1,不同为0,利用sum()总求和
    print("在1000张测试集图片上的准确率:{:.3f}%".format(torch.true_divide(correct,total))

# 显示每一类预测的概率
class_correct = list(0. for i in range(10))
total = list(0. for i in range(10))

with torch.no_grad():
    for (images,labels) in test_loader:
        outputs = model(images)  # 输出
        _,predicted = torch.max(outputs,dim=1)  # 获取到每一行的最大索引
        c = (predicted ==labels).squeeze()  # squeeze() 去掉0维【默认】,unsqueeze() 增加一维
        if labels.shape[0]  == 128:
            for i in range(BATCH_SIZE):
                label = labels[i] # 获取每一个label
                class_correct[label] += c[i].item()  # 累计维True都个数,注意:1 + True = 2,1 + False = 1
                total[label]  += 1 # 该类总的个数
                
# 输出正确率
for i in range(10):
    print("正确率 : %5s : $2d %%" % (classes[i],100 * class_correct[i] / total[i])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/169816.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vhdx中的win10进行大版本系统升级

文章目录前言普通的win10大版本iso升级方式vhdx中的win10大版本升级方式难点分析 - 无法在虚拟驱动器上安装windows解决方案 - HyperV升级vhdx win10过程效果图hyperV虚机创建mbr引导启动项hyperV虚机设置在hyperV中升级过程图问题集锦问题一:hyverV虚机中升级报错&…

力扣刷题记录——561. 数组拆分、566. 重塑矩阵、575. 分糖果

本专栏主要记录力扣的刷题记录,备战蓝桥杯,供复盘和优化算法使用,也希望给大家带来帮助,博主是算法小白,希望各位大佬不要见笑,今天要分享的是——《力扣刷题记录——561. 数组拆分、566. 重塑矩阵、575. 分…

IDEA远程调试

1 概述 原理:本机和远程主机的两个 VM 之间使用 Debug 协议通过 Socket 通信,传递调试指令和调试信息。 被调试程序的远程虚拟机:作为 Debug 服务端,监听 Debug 调试指令。jdwp是Java Debug Wire Protocol的缩写。 调试程序的本…

初识redis

1.初识Redis Redis是一种键值型的NoSql数据库,这里有两个关键字: 键值型 NoSql 其中键值型,是指Redis中存储的数据都是以key、value对的形式存储,而value的形式多种多样,可以是字符串、数值、甚至json:…

HTTPS一定可靠吗?

HTTPS一定可靠吗?中间人伪装服务器首先我们先看看客户端是如何验证证书的?数字证书签发和验证流程客户端校验服务端数字证书的过程如何出现中间人伪装服务器成服务器的情况?避免该情况中间人伪装服务器 客户端向服务端发起HTTPS建立连接请求时…

你知道吗?python lxml 库也能用于操作 svg 图片

在大多数场景中,我们都用 lxml 库解析网页源码,但你是否知道,lxml 库也是可以操作 svg 图片的。我们可以使用 lxml 中的 etree 模块来解析 SVG 文件,然后使用 SVG 中的各种元素和属性来进行操作。 python lxml 库操作 svg 图片lxm…

传输层协议:TCP协议(上)——协议结构、主要特点以及应用场景

简介 传输控制协议(英语:Transmission Control Protocol,缩写:TCP)是一种面向连接的、可靠的、基于字节流的传输层通信协议,由IETF的RFC 793定义。在简化的计算机网络OSI模型中,它完成第四层传…

xubuntu系统偶发自动登出

项目场景: 系统:xubuntu-16.04.3-desktop 问题描述 使用xubuntu系统期间,在root用户下进行相关开发,突然系统会回到普通用户登录界面,需要输入密码进入到普通用户下   它会终止所有打开的应用程序和进程&#xff0…

【Vue组件通信方式】

文章目录前言一、父子组件通信1、父传子①使用props接收父组件传递的属性② 使用$attrs接收父组件未在 props 和 emits 中定义的属性和事件③使用 $parent获取父组件的信息2、子传父① 使用 $emit传递信息给父组件② 使用$refs获取子组件的属性和事件二、自定义事件&#xff1a…

独家丨DeepMind科学家、AlphaTensor一作解读背后的故事与实现细节

一直以来,DeepMind的Alpha系列工作,AlphaGo、AlphaStar等致力于棋类和游戏应用中战胜人类,而两个月前发布的AlphaTensor则把目标指向了科学计算领域,意在为矩阵乘法等基本计算任务自动设计更高效的经典算法,这一工作一…

Burpsuite超详细安装教程(附安装包)

写在开头 Burp Suite 是用于攻击web 应用程序的集成平台,包含了许多工具。Burp Suite为这些工具设计了许多接口,以加快攻击应用程序的过程。所有工具都共享一个请求,并能处理对应的HTTP 消息、持久性、认证、代理、日志、警报。 接下来我来…

软件测试面试经 | 双非院校,从外包到外企涨薪85%,他的涨薪秘籍全公开

本文为霍格沃兹测试开发学社优秀学员跳槽笔记,测试开发进阶学习文末加群。 本身是一所不入流的院校毕业的一名建工类专业的瓜娃子,至今记得当初是因为找工作被培训公司忽悠才加入到这个行业的,抱着做着试试的想法这一干在深圳就是6年&#xf…

excel替换技巧:如何将手机号码的部分数字变成星号

每个销售员经常会接触大量客户,会用小本本记下众多客户的信息,而手机号码就是其中重要的一项。为了保护客户隐私,在公开的信息里销售员需要把客户手机号码的部分数字变成星号。比如说,把客户A的手机号码15867852976修改成158****2…

SpringMvc源码分析(三) 请求执行过程之获取MethodHandler

1.请求是如何关联到DispatcherServlet的 DispatcherServlet是Servlet的实现,遵循Servlet生命周期的规则。 Servlet的生命周期即其出生到死亡的过程中分别会调用Servlet里的以下方法: 加载和实例化:可以参考SpringMvc源码分析一 init方法…

【JavaEE】博客前端

目录 一、列表页 1.1导航条 1.2主题区域 1.2.1个人信息框 1.2.2 内容区 二、登录页 三、详情页 一、列表页 整体布局如下: 1.1导航条 导航条分为三块,整体都设置id为导航栏,然后左右分为导航栏左和导航栏右。左边靠左,右边靠…

计算机视觉Computer Vision课程学习笔记四之Region and Edge Descriptions

第四章讲了区域和边界的描述 包括最佳区域评估方法,多物体识别,标签算法,斑点标记 以及矩评估的方法和优劣 Region Description Simple measurements on binary images • Use for recognition, etc. • Generate region descriptions whic…

Win10+CMake+VS2017编译OpenCV4.5.5

第一步:准备工作1 下载opencv4.5.5下载OpenCV4.5.5,并解压到自己新建文件夹opencv下。2 下载opencv_contrib4.5.5下载opencv_contrib4.5.5,解压到上面的opencv文件夹中,并在opencv文件夹中新建一个build文件夹,用来存放…

第一天总结 之 用户管理界面的实现 之 添加操作 的实现

添加操作的实现 明确页面的跳转 找到 admin_adduser.jsp中 form表单 前端的添加页面展示 在表单中输入 信息 点击注册跳转到 from表单对应的 action地址 UserAddServlet 创建UserAddServlet 从前端的form表单中获取值 然后在service层 进行 业务操作 即将这些属性存放在 Ob…

私有部署与SaaS模式网站有什么区别

什么是SaaS SaaS 是 Software-as-a-Service 的简称,它是一种通过互联网提供软件的模式。 以官网为例,SaaS订阅的网站通常统一部署在SaaS提供商的云服务器上。用户通过自己的实际需求订购对应的网站系统服务,按订购的系统功能、使用流量/存储…

Word处理控件Aspose.Words功能演示:用Java从Word文档中提取文本

Aspose.Words For .NET是一种高级Word文档处理API,用于执行各种文档管理和操作任务。API支持生成,修改,转换,呈现和打印文档,而无需在跨平台应用程序中直接使用Microsoft Word。此外,API支持所有流行的Word…