李沐动手学深度学习:softmax回归的从零开始实现

news2024/9/25 17:15:48
import torch
from IPython import display
from d2l import torch as d2l

batach_size=256
train_iter,test_iter = d2l.load_data_fashion_mnist(batach_size)
num_input = 784 #图片的尺寸:28*28
num_output = 10 #10个类别
W = torch.normal(0,0.01,size=(num_input,num_output),requires_grad=True)
b = torch.zeros(num_output,requires_grad=True)

def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1,keepdim=True)
    return X_exp/partition

#定义模型
def net(X):
    return softmax(torch.matmul(X.reshape((-1,W.shape[0])),W)+b)
    
#交叉熵损失
def cross_entropy(y_hat,y):
    return -torch.log(y_hat[range(len(y_hat)),y])

#使⽤argmax获得每⾏中最⼤元素的索引来获得预测类别
def accuracy(y_hat,y):
    #计算预测正确的数量
    if len(y_hat.shape)>1 and y_hat.shape[1]>1:
        y_hat = y_hat.argmax(axis=1)
    #由于等式运算符“==”对数据类型很敏感,因此我们将y_hat的数据类型转换为与y的数据类型⼀致
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

class Accumulator:
    def __init__(self,n):
        self.data = [0.0] * n
    
    def add(self, *args):
        self.data = [a+float(b) for a,b in zip(self.data, args)]
        
    def reset(self):
        self.data = [0.0] * len(self.data)
        
    def __getitem__(self, idx):
        return self.data[idx]

def evaluate_accuracy(net,data_iter):
    if isinstance(net,torch.nn.Module):
        net.eval() #将模型设置为评估模式
    
    metric = Accumulator(2)  #正确预测数、预测总数
    with torch.no_grad():
        for X,y in data_iter:
            metric.add(accuracy(net(X),y),y.numel())
    return metric[0]/metric[1]

#训练
def train_epoch(net,train_iter,loss,updater):
    if isinstance(net, torch.nn.Module):
        net.train()
    #训练损失总合、训练准确度总和、样本数
    metric = Accumulator(3)
    for X,y in train_iter:
        y_hat = net(X)
        l = loss(y_hat,y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使⽤PyTorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.setp()
        else:
            # 使⽤定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]
    
class Animator: #@save
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
        ylim=None, xscale='linear', yscale='linear',
        fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
        figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使⽤lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
        self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts
    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

def train(net,train_iter,test_iter,loss,num_epoch,updater):
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
legend=['train loss', 'train acc', 'test acc'])
    
    for epoch in range(num_epochs):
        train_metrics = train_epoch(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

lr = 0.1 
def updater(batch_size):
    return d2l.sgd([W,b],lr,batch_size)

num_epochs = 10
train(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

#预测
def predict(net,test_iter,n=6):
    for X,y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true+'\n'+pred for true,pred in zip(trues,preds)]
    d2l.show_images(
        X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])

predict(net,test_iter)
   

训练图像
**加粗样式**

测试
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/731556.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Docker尝鲜

一、Docker的安装 卸载系统自带的旧版本 sudo apt-get remove docker docker-engine docker.io containerd runc 获取软件最新源 sudo apt-get update 安装apt依赖包 sudo apt-get -y install apt-transport-https ca-certificates curl software-properties-common 安装…

Feign实现远程接口的调用

Feign实现远程接口的调用 前言一、Fegin是什么&#xff1f;二、Feign使用步骤准备工作被调用的远程服务必须上传到nacos/eureka服务中心进行管理配置&#xff0c;比如我调用media-api服务(媒资管理服务)&#xff0c;那么media-api必须被nacos/eureka所管理。如图&#xff0c;都…

CentOS7和主机Win11不能复制粘贴解决之道及CentOS7最小安装版 VMware Tools安装

入世心法&#xff1a; 金字塔 结构化 系统化 底层认知 认知升级 来认识下你老祖宗 底层逻辑 分别从: ---->道|法|术|器|势<--- 这五个层次去打通............................翻新思维底层认知&#xff…

【CSS】跳动文字

文章目录 效果展示代码实现 效果展示 代码实现 <!DOCTYPE html> <html><head><meta charset"utf-8" /><title>一颗不甘坠落的流星</title></head><style type"text/css">/* 遮罩盒子样式 */#mask {/* 设…

基于单片机的蓝牙音乐喷泉的设计与实现

功能介绍 以51单片机作为主控系统&#xff1b;通过HM-18蓝牙音频模块进行无线传输&#xff1b; 通过LM386功放模块对音频信号进行放大&#xff1b;手机端可以直接控制音频播放&#xff0c;并且最远距离可达20米&#xff1b;手机端可以进行任意音乐切换&#xff0c;播报、暂停&a…

Python 图像文件压缩,使用PIL库对图像进行等比例压缩

题目 图像文件压缩。使用PIL库对图像进行等比例压缩&#xff0c;无论压缩前文件大小如何&#xff0c;压缩后文件大小小于10KB。 代码 from PIL import Image import os from tkinter import filedialog import tkinterf_path filedialog.askopenfilename() image Image.op…

第三章:L2JMobius学习 – 使用eclipse2023创建java工程

在前两个章节中&#xff0c;我们已经安装了mariadb数据库和jdk&#xff0c;本章节我们安装eclipse2023。eclipse作为老牌的java开发工具&#xff0c;真的是不错的。官方下载地址为&#xff1a; https://www.eclipse.org/downloads/download.php?file/technology/epp/download…

