AI+生命科学第二课:入门RNA和特征学习 【Datawhale AI夏令营】

news2024/9/27 12:13:22

教程链接:Task2:深入理解赛题,入门RNN和特征工程

打卡;https://linklearner.com/activity/12/4/4

在大佬讲解的基础上,带上一些我自己的理解

分析训练流程

从原始特征到输入模型

初始数据转换为tensor后,将x通过forward传入嵌入层

  def forward(self, x):
        # 将输入序列传入嵌入层
        embedded = [self.embedding(seq) for seq in x]
        outputs = []
        
        # 对每个嵌入的序列进行处理
        for embed in embedded:
            x, _ = self.gru(embed)  # 传入GRU层
            x = self.dropout(x[:, -1, :])  # 取最后一个隐藏状态,并进行dropout处理
            outputs.append(x)
        
        # 将所有序列的输出拼接起来
        x = torch.cat(outputs, dim=1)
        # 传入全连接层
        x = self.fc(x)
        # 返回结果
        return x.squeeze()

x作为输入数据,究竟是什么形式?

x中的数据以唯一标识的形式存在

在下述图中,inputs和x是相同的,我们直接分析inputs就好

  • inputs作为一个三维张量

    • 关于什么是张量(之前学习时候的一个小例子张量)

    • 第0维表示元素的个数,由inputs[0].shape可知每个元素的形状为64*25

    • 第1维为64表示batch的大小,batch就是在一次迭代中用于训练模型的一组样本,这里的64就意味着有64组样本提供训练

    • 第2维是序列,代表每一组样本中的内容,25表示序列的数量

      • 在inputs[0][0]中可以得到每一组样本的内容,每一个序列都代表序列编码后每一位的唯一标识

        • 然而只有前七位是有效数,其余数是为了保证样本长度相同进行的填充,0填充无意义

唯一标识如何对应路径

想要标识路径,那么我们就需要对数据进行加工,将其进行编号,形成一个“字典”,达到一一映射的关系

# 创建词汇表
all_tokens = []
for col in columns:
    for seq in train_data[col]:
        if ' ' in seq:  # 修饰过的序列
            all_tokens.extend(seq.split())
        else:
            all_tokens.extend(tokenizer.tokenize(seq))
vocab = GenomicVocab.create(all_tokens, max_vocab=10000, min_freq=1)

拿到词汇表以后,下一步获得序列的长度

max_len = max(max(len(seq.split()) if ' ' in seq else len(tokenizer.tokenize(seq)) 
                    for seq in train_data[col]) for col in columns)

接着将字典和序列一 一对应,在lloader获取样本的时候把token转为索引

def __getitem__(self, idx):
    # 获取数据集中的第idx个样本
    row = self.df.iloc[idx]  # 获取第idx行数据
    
    # 对每一列进行分词和编码
    seqs = [self.tokenize_and_encode(row[col]) for col in self.columns]
    if self.is_test:
        # 仅返回编码后的序列(测试集模式)
        return seqs
    else:
        # 获取目标值并转换为张量(仅在非测试集模式下)
        target = torch.tensor(row['mRNA_remaining_pct'], dtype=torch.float)
        # 返回编码后的序列和目标值
        return seqs, target

def tokenize_and_encode(self, seq):
    if ' ' in seq:  # 修饰过的序列
        tokens = seq.split()  # 按空格分词
    else:  # 常规序列
        tokens = self.tokenizer.tokenize(seq)  # 使用分词器分词
    
    # 将token转换为索引,未知token使用0(<pad>)
    encoded = [self.vocab.stoi.get(token, 0) for token in tokens]
    # 将序列填充到最大长度
    padded = encoded + [0] * (self.max_len - len(encoded))
    # 返回张量格式的序列
    return torch.tensor(padded[:self.max_len], dtype=torch.long)

此时,对于某一行数据,其两个特征分别为AGCCUUAGCACA和u u g g u u Cf c,假设整个数据集对应token编码后序列的最大长度为10,那么AGCCUUAGCACA对应的特征就可能是

  • [25, 38, 25, 24, 0, 0, 0, 0, 0, 0]

