线性神经网络——softmax 回归随笔【深度学习】【PyTorch】【d2l】

news2024/11/24 18:37:02

文章目录

    • 3.2、softmax 回归
      • 3.2.1、softmax运算
      • 3.2.2、交叉熵损失函数
      • 3.2.3、PyTorch 从零实现 softmax 回归
      • 3.2.4、简单实现 softmax 回归

在这里插入图片描述

3.2、softmax 回归

3.2.1、softmax运算

在这里插入图片描述

softmax 函数是一种常用的激活函数,用于将实数向量转换为概率分布向量。它在多类别分类问题中起到重要的作用,并与交叉熵损失函数结合使用。

y ^ = s o f t m a x ( o )      其中     y ^ i = e x p ( o j ) ∑ k e x p ( o k ) \hat{y} = softmax(o) \ \ \ \ \ 其中\ \ \ \ \hat{y}_i = \frac{exp(o_j)}{\sum_{k}exp(o_k)} y^=softmax(o)     其中    y^i=kexp(ok)exp(oj)

其中,O为小批量的未规范化的预测, Y ^ \hat{Y} Y^w为输出概率,是一个正确的概率分布【 ∑ y i = 1 \sum{y_i} =1 yi=1

3.2.2、交叉熵损失函数

通过测量给定模型编码的比特位,来衡量两概率分布之间的差异,是分类问题中常用的 loss 函数。

H ( P , Q ) = − Σ P ( x ) ∗ l o g ( Q ( x ) ) H(P, Q) = -Σ P(x) * log(Q(x)) H(P,Q)=ΣP(x)log(Q(x))

真实概率分布是从哪里得知的?

真实标签的概率分布是由数据集中的标签信息提供的,通常使用单热编码表示。

softmax() 如何与交叉熵函数搭配的?

softmax 函数与交叉熵损失函数常用于多分类任务中。softmax 函数用于将模型输出转化为概率分布形式,交叉熵损失函数用于衡量模型输出概率分布与真实标签的差异,并通过优化算法来最小化损失函数,从而训练出更准确的分类模型。

3.2.3、PyTorch 从零实现 softmax 回归

(非完整代码)

#在 Notebook 中内嵌绘图
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

#,将图形显示格式设置为 SVG 格式,以在 Notebook 中以矢量图形的形式显示图像。这有助于提高图像的清晰度和可缩放性。
d2l .use_svg_display()

在线下载数据集 Fashion-MNIST

#将图像数据转换为张量形式
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data"
                                                ,train=True,transform=trans,download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data"
                                               ,train=False,transform =trans,download=True)

len(mnist_train),len(mnist_test)

绘图(略)

读取小批量数据集

batch_size = 256

def get_dataloader_workers():
    """使用4进程读取"""
    return 4
    
train_iter = data.DataLoader(mnist_train,batch_size,shuffle=True,
                            num_workers=get_dataloader_workers())
timer = d2l.Timer()
for X,y in train_iter:
    continue
print(f'{timer.stop():.2f}sec')

定义softmax操作

def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

定义损失函数

def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

cross_entropy(y_hat, y)

分类精度

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

评估

def evaluate_accuracy(net, data_iter):  #@save
    """计算在指定数据集上模型的精度"""
    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]
class Accumulator:  #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

3.2.4、简单实现 softmax 回归

导入前面已下载数据集 Fashion-MNIST

import torch 
from torch import nn
from d2l import torch as d2l

batch_size =256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型

