单细胞组学大模型(3)--- scGPT,有非常详细的学习文档和应用说明,且有多种训练数据权重!

news2024/9/20 16:44:24


–https://doi.org/10.1038/s41592-024-02201-0

代码来源:https://github.com/bowang-lab/scGPT

学习参考:https://scgpt.readthedocs.io/en/latest/introduction.html

scGPT: Towards Building a Foundation Model for Single-Cell Multi-omics Using Generative AI

留意更多内容,欢迎关注微信公众号:组学之心

研究团队和研究单位

:::: column
::: column-left 60%

:::
::: column-right

:::
::::

scGPT简介

数据:

多种组学:scRNA-seq, scATAC-seq, CITE-seq, Spatial transcriptomics;

33 Million 的单细胞RNA数据 / cross-tissue, human, non-disease;来源于CELLxGENE数据库

数据预处理:先用raw count矩阵,scRNA-seq是RNA分子的read count数目,scATAC-seq则是染色质可及性的peak region。都需要处理成cell-by-gene的矩阵形式。

scGPT的preprocess代码有基本的质控,如filter_gene_by_counts和filter_cell_by_counts

模型主要特点:

  • 表达值自适应分箱
  • input embedding融合多种信息
  • 生成式预训练中独特的注意力掩码机制
  • 实现的下游任务丰富

-----数据预处理-质控和分箱-----

Preprocessor自定义函数有一系列参数,如果为False则不执行:

class Preprocessor:
    """
    Prepare data into training, valid and test split. Normalize raw expression
    values, binning or using other transform into the preset model input format.
    """

    def __init__(
        self,
        use_key: Optional[str] = None,
        filter_gene_by_counts: Union[int, bool] = False,
        filter_cell_by_counts: Union[int, bool] = False,
        normalize_total: Union[float, bool] = 1e4,
        result_normed_key: Optional[str] = "X_normed",
        log1p: bool = False,
        result_log1p_key: str = "X_log1p",
        subset_hvg: Union[int, bool] = False,
        hvg_use_key: Optional[str] = None,
        hvg_flavor: str = "seurat_v3",
        binning: Optional[int] = None,
        result_binned_key: str = "X_binned",
    ):

1. 质控过滤

1.1 过滤基因

        if self.filter_gene_by_counts:
            logger.info("Filtering genes by counts ...")
            sc.pp.filter_genes(
                adata,
                min_counts=self.filter_gene_by_counts
                if isinstance(self.filter_gene_by_counts, int)
                else None,
            )

1.2 过滤细胞

        if (
            isinstance(self.filter_cell_by_counts, int)
            and self.filter_cell_by_counts > 0
        ):
            logger.info("Filtering cells by counts ...")
            sc.pp.filter_cells(
                adata,
                min_counts=self.filter_cell_by_counts
                if isinstance(self.filter_cell_by_counts, int)
                else None,
            )

1.3 normalize

        if self.normalize_total:
            logger.info("Normalizing total counts ...")
            normed_ = sc.pp.normalize_total(
                adata,
                target_sum=self.normalize_total
                if isinstance(self.normalize_total, float)
                else None,
                layer=key_to_process,
                inplace=False,
            )["X"]
            key_to_process = self.result_normed_key or key_to_process
            _set_obs_rep(adata, normed_, layer=key_to_process)

1.4 log1p

        if self.log1p:
            logger.info("Log1p transforming ...")
            if is_logged:
                logger.warning(
                    "The input data seems to be already log1p transformed. "
                    "Set `log1p=False` to avoid double log1p transform."
                )
            if self.result_log1p_key:
                _set_obs_rep(
                    adata,
                    _get_obs_rep(adata, layer=key_to_process),
                    layer=self.result_log1p_key,
                )
                key_to_process = self.result_log1p_key
            sc.pp.log1p(adata, layer=key_to_process)

这里key_to_process = self.result_log1p_key代码修改了原本key_to_process的值(一开始参数选定是adata.X,存储raw data的位置,经过log1p步骤之后变成了归一化且log1p之后的矩阵)。所以后续的分箱是对归一化log1p后的矩阵来进行。

1.5 hvg挑选

        if self.subset_hvg:
            logger.info("Subsetting highly variable genes ...")
            if batch_key is None:
                logger.warning(
                    "No batch_key is provided, will use all cells for HVG selection."
                )
            sc.pp.highly_variable_genes(
                adata,
                layer=self.hvg_use_key,
                n_top_genes=self.subset_hvg
                if isinstance(self.subset_hvg, int)
                else None,
                batch_key=batch_key,
                flavor=self.hvg_flavor,
                subset=True,
            )

下一步就是进行表达值分箱了,在下面embedding中一起提到。

–模型框架之Input Embeddings–

1. Input Embeddings

