刘二大人《Pytorch深度学习实践》第九讲多分类问题

news2024/11/24 11:53:20

文章目录

  • 多分类问题
  • 损失函数
  • 课上代码
  • transforms的使用方法
  • view()函数
  • dim维度的理解
  • 为什么要使用item()

多分类问题

在这里插入图片描述
 把原来只有一个输出,加到10个 每个输出对应一个数字,这样可以得到每个数字对应的概率值,这里每个输出做的都是sigmoid二分类(即是非1即0),所以只要有一项输出为1时,其他非1的输出都规定为0,以此来判断。
 但是这种情况下出现一个问题,每个sigmoid的输出都是独立的,当一个类别的输出概率较高时,其他类别的概率仍然会高,也就是说在输出了1的概率后,2输出的概率不会因为1的出现而受影响,这点说明了所有输出的概率值之和大于1。
&emsp’所以我们希望输出的概括要满足分布性质的要求:
在这里插入图片描述

 如何来实现上面的两个要求呢,这就用到了Softmax:
在这里插入图片描述
 根据公式可以看出来:softmax层接受上一层的输出,分母为上一层每个神经元输出的指数再求和,计算每一个概率分子则为该类的输出指数;指数确保了P(y=i)≥0的条件,该公式能够满足概率和为1

损失函数

 使用Cross Entropy Loss Function(交叉熵损失函数):
在这里插入图片描述

课上代码

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
batch_size = 64
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]
)

train_dataset = datasets.MNIST('./dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader (train_dataset, shuffle = True, batch_size = batch_size)
test_dataset = datasets.MNIST ('./dataset/mnist/', train=True, download=True, transform=transform)
test_loader = DataLoader (test_dataset, shuffle = False, batch_size = batch_size)


class Net (torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.l1 = torch.nn.Linear (784, 512)
    self.l2 = torch.nn.Linear (512, 256)
    self.l3 = torch.nn.Linear (256, 128)
    self.l4 = torch.nn.Linear (128, 64)
    self.l5 = torch.nn.Linear (64, 10)
  
  def forward (self, x):
    x = x.view (-1, 784)
    x = F.relu (self.l1(x))
    x = F.relu (self.l2(x))
    x = F.relu (self.l3(x))
    x = F.relu (self.l4(x))
    return self.l5(x)

model = Net()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

def train (epoch):
  running_loss = 0.0
  for batch_idx, data in enumerate (train_loader,0):
    inputs, target = data
    optimizer.zero_grad()

    outputs = model (inputs)
    loss = criterion (outputs, target)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

    if batch_idx % 300 == 299:
      print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
      running_loss = 0.0

def test ():
  correct = 0
  total = 0
  with torch.no_grad():
    for data in test_loader:
      images, labels = data
      outputs = model (images)
      _,predicted = torch.max (outputs.data, dim = 1)
      total += labels.size (0)
      correct += (predicted == labels).sum().item()
  print('accuracy on test set: %d %% ' % (100*correct/total))



if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

transforms的使用方法

 采用transforms.Compose(),将一系列的transforms有序组合,实现时按照这些方法依次对图像操作。

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 缩放
    transforms.RandomCrop(32, padding=4),  # 随机裁剪
    transforms.ToTensor(),  # 图片转张量,同时归一化0-255 ---》 0-1
    transforms.Normalize(norm_mean, norm_std),  # 标准化均值为0标准差为1
])

还有很多其他的transforms处理方法,总结有四大类:

  1. 裁剪-Crop
中心裁剪:transforms.CenterCrop

随机裁剪:transforms.RandomCrop

随机长宽比裁剪:transforms.RandomResizedCrop

上下左右中心裁剪:transforms.FiveCrop

上下左右中心裁剪后翻转,transforms.TenCrop
  1. 翻转和旋转——Flip and Rotation
依概率p水平翻转:transforms.RandomHorizontalFlip(p=0.5)

依概率p垂直翻转:transforms.RandomVerticalFlip(p=0.5)

