softmax从零开始实现

news2025/2/24 20:03:28

softmax从零开始实现

  • 代码
  • 结果

代码

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils import data

# H,W,C -> C,H,W
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, download=True,
                                                transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, download=True,
                                               transform=transforms.ToTensor())
batch_size = 256
# 随机读取⼩批量
train_loader = data.DataLoader(mnist_train, batch_size, shuffle=True)
test_loader = data.DataLoader(mnist_test, batch_size, shuffle=True)

# feature, label = mnist_train[0]
# print(feature.shape, label) # torch.Size([1, 28, 28]) 9

num_inputs = 784
num_outputs = 10


def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim=1, keepdim=True)  # 按行
    return X_exp / partition  # 这⾥应⽤了⼴播机制


def net(X):
    return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)


def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))


def sgd(params, lr, batch_size):
    for param in params:
        param.data -= lr * param.grad / batch_size  # 注意这⾥更改param时⽤的param.data


def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()

W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)
W.requires_grad_()
b.requires_grad_()

num_epochs, lr = 10, 0.1
loss = cross_entropy
optimizer = sgd
for epoch in range(1, 1 + num_epochs):
    total_loss = 0.0
    train_sample = 0.0
    train_acc_sum = 0
    for x, y in train_loader:
        y_hat = net(x)
        l = loss(y_hat, y) # 256,1
        # 梯度清零
        l.sum().backward()
        sgd([W, b], lr, batch_size)  # 使用参数的梯度更新参数

        W.grad.data.zero_()
        b.grad.data.zero_()

        total_loss += l.sum().item()
        train_sample += y.shape[0]
        train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()

    print('epoch %d, loss %.4f, train acc %.3f' % (epoch, total_loss / train_sample, train_acc_sum / train_sample,))

with torch.no_grad():
    total_loss = 0.0
    test_sample = 0.0
    test_acc_sum = 0
    for x, y in test_loader:
        y_hat = net(x)
        l = loss(y_hat, y)  # 256,1

        total_loss += l.sum().item()
        test_sample += y.shape[0]
        test_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
    print('loss %.4f, test acc %.3f' % (total_loss / test_sample, test_acc_sum / test_sample,))

结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1889655.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《石化技术 》是什么级别的期刊?是正规期刊吗?能评职称吗?

问题解答 问:《石化技术 》是不是核心期刊? 答:不是,是知网收录的第一批认定学术期刊。 问:《石化技术 》级别? 答:国家级。主管单位:中国石油化工集团有限公司 主办单位&…

虚拟机下基于海思移植QT(一)——虚拟机下安装QT

0.参考资料 1.海思Hi3516DV300 移植Qt 运行并在HDMI显示器上显示 2.搭建海思3559A-Qt4.8.7Openssl开发环境 1.报错解决 通过下面命令查询 strings /lib/x86_64-linux-gnu/libc.so.6 | grep GLIBC_通过命令行没有解决: sudo apt install libc6-dev libc6参考解决…

内容监管与自由表达:Facebook的平衡之道

在当今数字化信息社会中,社交媒体平台不仅是人们交流和获取信息的主要渠道,也是自由表达的重要舞台。Facebook,作为全球最大的社交网络平台,连接了数十亿用户,形成了一个丰富多样的信息生态。然而,如何在维…

【小白教学】-- 安装Ubuntu-20.04系统

下载 Ubuntu-20.04 镜像 具体如何下载镜像,请移驾我上一篇文章。使用清华大学开源镜像站下载。https://zhuanlan.zhihu.com/p/706444837 制作 Ubuntu-20.04 系统盘 安装软件 UltralSO 开始制作系统盘 第一步,插入一个 u 盘,启动软件&#x…

博客的部署方法论

博客写完后,当然是要发布到网络上的。如果想要部署到服务器上,则需编译构建成静态文件,然后将其上传到服务器上的路径(该路径由我们自己决定),然后在 web 服务器(Nginx 等)上配置访问…

1688商品点击率和转化上不去,这9个细节你肯定没优化!

一、点击率上不去的原因 1、主图设计不佳 ✘问题:主图不够吸引人,无法迅速抓住顾客的注意力。 ✔解决方案: 通过店雷达主图分析,查询市场上目前在此关键词下的所有商品主图,一键对比多个关键销售指标数据&#xff…

go-redis源码解析:cluster模式如何选择节点

