前言:本文是以文本分类的迁移学习任务为例,对风险分析模型的整体框架流程做梳理。
目录
- 1. LearnRisk
- 1.1 motivatio
- 1.2 overall
- 2. LearnRisk-TC
- 2.1 构造风险特征
- 2.1.1 risk metric
- 2.1.2 risk feature
- 2.2 构建风险模型
- 2.3 训练风险模型
- 2.4 微调base model
1. LearnRisk
1.1 motivatio
- 传统的DNN结果有一定错误的风险
- 迁移学习目标域的标签数据难以获得,通常只有少量有标签样本
1.2 overall
风险分析整体分为三步:
- 构造风险特征
- 构建风险模型
- 训练风险模型
2. LearnRisk-TC
- 核心思路:在源域上训练好一个base model后,用目标域的少量有标签样本(如valid dataset)去训练风险模型,最后用无标签的test dataset重新微调base model。
- 主要流程:
(1)源数据集训练base model;
(2)有标签的目标域的验证数据集构建一批风险特征(决策树规则);
(3)构建每个类别的正态分布:对每个风险特征构建一个正态分布(u是先验, σ {\sigma} σ后验),风险特征加权和作为每个类别的正态分布;
(4)训练风险模型:损失函数的目标是实现正确的风险排序(风险由高到低);
(5)利用无标签的目标域的测试数据集进行base model的微调。
2.1 构造风险特征
2.1.1 risk metric
文章中将risk metric主要分为两类,statistics-based risk metrics和DNN-based risk metrics。对于每一个risk metric,都会生成一个长度为N的一维向量,N为总的类别数。假设目标域的测试数据集大小为Q,每一个文本都会有X个risk metric,最终共生成了Q*X个risk metric。
-
statistics-based risk metrics
文章中构建了一种新的统计特征,计算公式如下:
其中,p为超参
(1) C H I n e w = C H I ∗ α {CHI_{new} = CHI * {\alpha}} CHInew=CHI∗α, 各项解释如下:
(2) T F − I D F n e w = T F n e w ∗ I D F n e w ∗ β ∗ λ {TF-IDF_{new} = TF_{new} * IDF_{new} * {\beta} * {\lambda}} TF−IDFnew=TFnew∗IDFnew∗β∗λ,各项解释如下:
![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/58671c4563754de7a1ddcefc4b8d15e2.png)
- DNN-based risk metrics
文章采用了两种模型,bert和textcnn用于提取文档特征,然后使用knn和ccd两种方法计算。
2.1.2 risk feature
文章使用单边决策树来生成risk features,决策树生成的每条规则对应了一个risk feature,如下:
最终会生成一批决策树,得到一批规则,即一批risk feature。
2.2 构建风险模型
对于多分类问题,假设共N个类别,对于每个类别分别构建一个风险模型。
类别i的风险模型构建的主要流程:
(1)对每个风险特征分别建立一个正态分布
N
(
u
,
σ
2
)
{N(u, \sigma^2)}
N(u,σ2)。
u是先验知识:
u
=
n
/
m
{u=n/m}
u=n/m,n是风险分析的训练数据集(即目标域的验证数据集)中成功匹配该风险特征的文档数,m是训练数据集中属于该类别的总文档数。
σ
{\sigma}
σ是后验知识,待模型训练得到。
注意:不同类别对应的各个风险特征的正态分布并不一样。
(2)求所有风险特征的加权和作为类别i的正态分布。
所有的风险特征都是一条条规则,指向的是匹配某个类别,假设共5个风险特征,第2,3,5条风险特征指向的是匹配类别i,则类别i的特征向量
x
i
{x_i}
xi为(0,1,1,0,1)。类别i的权重向量为
w
i
w_i
wi则i的正态分布计算如下:
u
i
=
x
i
(
w
i
∗
u
f
)
{u_i = x_i (w_i * u_f)}
ui=xi(wi∗uf)
σ
i
2
=
x
i
(
w
i
∗
σ
f
2
)
{\sigma_i^2 = x_i (w_i * \sigma_f^2)}
σi2=xi(wi∗σf2)
其中
u
f
u_f
uf代表的是一个长度为m的一维向量,即每个风险特征的u,
σ
f
2
\sigma_f^2
σf2同理。
2.3 训练风险模型
风险模型的训练目标是排序,即能够让高风险的文档正确的排在低风险文档的前面,或者说能让分类错误的文档排在分类正确的文档前面。
损失构建如下:
2.4 微调base model
核心思想:用base model对目标域的测试数据集做预测,求每个文本的预测类别,然后用训练好的风险模型去计算该类别的风险值,对base model设计一个新的损失函数进行微调。
损失函数如下: