深度学习_10_softmax_实战

news2024/11/27 6:23:00

由于网上代码的画图功能是基于jupyter记事本,而我用的是pycham,这导致画图代码不兼容pycharm,所以删去部分代码,以便能更好的在pycharm上运行

完整代码:

import torch
from d2l import torch as d2l

"创建训练集&创建检测集合"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

"创建模型w, b"
num_inputs = 784
num_outputs = 10

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

"softmax"
def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

"输出,即传入图片输出"
def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

"交叉熵损失"
def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

"显示预测与估计相对应下标数量"
def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: # 确定长宽高都大于1
        y_hat = y_hat.argmax(axis=1) # 取出每行中最大值
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum()) # 返回对应下标数量


"利用优化后的模型计算精度"
def evaluate_accuracy(net, data_iter):  #@save

    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel()) # 下标相同数量 / 总下标
    return metric[0] / metric[1]


"加法器"
class Accumulator:  #@save

    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

"训练更新模型&返回训练损失与精度函数"
def train_epoch_ch3(net, train_iter, loss, updater):  #@save
    """训练模型一个迭代周期(定义见第3章)"""
    # 将模型设置为训练模式
    if isinstance(net, torch.nn.Module):
        net.train()
    # 训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3)
    for X, y in train_iter:
        # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使用PyTorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:
            # 使用定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

lr = 0.1

"更新模型"
def updater(batch_size):
    return d2l.sgd([W, b], lr, batch_size)

if __name__ == '__main__':
    num_epochs = 10
    cnt = 1
    for i in range(num_epochs):
        X, Y = train_epoch_ch3(net, train_iter, cross_entropy, updater)
        print("训练次数: " + str(cnt))
        cnt += 1
        print("训练损失: {:.4f}".format(X))
        print("训练精度: {:.4f}".format(Y))
        print(".................................")
#        print(W)
#        print(b)

效果:

在这里插入图片描述

训练效果还是和网上一样的,就是缺了画图功能,将就着吧

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1182430.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STM32 TIM定时器,配置,详解(1)

计数器寄存器(TIMx_CNT)、预分频器寄存器(TIMx_PSC)、自动重载寄存器(TIMx_ARR)。 PSC预分频器,顾名思义,先预备一下分频,有时候频率过高,后面的定时器承受不住,就先用PSC先分频一下。如何分频的?将每接受到…

PTE作文练习(一)

目录 65分备考建议 WE模版 范文 Supporting ideas: SWT 65分备考建议 RA重在多听标准的正确的示范,RS重在抓大放小,WFD重在整理错题,以及反反复复的车轮战,FIBRW重在“以对代记” 就是直接看答案,节约时间&#…

Python编程的四个关键点——都知道吗?快来查漏补缺!

文章目录 前言一、Python 中的类型提示二、Python 虚拟环境和包管理三、新的 Python 语法四、Python 测试关于Python技术储备一、Python所有方向的学习路线二、Python基础学习视频三、精品Python学习书籍四、Python工具包项目源码合集①Python工具包②Python实战案例③Python小…

响应式新闻博客资讯网站模板源码带后台

模板信息: 模板编号:29779 模板编码:UTF8 模板分类:博客、文章、资讯、其他 适合行业:博客类企业 模板介绍: 本模板自带eyoucms内核,无需再下载eyou系统,原创设计、手工书写DIVCSS&a…

uniapp原生插件之安卓腾讯Bugly专业版原生插件

插件介绍 Bugly专业版是TDS腾讯端服务(Tencent Device-oriented Service)旗下的端质量监控平台,通过采集、监控、定位、告警等核心能力,提供专业的质量监控服务,帮助开发者及时发现并解决质量问题,打造高质…

SQL注入漏洞:CMS布尔盲注python脚本编写

SQL注入漏洞:CMS布尔盲注python脚本编写 文章目录 SQL注入漏洞:CMS布尔盲注python脚本编写库名爆破爆破表名用户名密码爆破 库名爆破 import requests #库名 database"" x0 while requests.get(urlf"http://10.9.47.77/cms/show.php?id33%20and%20length(data…

动态规划实例——01 背包详解

题目描述 有 n 件物品,每件物品有一个重量和一个价值,分别记为 w1,w2,…,wn 和 c1,c2,…,cn。现在有一个背包,其容量为 wk,要从 n 件物品种任取若干件。要求…

Python---capitalize() 方法---把字符串的首字母大写,其他字符全部小写,title()方法--把字符串中的所有单词的首字母大写,组成大驼峰