而u u g g u u Cf c对应的特征为

  • [65, 65, 63, 63, 65, 65, 74, 50, 0, 0]

那么假设batch的大小为16,此时forword函数的x就会是两个列表,每个列表的tensor尺寸为16 * 10

最后得到一个张量

开始RNN模型训练

将上一阶段的张量输入模型,这里的“__init__”方法对模型进行了初始化,后续可以通过修改这些参数来调整模型

这里通过前向传播的方式进行参数的传入

什么是前向传播?是不是对应着还有反向传播?

  简单来说:

  • 前向传播负责将输入数据传递到输出层得到预测结果。

  • 反向传播负责根据预测结果与真实值之间的差异来调整权重和偏差,以逐斩提高模型的准确性和性能。

  • developer.baidu.com

传入后我们通过embedding映射将索引离散的符号映射到连续的向量空间中

  • 在这里,越相似的符号在向量空间中距离越近

在从GRU模型输出后,x = self.dropout(x[:, -1, :])使得输出变为了BatchSize * (hidden_dim * 2),此处取了序列最后一个位置的输出数据(注意RNN网络的记忆性),这里的2是因为bidirectional参数为True,随后x = torch.cat(outputs, dim=1)指定在第二个维度拼接后,通过全连接层再映射为标量,因此最后经过squeeze(去除维数为1的维度)后得到的张量尺寸为批大小,从而可以后续和target值进行loss计算,迭代模型。

class SiRNAModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=200, hidden_dim=256, n_layers=3, dropout=0.5):
        super(SiRNAModel, self).__init__()
        
        # 初始化嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        # 初始化GRU层
        self.gru = nn.GRU(embed_dim, hidden_dim, n_layers, bidirectional=True, batch_first=True, dropout=dropout)
        # 初始化全连接层
        self.fc = nn.Linear(hidden_dim * 4, 1)  # hidden_dim * 4 因为GRU是双向的,有n_layers层
        # 初始化Dropout层
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # 将输入序列传入嵌入层
        embedded = [self.embedding(seq) for seq in x]
        outputs = []
        
        # 对每个嵌入的序列进行处理
        for embed in embedded:
            x, _ = self.gru(embed)  # 传入GRU层
            x = self.dropout(x[:, -1, :])  # 取最后一个隐藏状态,并进行dropout处理
            outputs.append(x)
        
        # 将所有序列的输出拼接起来
        x = torch.cat(outputs, dim=1)
        # 传入全连接层
        x = self.fc(x)
        # 返回结果
        return x.squeeze()
将输出与测试集做loss

差值返回到模型中,通过这个步骤计算我们所求出的结果与正确结果的差异,将这个达到可以不断优化模型的目的

损失函数是什么?有什么用?有哪些损失函数?

一言以蔽之,损失函数(loss function)就是用来度量模型的预测值f(x)与真实值Y的差异程度的运算函数,它是一个非负实值函数,通常使用L(Y, f(x))来表示,损失函数越小,模型的鲁棒性就越好。

损失函数通过计算损失值(预测值与真实值之间的差值),然后将这个损失值的各个参数让模型通过反向传播更新各个参数,进而达到降低损失值的目的。

