李宏毅2023机器学习作业HW03解析和代码分享

news2024/11/23 15:53:14

ML2023Spring - HW3 相关信息:
课程主页
课程视频
Kaggle link
Sample code
HW03 视频
HW03 PDF
个人完整代码分享: GitHub | Gitee | GitCode

P.S. 即便 kaggle 上的时间已经截止,你仍然可以在上面提交和查看分数。但需要注意的是:在 kaggle 截止日期前你应该选择两个结果进行最后的Private评分。
每年的数据集size和feature并不完全相同,但基本一致,过去的代码仍可用于新一年的 Homework。

代码仓库中关于HW03的代码暂时没有boss,仅为0.84666,最近繁琐事情太多,还在比赛,所以先上传分享思路

文章目录

  • 任务目标(图像分类)
  • 性能指标(Metric)
  • 数据解析
    • 数据下载(kaggle)
  • Gradescope (Report)
    • Q1. Augmentation Implementation
    • Q2. Visual Representations Implementation
  • Baselines
    • Simple baseline (0.637)
    • Medium baseline (0.700)
    • Strong baseline (0.814)
    • Boss baseline (0.874)
  • 小坑
  • 参考链接

任务目标(图像分类)

使用 CNN 进行图像分类

性能指标(Metric)

在测试集上的分类精度:
A c c = p r e d = = l a b e l l e n ( d a t a ) ∗ 100 % Acc = \frac{pred==label}{len(data)} * 100\% \nonumber Acc=len(data)pred==label100%

数据解析

  • ./train (Training set): 图像命名的格式为 “x_y.png”,其中 x 是类别,含有 10,000 张被标记的图像
  • ./valid (Valid set): 图像命名的格式为 “x_y.png”,其中 x 是类别,含有 3,643 张被标记的图像
  • ./test (Testing set): 图像命名的格式为 “n.png”,n 是 id,含有 3,000 张未标记的图像

数据来源于 food-11 数据集,共有 11 类。

数据下载(kaggle)

To use the Kaggle API, sign up for a Kaggle account at https://www.kaggle.com. Then go to the ‘Account’ tab of your user profile (https://www.kaggle.com/<username>/account) and select ‘Create API Token’. This will trigger the download of kaggle.json, a file containing your API credentials. Place this file in the location ~/.kaggle/kaggle.json (on Windows in the location C:\Users\<Windows-username>\.kaggle\kaggle.json - you can check the exact location, sans drive, with echo %HOMEPATH%). You can define a shell environment variable KAGGLE_CONFIG_DIR to change this location to $KAGGLE_CONFIG_DIR/kaggle.json (on Windows it will be %KAGGLE_CONFIG_DIR%\kaggle.json).

-- Official Kaggle API

gdown 的链接如果挂了或者太慢,可以考虑使用 kaggleapi,流程非常简单,替换<username>为你自己的用户名,https://www.kaggle.com/<username>/account,然后点击 Create New API Token,将下载下来的文件放去应该放的位置:

  • Mac 和 Linux 放在 ~/.kaggle
  • Windows 放在 C:\Users\<Windows-username>\.kaggle
pip install kaggle
# 你需要先在 Kaggle -> Account -> Create New API Token 中下载 kaggle.json
# mv kaggle.json ~/.kaggle/kaggle.json
kaggle competitions download -c ml2023spring-hw3
unzip ml2023spring-hw3

Gradescope (Report)

from PIL import image

什么是 PIL?

PIL (Python Image Library) 是 python 的第三方图像处理库,支持图像存储,显示和处理,能够处理几乎所有的图片格式。

PIL.Image 模块在 sample code 中用于加载图像。

Q1. Augmentation Implementation

需要完成至少 5 种 transform,这一步能让你熟悉 Data Augmentation 到底是在做什么。

直接看代码部分,调用了 transforms 中的函数。

image-20230331210628055

往回追溯:

image-20230331211403647

可以看到 transforms 其实就是 torchvision.transforms。

torchvision.transforms 是 pytorch 中的图像预处理包,提供了常用的图像变换方式,可以通过 Compose 将多个变换步骤整合到一起,你可以查看这篇文章:torchvision.transforms 常用方法解析(含图例代码以及参数解释)进一步了解,最好是自行组合 5 个跑几次实验之后再偷懒。

