pytorch-简单回归问题-手写数字识别

news2024/12/28 22:01:08

pytorch-简单回归问题-手写数字识别

    • 线性回归添加噪声
    • 简单例子
    • 分类问题引入-手写数字识别
      • 数据集
    • 训练推导
    • 手写数字识别1
      • 加载数据集
      • 编写网络
      • 训练网络
      • 计算正确率

线性回归添加噪声

在这里插入图片描述

使用均方差损失函数来衡量损失

简单例子

在这里插入图片描述

通过最小化损失函数,求解出参数w b

下图表示搜索最小的Loss

在这里插入图片描述

给出一系列的样本方程,然后训练出一个模型参数w b使得可以预测
在这里插入图片描述

分类问题引入-手写数字识别

数据集

7000张照片 6000张训练 1000张测试
在这里插入图片描述

训练推导

首先将一张28 * 28的照片展平 784,然后插入一个维度表示[1,784]

关于推导过程

在这里插入图片描述

使用one-hot编码对输出的结果进行编码

在这里插入图片描述

计算loss

这里的Loss计算很简单,直接使用输出的H3向量和标签向量做减法 然后求平方

在这里插入图片描述

也就是优化预测值和真实值的欧氏距离
在这里插入图片描述

ReLU函数的非线性增强
在这里插入图片描述

输出的预测值,是一个一维向量,里面包含每一种类别的预测值,然后去除概率最大的索引
在这里插入图片描述

手写数字识别1

加载数据集

from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt


#  加载数据集  batch_size表示每次取出512张图片
batch_size = 512
#  torchvision.transforms.Normalize((0.1307,),(0.3081,)) 表示归一化操作
# torchvision.transforms.ToTensor() 表示将numpy张量 转换为tensor
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('minst_data',
                                                                      train = True,
                                                                      download=True,
                                                                      transform=torchvision.transforms.
                                                                      Compose([torchvision.transforms.ToTensor(),
                                                                               torchvision.transforms.Normalize((0.1307,),(0.3081,))])),
                                                                               batch_size=batch_size,shuffle = True)


#  加载测试数据集
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('minst_data',
                                                                      train = False,
                                                                      download=True,
                                                                      transform=torchvision.transforms.
                                                                      Compose([torchvision.transforms.ToTensor(),
                                                                               torchvision.transforms.Normalize((0.1307,),(0.3081,))])),
                                                                               batch_size=batch_size,shuffle = False)

编写网络

#  编写网络
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

        # xw + b
        self.fc1 = nn.Linear(28 * 28,256)
        self.fc2 = nn.Linear(256,64)
        self.fc3 == nn.Linear(64,10)

    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x
    

训练网络

net = Net()
# 定义优化器
optimizer = optim.SGD(net.parameters(),lr = 0.01,momentum=0.9)

# 保存训练损失
train_loss = []
for epoch in range(3):
    for batch_idx,(x,y) in enumerate(train_loader):
        #  将 [b,1,28,28] 转换成 [b.feature] 二维的tensor

        x = x.view(x.size(0),28 * 28) # 第一个参数表示图片的batch_size  

        # 最后的out形状是 [b,10] 表示每一张图片有 十个类别的概率
        out = net(x)

        # 转换为独热编码
        y_onehot = one_hot(y)

        # 计算损失
        loss = F.mse_loss(out,y_onehot)

        # 梯度清零
        optimizer.zero_grad()

        # 计算梯度
        loss.backward()

        # 更新优化
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 ==0:
            print("第{}次迭代的损失是{}".format(epoch,loss.item()))

在这里插入图片描述

计算正确率

total_correct = 0