随机旋转:transforms.RandomRotation
  1. 图像变换
resize:transforms.Resize
标准化:transforms.Normalize
转为tensor,并归一化至[0-1]:transforms.ToTensor
填充:transforms.Pad
修改亮度、对比度和饱和度:transforms.ColorJitter
转灰度图:transforms.Grayscale
线性变换:transforms.LinearTransformation()
仿射变换:transforms.RandomAffine
依概率p转为灰度图:transforms.RandomGrayscale
将数据转换为PILImage:transforms.ToPILImage
transforms.Lambda:Apply a user-defined lambda as a transform.
  1. 对transforms操作,使数据增强更灵活
transforms.RandomChoice(transforms), 从给定的一系列transforms中选一个进行操作

transforms.RandomApply(transforms, p=0.5),给一个transform加上概率,依概率进行操作

transforms.RandomOrder,将transforms中的操作随机打乱

view()函数

 view()的作用相当于numpy中的reshape,重新定义矩阵的形状。
1.普通用法

#其中v1为1*16大小的张量,包含16个元素。v2为4*4大小的张量,同样包含16个元素。
import torch
v1 = torch.range(1, 16) 
v2 = v1.view(-1, 4)  
#-1的意思是这个维度的数量是不确定的,对于确定的第二维度4,整个元素一共有16个,因此最后的v4是4*4大小的张量
#如果4换成8的话,第一维度会调整为2
v3 = torch.range(1, 16) 
v4 = v1.view(-1, 4)  

dim维度的理解

 维度为0时,即tensor(张量)为标量。
 维度为1时,即tensor(张量)为一维向量。
 维度为2时,即tensor(张量)为二维矩阵。
 维度为3时,即tensor(张量)为三维空间矩阵。
pytorch中的tensor维度可以通过第一个数前面的中括号数量来判断,有几个中括号维度就是多少。拿到一个维度很高的向量,将最外层的中括号去掉,数最外层逗号的个数,逗号个数加一就是最高维度的维数,如此循环,直到全部解析完毕。
dim的使用:只有dim指定的维度是可变的,其他都是固定不变的

  1. torch.argmax(): 得到最大值的序号索引
a = torch.rand((3,4))
print(a)
b = torch.argmax(a, dim=1) ##指定列,也就是行不变,列之间的比较
print(b)
>>tensor([[0.8338, 0.6953, 0.7558, 0.5803],
        [0.2105, 0.7638, 0.0912, 0.3341],
        [0.5585, 0.8019, 0.6590, 0.2268]])
>>tensor([0, 1, 1])

说明:dim=1,指定列,也就是行不变,列之间的比较,所以原来的a有三行,最后argmax()出来的应该也是三个值,第一行的时候,同一列之间比较,最大值是0.8338,索引是0,同理,第二行,最大值的索引是1……
2. sum(): 求和

a = t.arange(0,6).view(2,3)
a
>>tensor([[0, 1, 2],
        [3, 4, 5]])
a.sum()
a.sum(dim=0) #指定行,行是可变的,列是不变的
a.sum(dim=1)
>>tensor(15.)
>>tensor([3., 5., 7.])
>>tensor([ 3., 12.])

说明:dim=0,指定行,行是可变的,列是不变,所以就是同一列中,每一个行的比较,所以a.sum(dim = 0),第一列的和就是3,第二列的和就是5,第三列的和就是7.
同理,a.sum(dim=1),指定列,列是可变的,行是不变的,所以就是同一列之间的比较或者操作,所以第一行的求和是3,第二行的求和是12

为什么要使用item()

 item()的作用是取出单元素张量的元素值并返回该值,保持该元素类型不变。在训练时统计loss变化时,直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大,那么消耗的显存也就越来越大,loss.item()能够防止tensor无线叠加导致的显存爆炸

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/417983.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Netty实战与调优