#nn.Flatten() 层的作用是将输入数据展平,将二维输入(如图像)转换为一维向量。因为线性层(nn.Linear)通常期望接收一维输入。
#nn.Linear(784,10) 将输入特征从 784 维降低到 10 维,用于图像分类问题中的 10 个类别的预测   784维向量->10维向量
net = nn.Sequential(nn.Flatten(),nn.Linear(784,10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight,std=0.01)
        
net.apply(init_weights);
#计算交叉熵损失函数,用于衡量模型预测与真实标签之间的差异。参数 reduction 控制了损失的计算方式。
#reduction='none' 表示不进行损失的降维或聚合操作,即返回每个样本的独立损失值。
loss = nn.CrossEntropyLoss(reduction='none')

优化算法

trainer = torch.optim.SGD(net.parameters(),lr=0.1)

训练

num_epochs = 10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/785493.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于STM32F10x的独立按键测试

本人基于野火指南者开发板,使用FreeRTOS,创建按键任务。 在按键任务里面每隔20毫秒扫描一下按键。包括独立按键,矩阵按键(由于本人没有矩阵按键,故没有测试)。 按键40毫秒以上为短按、1秒以上则为长按、3秒以上则为一直按。且长按10秒以上则…

网络安全(黑客)自学基础到高阶路线

01 什么是网络安全 网络安全可以基于攻击和防御视角来分类,我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术,而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 无论网络、Web、移动、桌面、云等哪个领域,都有攻与防两面…

Docker + MYSQL 启动nacos

Docker启动nacos默认用的是内存数据库,重启docker容器以后,nacos配置会丢失,非常不方便。所以需要修改为使用Mysql作为nacos的存储。 1.数据库 创建mysql数据库,过程省略,将nacos 的mysql脚本在数据库中进行导入。 m…

北航投资已投企业四象科技成功发射三颗卫星

1箭4星!2023年7月23日10时50分,我国在太原卫星发射中心使用长征二号丁运载火箭,成功将四象科技“矿大南湖号”SAR遥感卫星、“虹口复兴号”光学遥感卫星、“中电农创号”热红外遥感卫星以及银河航天灵犀03星共4颗卫星发射升空,卫星…

当机器人变硬核:探索深度学习中的时间序列预测

收藏自:Wed, 15 Sep 2021 10:32:56 UTC 摘要:时间序列预测是机器学习和深度学习领域的一个重要应用,它可以用于预测未来趋势、分析数据模式和做出决策。本文将介绍一些基本概念和常用方法,并结合具体的案例,展示如何使…

7D透明屏的市场应用广泛,在智能家居中有哪些应用表现?

7D透明屏是一种新型的显示技术,它能够实现透明度高达70%以上的显示效果。这种屏幕可以应用于各种领域,如商业广告、展览展示、智能家居等,具有广阔的市场前景。 7D透明屏的工作原理是利用光学投影技术,将图像通过透明屏幕投射出来…

VMware 创建Centos7虚拟机后nat模式无法联网

1. 网卡改为on,dhcp模式,重启网卡,如果还是无法联网 2.修改 /etc/resolv.conf,增加DNS 223.5.5.5,保存后即可ping 通百度,联网。在此记录一下

SolVES模型安装教程

前文是关于SolVES模型扫盲,熟悉SolVES模型的伙伴可直接跳到下面的安装教程。 目前生态系统服务评估主要集中于经济价值,相关的评估方法也较多,如价值当量法、InVEST模型法、市场价格法等,而随着生态系统服务的社会价值得到越多越…

firefox笔记-Centos7离线安装firefox

目前(2023-03-22 16:41:35)Centos7自带的firefox已经很新了是2020年的。主要原因是有个web项目,用2020年的firefox打不开。 发到互联网上是2023-07-24。 报错是js有问题,估计是搞前端的只做了chrome适应,没做firefox…

将数据转二进制流文件,用PostMan发送二进制流请求

一、将byte数组转二进制流文件,并保存到本地 byte [] oneshotBytesnew byte[]{78,-29,51,-125,86,-105,56,82,-94,-115,-22,-105,0,-45,-48,-114,27,13,38,45,-24,-15,-13,46,88,-90,-66,-29,52,-23,40,-2,116,2,-115,17,36,15,-84,88,-72,22,-86,41,-90,-19,-58,19…

Docker 的数据管 与 Dockerfile

目录 Docker 的数据管理容器互联(使用centos镜像)Docker 镜像的创建1.基于现有镜像创建2.基于本地模板创建3.基于Dockerfile 创建镜像加载原理 Dockerfile 操作常用的指令(1)FROM 镜像&#xff…

平台使用篇 | RflySim平台Simulink-PSP工具箱使用简介

导读 Pixhawk Pilot Support Package (PSP,自驾仪支持包)工具箱是Mathworks公司官方为Pixhawk推出的一个工具箱。本篇围绕RflySim平台Simulink-PSP工具箱使用进行详解。 RflySim平台Simulink-PSP工具箱使用简介 PSP工具箱 Pixhawk Pilot Support Package (PSP,自驾仪支持包)工…

显卡水洗充新、冒牌作坊彻底凉凉,1500万销量团伙被一锅端

上一轮显卡挖矿潮自 2020 年底开始,直到 2022 年底尾声,历时两年左右。 在这一波矿潮冲击过后,大量水洗、二手矿卡横行,显卡市场可谓一片混乱。 翻新、杂牌显卡厂商表示:就挺突然的,感觉人生到达了巅峰&a…

某奇艺缺陷书写规范及缺陷严重程度划分

目录 一、最基本的要求: 二、Bug标题 三、复现步骤 四、描述 五、期望结果 六、实际结果 七、附件 八、备注 九、Bug定级(优先级) Bug书写规范: 一、最基本的要求: 1、Bug内所有的文字表述要通顺&#xff0c…

最新!王中林院士再获全球大奖:“开创让西方跟随的研究领域”

最新!王中林院士再获全球大奖:“开创让西方跟随的研究领域” 北京时间7月6日下午17:00,2023年度“全球能源奖”(Global Energy Prize)揭晓,中国科学院北京纳米能源与系统研究所首席科学家王中林院士因发明摩…

Debug Stable Diffusion webui

文章目录 SD前期预备一些惊喜TorchHijackForUnet Txt2Img 搭配 Lora 使用单独运行 txt2img.py获取所有资源代码地址参数sd model 主程序代码地址参数(同上)模型InferenceLORA应用重构并使用LORA模型用Lora重构后的网络 做 sampler后处理 以下内容是最近的学习笔记,如…

Microsoft发布用于 AutoML 算法和训练的 NNI v1.3

将传统的机器学习方法应用于现实世界的问题可能非常耗时。自动化机器学习 (AutoML) 旨在改变这种状况——通过对原始数据运行系统流程并选择从数据中提取最相关信息的模型,使构建和使用 ML 模型变得更加容易。 为了帮助用户以高效和自动的方…

Redis 九种数据类型的基本操作

一、redis9种数据类型的基本操作 ①key操作 #查找所有的key 127.0.0.1:6379> keys * 1) "pop" 2) "mylist" 3) "lpl" 4) "myset" #设置key的过期时间 返回1表示执行成功,0表示失败,出现问题 127.0.0.1:6379…

【Spring Boot Admin】介绍以及使用

介绍 概述 Spring Boot Admin是一个监控工具,旨在以一种漂亮且易于访问的方式可视化Spring Boot Actuators提供的信息。 主要功能点 显示应用程序的监控状态应用程序上下线监控查看 JVM,线程信息可视化的查看日志以及下载日志文件动态切换日志级别Http…

【C++】入门基础2

引用 概念 引用不是新定义一个变量,而是给已存在变量取了一个别名,编译器不会为引用变量开辟内存空 间,它和它引用的变量共用同一块内存空间 类型& 引用变量名(对象名) 引用实体; 注意:引用类型必须和引用实体是…