capitalize 英 /ˈkpɪtəlaɪz/ v. 用大写字母书写(或印刷),把……首字母大写;为(开办或发展企业)提供资金;(将资产或股票)变现,使资本化;&…

Window10安装Docker

文章目录 Window10安装Docker前提条件Hyper -VWSL 2.0 安装包下载执行安装包更新 Window10安装Docker 前提条件 Hyper -V 如何启用 WSL 2.0 安装包下载 官网地址 下载后: 执行安装包 wsl --update等得有点久 重新打开 拉取一个helloworld镜像 说明已经…

[LeetCode] 4.寻找两个正序数组的中位数

一、题目描述 给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。 算法的时间复杂度应该为 O(log (mn)) 。 示例 1: 输入:nums1 [1,3], nums2 [2] 输出&#xff1a…

软件测试-根据状态迁移图设计测试用例

测试用例状态迁移图 许多需求用状态机的方式来描述,状态机的测试主要关注状态转移是否正确。对于一个有限状态机,通过测试验证其在给定的条件内是否能够产生需要的状态变化,有没有不可达的状态和非法的状态,是否可能产生非法的状…

【Spring】使用注解装配bean

目录 使用注解的两个必要步骤 正文 Cat Dog Animal beans.xml 测试 Qualifier 使用注解的两个必要步骤 1.导入约束 <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:…

RT-DETR 应用 BiFPN 结构 | 加权双向特征金字塔网络

模型效率在计算机视觉中变得越来越重要。在本文中,我们系统地研究了目标检测中的神经网络架构设计选择,并提出了几种关键的优化方法来提高效率。首先,我们提出了一种加权双向特征金字塔网络(BiFPN),它可以实现简单快速的多尺度特征融合;其次,我们提出了一种复合缩放方法…

C盘清理指南(二)——盘符划分操作

今天的内容是C盘清理系列的第二期——盘符划分操作。 1. 点击“我的电脑——左上角的管理” 2.进入后点击磁盘管理 3.右键单击某个想修改盘符&#xff0c;可进行扩展、压缩、删除三种操作 其中压缩卷是进行“分解反应”&#xff0c;即原盘过大要进行拆分。此处注意拆分的上限为…

数据结构与算法—双链表

前言 前面有很详细的讲过线性表(顺序表和链表)&#xff0c;当时讲的链表以单链表为主&#xff0c;但在实际应用中双链表有很多应用场景&#xff0c;例如大家熟知的LinkedList。 双链表与单链表区别 单链表和双链表都是线性表的链式实现&#xff0c;它们的主要区别在于节点结构…

矢量图设计软件层出不穷,CorelDRAW为何无人能替?

设计工作经验丰富的人一定对比过多种设计软件&#xff0c;在对众多矢量图设计软件进行对比之后&#xff0c;多数资深设计师认为CorelDRAW的专业性、便捷性以及兼容性的综合表现更好&#xff0c;而且软件还配置了海量艺术笔&#xff0c;这让工作成果更为出众&#xff0c;因此更愿…

【14】c++11新特性 —>共享智能指针

在C中没有垃圾回收机制&#xff0c;必须自己释放分配的内存&#xff0c;否则就会造成内存泄露。解决这个问题最有效的方法是使用智能指针&#xff08;smart pointer&#xff09;。智能指针是存储指向动态分配&#xff08;堆&#xff09;对象指针的类&#xff0c;用于生存期的控…

【Python】Python爬虫使用代理IP的实现

前言 在爬虫的过程中&#xff0c;我们经常会遇到需要使用代理IP的情况。比如&#xff0c;针对目标网站的反爬机制&#xff0c;需要通过使用代理IP来规避风险。因此&#xff0c;本文主要介绍如何在Python爬虫中使用代理IP。 一、代理IP的作用 代理IP&#xff0c;顾名思义&…

2011年408计网

第33题 TCP/IP 参考模型的网络层提供的是&#xff08;&#xff09;A. 无连接不可靠的数据报服务B. 无连接可靠的数据报服务C. 有连接不可靠的虚电路服务D. 有连接可靠的虚电路服务 本题考查TCP/IP 参考模型的网络层 若网络层提供的是虚电路服务&#xff0c;则必须建立网络层的…

531X304IBDASG1 F31X303MCPA002/00 发电用分布式控制系统

531X304IBDASG1 F31X303MCPA002/00 发电用分布式控制系统 2021年4月20日&#xff0c;马萨诸塞州戴德姆。-新的ARC咨询小组关于全球的研究发电用分布式控制系统(DCS)市场显示&#xff0c;全球燃煤发电能力的减少继续阻碍增长。老化的燃煤电厂越来越多地被淘汰&#xff0c;而不是…