Pytorch深度学习笔记(十)多分类问题

news2025/1/11 22:51:58

课程推荐:09.多分类问题_哔哩哔哩_bilibili 

目录

1. 多分类模型

2. softmax函数模型

3. Loss损失函数

4.实战MNIST Dataset


之前,在逻辑斯蒂回归中我们提到了二分类任务,现在我们讨论多分类问题。

1. 多分类模型

与二分类不同的是多分类有多个输出概率,由softmax层完成。

softmax用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类。

softmax层的前一层是线性层,也就是说softmax之前的一层不需要再做Sigmoid,Sigmoid在输入softmax之前已经做过了。线性层输出的值是一般值,还不是概率值。经过sotfmax层之后才能变成概率值

2. softmax函数模型

softmax函数的作用:softmax两个作用,如果在进行softmax前的input有负数,通过指数变换,得到正数。所有分类的概率求和为1。

softmax函数模型: p(y=i)=\frac{e^{z_{i}}}{\sum_{j=0}^{K-1}e^{z_{i}}},i\in\{0......i\}

softmax实现过程:Exponent为求指数,sum为求和,Divide为求除数。最后得到的概率值为\hat{y}

代码实现: 

import numpy as np
y = np.array([1, 0, 0])
z = np.array([0.2, 0.1, -0.1])
y_pred = np.exp(z) / np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum()
print(loss)

3. Loss损失函数

Loss损失函数计算公式:Loss(\hat{Y},Y)=-Ylog\hat{Y}

交叉熵损失CrossEntropyLoss <==> LogSoftmax + NLLLoss。LogSoftmax用于得到预测概率值\hat{y},NLLLoss用于 \hat{y}和y损失值计算。NLLLoss可单独使用,可以思考一下CrossEntropyLoss与NLLLoss区别,方便日后灵活应用。

代码实现:

import torch
# 长整形的张量,LongTensor[0]是指索引为0 的(就是第一个元素)为1,其余为0
y = torch.LongTensor([0])
z = torch.Tensor([[0.2, 0.1, -0.1]])
# 损失函数
criterion = torch.nn.CrossEntropyLoss()
loss = criterion(z, y)
print(loss)

4.实战MNIST Dataset

将MNIST Dataset中的手写图像映射到一个28*28的矩阵中,颜色越深数值越大,矩阵中的数值在[0,1]之间。把像素值0-255转化为图像张量0-1

将PIL图像转化为张量:

transforms.ToTensor()将PIL图像由单通道转变为多通道再取值[1,0]之间,单变多28*28 =>1*28*28,W*H => C*W*H , C:channel通道,W:宽,H:高,C*W*H。

标椎化处理:

Normalize计算公式

Normalize函数将转化后的张量映射到[0,1]之间

# 把像素值0-255转化为图像张量0-1
transform = transforms.Compose([
    # transforms.ToTensor()转化张量,Normalize映射到[0,1]之间
    transforms.ToTensor(),
    # (均值,标准差)
    transforms.Normalize((0.1307, ), (0.381, ))
])

完整代码:

这是一种全连接的神经网络。

1.把像素值0-255转化为图像张量0-1

2.准备数据:view()函数,改变张量形状,参数为-1,根据后一个数,自动调整张量的形状和大小。 

3.模型结构:线性层(降维) => 激活层(relu激活) => …… => 线性层(维度为10)

view()参考博客

老四步:

1.数据准备

2.设计模型

3.构造损失函数和优化器

4.训练周期(前馈—>反馈—>更新)

import torch
# 用于图像映射到矩阵中
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

batch_size = 64

# 把像素值0-255转化为图像张量0-1
transform = transforms.Compose([
    # transforms.ToTensor()转化张量,Normalize映射到[0,1]之间
    transforms.ToTensor(),
    # (均值,标准差)
    transforms.Normalize((0.1307, ), (0.381, ))
])

# 训练集
train_dataset = datasets.MNIST(root="../dataset/mnist",
                           train=True,
                           download=True,
                           transform=transform)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)

test_dataset = datasets.MNIST(root="../dataset/mnist",
                           train=False,
                           download=True,
                           transform=transform)

test_loader = DataLoader(test_dataset,
                          batch_size=batch_size,
                          shuffle=False)


#…2.设计模型………………………………………………………………………………………………………………………………………#
# 继承torch.nn.Module,定义自己的计算模块,neural network
class Net(torch.nn.Module):
    # 构造函数
    def __init__(self):
        # 调用父类构造
        super(Net, self).__init__()
        # 从784维降到10维
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)

    # 前馈函数
    def forward(self, x):
        # 改变张量形状,784表示确定的列数,自动调整行数
        x = x.view(-1, 784)
        # 激活
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        # 最后传入一个线性层
        return self.l5(x)