Netty实战与调优 聊天室业务介绍 代码参考 /*** 用户管理接口*/ public interface UserService {/*** 登录* param username 用户名* param password 密码* return 登录成功返回 true, 否则返回 false*/boolean login(String username, String password); }/*** 会话管理接口…

如何快速上手Vue框架?

编译软件:IntelliJ IDEA 2019.2.4 x64 运行环境:Google浏览器 Vue框架版本:Vue.js v2.7.14 目录一. 框架是什么?二. 怎么写一个Vue程序(以IDEA举例)?三. 什么是声明式渲染?3.1 声明式3.2 渲染四…

docker安装oracle_11g -- 命还长时,自己搞的小玩具!!!

前言: 如果不是嫌命长, 建议不这么玩, 因为装到最后你会很崩溃, 感觉毫无意义, 就是个玩具, 哎~~~就是玩!!! 参考文档 1.https://blog.51cto.com/u_12946336/5722259 2.https://www.muzhuangnet.com/show/118178.html 3.https://blog.csdn.net/qq_42957435/article/details/1…

spring security+jwt实现认证和授权

最近正在研究前后端分离的开发模式&#xff0c;做做小项目练练手&#xff0c;正好用到了spring security的认证和授权&#xff0c;就总结一波。 首先&#xff0c;引入相关的依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId&g…

支付系统设计:收银台设计一

文章目录前言1. 收银台前端页面1. 1 收银台的业务场景1. 2 同应用不同支付场景下的收银台2. 商户平台配置管理2.1 配置流程2.2 支付工具列表配置2.3 支付配置2.3 支付银行配置3. 系统处理流程3.1 下单流程3.1 拉起收银台流程总结前言 收银台即用户日常付款前选择支付方式的页面…

革新设计,小巧强大,水库保卫无忧!

水库安全运行事关广大人民群众生命财产安全&#xff0c;为规范水库管理&#xff0c;落实水库预报、预警、预演、预案措施&#xff0c;提升水库信息化管理水平&#xff0c;保障水库安全运行。水库大坝是重要的国民基础设施&#xff0c;承担着防洪抗旱&#xff0c;节流发电的重要…

新规拉开中国生成式AI“百团大战”序幕?

AI将走向何方&#xff1f; ChatGPT在全球范围掀起的AI热潮正在引发越来越多的讨论&#xff0c;AI该如何管理&#xff1f;AI该如何发展&#xff1f;一系列问题都成为人们热议的焦点。此前&#xff0c;马斯克等海外名人就在网络上呼吁OpenAI暂停ChatGPT的模型训练和迭代&#xf…

SGAT丨单基因分析工具SingleGeneAnalysisTool

Single Gene Analysis Tool 简介&#xff1a;SGAT是一个免费开源的单基因分析工具&#xff0c;基于Linux系统实现自动化批量处理&#xff0c;能够快速准确的完成单基因和表型的关联分析&#xff0c;只需要输入基因型和表型原始数据&#xff0c;即可计算出显著关联的SNP位点&…

学习大数据需要什么语言基础

Python易学&#xff0c;人人都可以掌握&#xff0c;如果零基础入门数据开发行业的小伙伴&#xff0c;可以从Python语言入手。 Python语言简单易懂&#xff0c;适合零基础入门&#xff0c;在编程语言排名上升最快&#xff0c;能完成数据挖掘、机器学习、实时计算在内的各种大数…

测试名词介绍

测试名词介绍一&#xff1a;敏捷测试1. 定义&#xff1a;2. 敏捷测试的核心&#xff1a;3. 敏捷测试的8大原则和传统测试的区别二&#xff1a;测试名词介绍瀑布模型回归测试Alpha测试Beta测试性能测试白盒测试黑盒测试灰盒测试三&#xff1a;测试流程单元测试 (unit test)集成测…

Java RSA加解密算法学习

一、前言 1.1 问题思考 为什么需要加密 / 解密&#xff1f;信息泄露可能造成什么影响&#xff1f; 二、 基础回顾 2.1 加密技术 加密技术是最常用的安全保密手段&#xff0c;利用技术手段把重要的数据变为乱码&#xff08;加密&#xff09;传送&#xff0c;到达目的地后再…

nginx的前端部署方式

1. 什么是nginx Nginx是一款高性能的http 服务器/反向代理服务器及电子邮件&#xff08;IMAP/POP3&#xff09;代理服务器。 由俄罗斯的程序设计师Igor Sysoev所开发&#xff0c;官方测试nginx能够支支撑5万并发链接&#xff0c; 并且cpu、内存等资源消耗却非常低&#xff0…

javascript 数组详解

1.数组是可变的 数组内元素可以是不同的类型&#xff1a; 字符串一旦创建就不可变&#xff0c;但数组是可变的&#xff0c;且操作起来十分随意&#xff0c;例如&#xff1a; 直接修改数组长度&#xff0c;若新赋予长度小于原数组长度&#xff0c;会直接舍弃多余元素: 若新赋予…

【AI绘画】Midjourney和Stable Diffusion教程

之前我向大家介绍了这两个AI绘画网站&#xff1a; Stable Diffusion介绍&#xff1a; https://mp.csdn.net/mp_blog/creation/editor/130059509 Midjourney介绍: https://mp.csdn.net/mp_blog/creation/editor/130003233 前言 这里是新星计划本周最后一篇&#xff0c;主要…

python 连接oracle

前提&#xff0c;navicate成功连接oracle 1、下载cx_oracle,根据python版本下载whl&#xff0c;或者通过 ​pip install cx_Oracle -i http://pypi.douban.com/simple/ 下载地址&#xff1a; cx-Oracle PyPIhttps://pypi.org/project/cx-Oracle/#files2、navicate下instant…

​Auction Design in the Auto-bidding World系列一:面向异质目标函数广告主的拍卖机制设计...

导读&#xff1a; 传统拍卖机制不存在了&#xff01;出价产品智能化成为行业发展趋势&#xff0c;自动出价&#xff08;Auto-bidding&#xff09;已成为互联网广告主营销的主流&#xff0c;经典效用最大化模型&#xff08;Utility Maximizer&#xff09;的假设已经不再能良好地…

使用 LXCFS 文件系统实现容器资源可见性

使用 LXCFS 文件系统实现容器资源可见性一、基本介绍二、LXCFS 安装与使用1.安装 LXCFS 文件系统2.基于 Docker 实现容器资源可见性3.基于 Kubernetes 实现容器资源可见性前言&#xff1a;Linux 利用 Cgroup 实现了对容器资源的限制&#xff0c;但是当在容器内运行 top 命令时就…

《金阁寺》金阁美之于幻想,我用摧毁它来成就其美

《金阁寺》金阁美之于幻想&#xff0c;我用摧毁它来成就其美 三岛由纪夫&#xff08;1925-1970&#xff09;,日本当代小说家、剧作家、记者、电影制作人和电影演员&#xff0c;右翼分子。主要作品有《金阁寺》《鹿鸣馆》《丰饶之海》等。曾3次获诺贝尔文学奖提名&#xff0c;属…

基于Sketch Up软件校园建模案例分享

Acknowledgements&#xff1a; 由衷感谢覃婉柔、赵泽昊同学在本次课程实习中做出的巨大贡献&#xff0c;感谢本团队成员一起努力奋斗的岁月。 一、建模地点&#xff08;中国地质大学&#xff08;武汉&#xff09;未来城校区图书馆周边&#xff09; 中国地质大学&#xff08;武汉…

关于ChatGPT的一些随笔

大家好&#xff0c;我是老三&#xff0c;最近几个月关于ChatGPT的信息可以说是铺天盖地。 “王炸&#xff0c;ChatGPT……” “xxx震撼发布……” “真的要失业了&#xff0c;xxx来袭……” “普通如何利用ChatGPT……” …… 不过老三前一阵比较忙&#xff0c;对ChatGPT…