下面的代码可以让你看到 train_tfm 究竟做了什么变换。

# I want to show you an example code of Q1. Augmentation Implementation that visualizes the effects of different image transformations.
import matplotlib.pyplot as plt

plt.rcParams["savefig.bbox"] = 'tight'

# You can change the file path to match your image
orig_img = Image.open('Q1/assets/astronaut.jpg')


def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

# Create a list of five transformed images from the original image using the train_tfm function
demo = [train_tfm(orig_img) for i in range(5)]

# Convert the transformed images from tensors to PIL images
pil_img_demo = [Image.fromarray(np.moveaxis(img.numpy()*255, 0, -1).astype(np.uint8)) for img in demo]

# Plot the transformed images using the plot function
plot(pil_img_demo) 

train_tfm

Q2. Visual Representations Implementation

下图是 Top/Mid/Bottom 的定义,你可以在 sample code 的最下面找到完成这个问题的代码。

 CNN architecture

根据你的模型修改其中的 index。

Baselines

Simple baseline (0.637)

  • 运行所给的 sample code

Medium baseline (0.700)

  • 做数据增强

    RandomChoice 很好用,另外,lamda x:x 可以返回原图。

  • 训练更长时间

    根据 PDF 给出的参考训练时间,simple 是 0.5h,medium 是 1.5h,那么在这里我选择的是简单的将原来的 epoch *= 3,也就是 24 个 epoch 来进行最终的训练

Strong baseline (0.814)

  • 使用预训练模型
    这里你可能有疑惑:不是说不能使用预训练模型吗?
    是的,你只能使用预训练模型的架构,不能使用预训练的权重,下面是不使用权重的参数设置。

    • Torchvision 版本 < 0.13 -> pretrained=False
    • > 0.13 -> weights=None

    模型对比 (160 epoch, 10 patience, ReduceLROnPlateau,使用了相当于原数据20倍的transforms) :

    • 初始模型:0.80000
    • resnet50: 0.732
    • vgg16: 0.64733
    • densenet121: 0.76533
    • alexnet: 0.61866
    • squeezenet: 0.64200

    我觉得这一项的主要目的在于让你认识这些预训练模型的架构,因为可以看到,不使用预训练参数的情况下,实验结果并没有变得更好(使用预训练参数的话,以resnet50为例,仅使用预训练模型就可以轻松到达strong baseline,你可以试试,但不要用它来当作你的kaggle结果)。
    image-20230406145717393

    但既然PDF中的hint仅仅只是使用预训练模型,我相信一定有什么地方可以调优,使得仅使用预训练模型架构就可以达到 strong baseline,简单对比了使用参数和不使用参数的情况下 acc 的提升情况,发现同样的 lr,使用预训练参数的时候上升幅度更大,所以我想了下:

    1. 有没有可能是我的 lr 太小了?调大试试
    2. 会不会是我的transforms不够,因为在我的代码中,5%的可能性不进行transforms,也就是说,20倍的数据增强。50倍试试
    3. Medium baseline的工作没做好,加TTA(Test Time Augmentation),将train_tfm用到测试集上试试

    但上述方法都没有得到好的效果,最终我直接用最开始的CNN模型跑了200多个epoch完成了该strong baseline,这个坑以后来填,再耗在这更新来不及了 : )

    image-20230408113446125

Boss baseline (0.874)

  • Cross validation 交叉验证

  • Ensemble 模型集合
    相关视频: ML Lecture 22: Ensemble ,如果没有科学上网,这里是两个相同视频的链接地址:bilibili,学校官网。
    这两项确实有很大的提升,差不多有6个点,再修改一下原来的架构就行了。

小坑

  1. 注意你的 lr,我在做 cross validation 的时候,不小心将 lr 设置的过大,导致一开始学习的很差,还以为是数据集划分的索引问题,折腾了半天。
  2. 如果你将train文件夹和valid文件夹下的内容合并成一个新的文件夹(为了做 cross validation),那么在做 K-fold 的时候,序号一定要 shuffle 去打乱,你只要默认打乱了,就不需要考虑太多,否则就会出现一种情况:验证集的标签有可能在训练集中不存在,那就意味着,你的模型可能几乎没见过验证集里面的 label,如果完全没见过,那 acc 甚至有可能是 0。下面是我当时疏忽导致的 bug:[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QdG5JLvc-1681389223934)(/Users/home/Library/Application%20Support/typora-user-images/image-20230407205208218.png)]

