深度学习(五)softmax 回归之:分类算法介绍,如何加载 Fashion-MINIST 数据集

news2024/12/28 18:54:59

Softmax 回归

基本原理

回归和分类,是两种深度学习常用方法。回归是对连续的预测(比如我预测根据过去开奖列表下次双色球号),分类是预测离散的类别(手写语音识别,图片识别)。

1699720169075

现在我们已经对回归的处理有一定的理解了,如何过渡到分类呢?

假设我们有 n 类,首先我们要编码这些类让他们变成数据。所有类变成一个列向量。

y = [ y 1 , y 2 , . . . y n ] T y=[y_1,y_2,...y_n]^T y=[y1,y2,...yn]T

有一个数据属于第 i 类,那么他的列向量就是:

y = [ 0 , 0 , . . . , 1 , . . . , 0 , 0 ] T y=[0,0,...,1,...,0,0]^T y=[0,0,...,1,...,0,0]T

也就是只有他所在的那个类的元素=1.

可以用均方损失训练,通过概率判断最终选用哪一个。

Softmax 回归就是一种分类方式(回归问题在多分类上的推广)。首先确定输入特征数和输出类别数。比如上图中我们有4个特征和3个可能的类别,那么计算各自概率的公式包括3个线性回归:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可以看出 Softmax 是全连接的单层神经网络。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们让所有输出结果归一化后,从中选择出最大可能的,置信度最高的分类结果。

image-20231112100423488

采用 e 的指数可以让值全变为非负。

用真实的概率向量-我们预测得到的概率向量就是损失。真实值就是只有一个1的列向量。

交叉熵损失:

image-20231112101259670

可见**分类问题,我们不关心对非正确的预测值,只关心正确预测值是否足够大。**因为正确值是只有一个元素为1的列向量。

常用的损失函数

L2 Loss:均方损失。

image-20231112101555142

L1 Loss:绝对值损失。

image-20231112101829868

L2 梯度是一条倾斜直线,对于梯度下降算法等更为合适;L1 是一个跳变,梯度要么 -1 要么 1. 如图是 L1 L2 的梯度。

image-20231112102551104

我们可以结合两者,得到一个新的损失函数(鲁棒损失 Huber Robust):

KaTeX parse error: {equation} can be used only in display mode.

image-20231112102721527

图像分类数据集

MINIST 是一个常用图像分类数据集,但是过于简单。后来的 upgrade 版叫 Fashion-MINIST(服装分类).

首先,我们研究研究怎么加载训练数据集,以便后面测试算法用。

# 导包
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

d2l.use_svg_display()

# 下载数据集并读取到内存
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)		# 训练数据集
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)	# 测试数据集用于评估性能

# 定义函数用于返回对应索引的标签
def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

# 图像可视化,让结果看着更直观,比如下面那个绿色图的样子
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

# 我们先读一点数据集看看啥样的
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

1699980345931

# 通过内置数据加载器读取一批量数据,自动随机打乱读取,不需要我们自己定义
batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

测量以上用时基本2-3s。

总结整合以上数据读取过程,代码如下:

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

加载图像还可以调整其大小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1221354.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Kafka学习笔记(三)

目录 第5章 Kafka监控(Kafka Eagle)5.2 修改kafka启动命令5.2 上传压缩包5.3 解压到本地5.4 进入刚才解压的目录5.5 将kafka-eagle-web-1.3.7-bin.tar.gz解压至/opt/module5.6 修改名称5.7 给启动文件执行权限5.8 修改配置文件5.9 添加环境变量5.10 启动…

【Proteus仿真】【51单片机】公交车报站系统

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真51单片机控制器,使用LCD12864显示模块、DS18B20温度传感器、DS1302时钟模块、按键、LED蜂鸣器、ULN2003、28BYJ48步进电机模块等。 主要功能: 系统运行后&…

Linux基础知识(1)——目录结构,绝对/相对路径,指令等(配图)

目录 1.目录结构 2.路径 3.认识指令 文章内容并不聚焦于Linux命令,而是针对Linux的基础知识进行讲解,相信这部分知识更能帮助大家了解Linux系统。 本文只是带大家简单了解一下Linux的入门知识,在之后的文章中,我们将讲解Linux中…

【图数据库实战】HugeGraph架构

一、概述 作为一款通用的图数据库产品,HugeGraph需具备图数据的基本功能,如下图所示。HugeGraph包括三个层次的功能,分别是存储层、计算层和用户接口层。 HugeGraph支持OLTP和OLAP两种图计算类型,其中OLTP实现了Apache TinkerPop3…

C语言指针详解(1)(能看懂字就能明白系列)文章超长,慢慢品尝

目录 1、内存和地址 2、指针简介 与指针相关的运算符: 取地址操作符(&) 解引用操作符(间接操作符)(*) ​编辑 指针变量的声明 指针变量类型的意义 指针的基本操作 1、指针与整数相加…

解决 uniapp 开发微信小程序 不能使用本地图片作为背景图 问题

参考博文:uniapp微信小程序无法使用本地静态资源图片(背景图在真机不显示)的解决方法_javascript技巧_脚本之家 问题:uniapp 开发微信小程序,当使用本地图片作为 background-image 时,真机无法显示 解决: 方法一&am…

在线预览excel,luckysheet在vue项目中的使用

一. 需求 需要在内网项目中在线预览excel文档,并可以下载 二.在项目中下载并引入luckysheet 1.打开项目根目录,npm i luckyexcel 安装 npm i luckyexcel2.在项目的index.html文件中引入依赖 外网项目中的引入(CDN引入)&#…

