A Self-Attentive model for Knowledge Tracing论文笔记和代码解析

news2024/11/26 14:36:12

原文链接和代码链接A Self-Attentive model for Knowledge Tracing | Papers With Code

motivation:传统方法面临着处理稀疏数据时不能很好地泛化的问题。

本文提出了一种基于自注意力机制的知识追踪模型 Self Attentive Knowledge Tracing (SAKT)。其本质是用 Transformer 的 encoder 部分来做序列任务。具体从学生过去的活动中识别出与给定的KC相关的KC,并根据所选KC相对较少的KC预测他/她的掌握情况。由于预测是基于相对较少的过去活动,它比基于RNN的方法更好地处理数据稀疏性问题。

模型结构

 

输入编码

交互信息 x_{t}=\left (e_{t},r_{t} \right ) 通过公式y_{t}=e_{t}+r_{t}\times E 转变成一个数字,总量为 2E。

我们 用Interaction embedding matrix 训练一个交互嵌入矩阵,M\in R^{2E \times d}
被用来为序列中的每个元素s_{i}

 Exercise 编码 利用 exercise embedding matrix训练练习嵌入矩阵,E\in R^{E \times d},每行代表一个题目ei 

Position Encoding

自动学习P \in R^{E \times d},n 是序列长度。

最终编码层的输出如下

        
        self.qa_embedding = nn.Embedding(
            2 * n_skill + 2, self.qa_embed_dim, padding_idx=2 * n_skill + 1
        )
        self.pos_embedding = nn.Embedding(self.max_len, self.pos_embed_dim)

#定义

#计算
qa = self.qa_embedding(qa)
pos_id = torch.arange(qa.size(1)).unsqueeze(0).to(self.device)
pos_x = self.pos_embedding(pos_id)
qa = qa + pos_x

 注意力机制

Self-attention layer采用scaled dotproduct attention mechanism。

Self-attention的query、key和value分别为:

self.multi_attention = nn.MultiheadAttention(
            embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout
        )
attention_out, _ = self.multi_attention(q, qa, qa, attn_mask=attention_mask)

 

 Causality:因果关系也是mask 避免未来交互对现在的

Feed Forward layer

用一个简单的前向传播网络将self-attention的输出进行前向传播。

class FFN(nn.Module):
    def __init__(self, state_size=200, dropout=0.2):
        super(FFN, self).__init__()
        self.state_size = state_size
        self.dropout = dropout
        self.lr1 = nn.Linear(self.state_size, self.state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(self.state_size, self.state_size)
        self.dropout = nn.Dropout(self.dropout)

    def forward(self, x):
        x = self.lr1(x)
        x = self.relu(x)
        x = self.lr2(x)
        return self.dropout(x)

剩余连接:剩余连接[2]用于将底层特征传播到高层。因此,如果低层特征对于预测很重要,那么剩余连接将有助于将它们传播到执行预测的最终层。在KT的背景下,学生尝试练习属于某个特定概念的练习来强化这个概念。因此,剩余连接有助于将最近解决的练习的嵌入传播到最终层,使模型更容易利用低层信息。在自我注意层和前馈层之后应用剩余连接。

层标准化:在[1]中,研究表明,规范化特征输入有助于稳定和加速神经网络。我们在我们的架构中使用了层规范化目的层在自我注意层和前馈层也应用了归一化。

       attention_out = self.layer_norm1(attention_out + q)# Residual connection ; added excercise embd as residual because previous ex may have imp info, suggested in paper.
        attention_out = attention_out.permute(1, 0, 2)

        x = self.ffn(attention_out)
        x = self.dropout_layer(x)
        x = self.layer_norm2(x + attention_out)# Layer norm and Residual connection

Prediction layer

self-attention的输出经过前向传播后得到矩阵F,预测层是一个全连接层,最后经过sigmod激活函数,输出每个question的概率

 模型的目标是预测用户答题的对错情况,利用cross entropy loss计算(y_true, y_pred)

实验

# -*- coding:utf-8 -*-
"""
    Reference: A Self-Attentive model for Knowledge Tracing (https://arxiv.org/abs/1907.06837)
"""

import torch
import torch.nn as nn
import deepkt.utils
import deepkt.layer

def future_mask(seq_length):
    mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype("bool")
    return torch.from_numpy(mask)


class FFN(nn.Module):
    def __init__(self, state_size=200, dropout=0.2):
        super(FFN, self).__init__()
        self.state_size = state_size
        self.dropout = dropout
        self.lr1 = nn.Linear(self.state_size, self.state_size)
        self.relu = nn.ReLU()
        self.lr2 = nn.Linear(self.state_size, self.state_size)
        self.dropout = nn.Dropout(self.dropout)

    def forward(self, x):
        x = self.lr1(x)
        x = self.relu(x)
        x = self.lr2(x)
        return self.dropout(x)

class SAKTModel(nn.Module):
    def __init__(
        self, n_skill, embed_dim, dropout, num_heads=4, max_len=64, device="cpu"
    ):
        super(SAKTModel, self).__init__()
        self.n_skill = n_skill
        self.q_embed_dim = embed_dim
        self.qa_embed_dim = embed_dim
        self.pos_embed_dim = embed_dim
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.num_heads = num_heads
        self.max_len = max_len
        self.device = device

        self.q_embedding = nn.Embedding(
            n_skill + 1, self.q_embed_dim, padding_idx=n_skill
        )
        self.qa_embedding = nn.Embedding(
            2 * n_skill + 2, self.qa_embed_dim, padding_idx=2 * n_skill + 1
        )
        self.pos_embedding = nn.Embedding(self.max_len, self.pos_embed_dim)

        self.multi_attention = nn.MultiheadAttention(
            embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout
        )

        self.key_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.value_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.query_linear = nn.Linear(self.embed_dim, self.embed_dim)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim)
        self.dropout_layer = nn.Dropout(self.dropout)
        self.ffn = FFN(self.embed_dim)
        self.pred = nn.Linear(self.embed_dim, 1, bias=True)

    def forward(self, q, qa):
        qa = self.qa_embedding(qa)
        pos_id = torch.arange(qa.size(1)).unsqueeze(0).to(self.device)
        pos_x = self.pos_embedding(pos_id)
        qa = qa + pos_x
        q = self.q_embedding(q)

        q = q.permute(1, 0, 2)
        qa = qa.permute(1, 0, 2)

        attention_mask = future_mask(q.size(0)).to(self.device)
        attention_out, _ = self.multi_attention(q, qa, qa, attn_mask=attention_mask)
        attention_out = self.layer_norm1(attention_out + q)# Residual connection ; added excercise embd as residual because previous ex may have imp info, suggested in paper.
        attention_out = attention_out.permute(1, 0, 2)

        x = self.ffn(attention_out)
        x = self.dropout_layer(x)
        x = self.layer_norm2(x + attention_out)# Layer norm and Residual connection
        x = self.pred(x)

        return x.squeeze(-1), None

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/48200.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【博学谷学习记录】超强总结,用心分享|架构师-Spring核心组件介绍