# 计算正确率
for x,y in test_loader:
    x = x.view(x.size(0),28 * 28)
    out = net(x)
    pred = out.argmax(dim = 1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct

total_num  = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:',acc)  # 测试集的正确率 0.8807

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/596676.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

封装的函数停发/启动CAN报文,以及报文接收检测,高可用

🍅 我是蚂蚁小兵,专注于车载诊断领域,尤其擅长于对CANoe工具的使用🍅 寻找组织 ,答疑解惑,摸鱼聊天,博客源码,点击加入👉【相亲相爱一家人】🍅 玩转CANoe,博客目录大全,点击跳转👉 📘前言 🍅 在测试过程中,我们可能需要可控的停/发某些报文,今天博主给…

chatgpt赋能python:Python主页面的SEO分析及优化建议

Python主页面的SEO分析及优化建议 Python是一种高级编程语言,广泛应用于人工智能、数据分析、Web开发等领域。Python官方网站是Python社区的一个重要门户,为全球学习Python的开发者提供了全面、权威、可靠的信息。在这篇文章中,我们将分析Py…

Text to image论文精读SeedSelect: 使用SeedSelect微调扩散模型It’s all about where you start

随着文本到图像扩散模型的发展,很多模型已经可以合成各种新的概念和场景。然而,它们仍然难以生成结构化、不常见的概念、组合图像。今年4月巴伊兰大学和OriginAI发表《It’s all about where you start: Text-to-image generation with seed selection》…

软件外包开发项目原型图工具

项目原型图工具有非常重要的作用,尤其是在APP项目开发中,对于整体需求的表达是必不可少的工具。相比于传统的文档需求,图形文字的表达可以更清楚的表达需求,让客户清楚的明白软件功能有哪些,最后的界面是怎样的&#x…

微信海量数据查询如何从1000ms降到100ms?

👉腾小云导读 微信的多维指标监控平台,具备自定义维度、指标的监控能力,主要服务于用户自定义监控。作为框架级监控的补充,它承载着聚合前 45亿/min、4万亿/天的数据量。当前,针对数据层的查询请求也达到了峰值 40万/m…

RL - 强化学习 上置信界算法 (UCB) 和 汤普森采样算法 (TS)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/130983835 上置信界算法和汤普森采样算法是两种解决多臂老虎机问题的经典方法。多臂老虎机问题是一种探索与利用的平衡问题,…

Java easypoi 导出excel 并合并相关列

在项目开发中经常会使用到合并列&#xff0c;格式如下&#xff1a; 1.引入easypoi <dependency><groupId>cn.afterturn</groupId><artifactId>easypoi-annotation</artifactId></dependency><dependency><groupId>cn.aftertur…

设计模式详解之策略模式

作者&#xff1a;刘文慧 策略模式是一种应用广泛的行为型模式&#xff0c;核心思想是对算法进行封装&#xff0c;委派给不同对象来管理&#xff0c;本文将着眼于策略模式进行分享。 一、概述 我们在进行软件开发时要想实现可维护、可扩展&#xff0c;就需要尽量复用代码&#x…

chatgpt赋能python:Python什么情况下用类

Python什么情况下用类 在Python编程中&#xff0c;类是一种重要的数据结构&#xff0c;它是面向对象编程的核心。类可定义数据类型&#xff0c;并把数据与操作数据的函数组合在一起。因此&#xff0c;通过使用类&#xff0c;我们可以将数据、函数和其他方法组合在一起&#xf…

OLAP系列:四、clickhouse分布式表使用指南

一、背景 ClickHouse中最强大的表引擎当属MergeTree&#xff08;合并树&#xff09;引擎及该系列&#xff08;*MergeTree&#xff09;中的其他引擎&#xff0c;支持索引和分区&#xff0c;地位可以相当于innodb之于Mysql。 而且基于MergeTree&#xff0c;还衍生出了很多小弟&a…

【量化分析】绘制指标线EWM和MACD(1)

目录 一、说明 二、使用mplfinance的前提 2.1 mplfinance生态圈 2.1 安装mplfinance 三、mplfinance绘图 3.1 单变量图 3.2 将用户自己生成的曲线添加到 mplfinance plot() 四、显示EWM和MACD 一、说明 在做量化分析的时候&#xff0c;需要有能力计算种种曲线&#xff…

ShowMeBug 持续升级,提供高信效度支撑的技术招聘方案

去年年底&#xff0c;全新升级版的 ShowMeBug ——一款支持实战编程的技术能力评估平台&#xff0c;首次揭开了它神秘的面纱。 而近日&#xff0c;ShowMeBug 再次迎来一系列产品更新&#xff0c;它将以全新的面貌&#xff0c;提供高信效度支撑的技术招聘方案&#xff0c;持续助…

chatgpt赋能python:Python人脸登录:这项技术将颠覆传统的登录方式

Python人脸登录&#xff1a;这项技术将颠覆传统的登录方式 简介 在互联网时代&#xff0c;登录是每个人使用网站或软件的第一步&#xff0c;但是传统的用户名和密码登录已经不能满足用户的需求。不断的爆出各种账户泄露事件、密码猜测和密码被盗等问题&#xff0c;导致用户的…

cleanmymac要不要下载装机?好不好用

当我们收到一台崭新的mac电脑&#xff0c;第一步肯定是找到一款帮助我们管理电脑运行的“电脑管家”&#xff0c;监控内存运行、智能清理系统垃圾、清理Mac大文件旧文件、消除恶意软件、快速卸载更新软件、隐私保护、监控系统运行状况等。基本在上mac电脑防护一款CleanMyMac就够…

生成程序片段(程序依赖图PDG)

生成程序片段(程序依赖图PDG) 生成程序片段 标准方法是&#xff1a; 基于依赖性分析的切片。 使用程序依赖图表示依赖。 从中生成切片。 我们将专注于这种方法。但是&#xff0c;还有其他选择。 程序依赖图 The Program Dependence Graph (PDG) 表示数据和控制依赖项&#xf…

Servlet的常用Api—HttpServletResponse

Servlet的常用Api—HttpServletResponse &#x1f50e;核心方法setContentType && setCharacterEncodingsendRedirect关于Keep-Alive关于状态码 && Body &#x1f50e;结尾 &#x1f50e;核心方法 方法描述(void) setStatus(int sc)为该响应设置状态码(void) s…

2023年4月和5月随笔

1. 回头看 为了不耽误学系列更新&#xff0c;4月随笔合并到5月。 日更坚持了151天&#xff0c;精读完《SQL进阶教程》&#xff0c;学系统集成项目管理工程师&#xff08;中项&#xff09;系列更新完成。 4月和5月两月码字114991字&#xff0c;日均码字数1885字&#xff0c;累…

python的AutoGui库(1)获取鼠标实时位置

1.安装AutoGui库,与库的导入 PyAutoGUI是一个纯Python的GUI自动化工具&#xff0c;其目的是可以用程序自动控制鼠标和键盘操作&#xff0c;多平台支持&#xff08;Windows&#xff0c;OS X&#xff0c;Linux&#xff09;。可以用pip安装&#xff0c;Github上有源码。 使用命令…

Ceph应用

//存储类型 块存储 一对一&#xff0c;只能被一个主机挂载使用&#xff0c;数据以块为单位进行存储&#xff0c;典型代表: 硬盘 文件存储 一对多&#xff0c;能被多个主机同时挂载使用&#xff0c;数据以文件的形式存储的(元数据和实际数据是分开存储的)&#xff0c;并且有…

Python学习笔记 - 探索33个保留关键字

Python编程语言中有33个保留关键字&#xff0c;这些关键字在Python语法中有特殊含义&#xff0c;不能用作变量名、函数名或其他标识符。 33个保留字&#xff08;关键字&#xff09; 不能冲突的关键词 33 个 来看看都有哪些关键字。 import keyword print("&#xff0c;…