Kaggle Feedback Prize 3比赛总结:两种模型设计思路

news2024/9/21 14:53:58

比赛的目标:本次竞赛的目标是评估8-12年级英语学习者(ELLs)的语言能力。利用英语学习者所写的论文数据集开发出能更好地支持所有学生的能力模型,帮助ELL学生在语言发展方面得到更准确的反馈,并加快教师的评分周期。

方法简介

本次比赛是NLP的回归任务。模型的输入是8-12年级英语学习者的文本,输出的是针对cohesion, syntax, vocabulary, phraseology, grammar, conventions六个方面的打分。分数范围从1.0到5.0,增量为0.5。数据集的结构如下:
在这里插入图片描述

策略一:一般NLP任务的策略

本次比赛中最常见的策略是使用比如deberta等大规模预训练模型,对文本进行特征的提取,再提取特征的基础上进行回归任务。流程如下图所示,其中linear layers可以根据实验自行定义多层并加入不同的激活函数等
假设batch size 为8.在这里插入图片描述

策略二:SVR回归

通过预训练模型得到文本特征的embedding,再针对embedding训练SVR模型进行回归任务。使用SVR回归,可以利用不同的预训练模型得到不同的embedding,再利用这些embedding来做回归任务。这个策略在之前我没有怎么使用过,所以这里会提供一些代码。整体策略图如下:

在这里插入图片描述

代码如下:

# 首先得到word embeddings
for batch in tqdm(embed_dataloader_tr,total=len(embed_dataloader_tr)):
		input_ids = batch["input_ids"].to(DEVICE)
		attention_mask = batch["attention_mask"].to(DEVICE)
		with torch.no_grad():
			model_output = model(input_ids=input_ids,attention_mask=attention_mask)
		sentence_embeddings = mean_pooling(model_output, attention_mask.detach().cpu())
		# Normalize the embeddings
		sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
		sentence_embeddings =  sentence_embeddings.squeeze(0).detach().cpu().numpy()
		all_train_text_feats.extend(sentence_embeddings)
	all_train_text_feats = np.array(all_train_text_feats)
	print('Train embeddings shape',all_train_text_feats.shape)

输出为:
Train embeddings shape (3911, 768)

这是由deberta base作为示例,这个输出是每一个样本最后的代表整个句子的embedding. 这样我们可以得到不同的模型输出的sentence embedding,将它们concatenate到一起,比如用2个不同的embedding size 均为768的模型。最后的输出则为 (3911, 768 * 2)

接下来训练SVR模型

from cuml.svm import SVR
import cuml
from sklearn.multioutput import MultiOutputRegressor

# 首先获取当前fold的训练数据和验证数据
train_folds = folds[folds['fold'] != fold].reset_index(drop=True)  
valid_folds = folds[folds['fold'] == fold].reset_index(drop=True)

#得到对应的embeddings
train_folds_feats = all_train_text_feats[list(train_folds.index),:]
valid_folds_feats = all_train_text_feats[list(valid_folds.index),:]

clf = SVR(C=1)
"""
这里 C = 1是惩罚因子
表示错误项的惩罚系数C越大,即对分错样本的惩罚程度越大,因此在训练样本中准确率越高,但是泛化能力降低;相反,减小C的话,容许训练样本中有一些误分类错误样本,泛化能力强。对于训练样本带有噪声的情况,一般采用后者,把训练样本集中错误分类的样本作为噪声。
"""
# 依次训练针对不同Target的SVR
for i,t in enumerate(target_cols):
        print(t,', ',end='')
        clf = SVR(C=1)
        clf.fit(train_folds_feats, dftr_[t].values)
        ev_preds[:,i] = clf.predict(valid_folds_feats)

#############################################################################
# 或者定义multilabel_regressor,它可以直接训练出针对多输出的模型。
multilabel_regressor = MultiOutputRegressor(clf, n_jobs=-1)  
multilabel_regressor.fit(train_folds_feats, train_folds[[t for t in CFG.target_cols]].values)  
prediction = multilabel_regressor.predict(valid_folds_feats)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/58121.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RestTemplate使用InputStreamResource上传文件

背景 1. 我们应用服务是Spring boot项目,预览服务是我们另一个团队提供的用.net写的,最终使用的是office online来实现文件预览的功能。 2. 我们文件在阿里云OSS存储,我们需要预览文件需要将文件上传至预览服务器。 3. 计划使用RestTemplate…

线程池自查注意点

文章目录线程池自查注意点1、线程池的标准创建方式2、线程池的任务调度流程3、避免使用Executors快捷创建线程池3.1、newSingleThreadExecutor()3.2、newCachedThreadPool()3.3、ScheduledThreadPool()4、避免在方法中创建线程池5、不要盲目使用同步队列6、使用线程池&#xff…

MySQL库的操作

文章目录MySQL库的操作创建数据库创建数据库案例字符集和校验规则查看系统默认字符集以及校验规则查看数据库支持的字符集查看数据库支持的字符集校验规则校验规则对数据库的影响操纵数据库查看数据库显示创建语句修改数据库删除数据库备份和恢复数据库的备份和恢复表的备份和恢…

Cracking the Safes之Linux系统下gdb调试