按照损失函数的作用将损失函数分为基于距离度量的损失函数和基于概率分布度量的损失函数。(损失函数(Loss Function)

BCE和CE:二者都是损失函数,BCE用于“是不是”问题,例如LR输出概率,明天下雨or不下雨的概率;CE用于“是哪个”问题,比如多分类问题。(BCE和CE的区别)

扩展一点:
什么是前向传播

前向传播(Forward Propagation)是神经网络中的一个重要概念,指的是数据在网络中从输入层经过隐藏层传递到输出层的过程。在这个过程中,输入数据被转换为输出预测,每层神经元都会对数据进行一定的线性和非线性变换。具体来说,前向传播包括以下几个步骤:

  1. 输入层:

    1. 输入数据进入网络。

  2. 隐藏层:

    1. 输入数据经过一系列的隐藏层处理,这些层通常包括线性变换(比如矩阵乘法),以及非线性激活函数的应用。

  3. 输出层:

    1. 最终的数据通过输出层产生预测结果。

前向传播的目的

  • 生成预测:前向传播主要用于生成给定输入数据的预测输出。

  • 计算损失:在训练过程中,前向传播产生的预测值会与实际的目标值进行比较,以此来计算损失函数的值。

  • 评估模型:在模型训练完成后,可以通过前向传播来评估模型在新数据上的表现。

什么是反向传播

反向传播(Backpropagation)是一种在神经网络训练过程中用于优化权重的方法。它基于链式法则来计算损失函数关于每个权重的梯度,从而使得我们能够使用梯度下降方法来调整权重,最小化损失函数。反向传播的过程主要包括以下几个步骤:

  1. 计算损失:

    1. 首先进行前向传播,得到预测值,并计算损失函数。

  2. 误差反向传播:

    1. 从输出层开始,计算损失函数相对于每一层输出的梯度。

    2. 利用链式法则,逐步向前计算每一层的梯度。

  3. 权重更新:

    1. 使用计算出的梯度来更新每一层的权重,以便在下一次迭代中减小损失。

前向传播与反向传播的关系
  • 前向传播是用于生成预测值并计算损失的步骤。

  • 反向传播是用于根据损失函数的梯度来更新权重的步骤。

在神经网络的训练过程中,这两个步骤通常是交替进行的:

  1. 前向传播生成预测值并计算损失。

  2. 反向传播根据损失计算梯度并更新权重。

  3. 重复上述步骤,直到达到预定的停止条件(例如最大迭代次数或者满足某个收敛标准)。

这样,随着训练的进行,网络的权重逐渐调整,使得网络的预测能力不断提高。

总结来说,前向传播和反向传播是神经网络训练中两个基本且互补的过程,前者用于生成预测并计算损失,后者则用于基于损失函数的梯度来更新网络的权重。

什么是RNN

RNN,全称为递归神经网络(Recurrent Neural Network),是一种人工智能模型,特别擅长处理序列数据。它和普通的神经网络不同,因为它能够记住以前的数据,并利用这些记忆来处理当前的数据。想象你在读一本书。你在阅读每一页时,不仅仅是单独理解这一页的内容,还会记住前面的情节和信息。这些记忆帮助你理解当前的情节并预测接下来的发展。这就是 RNN 的工作方式。假设你要预测一个句子中下一个单词是什么。例如,句子是:“我今天早上吃了一个”。RNN 会根据之前看到的单词(“我今天早上吃了一个”),预测下一个可能是“苹果”或“香蕉”等。它记住了之前的单词,并利用这些信息来做出预测。

  • RNN 在处理序列数据时具有一定的局限性:

    • 长期依赖问题:RNN 难以记住和利用很久以前的信息。这是因为在长序列中,随着时间步的增加,早期的信息会逐渐被后来的信息覆盖或淡化。

    • 梯度消失和爆炸问题:在反向传播过程中,RNN 的梯度可能会变得非常小(梯度消失)或非常大(梯度爆炸),这会导致训练过程变得困难。

  • LSTM 的改进

    • LSTM 通过引入一个复杂的单元结构来解决 RNN 的局限性。LSTM 单元包含三个门(输入门、遗忘门和输出门)和一个记忆单元(细胞状态),这些门和状态共同作用,使 LSTM 能够更好地捕捉长期依赖关系。

      • 输入门:决定当前输入的信息有多少会被写入记忆单元。

      • 遗忘门:决定记忆单元中有多少信息会被遗忘。

      • 输出门:决定记忆单元的哪些部分会作为输出。

    • 通过这些门的控制,LSTM 可以选择性地保留或遗忘信息,从而有效地解决长期依赖和梯度消失的问题。

  • GRU 的改进

    • GRU 是 LSTM 的一种简化版本,它通过合并一些门来简化结构,同时仍然保留了解决 RNN 局限性的能力。GRU 仅有两个门:更新门和重置门。

      • 更新门:决定前一个时刻的状态和当前输入信息的结合程度。

      • 重置门:决定忘记多少之前的信息。

    • GRU 的结构更简单,计算效率更高,同时在许多应用中表现出与 LSTM 类似的性能。

我们在pytorch的GRU文档中可以找到对应可选的参数信息,我们需要特别关注的参数如下,它们决定了模型的输入输出的张量维度

  • input_size(200)

  • hidden_size(256)

  • bidirectional(True)

假设输入的BatchSize为16,序列最大长度为10,即x尺寸为16 * 10 * 200,那么其输出的张量尺寸为 16 * 10 * (256 * 2)。

特征工程

由于我们前面提交的baseline确实分数不是很好看,所以我们引入了——数据的特征工程

找数据中“不一样”的部分,这个不一样的就是“特征”——大量数据中有用的数据,一样的没有学习的意义。

所谓特征工程(我的浅薄理解)

就是把一堆数据中找出来最有代表性、最有特征、最能区分出这大量数据的不同的一个工程

在dw大佬给的讲解中,介绍的是在表格数据上进行特征工程,这里我们存个疑:

为什么特征工程要在表格数据上进行?

下面就是这里特征工程的几种方法

  • 类别型变量

  • 可能的时间特征构造

  •  

     
    siRNA_duplex_id_values = df.siRNA_duplex_id.str.split("-|\.").str[1].astype("int")
    • 我们看到所有的数据中都有AD前缀,所以我们处理的时候就能把AD去掉

  • 根据序列模式提取特征

    • 假设siRNA的序列为ACGCA...,此时我们可以根据上一个task中提到的rna背景知识,对碱基的模式进行特征构造

  • 包含某些单词

    • 可能在单词中包含“Cells”或者其他的一些共性单词,这些数据就能归于一类

这里我们就能回答前面的疑问:

在上述的方法中我们可以发现进行特征工程的前提是在我们观测过大量数据后进行的一些操作,相较而言,表格数据更容易我们在直观上进行观测后进一步进行提取特征

最后我们使用基于lightgbm的baseline进行测试

模型参数如下所示

train_data = lgb.Dataset(X_train, label=y_train)
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)

