3.Softmax回归

news2024/11/15 18:02:35

回归和分类

回归估计一个连续值

分类预测一个离散类别

Softmax回归实际是一个分类问题

在这里插入图片描述

从回归到多类分类

对类别进行一位有效编码

y = [ y 1 , y 2 , ⋯   , y n ] T y=[y_1,y_2,\cdots,y_n]^T y=[y1,y2,,yn]T,如果是第i类,则值为1,否则为0

使用均方损失训练,最大值预测为(即softmax函数)
y ^ = a r g m a x i   o i \hat y = argmax_i\ o_i y^=argmaxi oi
需要更置信的识别正确类(大余量):

o y − o i ≥ Δ ( y , i ) o_y -o_i\ge \Delta(y,i) oyoiΔ(y,i)

校验比例

输出匹配概率(非负,和为1)
y ^ = s o f t m a x ( o ) y ^ i = e x p ( o i ) ∑ k e x p ( o k ) \hat y = softmax(o)\\ \hat y_i =\frac{exp(o_i)}{\sum_k exp(o_k)} y^=softmax(o)y^i=kexp(ok)exp(oi)
概率 y y y y ^ \hat y y^的区别作为损失

交叉熵损失

交叉熵用来衡量两个概率的区别 H ( p , q ) = ∑ i − p i l o g ( q i ) H(p,q)=\sum_i - p_ilog(q_i) H(p,q)=ipilog(qi)

将它作为损失函数:
l ( y , y ^ ) = − ∑ i y i l o g y ^ i = − l o g y ^ y (假设是第 y 类) l(y,\hat y)=-\sum_i y_ilog\hat y_i = -log \hat y_y (假设是第y类) l(y,y^)=iyilogy^i=logy^y(假设是第y类)
​ 关心正确类的预测值

其梯度是真实概率和预测概率的区别
∂ o i l ( y , y ^ ) = s o f t m a x ( o ) i − y i \partial_{o_i}l(y,\hat y) =softmax(o)_i -y_i oil(y,y^)=softmax(o)iyi

损失函数

均方损失(L2 Loss)


l ( y , y ′ ) = 1 2 ( y − y ′ ) 2 l(y,y')=\frac 12 (y-y')^2 l(y,y)=21(yy)2
​ 在梯度下降时,预测值与真实值相差较远时,梯度会较大,但在离原点比较远时,可能并不希望有较大的梯度,这种情况下可以使用L1 Loss。

绝对值损失(L1 Loss)

l ( y , y ′ ) = ∣ y − y ′ ∣ l(y,y')=|y-y'| l(y,y)=yy

​ 好处就是,无论离原点多远,梯度下降时的导数都是正负1,但在比较接近时,可能就出现振荡了

Huber’s Robust Loss

​ 结合两种的好处
KaTeX parse error: Unknown column alignment: * at position 32: … \begin{array}{*̲*lr**} |y-y'|-\…

读取多类分类的数据集

图像分类数据集

​ 使用Fashion-MNIST数据集

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

# 看一下图片的形状

def get_fashion_mnist_labels(labels):
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = [
        't-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'
    ]

    return [text_labels[int(i)] for i in labels]


def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    "画图"
    figsize = (num_cols * scale, num_rows * scale)
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 是图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    d2l.plt.show()  # 加上show图片才会显示
    return axes


def get_dataloader_workers():
    '''使用4个进程来读取数据'''
    return 4


def load_data_fashion_mnist(batch_size, resize=None):  #resize可以改变图片的大小
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]# 将图片转换成tensor
# 将图片下载,train表示是训练数集,transform表示是tensor而不是图片,download表示从网上下载
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    # 将图片下载,train表示是训练数集,transform表示是tensor而不是图片,download表示从网上下载
    mnist_train = torchvision.datasets.FashionMNIST(
        root="./data", train=True, transform=trans, download=True)
    # 训练数据集的下载,则train是False
    mnist_test = torchvision.datasets.FashionMNIST(
        root="./data", train=False, transform=trans, download=True)
    print(len(mnist_train))
    print(len(mnist_test))
    print(mnist_train[0][0].shape)  # 黑白图片,所以channel为1,train[0]表示取第一个元素,第二个[0]表示是取图片,[1]表示取标签
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

