【Datawhale X 李宏毅苹果书 AI夏令营】Task2笔记

news2024/9/21 2:33:32

第三章:深度学习基础

本章前部分的内容见:【Datawhale X 李宏毅苹果书 AI夏令营】Task1笔记-CSDN博客

3.6 分类

分类与回归的关系

假设三个类本身没有特定的关系,类 1 是 1,类 2 是 2 类 3 是 3。这种情况需要引入独热(one-hot)向量来表示类。

通常,我们使用逻辑回归(而不是线性回归)处理分类问题。

激活函数:softmax

按照上述的设定,分类实际过程是:输入 x,乘上 W,加上 b,通过激活函数 σ,乘上W′,再加上 b′ ,得到向量 yˆ。但实际做分类的时候,往往会把 yˆ 通过 softmax 函数得到 y′,才去计算 y′ 跟 yˆ 之间的距离。

softmax是一种激活函数,激活函数用于对输入信号进行非线性变换。常见的激活函数还有Sigmoid,ReLU等等。

激活函数必须满足以下条件:

  • 可微,优化方法是基于梯度。

  • 单调,保证单层网络是凸函数。

  • 输出值范围,有限则梯度优化更稳定,无限则训练更高效(学习率需要更小)。

softmax常用于多分类问题,计算公式如下:

使用softmax,一言以蔽之是为了归一化。另外,由于使用了 e 的幂函数,softmax 可以使正样本(正数)的结果趋近于 1,使负样本(负数)的结果趋近于 0;且样本的绝对值越大,两极化越明显。

参考:

https://blog.csdn.net/qq_41750911/article/details/124078768

https://blog.csdn.net/qq_43799400/article/details/131202148

分类损失:交叉熵

在分类问题下,使用的是交叉熵cross entropy,而非线性回归的均方误差(MSE)。最小化交叉熵其实就是最大化似然(maximize likelihood)。

为什么使用交叉熵作为损失函数L(f)而不是均方误差?可看下图的推导:假设训练集中 1 表示类 1,0 表示类 2,然后计算L(f)的微分(step3)。假设第n个数据是类 1,预测出的数据也是类 1,这意味着微分为0,是合理的。然而假如我预测出来的是类 2,但计算出微分仍然是0,这是不对的,因为离我的目标(希望output是1而不是0)还很遥远。

换句话说,当距离目标很近或者很远的时候,逻辑回归的均方误差会导致步长很小,但是交叉熵可以让步长在举例很远的时候仍然很大。

参考:

李宏毅机器学习【2017】https://www.bilibili.com/video/BV1SA411n7ou/?p=10

附注:分类与回归的对比

实践:HW3(CNN)卷积神经网络-图像分类

通过利用卷积神经网络架构,通过一个较小的10种食物的图像的数据集训练一个模型完成图像分类的任务。

关键流程的代码说明

  1. 图像预处理/变换

调整PIL图像大小并转换为Tensor

神经网络需要的输入数据类型一般是 FloatTensor 类型,且需要进行标准化,这个过程常常使用 transforms.ToTensor() 方法来实现。

test_tfm = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), ])
  1. 数据集加载

训练集、验证集和数据加载器
# 构建训练数据集 train_set = FoodDataset("./hw3_data/train", tfm=train_tfm) train_loader = 
DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True) 
# 构建验证数据集 valid_set = FoodDataset("./hw3_data/valid", tfm=test_tfm) valid_loader = 
DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
  1. 模型定义

该分类器通过一系列卷积层(Conv2d)、批归一化层(BatchNorm2d)、激活函数(ReLU)和池化层(MaxPool2d)构建卷积神经网络,用于提取图像特征。随后,这些特征被输入到全连接层进行分类,最终输出11个类别的概率,用于图像分类任务。