文章目录一、Bean组件二、Context组件一、Bean组件 Bean组件定义在Spring的org.springframework.beans包下,解决了以下几个问题: 这个包下的所有类主要解决了三件事: Bean的定义 Bean的创建 Bean的解析 Spring Bean的创建是典型的工厂模式…

centos7安装字体和中文字体

文章目录1.查看自己的操作系统2. 安装字体库3.安装更新字体命令4.查看中文字体5.新建目录6.拷贝 fonts.scale 和windows上的字体到chinese文件夹中.将字体文件放在chinese目录7.授权,该目录及其下所有文件需要有执行权限8.重新建立字体索引、更新缓存9.查看字体是否…

信号包络及其提取方法(Matlab)

信号包络及其提取方法 介绍信号包络,以及信号包络的提取方法。 一、信号包络 直观地从时域来讲,信号包络就是信号波形的轮廓。 本质上,信号包络是带通信号的基带部分。 一个实带通信号记为x(t),将它频谱的中心频点搬移到零频…

数据结构初阶--栈和队列(讲解+类模板实现)

栈的概念和结构 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶,另一端称为栈底。栈中的数据元素遵守后进先出LIFO(Last In First Out)加粗样式的原则。 入…

debug - 用Procmon记录目标程序启动后的操作

文章目录debug - 用Procmon记录目标程序启动后的操作概述笔记备注ENDdebug - 用Procmon记录目标程序启动后的操作 概述 想看看 D:\Cadence\SPB_17.4\tools\bin\Capture.exe 开始页中的recent projects 从哪里读的. 想用Procmon记录Capture.exe启动后的动作, 再记录成文本日志…

【Spring】一文带你吃透AOP面向切面编程技术(上篇)

个人主页: 几分醉意的CSDN博客_传送门 文章目录💖AOP概念✨AOP作用✨AOP术语✨什么时候需要用AOP💖Aspectj框架介绍✨Aspectj的5个通知注解✨Aspectj切入点表达式✨前置通知Before💖投票传送门(欢迎伙伴们投票&#xf…

Nginx加载Lua脚本lua_shared_dict缓存

1、介绍 lua_shared_dict缓存是nginx为lua提供的一个多进程共享空间,为了避免多进程修改造成脏数据,lua_shared_dict修改数据是用锁来实现的。这样就会有qps访问瓶颈变小的问题。这是性能缺点。 2、使用 1)首先在nginx.conf里申请一块共享…

数据分享|PYTHON用决策树分类预测糖尿病和可视化实例

全文下载链接:http://tecdat.cn/?p23848在本文中,决策树是对例子进行分类的一种简单表示。它是一种有监督的机器学习技术,数据根据某个参数被连续分割。决策树分析可以帮助解决分类和回归问题(点击文末“阅读原文”获取完整代码数…

