李宏毅2023机器学习作业HW05解析和代码分享

news2025/1/19 23:02:26

ML2023Spring - HW5 相关信息:
课程主页
课程视频
Sample code
HW05 视频
HW05 PDF

个人完整代码分享: GitHub | Gitee | GitCode
运行日志记录: wandb

P.S. HW05/06 是在 Judgeboi 上提交的,完全遵循 hint 就可以达到预期效果。

因为无法在 Judgeboi 上提交,所以 HW05/06 代码仓库中展示的是在验证集上的分数。

每年的数据集 size 和 feature 并不完全相同,但基本一致,过去的代码仍可用于新一年的 Homework。

仓库中 HW05 的代码分成了英文 EN 和中文 ZH 两个版本。

(碎碎念:翻译比较麻烦,所以之后的 Homework 代码暂只有英文版本)

任务目标(seq2seq)

  • Machine translation 机器翻译,英译中

性能指标(BLEU)

参考链接:

BLEU: a Method for Automatic Evaluation of Machine Translation

Foundations of NLP Explained — Bleu Score and WER Metrics

BLEU(Bilingual Evaluation Understudy) 双语评估替换

公式:
BLEU = B P ⋅ exp ⁡ ( ∑ n = 1 N w n l o g   p n ) 1 N \text{BLEU} = BP \cdot \exp\left( \sum_{n=1}^{N} w_n log\ p_n\right)^{\frac{1}{N}} BLEU=BPexp(n=1Nwnlog pn)N1
首先要明确两个概念

  1. N-gram
    用来描述句子中的一组 n 个连续的单词。比如,“Thank you so much” 中的 n-grams:

    • 1-gram: “Thank”, “you”, “so”, “much”
    • 2-gram: “Thank you”, “you so”, “so much”
    • 3-gram: “Thank you so”, “you so much”
    • 4-gram: “Thank you so much”

    需要注意的一点是,n-gram 中的单词是按顺序排列的,所以 “so much Thank you” 不是一个有效的 4-gram。

  2. 精确度(Precision)
    精确度是 Candidate text 中与 Reference text 相同的单词数占总单词数的比例。 具体公式如下:
    $ \text{Precision} = \frac{\text{Number of overlapping words}}{\text{Total number of words in candidate text}} $
    比如:
    Candidate: Thank you so much, Chris
    Reference: Thank you so much, my brother
    这里相同的单词数为4,总单词数为5,所以 Precision = 4 5 \text{Precision} = \frac{{4}}{{5}} Precision=54
    但存在一个问题:

    • Repetition 重复

      Candidate: Thank Thank Thank
      Reference: Thank you so much, my brother

      此时的 Precision = 3 3 \text{Precision} = \frac{{3}}{{3}} Precision=33

解决方法:Modified Precision

很简单的思想,就是匹配过的不再进行匹配。

Candidate: Thank Thank Thank
Reference: Thank you so much, my brother

Precision 1 = 1 3 \text{Precision}_1 = \frac{{1}}{{3}} Precision1=31

  • 具体计算如下:

    C o u n t c l i p = min ⁡ ( C o u n t ,   M a x _ R e f _ C o u n t ) = min ⁡ ( 3 ,   1 ) = 1 Count_{clip} = \min(Count,\ Max\_Ref\_Count)=\min(3,\ 1)=1 Countclip=min(Count, Max_Ref_Count)=min(3, 1)=1
    $ p_n = \frac{\sum_{\text{n-gram}} Count_{clip}}{\sum_{\text{n-gram}} Count} = \frac{1}{3}$

现在还存在一个问题:译文过短

Candidate: Thank you
Reference: Thank you so much, my brother

p 1 = 2 2 = 1 p_1 = \frac{{2}}{{2}} = 1 p1=22=1

这里引出了 brevity penalty,这是一个惩罚因子,公式如下:

