交叉熵Loss多分类问题实战(手写数字)

news2025/1/10 3:01:37

1、import所需要的torch库和包
在这里插入图片描述
2、加载mnist手写数字数据集,划分训练集和测试集,转化数据格式,batch_size设置为200在这里插入图片描述
3、定义三层线性网络参数w,b,设置求导信息
在这里插入图片描述
4、初始化参数,这一步比较关键,是否初始化影响到数据质量以及后续网络学习效果
在这里插入图片描述
5、自定义三层线性网络
在这里插入图片描述
6、选定优化器激活函数和loss函数
在这里插入图片描述
7、训练及测试,并记录每轮训练的loss变化和在测试集上的效果。第一轮就达到了98的准确度,判断是初始化效果较好,在前几次测试中根据初始化的情况不同,初始准确率为50%-85%不等
在这里插入图片描述
完整代码:

import torch
import torchvision
import torch.nn.functional as F

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                          transform=torchvision.transforms.Compose([
                              torchvision.transforms.ToTensor(),
                              torchvision.transforms.Normalize(
                                  (0.1307, ), (0.3081, ))
                              ])
                          ),
    batch_size=200, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=False, download=True,
                          transform=torchvision.transforms.Compose([
                              torchvision.transforms.ToTensor(),
                              torchvision.transforms.Normalize(
                                  (0.1307, ), (0.3081, ))
                              ])
                          ),
    batch_size=200, shuffle=True)

w1 = torch.randn(200, 784, requires_grad=True)
b1 = torch.randn(200, requires_grad=True)
w2 = torch.randn(200, 200, requires_grad=True)
b2 = torch.randn(200, requires_grad=True)
w3 = torch.randn(10, 200, requires_grad=True)
b3 = torch.randn(10, requires_grad=True)

torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)

def forward(x):
    x = x@w1.t() +b1
    x = F.relu(x)
    x = x@w2.t() +b2
    x = F.relu(x)
    x = x@w3.t() +b3
    x = F.relu(x)
    
    return x
    
optimizer = torch.optim.Adam([w1, b1, w2, b2, w3, b3], lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(10):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)
        logits = forward(data)
        loss = criterion(logits, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (batch_idx+1) % 150 == 0:
            print('Train Epoch:{} [{}/{}({:.0f}%)]\tLoss:{:.6f}'.format(
                epoch, (batch_idx+1) * len(data), len(train_loader.dataset),
                100. * (batch_idx+1) / len(train_loader), loss.item())
            )
            
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28*28)
        logits = forward(data)
        test_loss += criterion(logits, target).item()
        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()
    
    test_loss /= len(test_loader)
    print('\nTest Set:Average Loss:{:.4f}, Accuracy:{}/{}({:.0f}%)\n'.format(
         test_loss, correct, len(test_loader.dataset),
         100. * correct / len(test_loader.dataset))
    )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1082746.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

bootz启动 Linux内核涉及do_bootm_linux 函数

一. bootz启动Linux uboot 启动Linux内核使用bootz命令。当然还有其它的启动命令,例如,bootm命令等等。 本文只分析 bootz命令启动 Linux内核的过程中涉及的几个重要函数。具体分析 do_bootm_linux函数执行过程。 本文继上一篇文章,地址…

vue3 + element Plus实现表格根据关键字合并行,并实现行的增删改操作

根据关键字合并表格 1.实现初始化表格2.实现添加班级与学生的功能3.添加的弹窗4.删除班级5.删除学生 首先看最终实现的效果 1.实现初始化表格 <template><div class"main-page"><div class"flex-end"><div class"public-search…

Service Weaver:以单体形式编码,以微服务形式部署

分布式应用的主流架构模式演化为微服务架构已经有些年头了。微服务、DevOps、持续交付和容器技术(k8s)是构成最初云原生概念[1]的核心要素。它们相生相拌&#xff0c;共同演进&#xff0c;并推动了云计算全面进入云原生时代。 云原生应用普遍采用微服务架构&#xff0c;遗留的单…

阿里云r7服务器内存型CPU采用

阿里云服务器ECS内存型r7实例是第七代内存型实例规格族&#xff0c;CPU采用第三代Intel Xeon可扩展处理器&#xff08;Ice Lake&#xff09;&#xff0c;基频2.7 GHz&#xff0c;全核睿频3.5 GHz&#xff0c;计算性能稳定&#xff0c;CPU内存比1:8&#xff0c;2核16G起步&#…

HBase 表如何按照某表字段排序后顺序存储的方法?

首先需要明白HBase表的排序规则&#xff1a; &#xff08;1&#xff09;rowkey排序&#xff08;字典排序&#xff09;——升序 &#xff08;2&#xff09;Column排序&#xff08;字典排序&#xff09;——升序 &#xff08;3&#xff09;时间戳排序——降序 rowkey 字典序排序…

DeskHIL桌面级仿真测试平台

硬件在环测试系统&#xff0c;一直是汽车电子功能测试的重要工具。随着测试工程师对硬件在环系统的要求越来越高&#xff0c;对其高集成性、便携性的需求也愈发强烈。由于硬件在环系统设计复杂、定制化程度高等因素&#xff0c;成本一直居高不下&#xff0c;因此&#xff0c;市…

C++基础入门学习笔记

问题1&#xff1a;什么是 C 中的多态&#xff1f;如何实现多态&#xff1f; 回答1&#xff1a;C 中的多态是指同一种类型的实体&#xff0c;可以在不同的情况下表现出不同的行为。实现多态的方式有两种&#xff1a;虚函数和模板函数。虚函数是在基类中声明为虚函数的函数&…