大一学生WEB前端静态网页——唯品会1页 包含hover效果

⛵ 源码获取 文末联系 ✈ Web前端开发技术 描述 网页设计题材,DIVCSS 布局制作,HTMLCSS网页设计期末课程大作业 | 在线商城购物 | 水果商城 | 商城系统建设 | 多平台移动商城 | H5微商城购物商城项目 | HTML期末大学生网页设计作业,Web大学生网页 HTML&a…

SpringCloud:使用Nacos作为配置中心

目录 一、nacos配置中心简介 二、nacos配置实时更新及同一个微服务不同环境的差异化配置 准备工作 针对商品微服务实现实时更新(以商品微服务为例) 三、nacos同一个微服务不同环境的共享配置 同一个微服务修改配置才能访问不同环境 四、nacos不同微…

【JavaEE】MyBatis

文章目录1.MyBatis介绍2.MyBatis快速入门3.Mapper代理开发4.MyBatis核心配置文件5.配置文件完成增删改查5.1 查询5.2 添加/修改5.3 删除6.MyBatis参数传递7.注解完成增删改查1.MyBatis介绍 1.什么是MyBatis? MyBatis是一款优秀的 持久层框架,用于简化JDBC开发MyBat…

STC 51单片机46——看门狗测试

#include <reg52.h> sfr WDT_CONTR 0xE1; //声明WDT_CONTR void delay(void){ //改变延时长度&#xff0c;可以观测是否触发看门狗 unsigned char i,j,k; for(i0;i<255;i) for(j0;j<255;j) for(k0;k<255;k); } void…

图神经网络

前言 图与图的表示 图是由一些点和一些线构成的&#xff0c;能表示一些实体之间的关系&#xff0c;图中的点就是实体&#xff0c;线就是实体间的关系。如下图&#xff0c;v就是顶点&#xff0c;e是边&#xff0c;u是整张图。attrinbutes是信息的意思&#xff0c;每个点、每条…

MFC界面控件BCGControlBar v33.3 - 升级Ribbon Bar自定义功能

BCGControlBar库拥有500多个经过全面设计、测试和充分记录的MFC扩展类。 我们的组件可以轻松地集成到您的应用程序中&#xff0c;并为您节省数百个开发和调试时间。 该版本包含了增强的Ribbon自定义、新的日期/时间数字指示器、带有文本对齐的组控件、多行支持以及其他一些新功…

第二证券|12月A股投资方向来了!这些板块已先涨为敬

日前&#xff0c;我国银河、信达证券、中泰证券、安全证券等多家券商连续发布12月A股月度出资组合。全体上券商对后市持活跃情绪&#xff0c;以为当时商场处于震动磨底装备区间&#xff0c;商场动摇并不影响“暖冬行情”的延续&#xff0c;一些活跃的券商以为后市有望走出季度级…

R语言rcurl抓取问财财经搜索网页股票数据

问财财经搜索是同花顺旗下的服务之一,主要针对上市公司的公告、研报、即时新闻等提供搜索及参考资料。相对于其他股票软件来说&#xff0c;一个强大之处在于用自然语言就可以按你指定的条件进行筛选。而大部分现有的行情软件支持的都不是很好&#xff0c;写起来就费尽心思&…

Nginx加载Lua脚本链接mysql

1、nginx加载lua脚本方法可参我的这篇文章 Nginx安装Openresty加载Lua代码_IT东东歌的博客-CSDN博客 2、测试代码 官网 https://github.com/openresty/lua-resty-mysql local mysql require "resty.mysql" local db, err mysql:new() if not db then ngx.sa…

Django 第四章 模版系统详解(ORM数据模型-使用mysql数据库增删改查)

djiango模版系统&#xff1a; 用于自动渲染一个文本文件&#xff0c;一般用于HTML页面&#xff0c;模版引擎渲染的最终HTML内容返回给客户端浏览器 模版系统分成两部分 静态部分&#xff1a; 例如html css .js 动态部分 djiango 模版语言&#xff0c;类似于jinja语法变量定义&…

SpringCloud 组件Gateway服务网关【全局过滤器】

目录 1&#xff0c;全局过滤器 1.1&#xff1a;全局过滤器作用 1.2&#xff1a;自定义全局过滤器 1.3&#xff1a;过滤器执行顺序、 2&#xff1a;跨域问题 2.1&#xff1a;什么是跨域问题 2.2&#xff1a;示例跨域问题 2.3&#xff1a;解决跨域问题 1&#xff0c;全局…

python将CSV文件(excel文件)按固定行数拆分成小文件

最近接到一个需求&#xff0c;就是把非常大的CSV文件&#xff0c;电脑根本打不开&#xff08;或者打开也不能完全展现所有的数据&#xff09;&#xff0c;以每 80万(不够80万行的也独自成为一个单独的文件) 行进行拆分成一个小文件&#xff0c;各位小伙伴在日常工作中有没有遇到…