构建卷积神经网络的结构:卷积层、批归一化层、激活函数和池化层
    def __init__(self):
        super(Classifier, self).__init__()
        # 定义卷积神经网络的序列结构
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),  # 输入通道3,输出通道64,卷积核大小3,步长1,填充1
            nn.BatchNorm2d(64),        # 批归一化,作用于64个通道
            nn.ReLU(),                 # ReLU激活函数
            nn.MaxPool2d(2, 2, 0),      # 最大池化,池化窗口大小2,步长2,填充0
            
            nn.Conv2d(64, 128, 3, 1, 1), # 输入通道64,输出通道128,卷积核大小3,步长1,填充1
            nn.BatchNorm2d(128),        # 批归一化,作用于128个通道
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # 最大池化,池化窗口大小2,步长2,填充0
            
            nn.Conv2d(128, 256, 3, 1, 1), # 输入通道128,输出通道256,卷积核大小3,步长1,填充1
            nn.BatchNorm2d(256),        # 批归一化,作用于256个通道
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),      # 最大池化,池化窗口大小2,步长2,填充0
            
            nn.Conv2d(256, 512, 3, 1, 1), # 输入通道256,输出通道512,卷积核大小3,步长1,填充1
            nn.BatchNorm2d(512),        # 批归一化,作用于512个通道
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),       # 最大池化,池化窗口大小2,步长2,填充0
            
            nn.Conv2d(512, 512, 3, 1, 1), # 输入通道512,输出通道512,卷积核大小3,步长1,填充1
            nn.BatchNorm2d(512),        # 批归一化,作用于512个通道
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),       # 最大池化,池化窗口大小2,步长2,填充0
        )
        # 定义全连接神经网络的序列结构
        self.fc = nn.Sequential(
            nn.Linear(512*4*4, 1024),    # 输入大小512*4*4,输出大小1024
            nn.ReLU(),
            nn.Linear(1024, 512),        # 输入大小1024,输出大小512
            nn.ReLU(),
            nn.Linear(512, 11)           # 输入大小512,输出大小11,最终输出11个类别的概率
        )
前向传播
    def forward(self, x):
        """
        前向传播函数,对输入进行处理。
        
        参数:
        x -- 输入的图像数据,形状为(batch_size, 3, 128, 128)
        
        返回:
        输出的分类结果,形状为(batch_size, 11)
        """
        out = self.cnn(x)               # 通过卷积神经网络处理输入
        out = out.view(out.size()[0], -1)  # 展平输出,以适配全连接层的输入要求
        return self.fc(out)             # 通过全连接神经网络得到最终输出
  1. 定义损失函数

分类任务使用交叉熵作为性能衡量标准。

criterion = nn.CrossEntropyLoss()
  1. 训练模型

通过多轮训练(epochs)逐步优化模型的参数,以提高其在验证集上的性能,并保存效果最好的模型。

注意optimizer.zero_grad()是每次都需要做的操作。

# 初始化追踪器,这些不是参数,不应该被更改
stale = 0
best_acc = 0

for epoch in range(n_epochs):
    # ---------- 训练阶段 ----------
    # 确保模型处于训练模式
    model.train()

    # 这些用于记录训练过程中的信息
    train_loss = []
    train_accs = []

    for batch in tqdm(train_loader):
        # 每个批次包含图像数据及其对应的标签
        imgs, labels = batch
        # imgs = imgs.half()
        # print(imgs.shape,labels.shape)

        # 前向传播数据。(确保数据和模型位于同一设备上)
        logits = model(imgs.to(device))

        # 计算交叉熵损失。
        # 在计算交叉熵之前不需要应用softmax,因为它会自动完成。
        loss = criterion(logits, labels.to(device))

        # 清除上一步中参数中存储的梯度
        optimizer.zero_grad()

        # 计算参数的梯度
        loss.backward()

        # 为了稳定训练,限制梯度范数
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

        # 使用计算出的梯度更新参数
        optimizer.step()

        # 计算当前批次的准确率
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # 记录损失和准确率
        train_loss.append(loss.item())
        train_accs.append(acc)

    train_loss = sum(train_loss) / len(train_loss)
    train_acc = sum(train_accs) / len(train_accs)
  1. 评估模型

在验证集上进行。