#……3.构造损失函数和优化器………………………………………………………………………………………………………#
model = Net()
# 实例化损失函数,返回损失值
criterion = torch.nn.CrossEntropyLoss()
# 优化器,momentum冲量
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

#……4.训练和测试……………………………………………………………………………………………………………………………#
def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        # 1.准备数据
        inputs, labels = data
        optimizer.zero_grad()
        # 2.正向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # 3.反向传播
        loss.backward()
        # 4.更新权重w
        optimizer.step()
        # 损失求和
        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0

def test():
    correct = 0
    total = 0
    # with torch.no.grad():内部代码不会再计算梯度
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            # dim沿着第一个纬度(行)找最大值,返回(最大值,最大值下标)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            # 预测值与标签对比,正确则相加
            correct += (predicted == labels).sum().item()
    # 输出精确率
    print('Accuracy on test set: %d %%' % (100 * correct / total))

if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

_, predicted = torch.max(outputs.data, 1):返回与样本对比后,相似度最大的样本下标

_, predicted = torch.max(outputs.data, 1)的理解

训练结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/457493.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于C#asp.net心里咨询服务网站系统

功能模块&#xff1a; 主要分为管理员和注册用户&#xff0c;注册用户可以查看所有人发布的心里文章&#xff0c;情感在线问答&#xff0c;查询相似问题&#xff0c;以及进入论坛进行交流&#xff08;发帖跟帖评论收藏等&#xff09;后台管理主要是针对个人信息修改 管理员对注…

SpringBoot自动装配机制的原理

自动装配 简单来说&#xff0c;就是自动去把第三方组件的Bean装载到IOC容器里面。不需要开发人员再去写Bean相关的一个配置。 在SpringBoot应用里面&#xff0c;只需要在启动类上去加上SpringBootApplication注解就可以实现自动配置&#xff0c;SpringBootApplication注解它是…

DSPC174 3BSE005461R1码垛工业机器人安装调试的13个步骤

​ DSPC174 3BSE005461R1码垛工业机器人安装调试的13个步骤 ABB码垛工业机器人安装调试的13个步骤 最近工业机器人市场上&#xff0c;调试工作比较火爆&#xff0c;单个项目动辄几十台机器人同时调试&#xff0c;开出的日薪达到1500-2000元。拥有如此庞大的市场需求和丰厚收入…

Hudi数据湖技术之核心概念

目录 1 基本概念1.1 时间轴Timeline1.2 文件管理1.3 索引Index 2 存储类型2.1 计算模型2.1.1 批式模型&#xff08;Batch&#xff09;2.1.2 流式模型&#xff08;Stream&#xff09;2.1.3 增量模型&#xff08;Incremental&#xff09; 2.2 查询类型&#xff08;Query Type&…

力扣:通过《84.柱状图中最大的矩形》求解《85. 最大矩形》

84. 柱状图中最大的矩形 85. 最大矩形 84.柱状图中最大的矩形&#xff1a; 单调栈求解问题范围&#xff1a; 输出每个数左边第一个比它小的数 单调栈例题&#xff1a; Acwing 830. 单调栈 #include <iostream>using namespace std;const int N 100010; int stk[N],tt …

再多猜一次就爆炸(小黑子误入)

目录 猜数字游戏 游戏设计思路 1.电脑随机生成一个数 2.猜数字 3.输入我是ikun&#xff0c;泰裤辣! 否则电脑将在一分钟后关机 游戏运行效果 源码 代码分析 代码实现关键语句 strcmp() rand()与srand() 时间戳time() 寄语 猜数字游戏 游戏设计思路 1.电脑随机生…

C语言_Printf函数返回值

目录 1. 嵌套结构 2. Printf 函数返回值 在了解Printf 函数的返回值之前&#xff0c;先来了解下什么叫嵌套结构。 1. 嵌套结构 这里直接举个例子进行介绍&#xff1a; strlen 函数计算字符串长度&#xff0c;显然打印的结果是 3 但是如果采用嵌套结构&#xff08;简单来说就…

【深度学习】基于华为MindSpore的手写体图像识别实验

1 实验介绍 1.1 简介 Mnist手写体图像识别实验是深度学习入门经典实验。Mnist数据集包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心&#xff0c;图像是固定大小(28x28像素)&#xff0c;其值为0到255。为简单起见&#xff0c;每…

看完这篇文章你就彻底懂啦{保姆级讲解}-----(LeetCode刷题142环形链表II) 2023.4.24

