《PyTorch深度学习实践》第九讲多分类问题

news2024/11/19 4:46:37

一、

1、softmax的输入不需要再做非线性变换,也就是说softmax之前不再需要激活函数。softmax两个作用,如果在进行softmax前的input有负数,通过指数变换,得到正数。所有类的概率求和为1。

2、y的标签编码方式是one-hot。one-hot是只有一位是1,其他位为0。(但是标签的one-hot编码是算法完成的,算法的输入仍为原始标签)

损失函数loss为交叉熵损失函数。

3、多分类问题,标签y的类型是LongTensor。比如说0-9分类问题,如果y = torch.LongTensor([3]),对应的one-hot是[0,0,0,1,0,0,0,0,0,0].(这里要注意,如果使用了one-hot,标签y的类型是LongTensor,糖尿病数据集中的target的类型是FloatTensor)

 4、CrossEntropyLoss <==> LogSoftmax + NLLLoss。也就是说使用CrossEntropyLoss最后一层(线性层)是不需要做其他变化的;使用NLLLoss之前,需要对最后一层(线性层)先进行SoftMax处理,再进行log操作。

二、

1、第8讲 from torch.utils.data import Dataset,第9讲 from torchvision import datasets。该datasets里面init,getitem,len魔法函数已实现。

                  2、torch.max的返回值有两个,第一个是每一行的最大值是多少,第二个是每一行最大值的下标(索引)是多少。

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

# prepare dataset

batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])  # 归一化,均值和方差

train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)


# design model using class


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)  # -1其实就是自动获取mini_batch
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        return self.l5(x)  # 最后一层不做激活,不进行非线性变换


model = Net()

# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


# training cycle forward, backward, update


def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        # 获得一个批次的数据和标签
        inputs, target = data
        optimizer.zero_grad()
        # 获得模型预测结果(64, 10)
        outputs = model(inputs)
        # 交叉熵代价函数outputs(64,10),target(64)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0


def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度
            total += labels.size(0)
            correct += (predicted == labels).sum().item()  # 张量之间的比较运算
    print('accuracy on test set: %d %% ' % (100 * correct / total))


if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1478819.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java爬取深圳新房备案价

Java爬取深圳新房备案价 这是我做好效果,一共分3个页面 1、列表;2、统计;3、房源表 列表 价格分析页面 房源页面 一、如何爬取 第一步:获取深圳新房备案价 链接是:http://zjj.sz.gov.cn/ris/bol/szfdc/index.aspx 第二步:通过楼盘名查询获取明细 链接:http://z…

就业班 2401--2.27 Linux Day6--管道和重定向

管道与重定向 只有在开水里&#xff0c;茶叶才能展开生命浓郁的香气. 一、重定向 标准输入、标准正确输出、标准错误输出 进程在运行的过程中根据需要会打开多个文件&#xff0c;每打开一个文件会有一个数字标识。这个标识叫文件描述符。 进程使用文件描述符来管理打开的文件…

Android PDFView 提示401 pom

背景 在开发安卓app&#xff0c;使用PDF组件来解析URL地址 &#xff0c;从github找到一个开源组件 AndroidPdfViewer 遇到一个大坑&#xff0c;一直提示下载依赖401 pom 打开控制台链接弹出需要登录jitpack 原因分析&#xff1a; 这个组件项目依赖库链接到了需要鉴权的…

【airtest】自动化入门教程(一)AirtestIDE

目录 一、下载与安装 1、下载 2、安装 3、打开软件 二、web自动化配置 1、配置chrome浏览器 2、窗口勾选selenium window 三、新建项目&#xff08;web&#xff09; 1、新建一个Airtest项目 2、初始化代码 3、打开一个网页 四、恢复默认布局 五、新建项目&#xf…

服务器数据恢复-服务器RAID5上层XFS文件系统分区数据恢复案例

服务器数据恢复环境&#xff1a; MD1200磁盘柜中的磁盘通过RAID卡创建了一组RAID5阵列&#xff0c;分配了一个LUN。在Linux操作系统层面对该LUN进行了分区&#xff0c;划分sdc1和sdc2两个分区&#xff0c;通过LVM扩容的方式将sdc1分区加入到了root_lv中&#xff1b;sdc2分区格式…

VuePress + GitHub 搭建个人博客踩坑记录

最近想给我教练搭个网站,本来选的是 VuePress 框架,也折腾完了,起码是搭建出来了,踩的坑也都总结好了 但是最近发现了一个更简洁的模板: VuePress-theme-hope ,所以最终网站使用的样式是这个 不过我觉得这里面踩坑的记录应该还是有些价值的,分享出来,看看能不能帮到一些小伙伴~…

手机AI摄影时代开启,传音引领行业标准化建设

今年春节&#xff0c;AI摄影可谓大出风头。人们在社交平台晒出自己在龙年的AI写真&#xff0c;极大地增添了节日的氛围感&#xff0c;也让我们看到了“AI摄影”的价值。新年伊始&#xff0c;手机巨头们纷纷布局该赛道&#xff0c;基于AI大模型实现的影像功能成为业界关注焦点。…

CAPL编程学习笔记--关于on 事件的详细解释