def print_validation_result(env):
    result = env.evaluation_result_list[-1]
    print(f"[{env.iteration}] {result[1]}'s {result[0]}: {result[2]}")

params = {
    "boosting_type": "gbdt",
    "objective": "regression",
    "metric": "root_mean_squared_error",
    "max_depth": 7,
    "learning_rate": 0.02,
    "verbose": 0,
}

gbm = lgb.train(
    params,
    train_data,
    num_boost_round=15000,
    valid_sets=[test_data],
    callbacks=[print_validation_result],
)
经过一些调试

对下面一些参数调整,可以一定程度提法哦模型的准确率:

  • max_depth:这是树的最大深度。增加这个值可能会让模型更复杂,从而有可能提高准确率,但也可能导致过拟合。你可以尝试不同的深度值,比如5、10或15,来看看哪个值效果最好。

  • learning_rate:这是学习率,它控制着每一步的更新幅度。较小的学习率需要更多的迭代次数来收敛,但可能有助于找到更精确的解。你可以尝试将学习率设置为0.01或0.05,并观察效果。

  • num_boost_round:这是迭代次数。增加迭代次数可能会提高准确率,但也会增加计算成本。你可以根据模型的训练情况和准确率来调整这个值。

depth取10,学习率取0.01,迭代次数取20000,得分0.7657

max_depth

learning_rate

num_boost_round

得分

7

0.02

15000

0.7661

10

0.01

20000

0.7657

7

0.01

20000

0.7600

10

0.005

15000

0.7569

10

0.02

15000

0.7692

10

0.02

18000

0.7707

经验:

保留最大深度10,学习率设置为0.005太小,每次移动的距离太短

下次目标:

保留最大深度和0.02的学习深度,找到合适的迭代次数(20000的数据确实有点大)