目录 前言算法题&#xff08;LeetCode刷题142环形链表II&#xff09;—&#xff08;保姆级别讲解&#xff09;分析题目&#xff1a;算法思想环形链表II代码&#xff1a;补充 结束语 前言 本文章一部分内容参考于《代码随想录》----如有侵权请联系作者删除即可&#xff0c;撰写…

ESP32设备驱动-LIS3MDL磁场传感器驱动

LIS3MDL磁场传感器驱动 文章目录 LIS3MDL磁场传感器驱动1、LIS3MDL介绍2、硬件准备3、软件准备4、驱动实现1、LIS3MDL介绍 LIS3MDL 具有4/8/12/16 高斯的用户可选满量程。自检功能允许用户在最终应用中检查传感器的功能。该设备可以被配置为生成用于磁场检测的中断信号。 LIS…

Vue 3 第十四章:组件五(内置组件-transitiontransition-group)

文章目录 1. transition组件1.1. 基本用法1.2. css过渡class介绍1.3. 过渡效果命名1.3.1. 基本用法 1.4. 配合自定义动画&#xff08;animation&#xff09;使用1.5. 自定义过渡 class1.6. <Transition>组件生命周期1.7. transition 常用场景 2. transition-group组件2.1…

Java基础(十一)日期时间API

1 JDK8之前&#xff1a;日期时间API 1.1 java.lang.System类的方法 System类提供的public static long currentTimeMillis()&#xff1a;用来返回当前时间与1970年1月1日0时0分0秒之间以毫秒为单位的时间差。 此方法适用于计算时间差。 计算世界时间的主要标准有&#xff1a;…

SCAU 统计学 实验6

要确定不同培训方式对产品组装时间是否有显著影响&#xff0c;我们可以使用单因素方差分析&#xff08;One-way ANOVA&#xff09;。我们将使用以下数据&#xff1a; 培训方式 A 的样本数据 培训方式 B 的样本数据 培训方式 C 的样本数据 显著性水平&#xff08;α&#xff09…

windows下springboot集成ELK

ELK ElasticSearch Logstash Kibana的集合。ELK主要用于日志的集中管理、快速查询和分析。主要是通过 Logstash 将应用系统的日志通过 input 收集&#xff0c;然后通过内部整理&#xff0c;通过 output 输出到 Elasticsearch 中&#xff0c;其实就是建立了一个 index&#x…

【利刃出鞘】链式思维利用ChatGPT,让其成为工作中的利剑?附带初学者扫盲SpringBoot

【利刃出鞘】链式思维利用ChatGPT&#xff0c;让其成为工作中的利剑 一、一点思考二、技术学习——链式思维2.1 springboot注册bean的几种方式&#xff1f;2.2 springboot Component 注册的原理&#xff1f;2.3 springboot引用注册的Bean原理&#xff1f;2.4 private final MyB…

26-第一个Servlet项目

目录 1.Servlet是什么&#xff1f; 2.第一个Servlet项目 2.1.创建Maven项目 2.2.引入Servlet依赖&#xff08;将Maven项目改为Servlet项目(尚不完整)&#xff09; 2.3.完善Servlet项目目录——源代码目录&单元测试目录&#xff08;非必须&#xff09; 2.4.编写代码 …

4月24日作业

作业1 #include <iostream> using namespace std; template <typename T> class Node { private: T* p; //指针指向栈的首地址 int maxsize; //栈最大容量 int top-1; //栈顶 public: Node(){} //无参构造 Node(int max):maxsize(max)//有参构造 填最大容…

2022 ICPC Gran Premio de Mexico Repechaje 题解

目录 A. Average Walk&#xff08;签到&#xff09; 题意&#xff1a; 思路&#xff1a; 代码&#xff1a; C. Company Layoffs&#xff08;签到&#xff09; 题意&#xff1a; 思路&#xff1a; 代码&#xff1a; D. Denji1&#xff08;模拟/二分&#xff09; 思路&am…

Bsah shell的操作环境

文章目录 Bsah shell的操作环境路径与命令查找顺序使用案例 bash的登录与欢迎信息&#xff1a;/etc/issue、/etc/motdbash的环境配置文件如下login与non-login shell/etc/profile(login shell 才会读)~/.bash_profile(login shell 才会读)source&#xff1a;读入环境配置文件的…

简单介绍一下什么是“工作内存”和“主内存”(JMM中的概念)

在学习Java多线程编程里&#xff0c; volatile 关键字保证内存可见性的要点时&#xff0c;看到网上有些资料是这么说的&#xff1a;线程修改一个变量&#xff0c;会把这个变量先从主内存读取到工作内存&#xff1b;然后修改工作内存中的值&#xff0c;最后再写回到主内存。 对…