Cesium:绘制地质剖面

作者:CSDN @ _乐多_ 本文记录了根据地质剖面的三角网的点、索引和颜色数组,绘制地质剖面的框架和部分代码。 效果如下图所示, 文章目录 一、算法逻辑二、代码一、算法逻辑 将三角网的点、索引和颜色数组按规则排列好,输入到第二节的代码中,可以绘制一个面。将这个绘制函…

如何确保消息不会丢失

本篇文章大家还可以通过浏览我的博客阅读。如何确保消息不会丢失 - 胤凯 (oyto.github.io)很多人刚开始接触消息队列的时候&#xff0c;最经常遇到的一个问题就是丢消息了。<!--more-->对于大部分业务来说&#xff0c;丢消息意味着丢数据&#xff0c;是完全无法接受的。 …

骨传导耳机的优缺点是什么?有什么值得入手的骨传导耳机吗?

骨传导耳机的优点还是挺多的&#xff0c;比如说&#xff1a;佩戴舒适、避免听力损伤、使用更安全灯&#xff0c;在详细了解骨传导耳机有什么优点和缺点之前&#xff0c;先来认识一下什么是骨传导耳机。 骨传导耳机是一种通过人体骨骼来传递声音的耳机&#xff0c;与传统的耳机相…

23111710[含文档+PPT+源码等]计算机毕业设计基于SpringBoot的体育馆场地预约赛事管理系统的设计

文章目录 **软件开发环境及开发工具&#xff1a;****功能介绍&#xff1a;****论文截图&#xff1a;****数据库&#xff1a;****实现&#xff1a;****代码片段&#xff1a;** 编程技术交流、源码分享、模板分享、网课教程 &#x1f427;裙&#xff1a;776871563 软件开发环境及…

嵌入式酒精壁炉:时尚生活的新宠

在这个注重品质与生活方式的时代&#xff0c;我们对于家居生活的要求早已不仅仅停留在实用性上。越来越多的人希望能够在家中营造一种时尚、温馨的氛围&#xff0c;而酒精壁炉恰好成为了这个潮流生活的代表。 如今&#xff0c;品质生活已经成为时尚的代名词。酒精壁炉以其精湛的…

图像分类系列(二) VGGNet学习详细记录

经典神经网络论文超详细解读&#xff08;二&#xff09;——VGGNet学习笔记&#xff08;翻译&#xff0b;精读&#xff09; 前言 上一篇我们介绍了经典神经网络的开山力作——AlexNet&#xff1a;经典神经网络论文超详细解读&#xff08;一&#xff09;——AlexNet学习笔记&a…

解密.locked1勒索病毒:专家级策略保护您的数据免受勒索攻击

导言&#xff1a; 在当今数字化的世界中&#xff0c;勒索病毒的威胁日益严峻。.locked1 勒索病毒作为其中的一种&#xff0c;采用高级的加密算法对用户文件进行加密&#xff0c;要求支付赎金以获取解密密钥。本文91数据恢复将介绍如何面对.locked1 勒索病毒&#xff0c;有效恢…

Python 3.12 版本有什么变化?

在前不久&#xff0c;python 3.12 正式发布了&#xff0c;那到底更新了哪些内容呢&#xff1f;一起来看看。 # 改善报错信息 来自官方标准库的模块现在可以在报NameError时提示问题原因&#xff0c;比如 >>> sys.version_info Traceback (most recent call last):Fi…

SpringBoot2—基础篇

目录 快速上手SpringBoot • SpringBoot入门程序开发 基于Idea创建SpringBoot工程&#xff08;一&#xff09; 基于官网创建SpringBoot工程&#xff08;二&#xff09; 基于阿里云创建SpringBoot工程&#xff08;三&#xff09; 手工创建Maven工程修改为SpringBoot工程&…

线程状态及线程之间通信

线程状态概述 当线程被创建并启动以后&#xff0c;它既不是一启动就进入了执行状态&#xff0c;也不是一直处于执行状态。在线程的生命周期中&#xff0c; 有几种状态呢&#xff1f;在 java.lang.Thread.State 这个枚举中给出了六种线程状态&#xff1a; 线程状态 导致状态发生…

Shopee活动名称怎么填写好?Shopee活动名称设置注意事项——站斧浏览器

虾皮活动名称的设定不仅是一个技巧性的问题&#xff0c;更是一门艺术。通过合理的活动名称设计&#xff0c;可以吸引更多的消费者参与活动&#xff0c;增加活动的曝光度和影响力。 shopee活动名称怎么填写好 简洁明了&#xff1a;活动名称应该尽量简洁明了&#xff0c;能够一…

北邮22级信通院数电:Verilog-FPGA(10)第十周实验 实现移位寄存器74LS595

北邮22信通一枚~ 跟随课程进度更新北邮信通院数字系统设计的笔记、代码和文章 持续关注作者 迎接数电实验学习~ 获取更多文章&#xff0c;请访问专栏&#xff1a; 北邮22级信通院数电实验_青山如墨雨如画的博客-CSDN博客 目录 一.代码部分 二.管脚分配 三.实现过程讲解及效…

上机练习 8: DataFrame 综合练习

记录一下做的练习题 目录 1)自定义一个 Series 并命名为 s1&#xff0c;自定义索引值&#xff0c;采用随机数作为其中数据尝试使用 s1.sum(计算其中所有数据的和,使用 s.mean(计算其中所有数据的平均值。 2)创建一个形状为4*6的 DataFrame 并命名为 df1,并指定行索引为[“a”…