# ---------- 验证阶段 ----------
    # 确保模型处于评估模式,以便某些模块如dropout能够正常工作
    model.eval()

    # 这些用于记录验证过程中的信息
    valid_loss = []
    valid_accs = []

    # 按批次迭代验证集
    for batch in tqdm(valid_loader):
        # 每个批次包含图像数据及其对应的标签
        imgs, labels = batch
        # imgs = imgs.half()

        # 我们在验证阶段不需要梯度。
        # 使用 torch.no_grad() 加速前向传播过程。
        with torch.no_grad():
            logits = model(imgs.to(device))

        # 我们仍然可以计算损失(但不计算梯度)。
        loss = criterion(logits, labels.to(device))

        # 计算当前批次的准确率
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        # 记录损失和准确率
        valid_loss.append(loss.item())
        valid_accs.append(acc)
        # break

    # 整个验证集的平均损失和准确率是所记录值的平均
    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_acc = sum(valid_accs) / len(valid_accs)


    # 保存模型
    if valid_acc > best_acc:
        print(f"在第 {epoch} 轮找到最佳模型,正在保存模型")
        torch.save(model.state_dict(), f"{_exp_name}_best.ckpt")  # 只保存最佳模型以防止输出内存超出错误
        best_acc = valid_acc
        stale = 0
    else:
        stale += 1
        if stale > patience:
            print(f"连续 {patience} 轮没有改进,提前停止")
            break
  1. 预测

使用跑出来的模型在测试集上预测。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2082334.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

kubernetes培训

基本概念 Node 节点可以是物理机或虚拟机,每个节点上都运行着容器运行时环境; Pod Pod是k8s中的最小调度单元,一个Pod可以包含一个或多个容器,同一Pod内的容器共享存储卷和网络空间。容器则是轻量级、可移植的执行单元&#xf…

四、LogicFlow 自定义左侧菜单Menu

目录 前文LogicFlow 介绍实现基础界面框架实现左侧菜单组件将左侧菜单引入到demo组件中最后 前文 这篇相对来讲就稍微平凡了一点,只要有前端的一些基础就能够轻松完成上图中左侧的菜单,但是为了能够让前后文章能够连贯起来,所以还是要厚着脸…

一文彻底搞懂ZooKeeper选举机制

原文阅读:【巨人肩膀社区博客分享】一文彻底搞懂ZooKeeper选举机制 1. ZooKeeper 集群 ZooKeeper 是一个高性能分布式的开源协调服务,用于构建分布式应用程序和服务。 一个 ZooKeeper 集群通常由多个 ZooKeeper 服务器组成,这些服务器分布在…

C语言基础(二十三)

