从0开始深度学习(9)——softmax回归的逐步实现

news2024/11/28 21:45:25

文章使用Fashion-MNIST数据集,做一次分类识别任务
Fashion-MNIST中包含的10个类别,分别为:
t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)
sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)、ankle boot(短靴)

0 图像数据

0.1 读取展示数据

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt

# 下载数据集 ,60,000 个训练样本和 10,000 个测试样本,每个样本包含一张28*28的灰度图和一个标签

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="D/DL_Data/Fashion-MNIST", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="D/DL_Data/Fashion-MNIST", train=False, transform=trans, download=True)
    
print("test:",len(mnist_test))
print("train:",len(mnist_train))

# 获取第一个样本的图像和标签
image, label = mnist_train[0]
print("图像的形状:", image.shape)
print("标签:", label)

在这里插入图片描述

0.2 可视化图像

# 可视化
def show_img():
    class_names = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
    ]
    # 可视化前5张图片
    fig, axes = plt.subplots(1, 10, figsize=(15, 3))

    for i in range(10):
        # 获取第 i 个样本的图像和标签
        image, label = mnist_train[i]
    
        # 将图像从 Tensor 转换回 numpy 数组,并移除通道维度
        image_np = image.squeeze().numpy()
    
        # 在子图中显示图像
        axes[i].imshow(image_np, cmap='gray')
        axes[i].set_title(f'Label: {class_names[label]}')
        axes[i].axis('off')  # 关闭坐标轴

    plt.tight_layout()
    plt.show()
    
show_img()

在这里插入图片描述

0.3 整合为数据加载模块

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
                            
train_iter, test_iter = load_data_fashion_mnist(256, resize=28)

1 初始化参数模型

我们选择把 28 ∗ 28 28*28 2828的图片展开成 1 ∗ 784 1*784 1784的向量,认为每个像素位置都是一个特征,所以输入是784维,输出是10个类别标签,所以输出是10维

因为softmaxhi回归类似于线性回归,所以权重 w w w应该是 784 ∗ 10 784*10 78410 的矩阵,偏置是 1 ∗ 10 1*10 110 的行向量,接下来如同线性回归中一样,使用正太分布初始化权重,偏置初始化为0:

num_inputs = 784
num_outputs = 10

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

2 定义softmax操作

回顾一下softmax的公式:
在这里插入图片描述
由三个步骤组成:

  1. 对每个项目求幂
  2. 将每一行求和(小批量样本中,每个样本是一行),得到每个样本的规范化常数。
  3. 将每一行除以其规范化常数,确保结果的和为1。
# 定义softmax操作
def softmax(x):
    x_exp=torch.exp(x)
    x_exp_sum=x_exp.sum(1,keepdim=True)
    return x_exp/x_exp_sum

3 定义模型

# 定义模型
def net(x):
    x = x.reshape(-1, w.shape[0])  # 将图片重塑为 [batch_size, 784]
    temp = torch.matmul(x, w)
    temp = temp + b
    return softmax(temp)

4 定义损失函数

使用从0开始深度学习(8)——softmax回归提到的交叉熵损失函数

# 定义损失函数
def cross_entropy(y_hat, y): # 预测值、真实值
    return - torch.log(y_hat[range(len(y_hat)), y]) # 计算负对数似然

cross_entropy(y_hat, y)

5 分类精度

分类精度即正确预测数量与总预测数量之比。

def compute_accuracy(y_hat, y):  # 预测值、真实值
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)  # 找到一个样本中,对应的最大概率的类别
    cmp = y_hat.type(y.dtype) == y  # 将预测值 y_hat 与真实标签 y 进行比较,生成一个布尔张量 cmp
    return float(cmp.type(y.dtype).sum())