Input Embeddings主要由三种token组成:

  • Gene/peak token
  • 表达值token
  • 条件token。

输入的矩阵 X ∈ R N × G X∈ ℝ^{N×G} XRN×G 需要处理成cell-by-gene矩阵形式,其中 N N N 是细胞数量, G G G 是最初基因数量。

1.1 Gene tokens

把基因symbol名字作为tokens,并为每个基因 g j g_j gj 分配一个id( g j g_j gj)作为词汇表。词汇表还有特殊标记符,如 [cls]。所以每个细胞 i i i 的输入基因tokens由以下公式来表示(M是定义的最大输入长度,就是基因数量的最大值) t g ( i ) ∈ N M t^{(i)}_g ∈ ℕ^M tg(i)NM

之后把每个细胞的gene token通过GeneEncoder嵌入后,转换成 N M × D ℕ^{M×D} NM×D 维度:

class GeneEncoder(nn.Module):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
    ):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings, embedding_dim, padding_idx=padding_idx
        )
        self.enc_norm = nn.LayerNorm(embedding_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = self.embedding(x)  # (batch, seq_len, embsize)
        x = self.enc_norm(x)
        return x

1.2 表达值 tokens

用自适应值分箱方法,将每个细胞中的基因表达值进行离散化处理,并转换为相对值:

针对每个细胞中基因表达值非零部分的分布将它们划分为 B 个连续的 bin 区间 [ b k , b k + 1 ] [b_k, b_k+1] [bkbk+1],其中 k ∈ {1,2,…,B},这个B值是自定义的参数(代码在下面第二部分提到)。

大致意思是把非零基因表达值从小到大排列,按照bin的区间计算分位点,再计算分位点在值队伍中的位置,确定上界和下界,再把基因放进去。

分到相同区间的基因不管之前的表达值具体是多少,分完后都一样。这个分箱过程是针对每个细胞独立完成的,因此 b k b_k bk 会因细胞而异。细胞 i 的分箱基因值 x j ( i ) x^{(i)}_j xj(i) 定义为:

这里和上一章scBERT提到的bag-of-words思想一致,只是分箱的方法不一样。类似拔叔教我们怎么做人,对每样材料依据ta自己的特点进行解肢,然后定制一个个箱子,把材料放进去保存,使得材料能够刚好均匀分布在全部箱子中。

对于微调任务,在分箱步骤之前做了 log1p 转换和 HVG 选择。为了简化表示法,就在分箱之前使用 X i , j X_{i,j} Xi,j 来表示raw count矩阵或预处理数据矩阵。因此,细胞 i 的分箱基因表达值的最终输入向量表示为:

表达值就有3种value encoder,分别对应表达值是category, continuous, scaling的情况。模型默认使用continuous模式,对应着单细胞表达矩阵。作用也是把原本 1 × M 1×M 1×M 维的基因表达值的最终输入向量变成 M × D M×D M×D 维的embedding

CategoryValueEncoder(和GeneEncoder一样):

class CategoryValueEncoder(nn.Module):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
    ):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings, embedding_dim, padding_idx=padding_idx
        )
        self.enc_norm = nn.LayerNorm(embedding_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = x.long()
        x = self.embedding(x)  # (batch, seq_len, embsize)
        x = self.enc_norm(x)
        return x

ContinuousValueEncoder(2层神经网络):

class ContinuousValueEncoder(nn.Module):
    """
    Encode real number values to a vector using neural nets projection.
    """

    def __init__(self, d_model: int, dropout: float = 0.1, max_value: int = 512):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.linear1 = nn.Linear(1, d_model)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.max_value = max_value

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len]
        """
        # TODO: test using actual embedding layer if input is categorical
        # expand last dimension
        x = x.unsqueeze(-1)
        # clip x to [-inf, max_value]
        x = torch.clamp(x, max=self.max_value)
        x = self.activation(self.linear1(x))
        x = self.linear2(x)
        x = self.norm(x)
        return self.dropout(x)

Scaling(src先通过GeneEncoder然后存储在cur_gene_token_embs中,values通过valueEncoder后,根据input_emb_style来确定src和values之间的运算关系,一般是相加):

def _encode(
        self,
        src: Tensor,
        values: Tensor,
        src_key_padding_mask: Tensor,
        batch_labels: Optional[Tensor] = None,  # (batch,)
    ) -> Tensor:
        self._check_batch_labels(batch_labels)

        src = self.encoder(src)  # (batch, seq_len, embsize)
        self.cur_gene_token_embs = src

        values = self.value_encoder(values)  # (batch, seq_len, embsize)
        if self.input_emb_style == "scaling":
            values = values.unsqueeze(2)
            total_embs = src * values
        else:
            total_embs = src + values

        if getattr(self, "dsbn", None) is not None:
            batch_label = int(batch_labels[0].item())
            total_embs = self.dsbn(total_embs.permute(0, 2, 1), batch_label).permute(
                0, 2, 1
            )  # the batch norm always works on dim 1
        elif getattr(self, "bn", None) is not None:
            total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1)

        output = self.transformer_encoder(
            total_embs, src_key_padding_mask=src_key_padding_mask
        )
        return output  # (batch, seq_len, embsize)

1.3 条件 tokens

条件tokens包含与单个基因相关的各种meta信息,例如扰动实验改变(由扰动标记表示)。为了表示位置条件tokens,使用与输入基因相同维度的输入向量:

类似的有一个BatchLabelEncoder(和GeneEncoder一样):

class BatchLabelEncoder(nn.Module):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
    ):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings, embedding_dim, padding_idx=padding_idx
        )
        self.enc_norm = nn.LayerNorm(embedding_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = self.embedding(x)  # (batch, embsize)
        x = self.enc_norm(x)
        return x

1.4 Embedding 层

e m b g emb_g embg e m b c emb_c embc 分别表示 基因token 和 条件token。使用全连接层将低维离散的分箱基因表达值引入非线性表示,学习基因表达模式(表示为 e m b x emb_x embx),以增强表现力(就是上面提到的ContinuousValueEncoder)。这种选择可以对基因表达值的序数关系进行建模:

细胞 i i i 的 final embeddings — h ( i ) ∈ R M × D h^{(i)}∈ ℝ^{M×D} h(i)RM×D定义为三种 e m b emb emb 的逐元素相加:

2. 自适应分箱

2.1 执行分箱与变量初始化

        if self.binning:
            logger.info("Binning data ...")
            if not isinstance(self.binning, int):
                raise ValueError(
                    "Binning arg must be an integer, but got {}.".format(self.binning)
                )
            n_bins = self.binning  # NOTE: the first bin is always a spectial for zero
            binned_rows = []
            bin_edges = []
            layer_data = _get_obs_rep(adata, layer=key_to_process)
            layer_data = layer_data.A if issparse(layer_data) else layer_data
            if layer_data.min() < 0:
                raise ValueError(
                    f"Assuming non-negative data, but got min value {layer_data.min()}."
                )
  • binned_rows 用于存储所有行的分箱结果。
  • bin_edges 用于存储每个行的分箱边界。
  • layer_data 获取指定层的数据,并处理稀疏矩阵,检查数据是否包含负值。

2.2 逐行处理数据

            for row in layer_data:
                if row.max() == 0:
                    logger.warning(
                        "The input data contains all zero rows. Please make sure "
                        "this is expected. You can use the `filter_cell_by_counts` "
                        "arg to filter out all zero rows."
                    )
                    binned_rows.append(np.zeros_like(row, dtype=np.int64))
                    bin_edges.append(np.array([0] * n_bins))
                    continue
  • 代码遍历 layer_data 中的每一行(一个细胞的表达数据)。
  • 如果细胞基因表达全为零,则记录一个警告,并将这行标记为零箱,存储在 binned_rows 中,同时将所有边界设为零,存储在bin_edges中。所以预训练没有去掉无基因表达的细胞。

2.3 计算分箱边界和离散化数据

non_zero_ids = row.nonzero()
non_zero_row = row[non_zero_ids]
bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
# bins = np.sort(np.unique(bins))
# NOTE: comment this line for now, since this will make the each category
# has different relative meaning across datasets
non_zero_digits = _digitize(non_zero_row, bins)
assert non_zero_digits.min() >= 1
assert non_zero_digits.max() <= n_bins - 1
binned_row = np.zeros_like(row, dtype=np.int64)
binned_row[non_zero_ids] = non_zero_digits
binned_rows.append(binned_row)
bin_edges.append(np.concatenate([[0], bins]))
  • 首先提取出非零的基因及其索引。

  • 先通过np.linspace(0, 1, n_bins - 1) 生成 n_bins-1 个均匀间隔的分箱边界,然后np.quantile计算这些非零基因表达值的分位数,再用自定义的_digitize函数计算分箱区间,将数据分配到一组给定的区间中:

def _digitize(x: np.ndarray, bins: np.ndarray, side="both") -> np.ndarray:
    """
    Digitize the data into bins. This method spreads data uniformly when bins
    have same values.

    Args:

    x (:class:`np.ndarray`):
        The data to digitize.
    bins (:class:`np.ndarray`):
        The bins to use for digitization, in increasing order.
    side (:class:`str`, optional):
        The side to use for digitization. If "one", the left side is used. If
        "both", the left and right side are used. Default to "one".

    Returns:

    :class:`np.ndarray`:
        The digitized data.
    """
    assert x.ndim == 1 and bins.ndim == 1

    left_digits = np.digitize(x, bins)
    if side == "one":
        return left_digits

    right_difits = np.digitize(x, bins, right=True)

    rands = np.random.rand(len(x))  # uniform random numbers

    digits = rands * (right_difits - left_digits) + left_digits
    digits = np.ceil(digits).astype(np.int64)
    return digits
  • 将非零数据转换为对应的分箱(1 到 n_bins-1),零值保留为零。
  • 将结果存储到 binned_row(分箱后的数据)和 bin_edges(分箱边界)中,并添加到对应的列表中。

提一嘴:最后细胞i的输入embedding是:

3. 通过 scGPT transformer 对细胞和基因表达建模

用自注意力机制transformer对 h ( i ) h^{(i)} h(i) 进行encode得到 h n ( i ) ∈ R M , D h^{(i)}_n∈ ℝ^{M,D} hn(i)RMD

将scGPT-transformer得到的 h n ( i ) ∈ R M , D h^{(i)}_n∈ ℝ^{M,D} hn(i)RMD 直接用于基因水平的微调任务:基因表达预测(GEP)和扰动表达预测任务(perturb-GEP)。

而对于细胞水平任务,需要先将 h n ( i ) h^{(i)}_n hn(i) 整合到细胞embedding向量(3.1细胞表示)中,任务有细胞类型注释,即细胞嵌入通过分类器后预测细胞类型标签。

输入维度 M 可以达到数万个基因,超过了 NLG 中常用的常规transformer的输入长度。使用 FlashAttention的加速自我注意力机制实现,增强模型容量,并支持对大型输入维度的有效处理。其它任何高效的transformer也可以用,例如Linformer和内核自注意力(KSA)。

3.1 细胞表示

细胞i表示 h c ( i ) ∈ R D h^{(i)}_c∈ ℝ^D hc(i)RD 是在基因水平表示 h n ( i ) h^{(i)}_n hn(i)中得到的。各种池化操作,如逐元素均值池化或加权池化,都可以在这种情况下使用。[cls] token 是在 Transformer 模型中常用的一种特殊标记符,用来表示整个输入序列的整体信息。

在基因表达模型中,将 [cls] token 附加到input token的开头。这个 [cls] token 的嵌入向量最终将包含整个细胞的综合信息。经过 Transformer 层的处理后,细胞嵌入 h c ( i ) h^{(i)}_c hc(i) 可以由堆叠的 h n ( i ) h^{(i)}_n hn(i) [cls]中的相应行提取出来。

[cls] 操作其实就是检索 [cls] token位置索引处的行并提取出来。

    def _get_cell_emb_from_layer(
        self, layer_output: Tensor, weights: Tensor = None
    ) -> Tensor:
        """
        Args:
            layer_output(:obj:`Tensor`): shape (batch, seq_len, embsize)
            weights(:obj:`Tensor`): shape (batch, seq_len), optional and only used
                when :attr:`self.cell_emb_style` is "w-pool".

        Returns:
            :obj:`Tensor`: shape (batch, embsize)
        """
        if self.cell_emb_style == "cls":
            cell_emb = layer_output[:, 0, :]  # (batch, embsize)
        elif self.cell_emb_style == "avg-pool":
            cell_emb = torch.mean(layer_output, dim=1)
        elif self.cell_emb_style == "w-pool":
            if weights is None:
                raise ValueError("weights is required when cell_emb_style is w-pool")
            if weights.dim() != 2:
                raise ValueError("weights should be 2D")
            cell_emb = torch.sum(layer_output * weights.unsqueeze(2), dim=1)
            cell_emb = F.normalize(cell_emb, p=2, dim=1)  # (batch, embsize)

        return cell_emb

3.2 批次和模态的表示(类似Condition tokens)

使用标准嵌入层以类似的方式实现(比如上面提到的GeneEncoder和BatchLabelEncoder)。这一步主要是想缓解批次效应和整合多模态数据。

模态(modality)token t m ( i ) t^{(i)}_m tm(i) 与各个输入特征 g j g_j gj 相关联(例如,指示它是基因、region还是蛋白质)。批次(batch)token最初位于细胞水平,但也可以传播到单个细胞的所有特征。相同的批次token t b ( i ) t^{(i)}_b tb(i)可以重复,直到单个细胞 i i i 的输入特征长度达到 M M M

这些批次和模态token的嵌入不用作transformer块的输入。它们在进入特定的微调目标之前,在特征或细胞级别与transformer输出连接起来。

这是为了防止transformer放大相同模态特征内的注意力,而低估不同模态的特征。此外,了解模态和/或批次身份有助于在下游微调目标中进行基因表达建模。当模型学习预测以模态和/或批次身份为条件的表达值时,这种偏差会从基因和细胞表示本身中隐式消除。这是一种促进批次校正的技术。

在单细胞多组学整合任务中,将transformer输出与batch和modality embeddings的和连接起来( e m b b emb_b embb e m b m emb_m embm 先逐元素相加,再和 h n ( i ) h^{(i)}_n hn(i)拼接起来)得到 h n ′ ( i ) h^{'(i)}_n hn(i),用作表达建模的下游微调目标的输入:

而只有scRNA-seq一种模态,但是有不同的批次,面对scRNA-seq整合任务的时候,只用 h c ( i ) h^{(i)}_c hc(i) e m b b emb_b embb 拼接起来得到 h c ′ ( i ) h^{'(i)}_c hc(i)

上面的式子中, t b ( i ) t^{(i)}_b tb(i) 表示细胞 i i i 的批次同一性, h c ( i ) h^{(i)}_c hc(i) 是微调目标中的原始细胞表示。请注意,修改后的版本 h c ′ ( i ) h^{'(i)}_c hc(i) 仅与表达式建模目标相关,不适用于基于分类的目标。

–模型框架之Masked-attention–

1.一些背景

自注意力机制被广泛用于捕捉词汇之间的共现模式。在自然语言处理中,主要通过以下两种方式实现:

  • ①掩码token预测,主要用于如BERT和RoBERTa等Transformer编码器模型中。在这种方法中,输入序列中的一些token会被随机掩码,然后模型会在输出中预测这些标记。

  • ②自回归生成,通过因果Transformer解码器模型(如OpenAI GPT)进行顺序预测。

在OpenAI GPT-3和GPT-4中使用的生成预训练采用了统一框架,在该框架中,模型预测最有可能的下一个token,该token来自于已知输入token的“prompts”。这个框架提供了高度的灵活性,可以用于各种自然语言生成应用,并展示了诸如上下文感知等能力。scGPT研究认为,生成训练也可以在单细胞模型中以类似的方式发挥作用。

具体来说,scGPT研究对以下两个任务感兴趣:

  • 1.基于已知基因表达生成未知的基因表达值,即“基因prompts”的生成;

  • 2.在给定输入细胞类型条件下生成全基因组表达值,即“细胞prompts”的生成。

尽管token和prompts的概念相似,但由于数据的非顺序性,基因读取与自然语言的建模有本质区别。与句子中的词语不同,细胞中基因的顺序是可以互换的,因此不存在“下一个基因”来预测的概念。这使得将GPT模型的因果掩码公式不能直接应用于单细胞数据。因此,scGPT研究开发了一种专门的注意力掩码机制,定义了基于注意力分数的预测顺序。

2.注意力掩码

这部分代码较长,可以在github上的model.py中FlashTransformerEncoderLayer和FlashTransformerEncoderLayer中学习~

注意力掩码可以应用于 transformer 块中的自注意力矩阵,如对于 M M M 个基因的input token:

第(l + 1)个 transformer块对其输入 M M M 个token的 h l ( i ) ∈ R M , D h^{(i)}_l∈ ℝ^{M,D} hl(i)RMD 应用多头自我注意, h l ( i ) h^{(i)}_l hl(i)是细胞的final embedding h ( i ) h^{(i)} h(i) 经过transformer块的输出:

具体来说,每个self-attention操作的计算方式如下:

其中 Q , K , V ∈ R M × D Q, K, V ∈ ℝ^{M×D} Q,K,VRM×D W q , W k , W v ∈ R M × d Wq, Wk, Wv ∈ ℝ^{M×d} Wq,Wk,WvRM×d A m a s k ∈ [ 0 , − i n f ] M × M A_{mask} ∈ [0,− inf]^{M×M} Amask[0inf]M×M A m a s k A_{mask} Amask Q Q Q 是行, K K K 是列(需要符合上面公式中 Q K T QK^T QKT 的维度)

注意力掩码 A m a s k A_{mask} Amask 通过“灭活” Q Q Q K K K 之间的原始注意力权重来修改自我注意力的范围:将 − i n f − inf inf 添加到矩阵中的位置 ( i , j ) (i,j) ij 会使 softmax 之后的注意力权重无效,从而禁止第 i 个查询和第 j 个键之间的注意力。另一方面,添加 0 意味着注意力权重保持不变。

A m a s k A_{mask} Amask如上图A所示,scGPT注意力掩码的原则是仅允许在 “已知基因” 的 embedding 和基因 Q Q Q 之间进行注意力计算,这是通过 A m a s k A_{mask} Amask a i , j a_{i,j} ai,j来实现的,对未知基因(需要预测的基因)的 Q K T QK^T QKT注意力矩阵进行掩码(蓝色),但是从左上到右下的对角线上的区域不掩码:

每个输入嵌入向量 h l ( i ) h^{(i)}_l hl(i)中的token是保留的[cls]标记,用于细胞嵌入;具有已知基因symbol token嵌入和基因表达值嵌入;需要预测表达值的未知基因。

3.掩码后的计算

因果掩码设计:在每次生成迭代中,scGPT预测一组新基因的表达值,而这些基因在下次迭代中会成为“已知基因”,用于下一步的注意力计算,即通过逐步预测非顺序的单细胞数据来实现下一个token预测。在训练过程中随机选择一部分基因作为未知基因,因此它们的表达值在输入中被忽略。

scGPT模型通过堆叠的Transformer块与上述的掩码注意力图预测这些未知基因的表达。在进行细胞prompts生成时,scGPT根据特定的细胞类型生成全基因组表达值。训练好的细胞嵌入被输入,放置在代表细胞类型条件的第一个位置。整个基因表达生成过程分为 K K K 次迭代步骤(即上图b中K=3步)。

例如,在一次迭代中 𝑖∈[1,2,…,𝐾],注意力掩码机制允许与前 0≤𝑖−1次迭代中预测的所有基因进行注意力计算。在每次迭代中,scGPT选择未知集合中预测置信度最高的前 1/K 的基因,以作为下次迭代中的已知基因。

直观地,这种工作流通过自回归方式简化了基因表达的生成,其中预测置信度最高的基因表达值首先生成,并用于帮助后续轮次的生成。

基因prompts生成以类似的方式进行。不同之处在于它以一组具有观察到的表达值的已知基因开始,而不是细胞嵌入。scGPT注意力掩码统一了已知基因编码过程和未知基因生成过程,也成为首个在非顺序数据中执行自回归生成的Transformer架构之一。

4.预训练的学习目标

预训练中的学习目标是通过估计未知基因的表达值来优化模型。这通过使用多层感知器(MLP)来实现,MLP的任务是根据模型的输入预测基因表达值,并通过计算预测值与真实值之间的均方误差来衡量模型的性能,并计算均方误差损失 𝐿:

其中,𝑢_unk 表示未知基因的输出位置集合,𝑥_𝑗^{(𝑖)}是要预测的实际基因表达值。∣⋅∣操作获取集合中的元素数量。

在训练过程中,基因prompts和细胞prompts生成这两种模式连续进行。对于给定细胞的输入基因标记,一部分基因被选为“未知”基因,它们的表达值被忽略。

  • 在基因prompts步骤中,模型的输入包括[cls]token嵌入、已知基因嵌入和未知基因嵌入。损失使用模型的输出计算。

  • 在细胞prompts步骤中,使用先前步骤输出的细胞嵌入(即 ℎ_𝑐^{(𝑖)}在细胞表示中)替换[cls]位置的嵌入,其他计算保持相同。

最后,两种步骤的损失值相加,并用于计算梯度以优化模型参数。整个过程是连续进行的,通过这种方式,模型不仅可以学习预测未知基因的表达值,还能根据细胞的上下文信息进行更精准的表达预测。

–微调目标–

1.基因表达预测(GEP)

这个微调目标与预训练中的目标类似,但应用于掩码位置。具体来说,对于每个输入细胞,一部分基因标记及其对应的表达值 𝑥^{(𝑖)} 会被随机掩码。scGPT通过优化准确预测这些掩码位置的表达值来达到微调目的。这一微调目标有助于模型有效编码数据集中基因间的共表达关系。

该目标通过最小化掩码位置的均方误差来实现,表示为 𝑀_{mask}。GEP的工作原理如下:这里x̃ ^{(i)} ∈ℕ^M,表示细胞 i的表达估计的行。如果有批次或者模态条件信息,用h{'(i)}_n来代替h{(i)}_n

GEP是一种通用的自监督微调目标,旨在预测基因表达值。在某些下游任务中,如扰动预测,模型需要预测基因扰动后的表达值,而不是原始值。研究将这种变体称为扰动GEP。保留上面方程中的MLP估算器,但使用扰动后基因表达作为目标值 𝑥^{(𝑖)}_j。在扰动GEP中,模型的任务是预测所有输入基因的扰动后表达。

2.用于细胞建模的基因表达预测

这个微调目标与GEP类似,但基于细胞表示h^{(i)}_c来预测基因表达值,从而促进细胞表示学习。对于每个细胞 𝑖中的基因 𝑗,我们创建一个查询向量 𝑞_𝑗,并使用 𝑞_𝑗和细胞表示的 h^{(i)}_c 参数化内积作为预测的表达值

3.弹性细胞相似性

这个微调目标通过使用相似性学习损失来增强细胞表示:

其中 sim代表余弦相似度函数,𝑖和 𝑖^′指代小批量中的两个细胞。此外,𝛽表示一个预定阈值,而 ECS是弹性细胞相似性。这个方法背后的思想是通过提升相似性高于 𝛽的配对,使其更加相似;相反,不相似的配对则被鼓励进一步分离。

4.通过反向传播来实现领域自适应

细胞表示学习受限于批次效应的存在,这些效应是由于测序技术引入的非生物学批次差异所导致的。为了缓解这个问题,scGPT使用了一个独立的MLP分类器来预测与每个输入细胞 ℎ_𝑐^{(𝑖)}相关的测序批次,并通过在模型内反转梯度来修改反向传播过程。

5.细胞类型分类

这个微调目标旨在利用学习到的细胞表示来注释单个细胞。scGPT使用一个独立的MLP分类器根据细胞表示 ℎ_𝑐^{(𝑖)} 预测细胞类型。这个微调目标通过预测的细胞类型概率与真实标签之间的交叉熵损失来优化。

–实操应用–

scGPT需要Python >= 3.7.13 and R >=3.6.1,而且scGPT已经上传在PyPI中,而且在HuggingFace中提供了运行预训练工作流程的初步支持
https://github.com/bowang-lab/scGPT/tree/integrate-huggingface-model

pip install scgpt "flash-attn<1.0.5" 

pip install wandb

git clone this-repo-url
cd scGPT
poetry install

预训练模型的列表,可以获取训练好的权重文件:

作者还提供了微调模型的练习代码文件
https://github.com/bowang-lab/scGPT/blob/main/examples/finetune_integration.py

https://scgpt.readthedocs.io/en/latest/introduction.html

–下游任务应用–

1.细胞类型注释

2.基因扰动响应和逆扰动的预测

3.多批次多模态数据整合

4.基因token嵌入分析

5.基于注意力的基因互作分析

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2102261.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024.9.3

#include <iostream> #include <cstring> using namespace std;class Stack { private:int len;int count 0;int *stack; public:Stack():len(10) //无参构造{stack new int[len];stack[len] {0};}Stack(int len):len(len) //有参构造{stac…

一文搞懂微服务架构之限流

前置知识 限流是通过限制住流量大小来保护系统&#xff0c;能够解决异常突发流量打崩系统的问题。例如常见的某个攻击者在攻击你维护的系统&#xff0c;那么限流就是极大程度上保护住你的系统。 算法 限流算法也可以像负载均衡算法那样&#xff0c;划分成静态算法和动态算法…

【软件测试专栏】测试分类篇

博客主页&#xff1a;Duck Bro 博客主页系列专栏&#xff1a;软件测试专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ 测试分类篇 关键词&#xff1a;测试方法的概念、测试类型、先后顺…

花生壳二级域名的绑定测试

1、花生壳客户端的登录 左下角显示的就是你的外部IP。 2、新建映射 点击新建映射&#xff0c;就会进入一个管理页面&#xff0c;如下图&#xff1a; 3、可以通过域名访问网站了 就可以二级域名直接访问&#xff0c;在192.168.1.11:8080 上建立的Tomcat网站了&#xff0c;非常…

uniapp写的一个年月日时分秒时间选择功能

代码: <template><view><picker mode"multiSelector" :value"multiIndex" :range"multiRange" change"onMultiChange"><view class"picker">当前选择&#xff1a;{{ formattedDateTime }}</vie…

各业务领域相关方案

电商 电商系统的简单架构 电商系统的简单架构_电商交易平台 系统架构-CSDN博客 订单系统 美团团购订单系统优化记 vivo 全球商城&#xff1a;订单中心架构设计与实践 库存系统 电商库存系统的防超卖和高并发扣减方案 vivo全球商城&#xff1a;库存系统架构设计与实践 资金…

开篇_____何谓安卓机型“工程固件” 与其他固件的区别 作用

此系列博文将分析安卓系列机型与一些车机 wifi板子等工程固件的一些常识。从早期安卓1.0起始到目前的安卓15&#xff0c;一些厂家发布新机型的常规流程都是从工程机到量产的过程。在其中就需要调试各种参数以便后续的量产参数可以固定到最佳&#xff0c;工程固件由此诞生。 后…

30道python自动化测试面试题与答案汇总

对于机器学习算法工程师而言,Python是不可或缺的语言,它的优美与简洁令人无法自拔,下面这篇文章主要给大家介绍了关于30道python自动化测试面试题与答案汇总的相关资料,需要的朋友可以参考下 1、什么项目适合做自动化测试&#xff1f; 关键字&#xff1a;不变的、重复的、规范…

【Agent】Agent Q: Advanced Reasoning and Learning for Autonomous AI Agents

1、问题背景 传统的训练Agent方法是在静态数据集上进行监督预训练&#xff0c;这种方式对于要求Agent能够自主的在动态环境中可进行复杂决策的能力存在不足。例如&#xff0c;要求Agent在web导航等动态设置中执行复杂决策。 现有的方式是用高质量数据进行微调来增强Agent在动…

专业文件搜索工具 | UltraSearch Pro v4.4.1.1015 绿色特别版

大家好&#xff0c;今天电脑天空给大家推荐一款非常实用的文件搜索软件——UltraSearch Pro。这款软件在文件搜索领域有着出色的表现。 UltraSearch Pro 是一款专业的文件搜索工具&#xff0c;以其快速、全面、精准的搜索能力赢得了用户的一致好评。无论是本地硬盘、网络驱动器…

【IPV6从入门到起飞】2-2 获取你的IPV6(Teredo隧道)

【IPV6从入门到起飞】2-2 获取你的IPV6&#xff08;Teredo隧道&#xff09; 1 打工人的忧伤2 Teredo介绍2.1 背景2.2 工作原理 3 Linux 服务器获取IPV63.1 安装3.2 设置开机自启动和启动3.3 开放防火墙 UDP 35443.4 查看IPV6以及ping包测试3.5 修改Teredo服务器3.6 重启服务3.7…

系统思考—关键决策

结‮影构‬响行为&#xff0c;精‮决准‬策创造价值&#xff01; 最‮身近‬边很多‮伙小‬伴找我“助力”&#xff0c;父‮也母‬经常发‮息信‬让我帮忙&#xff0c;忍‮住不‬研究了一下‮些这‬助力活动的“玩法”。说实话&#xff0c;这‮设种‬计从‮构结‬上真‮很的‬…

论文精读:Dirac半金属反常能斯特效应设计

摘要节选&#xff1a; 反常能思特效应&#xff08;ANE&#xff09;产生垂直于温度梯度的横向电压。与用于能量转换的纵向热电相比&#xff0c;它具有电子和热输运解耦、更高的灵活性和更简单的横向结构等优点。 在这项工作中&#xff0c;从理论上证明了一对Dirac节点在Zeeman…

用Leangoo领歌敏捷工具进行迭代管理的实践分享Sprint Backlog

在敏捷开发中&#xff0c;迭代管理是确保项目持续推进、不断优化的重要环节。有效的迭代管理能够帮助团队快速响应变化&#xff0c;持续交付高质量产品。 Leangoo是一款免费的敏捷项目管理工具&#xff0c;为团队提供了直观、高效的看板管理方式来管理迭代过程。本文将探讨如何…

python 正则表达式“.*”和“.*? ”的区别

“.*”和“.*? ”的区别 点号表示任意非换行符的字符&#xff0c;星号表示匹配它前面的字符0次或者任意多次。所以“.*”表示匹配一串任意长度的字符串任意次。这个时候必须在“.*”的前后加其他的符号来限定范围&#xff0c;否则得到的结果就是原来的整个字符串。 “.*? &…

嵌入式开发者必看:如何选择最适合的嵌入式开发数据管理工具?SVN/ClearCase VS. Perforce Helix Core(内含研讨会视频

嵌入式软件开发比以往任何时候都更为复杂&#xff0c;涉及的文件、团队和依赖项都在不断增加。 处理如此大规模的开发工作&#xff0c;需要一个强大的数据管理系统——该系统能够处理大量的大型文件和元数据&#xff0c;简化跨多个分支和用户的变更流程&#xff0c;并支持不断…

初级测评师能力要求

管理和技术的一致性、同一控制点不同层面提出的不同要求 1.恶意代码 &#xff08;1&#xff09;安全区域边界&#xff1a;恶意代码和垃圾邮件防范 a&#xff09; 应在关键网络节点处对恶意代码进行检测和清除&#xff0c;并维护恶意代码防护机制的升级和更新 b&#xff09; 应…

【C++ 面试 - STL】每日 3 题(五)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/fYaBd &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 C 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏&…

《JavaEE进阶》----10.<SpringMVC应用分层:【三层架构】>

本篇博客我们主要讲解 1.应用的分层&#xff1a;三层架构 2.Spring MVC和三层架构的区别和联系 3.软件设计原则&#xff1a;高内聚低耦合 4.应用分层的好处 5.通过应用分层后的代码示例 一、三层架构简介 阿里开发手册中,关于工程结构部分,定义了常见工程的应用分层结构: 上图…

echarts处理y轴最大小值根据数据动态处理、分割数和是否从0开始

方法 /*** 取整* param value 输入值* param scale 保留位数 正数是整数位&#xff0c;负数是小数位* param isUpper true 向上取整 false向下取整* returns*/ const numberFix (value, scale, isUpper) > {const moreThanZero value > 0value Math.abs(value)if (s…