Geneformer基于hugging face的transformers实现,具体模型是BertForSequenceClassification,本篇先熟悉该模型。
首先直观看Geneformer的模型架构,基于BERT构建一个文本分类模型,我们直接从预训练的Geneformer加载BERT,num_labels自动设置分类器的Linear为out_features=3:
from transformers import BertForSequenceClassification
# 加载模型
model = BertForSequenceClassification.from_pretrained("./Geneformer/gf-6L-30M-i2048/",
num_labels=3,
output_attentions = False,
output_hidden_states = False).to("cuda")
print(model)
BERT模型为:
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(25426, 256, padding_idx=0)
(position_embeddings): Embedding(2048, 256)
(token_type_embeddings): Embedding(2, 256)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-5): 6 x BertLayer(
(attention): BertAttention(
(self): BertSdpaSelfAttention(
(query): Linear(in_features=256, out_features=256, bias=True)
(key): Linear(in_features=256, out_features=256, bias=True)
(value): Linear(in_features=256, out_features=256, bias=True)
(dropout): Dropout(p=0.02, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=256, out_features=256, bias=True)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=256, out_features=512, bias=True)
(intermediate_act_fn): ReLU()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=256, bias=True)
(LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.02, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=256, out_features=256, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.02, inplace=False)
(classifier): Linear(in_features=256, out_features=3, bias=True)
)
进一步,我们想知道在(classifier)前的输入是什么,为什么从序列变成了单个样本,比如,假设embedding维度为256,序列最大长度为527,但是(classifier)输入一定是(batch_size,1,256)。
这里插入一个注意事项,每个文本的seq_len是不同的,但是在深度学习中,每个批次中的文本序列都是统一处理的,所以同一批次中的seq_len都是被padding或着剪断到同一长度,但是不同批次的seq_len可以不同。
对于全体批次,还要指定max_seq_len,比如Geneformer中的2048,这对密集注意力来说已经非常大。可以打印model.config,发现BertConfig中有:"max_position_embeddings": 2048
,这里说明最大input_seq_len=2048,这是因为BERT 使用位置嵌入(position embeddings)来编码输入序列中单词的位置信息。max_position_embeddings 参数指定了模型能够处理的最大位置数,从而间接限制了输入序列的最大长度。在BERT原文中,max_seq_len=512。
对于last_hidden_state到pooled_output的变化,其实来自BertModel.BertPooler,我们可以查看其实现:
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
可以看到其实是取出了位于序列最开始的cls token的embedding,这里也说明了Geneformer论文中所说的cell embedding其实不是所有token embedding的平均, 而是直接取出了cls token。
我们可以通过Geneformer的数据集生成来发现,cls token是一直存在的,但是论文中没有体现它。
数据集生成主要是给每个基因文本对应到Geneformer的词汇表上,我们可以直接通过h5ad格式生成数据集,但是首先需要完成两个必须的事情,给adata一个 obs[“n_counts”] 和一个 var[“ensembl_id”] ,n_counts是因为Geneformer对adata.X的预处理需要,ensembl_id是为了对应到Geneformer的词汇表,虽然hugging face中说需要把gene name通过Biomart转为Ensembl ID,但其实词汇表中是有gene name的,所以不经过Biomart也可以:
import warnings
warnings.filterwarnings("ignore")
import json
import os
import datasets
import numpy as np
import scanpy as sc
import geneformer
import scipy.sparse as sp
from geneformer import DataCollatorForCellClassification, EmbExtractor, TranscriptomeTokenizer
import geneformer
print(geneformer.__file__)
from transformers import BertForSequenceClassification, Trainer
# adata = sc.read("./10xMultiome/10x-Multiome-Pbmc10k-RNA.h5ad")
adata = sc.read("./10xMultiome/liu.h5ad")
adata.X = sp.csr_matrix(adata.X)
print(adata)
# print(adata.obs_names)
print(adata.var.index)
# print(adata.var_names)
print(adata.X.max())
adata.var["ensembl_id"] = adata.var.index
adata.obs["n_counts"] = adata.X.sum(axis=1)
adata.obs["joinid"] = list(range(adata.n_obs))
# print(adata.obs["joinid"])
h5ad_dir = "./data/h5ad/"
if not os.path.exists(h5ad_dir):
os.makedirs(h5ad_dir)
adata.write(h5ad_dir + "pbmcs.h5ad")
然后进行token化,joinid是每个细胞的id:
# 使用Geneformer的tokenizer
token_dir = "./data/tokenized_data/"
if not os.path.exists(token_dir):
os.makedirs(token_dir)
tokenizer = TranscriptomeTokenizer(custom_attr_name_dict={"joinid": "joinid", "celltype": "celltype"})
tokenizer.tokenize_data(
data_directory=h5ad_dir,
output_directory=token_dir,
output_prefix="pbmc",
file_format="h5ad",
)
然后,我们查看生成的数据集内容:
import warnings
warnings.filterwarnings("ignore")
import json
import os
import datasets
import numpy as np
import scanpy as sc
import geneformer
from geneformer import DataCollatorForCellClassification, EmbExtractor, TranscriptomeTokenizer
from transformers import BertForSequenceClassification, Trainer
import pickle
token_dataset = datasets.load_from_disk("./data/tokenized_data/pbmc.dataset/")
print(type(token_dataset))
print(token_dataset)
print(token_dataset.shape)
print(token_dataset[1])
打印结果为:
<class 'datasets.arrow_dataset.Dataset'>
Dataset({
features: ['input_ids', 'joinid', 'celltype', 'length'],
num_rows: 206
})
(206, 4)
{'input_ids': [2, 10508, 839, 15750, 4619, 15240, 16485, 835, 9018, 16932, 7021, 8129, 1567, 441, 4093, 8426, 3397, 1094, 136, 16531, 85, 4741, 10961, 652, 312, 17726, 11116, 11300, 15458, 7558, 2404, 8678, 767, 7431, 329, 6664, 14331, 14940, 13271, 7131, 2146, 1107, 1224, 11762, 3344, 16594, 10775, 8184, 14134, 4209, 17811, 1606, 11397, 14794, 14029, 7629, 7192, 1390, 4829, 2598, 3205, 8149, 8603, 16904, 1469, 5976, 13117, 17197, 14844, 13367, 11275, 11548, 37, 4948, 3343, 3238, 2528, 9850, 6948, 4523, 7242, 6381, 9624, 4150, 6501, 15332, 11979, 1616, 15637, 87, 604, 3411, 6728, 10148, 9642, 1584, 10677, 4671, 16265, 15597, 5061, 15030, 802, 10452, 13033, 14732, 781, 8364, 4428, 3983, 14187, 16415, 965, 2897, 6294, 16033, 1343, 8509, 9823, 9931, 333, 12907, 5055, 1198, 5799, 7506, 555, 17978, 12804, 14626, 7467, 3330, 2564, 1915, 10662, 9194, 14510, 2037, 11922, 13228, 2303, 8968, 1657, 3713, 6233, 1350, 745, 4632, 5059, 8361, 2054, 6790, 12042, 9005, 15962, 7693, 12797, 3181, 1156, 7697, 2691, 7257, 8200, 4557, 8617, 4713, 13702, 6759, 2256, 413, 2889, 18080, 353, 7588, 9955, 18541, 923, 12866, 13443, 8597, 9815, 1082, 344, 7067, 6601, 6285, 2411, 3931, 1098, 13185, 7054, 3375, 6400, 14165, 6057, 12325, 11708, 10923, 7262, 8618, 5292, 6634, 1334, 3942, 2068, 15232, 7203, 12531, 16188, 1361, 2193, 6627, 4533, 8307, 9190, 11395, 16788, 11491, 4681, 14405, 575, 140, 3221, 4935, 18639, 11075, 810, 1863, 9321, 12038, 7056, 14011, 3879, 4690, 17436, 1479, 11020, 46, 2535, 4765, 11799, 5536, 1582, 515, 2701, 11208, 3519, 12854, 1674, 5999, 4026, 8251, 8811, 4729, 9373, 576, 12649, 4188, 3801, 12965, 12422, 14052, 8591, 1896, 6890, 731, 9507, 12007, 1878, 5672, 16754, 12005, 457, 220, 2688, 8887, 8744, 12491, 17808, 6507, 4141, 6475, 8177, 5376, 16955, 200, 5869, 5060, 4524, 8792, 14100, 5541, 8209, 5725, 3204, 9653, 10878, 4784, 3800, 11724, 4304, 5045, 15100, 11711, 12322, 8090, 10987, 1449, 12712, 1738, 13326, 2440, 7617, 1304, 11165, 805, 78, 4703, 6787, 1272, 4368, 7320, 7601, 14657, 11730, 552, 4481, 7682, 9720, 10096, 6329, 13866, 9137, 3625, 12187, 16818, 7079, 750, 4670, 1524, 3891, 14533, 1396, 6991, 9632, 13282, 11026, 8690, 7450, 4794, 3906, 6608, 12576, 11106, 7229, 5633, 10026, 1793, 14718, 868, 4167, 8348, 5329, 11286, 8544, 991, 19091, 7781, 3809, 16026, 8378, 3242, 2049, 7832, 4059, 6967, 13382, 9337, 8108, 4065, 12391, 7160, 6322, 14406, 14463, 11421, 8085, 10061, 2979, 16273, 7031, 8880, 12552, 8622, 16427, 16558, 15154, 488, 5063, 5315, 13105, 202, 5820, 13399, 16857, 1305, 10763, 12558, 1740, 16067, 5826, 2011, 1973, 16347, 3512, 9274, 5622, 5410, 12601, 14987, 3880, 16251, 1539, 11659, 17207, 17869, 15583, 233, 12493, 8131, 13119, 12502, 16134, 4386, 1142, 4420, 16620, 6365, 9128, 9247, 16845, 4246, 3223, 7279, 1813, 1605, 6154, 5209, 8944, 16561, 15987, 12910, 4480, 1796, 12676, 853, 3025, 15695, 16917, 4740, 7587, 16849, 10410, 10532, 15509, 12039, 8134, 8599, 1711, 9061, 14751, 6622, 7535, 2244, 16645, 11199, 1312, 7313, 3478, 9842, 8182, 3766, 16915, 13931, 11901, 5024, 8157, 8380, 8582, 16830, 511, 6391, 3067, 3638, 8491, 4009, 17362, 16923, 9236, 16822, 11717, 16868, 16946, 7435, 3], 'joinid': 1, 'celltype': 'HCT', 'length': 505}
原始的liu.h5ad本身就是206个细胞样本,由于表达为0的gene都不被考虑,所以每个细胞(句子)的seq_len都是不一样的。体现在token_dataset[cell idx]['length']
,cell idx是细胞的index。
对于词汇表,存储在Geneformer/geneformer/token_dictionary_gc95M.pkl
,字典形式,key是gene name,Ensembl ID,还有四个特殊的键,可以看到前四个键就是特殊的键:'<pad>', '<mask>', '<cls>', '<eos>'
,value对应每个词的ID,pad的值是0,mask是1,cls是2,eos是3,也就是提供给BERT的具体input数字化形式(输入BERT的是input_ids)。
可以看到每个细胞的input_ids以2开头,3结尾,在输入BERT计算时,由于需要pad到固定长度,所以这时候大部分是以很多0结尾。
我们不重新训练BERT,所以用不到mask,但是要注意,不管是微调训练还是推理,我们需要提供attention_mask,这样可以避免注意力在pad token上计算。所以还需要一个函数来处理数据集:
from collections import Counter
import datetime
import pickle
import subprocess
import seaborn as sns
sns.set()
from datasets import load_from_disk
from sklearn.metrics import accuracy_score, f1_score
import transformers
print(transformers.__file__)
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification
from geneformer import TOKEN_DICTIONARY_FILE
import numpy as np
import torch
# 加载数据
evalset=load_from_disk("./data/tokenized_data/pbmc.dataset")
evalset = evalset.select([0,1,2])
label_name = "celltype"
def preprocess_classifier_batch(cell_batch, max_len, label_name):
if max_len is None:
max_len = max([len(i) for i in cell_batch["input_ids"]])
# load token dictionary (Ensembl IDs:token)
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
gene_token_dict = pickle.load(f)
def pad_label_example(example):
example[label_name] = np.pad(
example[label_name],
(0, max_len - len(example["input_ids"])),
mode="constant",
constant_values=-100,
)
example["input_ids"] = np.pad(
example["input_ids"],
(0, max_len - len(example["input_ids"])),
mode="constant",
constant_values=gene_token_dict.get("<pad>"),
)
example["attention_mask"] = (
example["input_ids"] != gene_token_dict.get("<pad>")
).astype(int)
return example
padded_batch = cell_batch.map(pad_label_example)
return padded_batch
max_evalset_len = max(evalset.select([i for i in range(len(evalset))])["length"])
print(max_evalset_len)
padded_batch = preprocess_classifier_batch(evalset, max_evalset_len, label_name)
input_data_batch = torch.tensor(padded_batch["input_ids"])
attn_msk_batch = torch.tensor(padded_batch["attention_mask"])
label_batch = padded_batch[label_name]
print(len(input_data_batch)) # 3
print(len(attn_msk_batch)) # 3
print(len(label_batch)) # 3
在这个例子中,取了三个样本构成batch:evalset = evalset.select([0,1,2])
,首先计算三个样本的最大长度:print(max_evalset_len)
为527,进一步,我们看第一个样本和第二个的input data:
print(input_data_batch[0])
print(input_data_batch[1])
print(len(input_data_batch[0]))
结果为:
tensor([ 2, 10508, 17269, 835, 9018, 1567, 15750, 16469, 441, 2897,
...
4009, 17362, 7435, 11717, 16946, 16868, 16822, 3, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0])
tensor([ 2, 10508, 839, 15750, 4619, 15240, 16485, 835, 9018, 16932,
...
11717, 16868, 16946, 7435, 3, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0])
527
再看attention_mask和label:
print(attn_msk_batch[0])
print(len(attn_msk_batch[0]))
print(label_batch[0])
结果为:
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
527
['HCT']
在输入给预训练模型计算时,只需要提供input data和attention mask:
# 加载模型
model = BertForSequenceClassification.from_pretrained("/Geneformer/gf-6L-30M-i2048/",
num_labels=3,
output_attentions = False,
output_hidden_states = False).to("cuda")
print(model.config)
outputs = model(input_ids=input_data_batch.to("cuda"), attention_mask=attn_msk_batch.to("cuda"))
print(outputs.logits.size())
print(outputs.logits.max())
print(outputs.logits.min())
我们也可以从头初始化一个BERT,前提是指定详细的参数:
from transformers import BertForSequenceClassification
import pickle
from transformers import BertConfig
# set model parameters
# model type
model_type = "bert"
# max input size
max_input_size = 2**11 # 2048
# number of layers
num_layers = 6
# number of attention heads
num_attn_heads = 4
# number of embedding dimensions
num_embed_dim = 256
# intermediate size
intermed_size = num_embed_dim * 2
# activation function
activ_fn = "relu"
# initializer range, layer norm, dropout
initializer_range = 0.02
layer_norm_eps = 1e-12
attention_probs_dropout_prob = 0.02
hidden_dropout_prob = 0.02
with open("/Geneformer/geneformer/token_dictionary_gc95M.pkl", "rb") as fp:
token_dictionary = pickle.load(fp)
print(token_dictionary.get("<pad>"))
# model configuration
config = {
"hidden_size": num_embed_dim,
"num_hidden_layers": num_layers,
"initializer_range": initializer_range,
"layer_norm_eps": layer_norm_eps,
"attention_probs_dropout_prob": attention_probs_dropout_prob,
"hidden_dropout_prob": hidden_dropout_prob,
"intermediate_size": intermed_size,
"hidden_act": activ_fn,
"max_position_embeddings": max_input_size,
"model_type": model_type,
"num_attention_heads": num_attn_heads,
"pad_token_id": token_dictionary.get("<pad>"),
"vocab_size": len(token_dictionary), # genes+2 for <mask> and <pad> tokens
"num_labels": 3,
"output_attentions": False,
"output_hidden_states": True
}
config = BertConfig(**config)
model = BertForSequenceClassification(config).to("cuda")
- 详细参考:https://github.com/huggingface/transformers