# 计算在指定数据集上模型的准确率
def evaluate_accuracy(net, data_iter):  
    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 累加多个变量的总和。这里初始化了一个包含两个元素的累加器,分别用来存储正确预测的数量和总的预测数量。
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(compute_accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

class Accumulator:  #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
        
# 评估模型
accuracy = evaluate_accuracy(net, test_iter)
print(f"Test Accuracy: {accuracy:.4f}")

在这里插入图片描述

6 定义优化器

# 定义优化器
def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

7 训练

# 训练模型
def train_epoch(net, train_iter, loss, updater):
    if isinstance(net, torch.nn.Module):
        net.train()  # 将模型设置为训练模式
    metric = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数
    for X, y in train_iter:
        y_hat = net(X)
        l = loss(y_hat, y).mean()
        if isinstance(updater, torch.optim.Optimizer):
            updater.zero_grad()
            l.backward()
            updater.step()
        else:
            l.backward()
            updater([w, b], lr, batch_size)
        metric.add(float(l) * y.numel(), compute_accuracy(y_hat, y), y.numel())
    return metric[0] / metric[2], metric[1] / metric[2]

def train(net, train_iter, test_iter, loss, num_epochs, updater):
    for epoch in range(num_epochs):
        train_metrics = train_epoch(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        print(f'Epoch {epoch + 1}: Train Loss {train_metrics[0]:.3f}, Train Acc {train_metrics[1]:.3f}, Test Acc {test_acc:.3f}')

# 训练模型
updater = lambda params, lr, batch_size: sgd(params, lr, batch_size)
train(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

在这里插入图片描述

8 预测

# 定义 Fashion-MNIST 标签的文本描述
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

# 预测并显示结果
def predict(net, test_iter, n=6):
    for X, y in test_iter:
        break  # 只取一个批次的数据
    trues = get_fashion_mnist_labels(y)
    preds = get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
    n = min(n, X.shape[0])
    fig, axs = plt.subplots(1, n, figsize=(12, 3))
    for i in range(n):
        axs[i].imshow(X[i].permute(1, 2, 0).squeeze().numpy(), cmap='gray')
        axs[i].set_title(titles[i])
        axs[i].axis('off')
    plt.show()

# 调用预测函数
predict(net, test_iter, n=6)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2206421.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SSD |(二)SSD主控

文章目录 📚控制器架构🐇PCIe和NVMe控制器前端子系统🐇NAND闪存控制器后端子系统🐇内存子系统🐇安全子系统🐇CPU计算子系统 📚控制器架构 控制器作为一个片上系统,处理来自用户端的…

Linux下的Makefile基本操作

1.Makefile与 make介绍 在Linux中, Makefile 是⼀个⽂件, 令会在当前⽬录下找 make 是⼀个指令,当使⽤ Makefile ⽂件从⽽执⾏内部的内容 2.创建第一个 Makefile并使用make ⾸先,在当前⽬录下创建⼀个makefile文件 接下来在同级…

【小工具分享】下载保存指定网页的所有图片

一、保存百度首页所有的图片 先看一下保存的图片情况 二、思路 1、打开网页 2、获取所有图片 3、依次下载保存图片到指定路径 三、完整代码 from selenium import webdriver from selenium.webdriver.common.by import By b webdriver.Firefox() import urllib.request…

企业如何借力AI,提升人力资源管理的效率完成组织提效变革

大家好,我是Shelly,一个专注于输出AI工具和科技前沿内容的AI应用教练,体验过300款以上的AI应用工具。关注科技及大模型领域对社会的影响10年。关注我一起驾驭AI工具,拥抱AI时代的到来。 企业面临的压力: 在当今这个充…

LeetCode|70.爬楼梯

这道题很像斐波那契数列,但是初始值不同,也有动态规划的解法,但是一开始我想到的是递归写法。现在我们站在第n阶台阶,那么,我们上一步就有两种可能:1、我们从第n-1阶台阶走一步上来的;2、我们从…

商家转账到零钱接口开通

商家想要开通“商家转账到零钱”功能,需要遵循一系列详细步骤和条件,以确保顺利通过审核。以下是开通办法的详解: 申请流程: 主体资格确认:确保申请主体为公司性质(有限公司),个体工…

ScribbleDiff:使用涂鸦引导扩散,实现无需训练的文本到图像生成

ScribbleDiff可以通过简单的涂鸦帮助计算机生成图像。比如你在纸上随意画了一些线条,表示你想要的图像的轮廓。ScribbleDiff会利用这些线条来指导图像生成的过程。 首先,它会分析这些涂鸦,确保生成的图像中的对象朝着你画的方向。比如&#…

品民俗、看展演、逛非遗市集……在海淀,重阳节还可以这样过

秋菊溢彩、叠翠鎏金。由北京市海淀区文化和旅游局主办,北京市海淀区文化馆承办,海淀区上庄镇文化活动中心支持的品鉴民俗 巧手绘梦——2024年海淀区重阳节非遗主题文化活动于10月11日在上庄镇市民活动中心顺利举办。海淀非遗传承人以非遗为媒,与地区群众度过了一个温馨、热闹、…

第四次论文问题知识点及问题

1、NP-hard问题 NP-hard,指所有NP问题都能在多项式时间复杂度内归约到的问题。 2、启发式算法 ‌‌启发式算法(heuristic algorithm)是相对于最优化算法提出的。它是一种基于直观或经验构造的算法,旨在以可接受的花费给出待解决…

Android 如何实现远程网页控制售卖机出商品:RabbitMQ的对接,如何使用?如何断网重连?连接不上后台的MQ有哪些方面的原因

目录 一、如何实现远程网页控制售卖机出商品? 比如,我们想实现,通过一个网页去控制自动售卖机(自动售卖机装有Android系统,装有App)出商品,也就是我们熟知的远程控制,不用你人到现场…

搭建电商商城系统各项功能时需要用到的电商API数据采集接口

在搭建电商商城系统时,选择合适的电商API接口至关重要。以下是一些常用的电商API接口提供商及其功能: 常用电商API接口提供商 淘宝开放平台:提供淘宝、天猫、1688等阿里巴巴集团旗下的电商平台接口,用于商品检索、订单管理、物流…

如何把pdf转换成jpg图片?在线pdf转图片,这6种方法很简单!

“如何把pdf转换成jpg图片?”相信很多小伙伴们都有这个疑问。pdf格式是如今在商业和其他正式场合中使用最广泛的文档类型,因为它能以安全且方便的方式共享信息。然而,查看pdf文件通常需要使用一些专业的pdf阅读器,这可能给一些用户…

服务端给客户端push消息的demo的实现流程

摘要: 本示例演示了一个基本的服务端5分钟定时向客户端app推送消息的WebSocket机制。服务端使用WebSocket协议接受客户端的订阅和取消订阅请求,并根据客户端的订阅状态发送实时消息。服务端记录并打印带有时间戳的日志,以监控订阅活动。客户…

python画图|二维动态柱状图输出

【1】引言 在前面的学习过程中,已经探索过二维柱状图和三维柱状图的绘制教程,包括且不限于的文章链接有: python画图|水平直方图绘制_绘制水平直方图-CSDN博客 python画图|3D bar进阶探索_ax.bar3d-CSDN博客 此外也学习了动态的直线输出和…

调用AI 通过相机识别地标

https://www.youtube.com/watch?vViRfnLAR_Uc&listPLQkwcJG4YTCRJxkPPDBcKqDWrfF5qanQs&index3学习视频 TensorFlow Hub 机器学习模型的代码库 找到地标模型 如何在Android上使用ts模型 https://blog.tensorflow.org/2018/03/using-tensorflow-lite-on-android.html…

10.11每日作业

数据表 #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this);//想要添加某个数据库if(!db.contains("stu.db")){//如果当前对象中没有包含所需数据库&…

dowhy中反驳实验怎么做?

首先,我们打开最新的dowhy版本网站。 https://www.pywhy.org/dowhy/v0.11.1/index.html 我们主要看标题栏的User Guide和Examples就可以了,如果在User Guide 里找不到使用方法,就去Examples里找例子,里面的数据读取修改为自己的数…

HI6338 (DIP-8内置75W方案)

Hi6338 combines a dedicated current mode PWM controller with integrated high voltage power MOSFET.Vcc low startup current and low operating current contribute to a reliable power on startup design with Hi6338. the IC operates in Extended ‘burst mode’ to …

前端的全栈混合之路Meteor篇:分布式数据协议DDP深度剖析

本文属于进阶篇,并不是太适合新人阅读,但纯粹的学习还是可以的,因为后续会实现很多个ddp的版本用于web端、nodejs端、安卓端和ios端,提前预习和复习下。ddp协议是一个C/S架构的协议,但是客户端也同时可以是服务端。 什…

Java程序打包成jar包

步骤1 打开项目结构 步骤2 配置工件 选择你要打包的模块选择主类(程序的主入口main类)提取到目标会把库文件的jar包打包到目标,一般选择这个,更方便在不同电脑上运行 步骤3 构建并生成jar包 最后,在对应的out/artifacts文件夹中找到jar包,在终端输入java -jar xxxx.jar就可以正…