参考链接

Image Module - Pillow (PIL Fork) 9.4.0 documentation

TRANSFORMING AND AUGMENTING IMAGES

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/422340.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringAOP入门基础银行转账实例------------事务处理

SpringAOP入门基础银行转账实例------------事务处理 AOP为Aspect Oriented Programming 的缩写&#xff0c;意思为面向切面编程&#xff0c;是通过编译方式和运行期动态代理实现程序功能的统一维护的一种技术。 AOP编程思想 AOP面向切面是一种编程思想&#xff0c;是oop的延…

Python 字符串format()格式化 / 索引

前言 嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! 为了确保字符串按预期显示&#xff0c;我们可以使用 format() 方法对结果进行格式化。 字符串 format() format() 方法允许您格式化字符串的选定部分。 有时文本的一部分是你无法控制的&#xff0c;也许它们来自数据库或…

更深度了解getchar和putchar现象

目录 前言&#xff1a; 1.getchar和putchar 1.1基本使用 1.2一些特殊打印 1.3putchar打印空格 2.深度了解现象 前言&#xff1a; 经过学习&#xff0c;总结getchar()函数和putchar()函数在搭配使用while循环的时候&#xff0c;控制台窗口光标位置的出现位置的由来。 1.…

JavaSE学习进阶day04_03 包装类

第五章 包装类&#xff08;重点&#xff09; 5.1 概述 Java提供了两个类型系统&#xff0c;基本类型与引用类型&#xff0c;使用基本类型在于效率&#xff0c;然而很多情况&#xff0c;会创建对象使用&#xff0c;因为对象可以做更多的功能&#xff0c;如果想要我们的基本类型…

蓝桥杯15单片机--超声波模块

目录 一、超声波工作原理 二、超声波电路图 三、程序设计 1-设计思路 2-具体实现 四、程序源码 一、超声波工作原理 超声波时间差测距原理超声波发射器向某一方向发射超声波&#xff0c;在发射时刻的同时开始计时&#xff0c;超声波在空气中传播&#xff0c;途中碰到障碍…

计算属性,watch和watchEffect

计算属性-computed 什么是计算属性&#xff1a; computed函数&#xff0c;是用来定义计算属性的&#xff0c;计算属性不能修改。 模板内的表达式非常便利&#xff0c;但是设计它们的初衷是用于简单运算的。在模板中放入太多的逻辑会让模板过重且难以维护。 计算属性还可以依…

【目标检测论文阅读笔记】Extended Feature Pyramid Network for Small Object Detection

&#xff08;未找到代码&#xff0c;只有yaml文件&#xff09; Abstract. 小目标检测仍然是一个未解决的挑战&#xff0c;因为很难提取只有几个像素的小物体的信息。虽然特征金字塔网络中的尺度级对应检测缓解了这个问题&#xff0c;但我们发现各种尺度的特征耦合仍然会损害小…

百度飞桨paddlespeech实现小程序实时语音流识别

前言&#xff1a; 哈哈&#xff0c;这是我2023年4月份的公司作业。如果仅仅是简单的语音识别倒也没什么难度&#xff0c;wav文件直接走模型输出结果的事。可是注意标题&#xff0c;流式识别、实时&#xff01; 那么不得不说一下流式的优点了。 1、解决内存溢出的烦恼。 2、…

《论文阅读》Unified Named Entity Recognition as Word-Word Relation Classification

总结 将NER视作是word-word间的 Relation Classification。 这个word-word 间的工作就很像是TPlinker那个工作&#xff0c;那篇工作是使用token间的 link。推荐指数&#xff1a;★★★☆☆值得学习的点&#xff1a; &#xff08;1&#xff09;用关系抽取的方法做NER抽取 &…

佳明手表APP开发系列01——简单汉化英文版

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录前言一、佳明手表APP开发过程简介二、做个简单的个性化——在英文版写几个汉字1.MonkeyC 图形处理2.获得汉字点阵字模数据3.MonkeyC 汉字输出函数总结前言 佳明手表…