d2l.use_svg_display()  # 使用svg来显示图片
# 通过ToTenseor实例将图像数据从PIL类型变换成32位浮点数格式
# 并除以255使得所有像素的值均在0到1之间

# 将数据集放进dataloader里面,指定一个batch_size,我们就可以得到一个批次的数据
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
# show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))

batch_size = 256

train_iter = data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=get_dataloader_workers())
timer = d2l.Timer()
for X, y in train_iter:
    continue

print(f'{timer.stop():.2f} seconds')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1920497.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

摸鱼大数据——Kafka——Kafka的shell命令使用

Kafka本质上就是一个消息队列的中间件的产品,主要负责消息数据的传递。也就说学习Kafka 也就是学习如何使用Kafka生产数据,以及如何使用Kafka来消费数据 topics操作 注意: 创建topic不指定分区数和副本数,默认都是1个 分区数可以后期通过alter增大,但是…

k8s集群离线部署

K8s离线部署 环境 目标 k8s离线部署 步骤 部署docker 详情见文章:《离线安装docker及后端项目离线打包》 https://blog.csdn.net/qq_45371023/article/details/140279746?spm1001.2014.3001.5501 所用到的所有文件在: 链接:https://pan…

摸鱼大数据——Kafka——Kafka的集群搭建

1、软件安装 搭建Kafka集群 1、下载安装 安装包下载地址:https://kafka.apache.org/download 2、将Kafka的安装包上传到虚拟机,并解压 cd /export/software/ tar -xzvf kafka_2.12-2.4.1.tgz -C ../server/ 配置软连接: cd /export/server ln -s kaf…

Debezium日常分享系列之:Debezium 3.0.0.Alpha1 Released

Debezium日常分享系列之:Debezium 3.0.0.Alpha1 Released 一、重大改变Java 和 Maven 要求已更改 二、新的特征和提高MongoDB 三、更多内容 Debezium 3 的第一个预发布版本 3.0.0.Alpha1。这个版本虽然比正常的预版本要小,但高度关注几个关键点&#xff…

【漏洞复现】Splunk Enterprise for Windows 任意文件读取漏洞 CVE-2024-36991

声明:本文档或演示材料仅用于教育和教学目的。如果任何个人或组织利用本文档中的信息进行非法活动,将与本文档的作者或发布者无关。 一、漏洞描述 Splunk Enterprise 是一款强大的机器数据管理和分析平台,广泛应用于企业中,用于实…

【单片机毕业设计选题24058】-基于嵌入式的智慧酒店管理系统设计与实现

系统功能: 系统分为主机端和从机端,主机端主动向从机端发送信息和命令,从机端 收到主机端的信息后回复温湿度和光照强度信息。 从机端操作: 从机端上电后显示“欢迎使用智慧酒店系统请稍后”两秒后进入正常显示界面。 第一行显示系统状态…

文心快码——百度研发编码助手

介绍 刚从中国互联网大会中回来,感受颇深吧。百度的展商亮相了文心快码,展商人员细致的讲解让我们一行了解到该模型的一些优点。首先,先来简单介绍一下文心快码吧。 文心快码(ERNIE Code)是百度公司推出的一个预训练…

Go语言---并发编程之channel(双channel,单channel)以及应用实例(生产者消费者、打印机模型)

Channel goroutine 运行在相同的地址空间,因此访问共享内存必须做好同步。goroutine 通过通信来共享内存,而不是其享内存来通信。 引用类型 channel 是CSP 模式的具体实现,用于多个 goroutine 通讯。其内部实现了同步,确保并发安全。 chan…

【Linux】磁盘性能压测-FIO工具