QA

新手听不懂

  • 哪里不懂学哪里

怎么调上高分

  • 学task3

路径名称要对应

训练时模型数据需要解压

unzip siRNA_0715.zipV

文中代码部分参考:

Task2:深入理解赛题,入门RNN和特征工程

代码提交官网:

上海科学智能研究院

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1975468.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C++】初识引用

目录 概念引用的五大特性引用在定义时必须初始化一个变量可以有多个引用一个引用可以继续有引用引用了一个实体就不能再引用另一个实体可以对任何类型做引用(包括指针) 引用使用的两种使用场景做参数交换两数单链表头结点的修改 做返回值优化传递返回值 常引用权限放大这时候进…

【前端学习笔记二】CSS基础二

一、颜色模型 1.颜色设置 颜色名称 https://www.w3schools.com/colors/colors_names.asp 这里是一些颜色的名称&#xff08;关键字&#xff09;&#xff0c;比如red、black、green等&#xff0c;可以直接指定名称来设置颜色。名称不区分大小写。 color:red;transparent tr…

OCC 网格化(三)-网格划分算法原理

目录 一、简介 二、基本原理 三、工作流程 四、BRepMesh模块与网格化流程 4.1 BRepMesh 主要组件 4.2 工作流程 4.3 网格生成示例 五、关键参数总结 一、简介 BRepMesh_IncrementalMesh 是一种基于迭代细分的网格划分算法,通过设置线性偏转和角偏转参数,可以生成高精…

利用Python爬虫实现数据收集与挖掘

Python爬虫通常使用requests、selenium等库来发送HTTP请求&#xff0c;获取网页内容&#xff0c;并使用BeautifulSoup、lxml等库来解析网页&#xff0c;提取所需的数据。 以下是一个简单的Python爬虫示例&#xff0c;用于从某个网页上抓取数据&#xff1a; import requests …

免费【2024】springboot 大学生志愿者管理系统的设计与实现

博主介绍&#xff1a;✌CSDN新星计划导师、Java领域优质创作者、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流✌ 技术范围&#xff1a;SpringBoot、Vue、SSM、HTML、Jsp、PHP、Nodejs、Python、爬虫、数据可视化…

Executing an update/delete query,解决Hibernate更新数据库报错