CAPL编程是比较有特色的一种面向通讯的编程语言。 1&#xff1a;on XXX类型&#xff08;即事件类型&#xff09; 维克多的官方文档对CAPL的描述是一门类C语言&#xff0c;说白了它也是用C写出来的。我们看on&#xff08;注意都是小写&#xff09;事件的代码结构 on * { }&…

算法day03_ 59.螺旋矩阵II

推荐阅读 算法day01_ 27. 移除元素、977.有序数组的平方 算法day02_209.长度最小的子数组 目录 推荐阅读59.螺旋矩阵 II题目思路解法 59.螺旋矩阵 II 题目 给你一个正整数 n &#xff0c;生成一个包含 1 到 n2 所有元素&#xff0c;且元素按顺时针顺序螺旋排列的 n x n 正方形…

2023年全国职业院校技能大赛中职组大数据应用与服务赛项题库参考答案陆续更新中,敬请期待…

2023年全国职业院校技能大赛中职组大数据应用与服务赛项题库参考答案陆续更新中&#xff0c;敬请期待… 武汉唯众智创科技有限公司 2024 年 2 月 联系人&#xff1a;辜渝傧13037102709 题号&#xff1a;试题01 模块三&#xff1a;业务分析与可视化 &#xff08;一&#xff0…

【Web安全靶场】sqli-labs-master 38-53 Stacked-Injections

sqli-labs-master 38-53 Stacked-Injections 其他关卡和靶场看专栏… 文章目录 sqli-labs-master 38-53 Stacked-Injections第三十八关-报错注入第三十九关-报错注入第四十关-盲注第四十一关-盲注第四十二关-联合报错双查询注入第四十三关-报错注入第四十四关-盲注第四十五关-…

Facebook的元宇宙实践:数字化社交的新前景

近年来&#xff0c;元宇宙&#xff08;Metaverse&#xff09;这一概念备受瞩目&#xff0c;被认为是数字化社交的未来趋势之一。而在众多科技巨头中&#xff0c;Facebook&#xff08;现更名为Meta&#xff09;一直处于元宇宙发展的前沿。在本文中&#xff0c;我们将深入探讨Fac…

linux系统Jenkins工具web配置

Jenkins工具配置 插件配置系统配置系统工具配置 插件配置 下载 Maven Integration Pipeline Maven lntegration gitlab Generic webhook Trigger nodejs Blue ocean系统配置 系统配置结束系统工具配置

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的水果质量识别系统(Python+PySide6界面+训练代码)

摘要&#xff1a;本篇博客详尽介绍了一套基于深度学习的水果质量识别系统及其实现代码。系统采用了尖端的YOLOv8算法&#xff0c;并与YOLOv7、YOLOv6、YOLOv5等前代算法进行了详细的性能对比分析&#xff0c;提供在识别图像、视频、实时视频流和批量文件中水果方面的高效准确性…

32单片机基础:TIM输出比较

这个输出比较功能是非常重要的&#xff0c;它主要是用来输出PWM波形,PWM波形又是驱动电机的必要条件&#xff0c;所以你如果想用STM32做一些有电机的项目&#xff0c;比如智能车&#xff0c;机器人等。 IC: Input Capture 输入捕获 CC:Capture/Compare一般表示输入捕获和输出…

【Leetcode每日一刷】哈希表|纲领、242.有效的字母异位词、349. 两个数组的交集

纲领 &#x1f517;代码随想录理论部分 关于哈希表这个数据结构就不再重复讲了&#xff0c;下面对几个关键点记录一下&#xff1a; 哈希碰撞 解决方法1&#xff1a;拉链法 解决方法2&#xff1a;线性探测法 下面针对做题要用到的三种结构讲一下&#xff08;也是重复造轮子了…

解释一下前端框架中的虚拟DOM(virtual DOM)和实际DOM(real DOM)之间的关系。

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 欢迎来到前端入门之旅&#xff01;感兴趣的可以订阅本专栏哦&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

windows server mysql 数据库停止 备份 恢复全流程操作方法

一,mysql备份 mysql最好是原工程文件备份.不需要sql查询的方式备份.安全高效. 比如,安装php与mysql组合后,我的mysql文件保存在: D:\phpstudy_pro\Extensions\MySQL5.7.26\data\dux 我只需要复制一份,保存起来就行. 二,mysql恢复 怎么恢复呢.我们一般是只恢复其中一个表,则找…

华为---RSTP(四)---RSTP的保护功能简介和示例配置

目录 1. 技术背景 2. RSTP的保护功能 3. BPDU保护机制原理和配置命令 3.1 BPDU保护机制原理 3.2 BPDU保护机制配置命令 3.3 BPDU保护机制配置步骤 4. 根保护机制原理和配置命令 4.1 根保护机制原理 4.2 根保护机制配置命令 4.3 根保护机制配置步骤 5. 环路保护机…

thefour--Love is like a tide

最后一部分了&#xff0c;要开始进行我们的训练了。 先上代码&#xff1a; import os import numpy as np from tqdm import tqdm import tensorflow as tf from thetwo import NeuralStyleTransferModel import theone import thethree #创建模型 modelNeuralStyleTransferM…