Pytorch手动实现softmax回归

news2024/11/24 7:43:19

参考代码:https://blog.csdn.net/ccyyll1/article/details/126020585

softmax回归梯度计算方式,特别是i=j和i!= j时的计算问题,请看如下帖子中的描述,这个问题是反向传播梯度计算中的一个核心问题:反向传播梯度计算中的一个核心问题

直接上代码:

import torch  
import torchvision  
import torchvision.transforms as transforms  
import numpy as np  
#(2)下载并装载Fashion MNIST 数据集

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True,   
download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False,   
download=True, transform=transforms.ToTensor())  
#(3)构建迭代器

batch_size = 256  
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False) 
#(4)初始化学习参数

#初始化学习参数  
num_inputs = 784  
num_outputs = 10  
w = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)  
b = torch.zeros(num_outputs, dtype=torch.float)  
w.requires_grad_(requires_grad=True)  
b.requires_grad_(requires_grad=True) 
#(5)定义相关函数

#定义Softmax决策函数  
def softmax(x,w,b):  
    y = torch.mm(x.view(-1, num_inputs), w) + b  
    y_exp = y.exp()  
    y_sum = y_exp.sum(dim=1, keepdim=True)  
    return y_exp / y_sum  
#定义交叉熵损失函数  
def cross_entropy(y_hat, y):  
    #其中gather()就相当于是维度级高级的矩阵索引;并且真实值向量y中其他类别都为0所以不用考虑  
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))  
#定义梯度下降优化函数  
def sgd(params, lr, batch_size):  
    for param in params:  
        param.data -= lr * param.grad / batch_size  
#定义分类准确率  
def accuracy(y_hat, y):  
    return (y_hat.argmax(dim=1) == y).float().mean().item()  
#模型未训练前的准确率  
def evaluate_accuracy(data_iter, net):  
    acc_sum, n = 0.0, 0  
    for X, y in data_iter:  
        acc_sum += (net(X , w ,b).argmax(dim=1) == y).float().sum().item()  
        n += y.shape[0]  
    return acc_sum / n  
#(6)开始训练并计算每轮损失

#开始训练并计算每轮损失  
lr = 0.01  
num_epochs = 20  
net = softmax  
loss = cross_entropy  
for epoch in range(num_epochs):  
    train_l_sum, train_acc_sum, n = 0.0, 0.0, 0  
    for X, Y in train_iter:  
        l = loss(net(X, w, b), Y).sum()  
        l.backward()  
        sgd([w, b], lr, batch_size)  
        w.grad.data.zero_()  
        b.grad.data.zero_()  
        train_l_sum += l.item()  
        train_acc_sum += (net(X, w, b).argmax(dim=1) == Y).sum().item()  
        n += Y.shape[0]  
    test_acc = evaluate_accuracy(test_iter, net)  
    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'  
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))  

执行结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/777422.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

哈工大计算机网络课程局域网详解之:MAC地址与ARP协议

哈工大计算机网络课程局域网详解之:MAC地址与ARP协议 文章目录 哈工大计算机网络课程局域网详解之:MAC地址与ARP协议MAC地址ARP:地址解析协议寻址:从一个LAN路由至另一个LAN MAC地址 在介绍MAC地址前,首先回顾一下之前…

SAP ABAP 实现数据库表行项目和程序加解锁功能

1.SAP ABAP 实现数据库表行项目加解锁功能 实现效果: 当一个数据库表以某字段为关键字段的数据被锁定时,同一时间其他程序无法修改改表内被锁定的数据,除非被解锁或退出程序。 1.事务代码:SE11 创建锁对象。PS:命名…

【计组】不同进制数之间的相互转换

前言 1、推荐在线进制转换器:(都还不错) 在线进制转换 | 进制转换器 — 在线工具 (sojson.com) 在线进制转换器 | 菜鸟工具 (runoob.com) 在线进制转换 - 码工具 (matools.com) 2、进位计数法 (1)二进制&#xf…

JavaScript字符串和模板字面量