一、FIO工具介绍 fio(Flexible I/O Tester)是一个用于评估计算机系统中 I/O 性能的强大工具。 官网:fio - fio - Flexible IO Tester 注意事项! 1、不要指定文件系统名称(如/dev/mapper/centos-root),避…

vue + echart 饼形图

图表配置: import { EChartsOption, graphic } from echarts import rightCircle from /assets/imgs/index/right_circle.png export const pieOption: EChartsOption {title: {text: 100%,subtext: 游客加量,left: 19%,top: 42%,textStyle: {fontSize: 24,color:…

如何评估媒体邀约宣传的效果

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 评估媒体邀约宣传的效果是一个系统而全面的过程,它涉及多个维度的考量和分析。 一、受邀媒体的出席率: 1.受邀媒体出席率直观反映了媒体邀约的效果; …

「C++系列」一篇文章说透【存储类】

文章目录 一、C 存储类1. 类的定义2. 对象的创建3. 对象在内存中的布局4. 对象的存储位置 二、auto 存储类1. auto的基本用法2. auto与存储类的关系1) 自动存储类(最常见的)2) 静态存储类3) 动态存储类(通过new) 三、register 存储…

自学第十五天----深入理解函数上

1. 函数是什么? 维基百科中对函数的定义: 子程序 在计算机科学中,子程序(英语:Subroutine, procedure, function, routine, method, subprogram, callable unit),是一个大型程序中的某部分代码…

【2-1:RPC设计】

RPC 1. 基础1.1 定义&特点1.2 具体实现框架1.3 应用场景2. RPC的关键技术点&一次调用rpc流程2.1 RPC流程流程两个网络模块如何连接的呢?其它特性RPC优势2.2 序列化技术序列化方式PRC如何选择序列化框架考虑因素2.3 应用层的通信协议-http2.3.1 基础概念大多数RPC大多自…

windows上修改redis端口号

概况 redis是一个开源的内存数据结构存储系统,常用做数据库、缓存和消息代理。默认的端口号为6379 更改redis端口号步骤如下 先停止redis服务 redis-cli shutdowm 打开redis配置文件 在redis安装目录下,即redis.windows.conf文件。 port 6396 然后…

插片式远程 I/O模块:热电阻温度采集模块与PLC配置案例

XD系列成套系统主要由耦合器、各种功能I/O模块、电源辅助模块以及终端模块组成。有多种通讯协议总线的耦合器,例如Profinet、EtherCAT、Ethernet/IP、Cclink IE以及modbus/TCP等。I/O 模块可分为多通道数字量输入模块、数字量输出模块、模拟量输入模块、模拟量输出模…

js前端隐藏列 并且获取值,列表复选框

列表框 <div class"block" id"psi_wh_allocation_m"><table id"result" class"list auto hover fixed" style"width:100%;border-collapse:collapse"><thead><tr><%--<th></th>--%&…

人类大脑的计算与机器的类脑计算

人类大脑的计算基本原理涉及到神经元的基本工作方式、神经网络的结构和连接模式、信息传递的方式、学习和记忆的机制等多个层面的复杂互动&#xff0c;这些原理的深入理解不仅有助于神经科学的发展&#xff0c;还为人工智能领域的发展提供了重要的启示和指导。人类大脑计算基本…

【JavaScript 报错】未捕获的加载错误:Uncaught LoadError

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 一、错误原因分析1. 资源路径错误2. 资源不存在3. 网络问题 二、解决方案1. 检查资源路径2. 确保资源存在3. 处理网络问题 三、实例讲解四、总结 在JavaScript应用程序中&#xff0c;未捕获的加载错误&#xff08;Uncaught …

电脑录音如何操作?电脑麦克风声音一起录制,分享7款录音软件

电脑录音已经成为我们日常生活和工作中不可或缺的一部分。无论是录制会议、教学、音乐、网络直播、音源采集还是其他声音&#xff0c;电脑录音软件都为我们提供了极大的便利。本文将为大家介绍如何操作电脑录音&#xff0c;并分享七款录音软件&#xff0c;包括是否收费、具体操…