STM32的ADC模式及其应用例程介绍

STM32的ADC模式及其应用例程介绍 &#x1f4cd;ST官方相关应用笔记介绍资料&#xff1a;https://www.stmcu.com.cn/Designresource/detail/application_note/705947&#x1f4cc;相关例程资源包&#xff1a;STSW-STM32028&#xff1a;https://www.st.com/zh/embedded-software/…

MySQL---表数据高效率查询(简述)

目录 前言 一、聚合查询 &#x1f496;聚合函数 &#x1f496;GROUP BY子句 &#x1f496;HAVING 二、联合查询 &#x1f496;内连接 &#x1f496;外连接 &#x1f496;自连接 &#x1f496;子查询 &#x1f496;合并查询 &#x1f381;博主介绍&#xff1a;博客名…

NXP文档AN13000解读-基于S32K116的无感BLDC六步换相控制策略(预定位/开环启动/反电动势过零点检测)

目录 六步换相控制 单极性PWM 反电动势过零点检测技术 反电动势的测量 总线电流的测量 电机状态切换 Alignment Start-up Run 算法用到的各模块 各模块间的连接 ADC触发顺序 芯片的初始化 时钟配置与电源管理 FTM Trigger MUX Control (TRGMUX) PDB ADC LPS…

【Git原理与使用】-- 多人协作

目录 多人协作一&#xff08;多人同一分支&#xff09; 开发者一&#xff08;Linux&#xff09; 开发者二&#xff08;Windous&#xff09; master合并 远端上的合并 本地上的合并 总结 多人协作一&#xff08;多人多分支&#xff09; 开发者一&#xff08;Linux&…

SQL Server数据库忘记了sa密码怎么办 亲测真的可用

我们在安装SQL Server数据库时&#xff0c;往往选择混合模式安装&#xff0c;在这里我们可以设置sa密码。混合模式安装后&#xff0c;我们可以通过Windows身份验证和SQL Server身份验证两种方式登陆。 如果sa密码忘记了&#xff0c;我们就无法用SQL Server身份验证登陆了。 那么…

【高性能、高并发】页面静态化解决方案-OpenResty

OpenResty OpenResty介绍 OpenResty是一个基于 Nginx 与 Lua 的高性能 Web 平台&#xff0c;其内部集成了大量精良的 Lua 库、第三方模块以及大多数的依赖项 用于方便地搭建能够处理超高并发、扩展性极高的动态 Web 应用、Web 服务和动态网关 OpenResty通过汇聚各种设计精良的…

Python应用实例(一)外星人入侵(五)

外星人入侵&#xff08;五&#xff09; 1.项目回顾2.创建第一个外星人2.1 创建Alien类2.2 创建Alien实例 3.创建一群外星人3.1 确定一行可容纳多少个外星人3.2 创建一行外星人3.3 重构_create_fleet()3.4 添加行 在游戏《外星人入侵》中添加外星人。我们将首先在屏幕上边缘附近…

RHCE中级项目

一、项目需求 1、在 bbs.example.com 主机上创建 Discuz 论坛&#xff0c;数据库服务器使用 db.example.com 主机的 bbs 数据库实例&#xff0c;该实例由 MySQL数据库软件提供服务。 2、在 ntp.example.com 主机上创建 NTP 服务&#xff0c;该服务由 Chronyd软件提供服务&…

GoLang网络编程:HTTP服务端之底层原理与源码分析——http.HandleFunc()、http.ListenAndServe()

一、启动 http 服务 import ("net/http" ) func main() {http.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {w.Write([]byte("ping...ping..."))})http.ListenAndServe(":8999", nil) }在 Golang只需要几行代…

第三章 SSD存储介质:闪存 3.4

3.4 闪存数据完整性 可采用以下数据完整性的技术确保用户数据不丢失&#xff1a; &#xff08;1&#xff09;ECC纠错&#xff1b; &#xff08;2&#xff09;RAID数据恢复&#xff1b; &#xff08;3&#xff09;重读&#xff08;Read Retry&#xff09;&#xff1b; &#xff…

C/C++指针从0到99(详解)

目录 一&#xff0c;指针的基础理解 二&#xff0c;指针的基本使用 三&#xff0c;为什么要用指针 四&#xff0c;指针与数组的联系 五&#xff0c;指针的拓展使用 1&#xff09;指针数组 2)数组指针 3&#xff09;函数指针 结构&#xff1a;返回类型 &#xff08;*p)…

我国新能源汽车存量已突破1620万辆,登记数量创历史新高

根据公安部发布的最新数据&#xff0c;截至2023年6月底&#xff0c;全国的机动车数量达到4.26亿辆&#xff0c;其中汽车数量为3.28亿辆&#xff0c;新能源汽车数量为1620万辆。与此同时&#xff0c;机动车驾驶人口达到5.13亿人&#xff0c;其中汽车驾驶人口为4.75亿人。在2023年…

Zabbix监控软件 Linux外多平台监控【Windows JAVA SNMP】

在之前的博客中&#xff0c;已经介绍了zabbix的安装&#xff0c;配置&#xff0c;以及如何用zabbix监控Linux服务器。这篇博客则介绍zabbix监控的其他几种方式&#xff08;Windows服务器 Java应用 SNMP&#xff09;。 -------------------- Zabbix 监控 Windows 系统 ---------…