1. 如何选择节点 1.1. 确定slot 1.1.1. 通过cmdSlot方法确定在哪个槽上, 这一步只是本地计算 首先入口方法_process,先通过cmdSlot方法用key计算此次应该落在哪个槽上 通过crc16sum算法计算key应该属于哪个槽,slotNumber为16384 func Slot(key strin…

【设计模式】策略模式(定义 | 特点 | Demo入门讲解)

文章目录 定义策略模式的结构 QuickStart | DemoStep1 | 策略接口Step2 | 策略实现Step3 | 上下文服务类Step4 | 客户端 策略模式的特点优点缺点 定义 策略模式Strategy是一种行为模式,它能定义一系列算法,并将每种算法分别放入到独立的类中&#xff0c…

Unity Shader技巧:实现带投影机效果,有效避免边缘拉伸问题

这个是原始的projector 投影组件,边缘会有拉伸 经过修改shader 后边缘就没有拉伸了 (实现代码在文章最后) 这个着色器通过检查每个像素的UV坐标是否在定义的边界内,来确定是否应用黑色边框。如果UV坐标处于边缘区域,那么像素颜色会被强制设为黑色,从而在投影图像周围形成一…

技术成神之路:设计模式(三)原型模式

1. 定义 原型模式(Prototype Pattern)是一种创建型设计模式,旨在通过复制现有对象来创建新对象,而不是通过实例化类的方式。这个模式可以提高对象创建的效率,尤其是在创建对象的过程非常复杂或代价高昂时。 2. 结构 原…

三分钟看懂SMD封装与COB封装的差异

全彩LED显示屏领域中,COB封装于SMD封装是比较常见的两种封装方式,SMD封装产品主要有常规小间距以及室内、户外型产品,COB封装产品主要集中在小间距以及微间距系列产品中,今天跟随COB显示屏厂家中品瑞一起快速看懂SMD封装与COB封装…

Let‘s Encrypt 申请免费 SSL 证书(60天以后自动更新证书)

文章目录 官网文档简介安装 Nginxacme.sh生成证书智能化生成证书 安装证书查看已安装证书更新证书 官网 https://letsencrypt.org/zh-cn/ 文档 https://letsencrypt.org/zh-cn/docs/ 简介 Let’s Encrypt 是一个非营利组织提供的免费SSL/TLS证书颁发机构,旨在促…

滚珠花键促进汽车产业整体升级与发展!

滚珠花键能够实现高效的传动和连接,确保物体在运动过程中的精确位置和稳定性,被广泛应用于机械制造、航空航天、工业自动化、工业汽车、工业机器人、高速铁路等领域。为各个行业的发展提供了重要支持,尤其是在工业汽车领域中,为我…

Spring系统学习-什么是AOP?为啥使用AOP?

问题思考 我们为啥要使用AOP? 来看一个案例: 声明计算器接口Calculator,包含加减乘除的抽象方法 public interface Calculator {int add(int i, int j);int sub(int i, int j);int mul(int i, int j);int div(int i, int j); }public class Calculat…

HNU电子测试平台与工具2_《计算机串口使用与测量》

(这个有留word哈哈) 4.1 4.2 Linux 操作系统平台 一、实验目的 了解 Linux 系统文件系统的基本组织了解 Linux 基本的多用户权限系统熟练使用 ls、cd、cat、more、sudo、gcc、vim 等基本命令会使用 ls 和 chmod 命令查看和修改文件权限 二、实…

使用机器学习,通过文本分析,轻松实现原本复杂的情感分析

01、案例说明 本期分享案例是:文字分析-情感分析,内容是关于某部电影评论好坏的分析,使用大量的已知数据,通过监督学习的方法,可以对于未知的评论进行判断其为正面还是负面的评价。 对于数据分析,原来都是…

contrastive loss and triplet loss

contrastive loss 想要positive pair 的距离接近0 而triplet loss只需要postive pair的距离 和 negtaive pair 的距离相差m以上, postive pair的距离不一定为0, 所以条件更松驰, 为网络训练更友好一点 reference: https://www.youtube.com/watch?vBcF6FfZHDqA

图书管理系统(含登录验证码操作)

文章目录 登录需求分析登录界面注册功能:登录功能:忘记密码:验证码规则: 图书管理系统需求Book包Book类BookList类 IOperation包IOperation接口查找图书新增图书删除图书显示图书借阅图书归还图书退出系统 User包user类Users类adm…

【Java】如果让你设计一个分布式链路跟踪系统?你怎么做?

一、问题解析 分布式链路跟踪服务属于写多读少的服务,是我们线上排查问题的重要支撑。我经历过的一个系统,同时支持着多条业务线,实际用上的服务器有两百台左右,这种量级的系统想排查故障,难度可想而知。 因此&#…

大数据------JavaWeb------JSP(完整知识点汇总)

JSP 定义 JSP(Java Server Pages),即Java服务端页面。它是一种动态的网页技术,其中可以定义HTML、CSS、JS等静态内容,还可以定义Java代码的动态内容JSP HTML Java 说白了JSP就是一个页面,它既可以写HTML标…