Cracking Safe是什么 挑战是找出四个保险箱中每个保险箱预期的正确的5个输入集。在运行二进制安全程序时,您需要一次输入一个猜测,如下所示: 其实,就是输入5次,程序会对输入内容进行判断,只有符合程序要求才能成功,任务就是逆向找到正确的字符串!!! 解题思路 反汇…

mac pro M1(ARM)安装:centos8.0虚拟机

0.引言 mac发布了m1芯片,其强悍的性能收到很多开发者的追捧,但是也因为其架构的更换,导致很多软件或环境的安装成了问题,之前我们讲解了如何安装centos7。这次我们接着来看如何在mac m1环境下安装centos8 1.下载 1.1 安装VMwar…

Java基于springboot+vue的五金用品销售购物商城系统 前后端分离

五金用品是当前很多家庭和维修人员必备的工具,他们可以让维修变的更加简单,甚至有很多维修必须有配套的专业工具才能够完成,但是很多时候人们在五金店购买这些五金用品的时候不是价格昂贵就是缺少一些想要的工具,这个是通过开发一…

Guava 对 Map的操作

Guava是google公司开发的一款Java类库扩展工具包,内含了丰富的API,涵盖了集合、缓存、并发、I/O等多个方面。使用这些API一方面可以简化我们代码,使代码更为优雅,另一方面它补充了很多jdk中没有的功能,能让我们开发中更…

C语言刷题(2)

🐒博客名:平凡的小苏 📚学习格言:别人可以拷贝我的模式,但不能拷贝我不断往前的激情 文件拷贝 问题描述: 小蓝正在拷贝一份文件,他现在已经拷贝了 t 秒时间,已经拷贝了 c 字节&#…

解决eclipse导入svn项目报 403Forbidden

解决eclipse导入svn项目报 403Forbidden问题; 首先,产生这个问题的原因:①导入的svn项目没有权限;②上次导入的svn项目在身份验证的时候保存了用户名以及密码;(我遇到这个情况的原因是因为②) …

个人网页制作 个人网页设计作业 HTML CSS个人网页模板 大学生个人介绍网站毕业设计 DW个人主题网页模板下载 个人网页成品代码 个人网页作品下载

🎉精彩专栏推荐👇🏻👇🏻👇🏻 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业…

编码器的电路介绍

编码器的结构特点以及以及使用 对于8线到三线的编码器,一定是八线输入,三线输出,有十一条线 但是74HC148是一个16引脚的芯片 有十一线上述的信号,还有电源线以及地线,此时我们就有了13条线 另外的线则是归于控制信…

kubernetes深入理解之Service

版权声明:本文为CSDN博主「开着拖拉机回家」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 主页地址:开着拖拉机回家的博客_CSDN博客-Linux,Java基础学习,MySql数据库领域博主 目录 一、概述 1.1 Serv…

【salesforce平台基础】-想到啥写点啥

【salesforce基础】-想到啥写点啥1.salesforce架构2.学习过程中常见的几个“公司”🤭3.术语4.平台的用途(举例说明)5.AppExchange(软件应用商店)6.sandbox7.平台入门1.salesforce架构 salesforce是一家云公司&#xf…

7.关于线性回归模型的QA

为什么使用平方损失而不是绝对差值呢? 答: 二者区别不大,但是绝对差值是一个不可导的函数,在零点的时候,绝对差值的导数会有点难求。 损失为什么要求平均? 答:求平均的话,梯度是在…

原语科技宣布完成千万级天使+轮融资,致力于打造隐私计算标准化产品

原语科技 开放隐私计算 开放隐私计算 开放隐私计算OpenMPC是国内第一个且影响力最大的隐私计算开放社区。社区秉承开放共享的精神,专注于隐私计算行业的研究与布道。社区致力于隐私计算技术的传播,愿成为中国 “隐私计算最后一公里的服务区”。 180篇…

【基础算法】多项式三大运算 C++实现

●多项式计算 一维多项式就是包含一个变量的多项式,一个一维多项式示例如下: 一维多项式求值就是对于上述多项式,计算在指定的x处的函数值。一个通用的计算多项式值的算法可以采用递推的方式,可以将上述多项式变为如下的等价形式…

位运算 离散化 区间和算法

目录一、位运算1.1 思路1.1 例题:二进制中1的个数二、离散化2.1 概念2.2 例题:区间和三、合并区间3.1 概念3.2 例题:合并区间一、位运算 1.1 思路 首先知道一个概念:一个正整数的负数等于其按位取反后1 -x ~x 1 举个例子&…

干货——生产型企业的供应商管理系统模板

供应商管理主要是是通过提高供货产品和服务质量及交付能力,缩短企业采购周期和生产成本,从而提升产品核心竞争力。随着如今信息技术的发展,采用先进的信息化手段更能够提升供应商管控能力,实现资源的有效整合,从而加强…

[附源码]计算机毕业设计疫苗药品批量扫码识别追溯系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

测试服务器的udping值

测试服务器的udping值参考下载工具步骤一:在服务器上启动UDP Echo服务(必须)启动**UDP Echo服务**步骤二:在客户端下载UDPing工具步骤三:在客户端测试UDPing值参考 https://help.aliyun.com/document_detail/158771.html UDPing项目地址: h…