在C语言中,修改链表中的数据涉及遍历链表以找到要修改的元素,然后更新该元素的值。链表是一种动态数据结构,它由一系列节点组成,每个节点包含数据部分和指向列表中下一个节点的指针(双向链表,还会有指向前一…

<Rust>egui学习之小部件(四):如何在窗口中添加滑动条部件?

前言 本专栏是关于Rust的GUI库egui的部件讲解及应用实例分析,主要讲解egui的源代码、部件属性、如何应用。 环境配置 系统:windows 平台:visual studio code 语言:rust 库:egui、eframe 概述 本文是本专栏的第四篇博…

Ubuntu 24.04 安装 intel 编译器

目录 1.采用用户界面 GUI 安装英特尔基本工具包 Intel oneAPI Base Toolkit 1.1 下载离线英特尔基本工具包 1.2 安装英特尔基本工具包 1.3 英特尔基本工具包 Intel oneAPI Base Toolkit 环境设置 2.安装英特尔基本工具包 Intel HPC Toolkit 2.1 下载离线英特尔高性能计算…

智能座舱高通8155摄像头方案

高通汽车开发平台 (ADP)基于8155的多媒体硬件框图如下所示:有4个4路CSI摄像头处理通路,2个4路DSI屏幕处理通路,1个DisplayPort。 基于摄像头的详细方案如下:可以处理4路MAX9296解串后信号。 再深入细化基于…

Java10 集合

集合 集合集合接口等级:Collection:单例集合接口,将数据一个一个存储,存储的是值。ArrayList类:泛型集合Linkedlist集合:Vector集合:Stack集合:Vetor的子类 Set接口:存储是无序的&am…

【使用 Python 进行截图的两种方法】

在 Python 中,可以使用 pyautogui 和 Pillow 进行截图 使用 pyautogui 进行截图时,其提供了方便的函数。例如,使用 pyautogui.screenshot() 函数可以获取整个屏幕的截图,该函数返回一个包含屏幕截图的图像对象。如果不想截取整个…

最大噪音值甚至受法规限制,如何基于LBM算法有效控制风扇气动噪音

风扇的气动噪声 在工业设备行业,最大噪音值受法规限制。在很多使用风扇冷却的设备上,风扇噪声通常是这些设备工作噪声的最大贡献量。而在家电民用行业,例如空调、空气净化器、油烟机等,其噪音大小直接关系到用户的体验感受&#x…

从零开始掌握容器技术:Docker的奇妙世界

容器技术在当今的云计算和软件开发领域中扮演着越来越重要的角色。如果你是一名计算机专业的学生或从事IT行业的从业者,可能已经听说过Docker这个词。它在软件开发、部署、运维等环节中大放异彩,但对于刚接触这个概念的朋友来说,可能还是有些…

【乐企】有关乐企能力测试接口对接(详细)

1、申请密钥 2、验证本地服务器与乐企服务器的连通性 乐企服务器生产和测试域名均为:https://lqpt.chinatax.gov.cn:8443。开发者可以在“能力中心”查看基础公用能力详情,按照能力接入和开发指引完成接口对接,验证服务器连通性和证书配置正确…

给一个web网站,如何开展测试?

前言 Web测试是指针对Web应用程序(网站或基于Web的系统)进行的测试活动,以确保其质量、性能、安全性、可用性和兼容性等方面符合预期标准。Web测试涵盖了从前端用户界面(UI)到后端逻辑和数据库的各个方面,确保Web应用程序在不同环境和条件下都能正常运行…

参会投稿 | 第三届先进传感与智能制造国际学术会议(ASIM 2024)

第三届先进传感与智能制造国际会议(The 3rd International Conference on Advanced Sensing, Intelligent Manufacturing),由江汉大学、西安交通大学和山东大学主办,由江西省机械工程学会、东华理工大学机械与电子工程学院等联合协…

Hibernate 批量插入速度慢的原因和解决方法

由于业务需要一次性连续写入超过10k条以上的新数据,当对象超过10个成员变量以后,整个写入过程居然需要长达35秒,这个速度是不能接受的,故此研究了一下怎么开启Hibernate批量写入的功能。 我这边使用的是Hibernate 5.6.15 在网上…

推动光模块技术发展:从400G、800G到1.6T

随着数据通信领域的持续发展,对于更快、更高传输速率的需求也在不断增长。作为现代数据传输的基石,光模块技术不断进步以满足这一需求。其中一项重大进展是网络速率从400G提升到800G,并将向1.6T继续发展。让我们深入探讨这些技术的演变&#…

Java语言程序设计基础篇_编程练习题***17.9 (地址簿)

目录 题目:***17.9 (地址簿) 习题思路 代码示例 结果展示 题目:***17.9 (地址簿) 编写程序用于存储、返回、增加,以及更新如图 17-20 所示的地址薄。使用固定长度的字符串来存储地址中的每个属性。使用随机访问文件来读取和写人一个地址…

刚刚认证!网络主播成为国家新职业,易播易赚打造打造职业入门全新模式

近期,人力资源和社会保障部会同国家市场监督管理总局、国家统计局日前增设网络主播为国家新职业,这标志着网络主播的职业身份在“国家确定职业分类”上首次得以确立。 据人社部此前印发的《关于加强新职业培训工作的通知》表示,新职业从业者可…

代码随想录算法训练营第二十三天| 39. 组合总和 40.组合总和II 131.分割回文串

目录 一、LeetCode 39. 组合总和思路:C代码 二、LeetCode 40.组合总和II思路C代码 三、LeetCode 131.分割回文串思路C代码 总结 一、LeetCode 39. 组合总和 题目链接:LeetCode 39. 组合总和 文章讲解:代码随想录 视频讲解:带你学…

直播平台直播API集成之快手篇

前言: 本篇我们来介绍如何使用快手 的直播API创建直播。 准备工作: 1、你首先得有个快手账号; 2、创建快手应用,填写应用审核信息,等待应用创建审核通过,应用成功创建后在开发与上线前还要提前做好API权限申请,如果你只需要获取用户基本信息,以及得到直播API的访问权限…