c++视觉处理---仿射变换和二维旋转变换矩阵的函数

仿射变换cv::warpAffine cv::warpAffine 是OpenCV中用于执行仿射变换的函数。仿射变换是一种线性变换&#xff0c;可用于执行平移、旋转、缩放和剪切等操作。下面是 cv::warpAffine 函数的基本用法&#xff1a; cv::warpAffine(src, dst, M, dsize, flags, borderMode, borde…

嵌入式Linux裸机开发(七)UART串口、IIC、SPI通信

系列文章目录 文章目录 系列文章目录前言UART串口通信介绍UART配置 IIC介绍I.MX6U 的 I2C SPI介绍I.MX6U ECSPI 结语 前言 大概学完这三种通信后&#xff0c;之后就先去学系统移植&#xff0c;其他的先暂时放下 UART串口通信 介绍 串口全称叫做串行接口&#xff0c;通常也叫…

Maven创建父子工程详解

引言 在微服务盛行的当下&#xff0c;我们创建的工程基本都是父子工程&#xff0c;我们通过父工程来引入jar&#xff0c;定义统一的版本号等&#xff0c;这样我们在子工程中就可以直接引用后使用了&#xff0c;而不需要去重复的声明版本号等&#xff0c;这样会更方便对整个项目…

Linux中Locate命令查找不全

Locate locate(locate) 命令用来查找文件或目录。 locate命令要比find -name快得多&#xff0c;原因在于它不搜索具体目录&#xff0c;而是搜索一个数据库/var/lib/mlocate/mlocate.db 。这个数据库中含有本地所有文件信息。Linux系统会自动创建这个数据库&#xff0c;并且每天…

【ElasticSearch】深入探索 DSL 查询语法,实现对文档不同程度的检索,以及对搜索结果的排序、分页和高亮操作

文章目录 前言一、Elasticsearch DSL Query 的分类二、全文检索查询2.1 match 查询2.2 multi_match 查询 三、精确查询3.1 term 查询3.2 range 查询 四、地理坐标查询4.1 geo_bounding_box 查询4.2 geo_distance 查询 五、复合查询5.1 function score 查询5.2 boolean 查询 六、…

2023年中国溶瘤病毒药物上市产品、研发现状及行业市场规模前景[图]

溶瘤病毒&#xff08;Oncolyticvirus&#xff09;&#xff0c;是一类具有复制能力的肿瘤杀伤型病毒&#xff0c;溶瘤病毒根据所采用的毒株类型可以被分为天然病毒株&#xff08;野生型病毒株&#xff09;和基因改造病毒株两类&#xff0c;溶瘤病毒的种类也从最初的疱疹病毒发展…

使用c++视觉处理----canny 边缘检测、sobel边缘检测、scharr 滤波边缘检测

使用c视觉处理canny 边缘检测、sobel边缘检测、scharr 滤波边缘检测 #include <opencv2/opencv.hpp>int main() {// 读取图像cv::Mat image cv::imread("1.jpg", cv::IMREAD_GRAYSCALE); // 转为灰度图像if (image.empty()) {std::cerr << "无法加…

网络-网络状态网络速度

文章目录 前言一、网络状态二、网络速度 前言 本文主要记录如何监听网络状态和网络速度。 一、网络状态 获取当前网络状态: navigator.onLine // true:在线 false:离线监听事件&#xff1a;online&#xff08;联网&#xff09; 和 offline&#xff08;断网&#xff09; windo…

全国A级旅游景区清单数据(2023年更新)

全国A级旅游景区清单数据&#xff08;2023年更新&#xff09; 1.样本量&#xff1a;14847条 2.来源&#xff1a;政府公布资料 3.指标&#xff1a;景区名称、等级、所属省份、所属城市、所属区县、地址、当前等级评定时间、相关文件发布时间、坐标(GCJ02)Lng、坐标(GCJ02)Lat…

【Redis实战】分布式锁

分布式锁 synchronized只能保证单个JVM内部的线程互斥&#xff0c;不能保证集群模式下的多个JVM的线程互斥。 分布式锁原理 每个JVM内部都有自己的锁监视器&#xff0c;但是跨JVM&#xff0c;就会有多个锁监视器&#xff0c;就会有多个线程获取到锁&#xff0c;不能实现多JV…

分享一份关于 Rust 编程的学习指南

Rust是一种现代的系统级编程语言&#xff0c;以其注重内存安全、性能和并发性而闻名。学习Rust可以是一段有回报的旅程&#xff0c;为您打开构建强大高效应用的机会。无论您是经验丰富的开发者还是完全的初学者&#xff0c;本指南将通过精选的资源和技巧帮助您踏上Rust编程之旅…

【angular】实现简单的angular国际化(i18n)

文章目录 目标过程运行参考 目标 实现简单的angular国际化。本博客实现中文版和法语版。 将Hello i18n!变为中文版&#xff1a;你好 i18n!或法语版:Bonjour l’i18n !。 过程 创建一个项目&#xff1a; ng new i18nDemo在集成终端中打开。 添加本地化包&#xff1a; ng a…

景联文科技:3D点云标注应用场景和专业平台

3D点云技术之所以得到广泛发展和应用&#xff0c;主要是因为它能够以一种直观、真实和全面的方式来表示和获取现实世界中的三维信息。 3D点云的优势&#xff1a; 真实感和立体感&#xff1a;3D点云数据能够呈现物体的真实感和立体感&#xff0c;使观察者能够更直观地理解物体的…