问题描述 在使用Hibernate更新数据库中一条记录时,发送如下错误: javax.persistence.TransactionRequiredException: Executing an update/delete query at org.hibernate.internal.AbstractSharedSessionContract.checkTransactionNeededForUpdateOperation(AbstractShare…

HCIA基础回顾

OSI参考模型 OSI&#xff08;Open System Interconnect&#xff09;参考模型&#xff0c;即为开放式系统互连参考模型。 应用层&#xff1a;人机交互&#xff0c;提供网络服务。 表示层&#xff1a;将逻辑语言转换为二进制语言&#xff0c;定义数据格式。 会话层&#xff1…

Linux 安装gradle

1.下载 下载地址&#xff1a; 下载地址&#xff1a; Gradle | ReleasesFind binaries and reference documentation for current and past versions of Gradle.https://gradle.org/releases/ 2. 解压 unzip gradle-7.6.2-all.zip 3.修改配置文件 #1.进入配置文件 vim /etc/…

【探索Linux】P.44(数据链路层 —— 以太网的帧格式 | MAC地址 | MTU | ARP协议)

阅读导航 引言一、认识以太网二、以太网的帧格式三、MAC地址四、MTU五、ARP协议温馨提示 引言 在深入探讨了网络层的IP协议之后&#xff0c;本文将带领读者进一步深入网络的底层——数据链路层。我们将详细解析以太网的帧格式&#xff0c;这是数据链路层传输数据的基本单元&am…

漏洞复现:Apache solr

目录 漏洞简述 环境搭建 漏洞复现 漏洞检测 漏洞修复 漏洞简述 Apache Solr是一个开源的搜索服务&#xff0c;使用Java编写、运行在Servlet容器的一个独立的全文搜索服务器&#xff0c;是Apache Lucene项目的开源企业搜索平台。 该漏洞是由于没有对输入的内容进行校验&…

深度体验:IntelliJ Idea自带AI Assistant,开启面向AI编程新纪元!

首发公众号&#xff1a; 赵侠客 引言 JetBrains AI Assistant 是 JetBrains 集成开发环境&#xff08;IDE&#xff09;中嵌入的一款智能开发助手工具&#xff0c;旨在通过人工智能技术来简化和提升软件开发过程&#xff0c;我深度体验了一下在IntelliJ IDEA 2024.2 Beta (Ulti…

JAVA项目基于SSM的学生成绩管理系统

目录 一、前言 二、技术介绍 三、项目实现流程 四、论文流程参考 五、核心代码截图 专注于大学生实战开发、讲解和毕业答疑等辅导&#xff0c;获取源码后台 一、前言 二、技术介绍 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端…

全球氢化双酚A (HBPA)市场规划预测:2030年市场规模将接近1330亿元,未来六年CAGR为2.7%

一、引言 随着全球化工行业的持续发展&#xff0c;氢化双酚A (HBPA)作为重要的化工原料&#xff0c;其市场重要性日益凸显。本文旨在探索HBPA行业的发展趋势、潜在商机及其未来展望。 二、市场趋势 全球HBPA市场的增长主要受全球化工行业增加、消费者对高性能化工产品要求提高…

vue3内置组件Suspense

给多个异步组件提供一个统一的状态管理 使用前&#xff0c;有两个loading... 使用后&#xff0c; 只有一个loading... Index.vue: <script setup lang"ts"> import { onMounted, ref, defineAsyncComponent } from vue import { useRouter } from vue-router…

CTF入门教程(非常详细)从零基础入门到竞赛,看这一篇就够了!

一、CTF简介 CTF&#xff08;Capture The Flag&#xff09;中文一般译作夺旗赛&#xff0c;在网络安全领域中指的是网络安全技术人员之间进行技术竞技的一种比赛形式。CTF起源于1996年DEFCON全球黑客大会&#xff0c;以代替之前黑客们通过互相发起真实攻击进行技术比拼的方式。…

什么是网络安全?一文了解网络安全究竟有多重要!

随着互联网的普及和数字化进程的加速&#xff0c;网络安全已经成为我们生活中不可或缺的一部分。然而&#xff0c;很多人对于网络安全的概念仍然模糊不清。 ​ 那么&#xff0c;什么是网络安全&#xff1f;它究竟有多重要呢&#xff1f; 一、网络安全的定义 网络安全是指通过…

【Java】/* JDK 新增语法 */

目录 一、yield 关键字 二、var 关键字 三、空指针异常 四、密封类 五、接口中的私有方法 六、instanceof 一、yield 关键字 yield关键字&#xff0c;从Java13开始引⼊。yield关键字⽤于从case的代码块中返回值。 原本的switch语句写法&#xff1a; public static void …

Notion爆红背后,笔记成了AI创业新共识?

在数字化时代&#xff0c;笔记软件已成为我们记录、整理和创造知识的得力助手。本文将带您深入了解Notion以及其他五个AI笔记产品&#xff0c;它们如何通过AI重塑笔记体验&#xff0c;满足我们快速记录、捕捉灵感、智能整理、情感陪伴和自动撰写文章的五大核心需求。 ———— …

NC 在两个长度相等的排序数组中找到上中位数

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站&#xff0c;这篇文章男女通用&#xff0c;看懂了就去分享给你的码吧。 描述 给定两个递增…

项目比赛项目负责人的汇报技巧:如何让每一次汇报都清晰有力

项目比赛项目负责人的汇报技巧&#xff1a;如何让每一次汇报都清晰有力 前言MECE原则&#xff1a;确保全面性与互斥性SCQA结构&#xff1a;讲一个引人入胜的故事逻辑树思维模型&#xff1a;深入挖掘问题根源STAR法则&#xff1a;展示你的行动与成果PREP模型&#xff1a;清晰表达…