蓝海创意云应邀参与苏州市元宇宙生态大会

4月14日&#xff0c;苏州市软件行业协会元宇宙专委会成立大会暨元宇宙生态大会在苏成功举办。此次大会由苏州市工业和信息化局指导&#xff0c;苏州高新区&#xff08;虎丘区&#xff09;经济发展委员会、苏州市软件行业协会主办&#xff0c;蓝海彤翔集团作为协办单位参与此次大…

IDEA集成Git、GitHub、Gitee

一、IDEA 集成 Git 1.1、配置 Git 忽略文件 为什么要忽略他们&#xff1f; 与项目的实际功能无关&#xff0c;不参与服务器上部署运行。把它们忽略掉能够屏蔽 IDE 工具之间的差异。 怎么忽略&#xff1f; 创建忽略规则文件 xxxx.ignore&#xff08;前缀名随便起&#xff0c…

创建Google play开发者账号,并验证身份通过

一、注册前准备 最好准备一台没有怎么用过Google的电脑和&#x1fa9c;准备一个没有注册过Google的手机号准备一张信用卡或者借记卡&#xff08;需要支付$25&#xff09;&#xff0c;支持的类型如下图 这里还需注意&#xff1a;最后账号注册成功还需要验证身份也就是实名认证&…

关于Python爬虫的一些总结

作为一名资深的爬虫工程师来说&#xff0c;把别人公开的一些合法数据通过爬虫手段实现汇总收集是一件很有成就的事情&#xff0c;其实这只是一种技术。 初始爬虫 问题&#xff1a; 什么是爬虫&#xff1f; 网络爬虫是一种按照一定的规则&#xff0c;自动地抓取网络信息的程…

动态规划算法OJ刷题(3)

CC19 分割回文串-ii 问题描述 给出一个字符串s&#xff0c;分割s使得分割出的每一个子串都是回文串。计算将字符串s分割成回文串的最小切割数。例如:给定字符串s“aab”&#xff0c;返回1&#xff0c;因为回文分割结果[“aa”,“b”]是切割一次生成的。 解题思路 方法1&…

计算机操作系统(第四版)第四章存储器管理—课后习题答案

1.为什么要配置层次存储器&#xff1f; &#xff08;1&#xff09;设置多个存储器可以使存储器两端的硬件能并行工作。 &#xff08;2&#xff09;采用多级存储系统,特别是Cache技术,这是一种减轻存储器带宽对系统性能影响的最佳结构方案。 &#xff08;3&#xff09;在微处理机…

《Java8实战》第5章 使用流

上一章已经体验到流让你从外部迭代转向内部迭代。 5.1 筛选 看如何选择流中的元素&#xff1a;用谓词筛选&#xff0c;筛选出各不相同的元素。 5.1.1 用谓词筛选 filter 方法&#xff0c;该操作会接受一个谓词&#xff08;一个返回boolean 的函数&#xff09;作为参数&am…

MySQL数据库:聚合函数、分组查询、约束、默认值设置、自增属性

一、聚合函数 1.聚合函数 在MySQL数据库中预定义好的一些数据统计函数。 2.count(*) 功能&#xff1a;统计结果条数。 3.sum(字段名) 功能&#xff1a;对指定字段的数据求和。 4.avg(字段名) 功能&#xff1a;对指定字段的数据求平均值。 5.max(字段名) 和 min(字段名) …

正则化的基本认识

正则化(一) 拟合与欠拟合(二) 正则化的目的(三) 惩罚项&#xff08;3.1&#xff09;常用的惩罚项&#xff1a;&#xff08;3.2&#xff09;L-P范数&#xff1a;&#xff08;3.3&#xff09;L1与L2的选择&#xff1a;(一) 拟合与欠拟合 欠拟合&#xff1a; 是指测试级与训练集都…

docker目录映射

docker 常用命令 docker ps // 查看所有正在运行容器 docker stop containerId // containerId 是容器的ID docker ps -a // 查看所有容器 $ docker ps -a -q // 查看所有容器ID docker stop $(docker ps -a -q) // stop停止所有容器 docker rm $(docker ps -a -q) // remove删…