B P = { 1 if  c > r e 1 − r c if  c ≤ r BP = \begin{cases} 1& \text{if}\ c>r\\ e^{1-\frac{r}{c}}& \text{if}\ c \leq r \end{cases} BP={1e1crif c>rif cr

其中 c 是 candidate 的长度,r 是 reference 的长度。

当候选译文的长度 c 等于参考译文的长度 r 的时候,BP = 1,当候选翻译的文本长度较短的时候,用 e 1 − r c e^{1-\frac{r}{c}} e1cr 作为 BP 值。

回到原来的公式:$ \text{BLEU} = BP \cdot \exp\left( \sum_{n=1}^{N} w_n log\ p_n\right)^{\frac{1}{N}}$,汇总一下符号定义:

  • B P BP BP 文本长度的惩罚因子
  • N N N n-gram 中 n 的最大值,作业中设置为 4。
  • w n w_n wn 权重
  • p n p_n pn n-gram 的精度 (precision)

数据解析

  • Paired data
    • TED2020: 演讲
      • Raw: 400,726 (sentences)
      • Processed: 394, 052 (sentences)
    • 英文和中文两个版本
  • Monolingual data
    • 只有中文版本的 TED 演讲数据

Baselines

这里存在一个问题,就是HW05是在 Judgeboi 上进行提交的,所以没办法获取最终的分数,所以简单的使用 simple baseline 对应的 validate BLEU 来做个映射。

因为有 EN / ZH 两个版本,对于每个 hint 我会给出代码的修改位置方便大家索引。

Simple baseline (15.05)

  • 运行所给的 sample code

Medium baseline (18.44)

  • 增加学习率的调度 (Optimizer: Adam + lr scheduling / 优化器: Adam + 学习率调度)
  • 训练得更久 (Configuration for experiments / 实验配置)
    这里根据预估的时间,可以简单的将 epoch 设置为原来的两倍。

Strong baseline (23.57)

  • 将模型架构转变为 Transformer (Model Initialization / 模型初始化)
  • 调整超参数 (Architecture Related Configuration / 架构相关配置)
    这里需要参考 Attention is all you need 论文中 table 3 的 transformer-base 超参数设置。
    image-20231115135033382

你可以仅遵循 sample code 的注释,将 encoder_layer 和 decoder_layer 改为 4(简单的将这一个改动称之为 transformer_4layer),此时模型的参数数量会和之前的 RNN 差不多,在 max_epoch =30 的情况下,Bleu 可以达到 23.59。

代码仓库中分享的 Strong 代码完全遵循了 transformer-base 的超参数设置,此时的模型参数将约为之前 RNN 的 5 倍,每一轮训练的时间约为 transform_4layer 的三倍,所以我将 max_epoch 设置为了 10,让其能够匹配上预估的时间,此时的 Bleu 为 24.91。如果将 max_epoch 设置为 30,最终的 Bleu 可以达到 27.48。

下面是二者实验对比。

Boss baseline (30.08)

  • 应用 back-translation (TODO)

    这里我们需要交换实验配置 config 中的 source_lang 和 target_lang,并修改 savedir,训练一个 back-translation 模型后再修改回原来的 config。

    然后你需要将 TODO 的部分完善,修改并复用之前的函数就可以达到目的。

    (为了与预估时间匹配,这里将 max_epoch 设置为 30 进行实验。)

代码仓库中分享的 Boss 代码展示的是最终训练的结果,完整的运行流程是:

  1. 实验配置中 / Configuration for experimentsBACK_TRANSLATION 设置为 True 运行
    训练一个 back-translation 模型,并处理好对应的语料。
  2. 实验配置 / Configuration for experiments 中的 BACK_TRANSLATION 设置为 False 运行
    结合 ted2020 和 mono (back-translation) 的语料进行训练。

Gradescope

Visualize Positional Embedding

你可以直接在 确定用于生成 submission 的模型权重 / Confirm model weights used to generate submission 后进行处理,在仓库的代码中我已经提前注释掉了 训练循环 / Training loop 中的训练部分,如果在之前,模型没有训练,直接运行代码会报错。

image-20231119122408389

添加的处理代码如下(可以复制下面的处理代码放到你的 submission 模块之后):

推荐阅读:All Pairs Cosine Similarity in PyTorch

pos_emb = model.decoder.embed_positions.weights.cpu().detach()

# 计算余弦相似度矩阵
def get_cosine_similarity_matrix(x):
    x = x / x.norm(dim=1, keepdim=True)
    sim = torch.mm(x, x.t())
    return sim
    
sim = get_cosine_similarity_matrix(pos_emb)
#sim = F.cosine_similarity(pos_emb.unsqueeze(1), pos_emb.unsqueeze(0), dim=2) # 一样的

# 绘制位置向量的余弦相似度矩阵的热力图
plt.imshow(sim, cmap="hot", vmin=0, vmax=1)
plt.colorbar()

plt.show()

Clipping Gradient Norm

只需要将 config.wandb 设置为 True 即可,此时可以在 wandb 上查看。

image-20231119183413555

或者直接在 train_one_epoch 添加一下处理代码,记录 gnorm。

image-20231119183535765

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1228117.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows上安装MySQL Server.

进入官网 MySQL 找到 下载,并点进入。 往下翻,找到社区下载,进入页面 选择 Mysql community Server 选择系统,下载 之后解压。 将解压文件夹下的bin路径添加到变量值中 配置初始化的my.ini文件 [mysqld] # 设置3306端口 port330…

YOLOv5 学习记录

文章目录 整体概况数据增强与前处理自适应Anchor的计算Lettorbox 架构SiLU激活函数YOLOv5改进点SSPF 模块 正负样本匹配损失函数 整体概况 YOLOv5 是一个基于 Anchor 的单阶段目标检测,其主要分为以下 5 个阶段: 1、输入端:Mosaic 数据增强、…

Parity Game——种类并查集、权值并查集、离散化

题目描述 思路 怎么得到这个序列中每一段的关系? 我们可以把这个只包含0和1的序列看作一个数组,0表示当前位置为0,1表示当前位置为1,利用前缀和的性质可以知道某一段中所包含的1的数量sum1 a[r] - a[l-1] 如果sum1为偶数&…

【0到1学习Unity脚本编程】第一人称视角的角色控制器

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:【0…

深信服AC流量管理技术

拓扑图 一.保证通道针对修仙部,访问网站,邮件,DNS,IM,办工 OA,微博论坛网上银行等常见应用保证带宽最低 50%,最高 100% 1. 先新建线路带宽 2.新增流量管理通道(保证关键应用&#x…

吾爱破解置顶的“太极”,太好用了吧!

日常工作和娱乐,都需要用到不同类型的软件,哪怕软件体积不大,也必须安装,否则到用时找不到就非常麻烦了。 其实,很多软件不一定一样不剩地全部安装一遍,一方面原因是用的不多,另一方面多少有点…

安装第三方包报错 error: Microsoft Visual C++ 14.0 or greater is required——解决办法

1、问题描述 手动安装第三方软件时,可以使用setup.py,来安装已经下载的第三方包。一般文件下会存在setup,在所要安装库的目录下的cmd执行:python setup.py install报错:error: Microsoft Visual C 14.0 or greater i…

【LeetCode】二叉树OJ

目录 一、根据二叉树创建字符串 二、二叉树的层序遍历 三、二叉树的层序遍历 II 四、二叉树的最近公共祖先 五、二叉搜索树与双向链表 六、从前序与中序遍历序列构造二叉树 七、从中序与后序遍历序列构造二叉树 一、根据二叉树创建字符串 606. 根据二叉树创建字符串 - …

window上Clion配置C++版本的opencv

window上Clion配置opencv 注意版本一定要对的上,否则可能会出错,亲测 widnows 11mingw 8.1.0opencv 4.5.5 mingw8.1下载地址https://sourceforge.net/projects/mingw/ 配置环境变量 cmake下载 安装完添加环境变量 来到官网,下载 windows 对…

【华为HCIP | 华为数通工程师】刷题日记1116(一个字惨)

个人名片: 🐼作者简介:一名大三在校生,喜欢AI编程🎋 🐻‍❄️个人主页🥇:落798. 🐼个人WeChat:hmmwx53 🕊️系列专栏:🖼️…

真心建议看看这个盈亏平衡点计算方法及要点解析!

说实话,进行产品动态盈亏平衡计算是非常考验人的,因为不是人人都具备评估不同产品组合的盈利能力和掌握风险的方法。 当然最简单的方式就是套用诸如单产品动态盈亏平衡表之类的现成模板进行测算,可以实现以下三点基本需求: 弹性输…

AI实践与学习1_Milvus向量数据库实践与原理分析

前言 随着NLP预训练模型(大模型)以及多模态研究领域的发展,向量数据库被使用的越来越多。 在XOP亿级题库业务背景下,对于试题召回搜索单单靠着ES集群已经出现性能瓶颈,因此需要预研其他技术方案提高试题搜索召回率。…

OpenAI 董事会与 Sam Altman 讨论重返 CEO 岗位事宜

The Verge 援引多位知情人士消息称,OpenAI 董事会正在与 Sam Altman 讨论他重新担任首席执行官的可能性。 有一位知情人士表示,Altman 对于回归公司一事的态度暧昧,尤其是在他没有任何提前通知的情况下被解雇后。他希望对公司的治理模式进行重…

系列三、GC垃圾回收算法和垃圾收集器的关系?分别是什么请你谈谈

一、关系 GC算法(引用计数法、复制算法、标记清除算法、标记整理算法)是方法论,垃圾收集器是算法的落地实现。 二、4种主要垃圾收集器 4.1、串行垃圾收集器(Serial) 它为单线程环境设计,并且只使用一个线程…

【论文阅读】基于隐蔽带宽的汽车控制网络鲁棒认证(二)

文章目录 第三章 识别CAN中的隐藏带宽信道3.1 隐蔽带宽vs.隐藏带宽3.1.1 隐蔽通道3.1.2 隐藏带宽通道 3.2 通道属性3.3 CAN隐藏带宽信道3.3.1 CAN帧ID字段3.3.2 CAN帧数据字段3.3.3 帧错误检测领域3.3.4 时间通道3.3.5 混合通道 3.4 构建信道带宽公式3.5通道矩阵3.6 结论 第四章…

基于Vue+SpringBoot的大病保险管理系统 开源项目

项目编号: S 031 ,文末获取源码。 \color{red}{项目编号:S031,文末获取源码。} 项目编号:S031,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 系统配置维护2.2 系统参保管理2.3 大…

腾讯云轻量数据库性能如何?轻量数据库租用配置价格表

腾讯云轻量数据库测评,轻量数据库100%兼容MySQL 5.7和8.0,腾讯云提供1C1G20GB、1C1G40GB、1C2G80GB、2C4G120GB、2C8G240GB五种规格轻量数据库,腾讯云百科txybk.com分享腾讯云轻量数据库测评、轻量数据库详细介绍、特性、配置价格和常见问题解…

六、文件上传漏洞

下面内容部分:参考 一、文件上传漏洞解释 解释:文件上传漏洞一般指的就是用户能够绕过服务器的规则设置将自己的木马程序放置于服务器实现远程shell(例如使用蚁剑远程连接),常见的木马有一句话木马(php) 无需启用sho…

Gin框架源码解析

概要 目录 Gin路由详解 Gin框架路由之Radix Tree 一、路由树节点 二、请求方法树 三、路由注册以及匹配 中间件含义 Gin框架中的中间件 主要讲述Gin框架路由和中间件的详细解释。本文章将从Radix树(基数树或者压缩前缀树)、请求处理、路由方法树…