● 上节课我们说明,号可以当作字符串连接符号使用,例如 const firstName "Sun"; const job "技术分享博主"; const birthYear 1991; const year 2023;const sun "我叫" firstName ",是一个" (year - bi…

线性结构:队列

文章目录 队列定义队列应用热土豆问题打印任务 队列定义 队尾进,队头出 队列是一种有次序的数据集合,其特征是新数据项的添加总发生在一端(通常称为“尾rear”端)而现存数据项的移除总发生在另一端(通常称为“首front”端&#x…

刷题记录-2最短路径

考点&#xff1a; 图论-最短路-Dijkstra 解题&#xff1a; c #include <iostream> #include <vector> #include <queue> using namespace std; const long long inf 0x3f3f3f3f3f3f3f3fLL; const int num 3e52; struct edge {int from,to;long long w;e…

算法竞赛入门【码蹄集新手村600题】(MT1001-1020)

算法竞赛入门【码蹄集新手村600题】(MT1001-1020&#xff09; 目录MT1001 程序设计入门MT1002 输入和输出整型数据MT1003 整数运算MT1004 求余MT1005 输入和输出实型数据MT1006 实型数运算MT1007 平均分MT1008 圆球等的相关运算MT1009 公式计算MT1010 输入和输出字符型数据MT10…

【Visual Studio】Qt 在其他 cpp 文件中调用操作 ui 界面控件

知识不是单独的&#xff0c;一定是成体系的。更多我的个人总结和相关经验可查阅这个专栏&#xff1a;Visual Studio。 还整了一个如何相互之间调用函数的文章&#xff0c;感兴趣可以看&#xff1a;【Visual Studio】Qt 在其他 cpp 文件中调用主工程下文件中的函数。 文章目录 …

第四章:包围体

第四章&#xff1a;包围体 引言-包围体&#xff08;1&#xff09;包围体测试和几何体测试&#xff08;2&#xff09;包围体测试的代价和作用&#xff08;3&#xff09;相交测试的优化&#xff08;4&#xff09;包围体相关章节和主旨 一、BV 期望特征1.1 有效的包围体1.2 包围体…

docker 网络配置详解

目录 1、docker网络模式 2、容器和容器之间是如何互通 3、容器之间互通 --link 3、自定义网络 4、不通网段的容器进行网络互通 1、docker网络模式 docker 网络模式采用的是桥接模式&#xff0c;当我们创建了一个容器后docker网络就会帮我们创建一个虚拟网卡&#xff0c;这…

Electron 学习_在进程之间通信

1.问题&#xff1a;Electron的主进程和渲染进程有着清楚的分工&#xff0c;并且不可互换。从渲染进程直接访问Node.js 接口&#xff0c;亦或者 从主进程访问HTML文档对象模型(DOM)都是不可能的 2.解决方法&#xff1a;使用进程间通信 (IPC) 可以使用 Electron 的ipcMain 模块和…

Redisson限流器RRateLimiter使用及源码分析

一、使用 使用很简单、如下 // 1、 声明一个限流器 RRateLimiter rateLimiter redissonClient.getRateLimiter(key);// 2、 设置速率&#xff0c;5秒中产生3个令牌 rateLimiter.trySetRate(RateType.OVERALL, 3, 5, RateIntervalUnit.SECONDS);// 3、试图获取一个令牌&#…

TCP首部格式【TCP原理(笔记五)】

文章目录 TCP首部格式源端口号&#xff08;Source Port&#xff09;目标端口号&#xff08;Destination Port&#xff09;序列号&#xff08;Sequence Number&#xff09;确认应答号&#xff08;Acknowledgement Number&#xff09;数据偏移&#xff08;Data Offset&#xff09…

Oracle 普通视图 (Oracle Standard Views)

视图&#xff08;views&#xff09;是一种基于表的"逻辑抽象"对象&#xff0c;由于它是从表衍生出来的&#xff0c;因此和表有许多相同点&#xff0c;我们可以和对待表一样对其进行查询/更新操作。但视图本身并不存储数据&#xff0c;也不分配存储空间。 本文只讨论普…

Linux下搭建pyqt5开发环境—基于Pycharm

防踩坑Tips&#xff1a; 1、不能学windows那样直接用pip安装PyQt5Designer和pyqt5-tools。这两个模块最根本的是用的windows的程序&#xff0c;linux上是运行不了的&#xff0c;特别是PyQt5Designer&#xff0c;会提示安装失败。 2、推荐在python环境安装同系统版本一致的pyq…

2023.7.16 第五十九次周报

目录 前言 文献阅读:跨多个时空尺度进行预测的时空 LSTM 模型 背景 本文思路 本文解决的问题 方法论 SPATIAL 自动机器学习模型 数据处理 模型性能 代码 用Python编写的LSTM多变量预测模型 总结 前言 This week, I studied an article that uses LSTM to solve p…

数据分析系统中的六边形战士——奥威BI系统

数据分析软件可以对收集的数据进行分析和报告&#xff0c;帮助企业获得更深入的数据洞察力&#xff0c;从而推动企业数字化运营决策&#xff0c;提高决策效率与质量。进入大数据时代&#xff0c;企业对数据分析软件的要求也在水涨船高&#xff0c;传统的数据分析软件显然已不能…

数据结构 单向链表(不循环)的基础知识和基础操作

头定义&#xff1a; typedef int datatype; typedef struct Node {//数据域存储数据datatype data;//指针域存储下一个地址struct Node *next; }*Linkelist; 创建节点 Linkelist create_node()//创建新节点 {Linkelist node(Linkelist)malloc(sizeof(struct Node));if(nodeN…

Elasticsearch 源码探究 001——故障探测和恢复机制

1、Elasticsearch 故障探测及熔断背景 探究Elasticsearch7.10.2 节点之间的故障探测以及熔断故障是怎么做的&#xff0c;思考生产上的最佳实践。 服务端故障场景&#xff1a; 单个master挂掉 除了断点断网&#xff0c;状态同步异常&#xff0c;主master也会认为自己已经失败&am…

ASPICE V模型之软件需求

ASPICE V模型之软件需求 了解ASPICE认识软件需求软件需求分解软件需求工作流程 了解ASPICE ASPICE全称是“Automotive Software Process Improvement and Capacity Determination”汽车软件过程改进及能力评定&#xff0c;是汽车行业用于评价软件开发团队的研发能力水平的模型框…