文献: scGPT: toward building a foundation model for single-cell multi-omics using generative AI

时间:2024 Feb. (Published)

期刊:Nature Method

DOI:https://doi.org/10.1038/s41592-024-02201-0

Github:https://github.com/bowang-lab/scGPT

1. 简介

1.1 关于作者

通讯作者:Bo Wang,多伦多大学

image-20240928131111066

1.2 文献概述

  • 受启发于GPT等基于transformer的大模型,Wang团队基于大量单细胞转录组数据,构建了scGPT foundation model,可用于多种常见的单细胞下游分析任务,性能优于其它已有的分析工具。

  • 核心假设:

    • 在NLP领域中,一个句子由多个单词组成。经过预训练的自监督学习,foundation model可以提取输入句子中每个单词(或者句子整体)的高级表示(Embedding);
    • 而在scRNA-seq中,一个细胞可以认为由所有基因的不同程度表达定义。经过预训练的自监督学习,foundation model可以提取输入单细胞表达数据的每个基因(或者细胞整体)的高级表示(Embedding)。
  • 在Pre-trained foundation models基础上,将提取的基因/细胞Embedding在具体的单细胞下游分析任务中进行二次微调(fine-tune),发挥大数据生成的强大优势。

image-20240928131642668

2. Pretrain预训练

Pre-train foundation model的核心功能就是以一个单细胞的数据为输入,将提取的Gene/Cell Embedding为输出。

2.0 数据规模

  • 文章从CELLxGENE网站收集了33M个正常细胞的scRNA-seq数据用于训练Foundation model
  • 其中来自Brain、Blood等器官的细胞数最多

image-20240929091250859

2.1 初始化输入

  • scGPT的原始输入数据通常为N个细胞,G个基因的Count表达矩阵。此时,需要对每个细胞的基因数据进行初始化处理,作为scGPT的标准输入;
  • 概括来说,分别从每个基因的三个角度进行D维嵌入编码 (1×D),然后再进行矩阵加法(仍为1×D),最终将所有基因合并得到(M×D)的细胞特征矩阵。

因为细胞的基因会根据情况进行选择,并且考虑到<cls>等特殊词元,所以这里并不是G×D

image-20240928150405475

(1)Gene token

  • 所有基因(词元)构成的词表中,每个基因都有一个对应的整数标识符(Integer Identifier)
  • 此外还有其它特殊的词元,例如<cls>, <pad>等
    • <cls>词元通常放在细胞M个词元中的第一个,不表示特定基因,而用于cell representation
    • <pad>词元用于补长至固定长度M(e.g. 有些情况不考虑表达值为0的基因时)
  • 采用PyTorch embedding layer将投射成D维的Embedding

(2)Gene expression

  • 单细胞的count表达值在不同测序背景的条件下,不具有可比性。
  • scGPT采用bin分段处理,将单个细胞i的非零基因表达值,分为k个bins,分数分别是1..k。
    • 例如,若基因bin打分为k,表示其表达值处于最高的bin范围内;
    • 每个细胞分bin的阈值标准都会不尽相同。

image-20240928144527642

  • 采用MLP,转换为D维的Embedding

(3)Condition token

  • 记录基因额外的Condition token,例如是否为perturbated genes
  • 采用PyTorch embedding layer将投射成D维的Embedding

综上,一个细胞(i)的初始化输入h (M × D)的计算方式如下图所示。其中第一列通常表示为<cls>作为Cell representation;其它列(除<pad>等特殊字符外,均表示Gene Embedding。

image-20240928143902797

2.2 自监督训练

(1)Transformer块

  • Foundation model主要由 12 (l)个Transformer块组成,每个块则均采用了多头(8)自注意力机制;

image-20240928152927329

  • 由于这里的M通常较大,即一个细胞考虑数千上万个基因(可以理解为特别长的句子),scGPT采用了FlashAttention算法用以加速自注意力计算。

    此外,对于每个细胞,只有表达值非0的基因参与预训练过程,以提高速度。

  • 作者也推荐了其它高效计算的Transformer变体,包括linear complexity (Linformer),Kernelized Self-Attention (KSA)。

image-20240928151221844

(2)自监督任务

  • scGPT自监督任务的核心是预测掩码(masked)基因的表达值水平,采用MSE损失函数。
  • 具体分为如下两个子任务:
    • Gene-prompt:基于已知表达值的Gene Embedding,预测未知表达值Gene的expression value;
    • Cell-prompt:基于<cls> cell representation预测全基因的expression value。

image-20240928154530333

  • 然后,将上述两种模式的损失loss相加后,再计算梯度并更新模型参数。

(3)Masked Attention

  • 自注意计算过程为:将masked gene Embedding( without expression) 作为query,计算与其它known gene Embedding(也包括自己)的注意力权重后,再进行加权运算得到输出;
  • 与GPT模型所处理的文本句不同之处在于:the non-sequential nature of the genes in one cell. 为此作者设计了masked attention注意力计算方式;

image-20240928155607273

  • 如上公式,在不考虑Amask的情况下,为标准的自注意力计算方式。而Amask可参考如下公式,以及下图左A(行表示query,列表示key)理解。每一个query(i)计算与其他词元注意力(包括与它自己)时:
    • 若key (j)不是unknown expression gene,则为0;
    • 若i = j时,且j是unknown expression gene,则为0(自己与自己的注意力计算)
    • 其它情况下,则为负无穷(对应图中的深蓝色单元格)
    • 每一行(i)表示一个query词元与该细胞所有词元的注意力计算。
image-20240928160244274
  • a = 0表示不产生任何影响,a = -inf表示将query(i)对于key(j)的注意力置换为-inf,经softmax转换后则变为0。

The rule of thumb for scGPT attention masking is to only allow attention computation between embeddings of the ‘known genes’ and the query gene itself.

image-20240928160540771
  • 在一个cell的全部masked gene表达值预测的过程中,采用多轮(k)迭代预测的思路(如上图右,类比GPT的自回归训练)
    • 在第一轮中,对于所有masked token的预测,将其中1/k个high prediction confidence的token标注为known genes。
    • 然后在新一轮迭代中,重复上一步骤,直至预测出所有的masked token。
    • 值得注意的是,文章并没有说明prediction confidence是如何计算的。个人理解本质还是基于与真实表达值之间的误差。

In each generation iteration, scGPT predicts the gene expression values of a new set of genes, and these genes in turn become the ‘known genes’ in the next iteration for attention computation.

2.3 多批次与多模态表示

  • 在多批次、多组学合并下游任务中,需要额外的token Embedding以表示必要的批次和组学信息,供模型学习;
  • 但是scGPT在Transformer的pre-train过程中,并未加入相关信息。而是在foundation model的输出结果中再进行拼接操作,即模型在Pre-train后,Fine-tune前引入批次和模态信息,以希望在微调过程中学习到相关信息。
  • 这样做的主要原因是:如果批次和模态信息在输入阶段被引入,Transformer的自注意力机制可能会过度关注同一模态或批次内的特征,导致模型忽略跨模态或批次的重要关联。

This is to prevent the transformer from amplifying the attention within features of the same modalities while underestimating those of different modalities.

  • modality tokens (tm):表示词元对应的feature是Gene/Protein/Peak中的哪一种;
  • batch tokens (tb):表示token是否来自一个批次,通常对于细胞来说的。因此一个细胞的batch token都是一样的。

image-20240924205846107

  • 如下公式,表示同时存在多批次以及多组学的情况

image-20240924211913127

  • 如下公式,表示同种组学,存在多批次的情况

image-20240924212230354

3. Fine-tune微调

3.1 细胞类型注释

(1)Method

  • 对于Foundation model提取的每个细胞Cell Embedding (<cls>) 构建一个MLP分类器,用以预测细胞的类别,并使用交叉熵作为损失函数。

  • 首先使用一个标注细胞类型的数据集(reference set)进行Fine-tune,然后再使用一个Held-out数据集验证

  • 输入数据前的预处理:

    • common set between foundation model and the input data
    • Gene expression: normalization → log1 → bin
    • All gene as input (include zero expression)

(2)Result

  • 在Human pancreas (胰腺) dataset

    • 对于每种细胞的预测Precision都达到0.8以上 (预测为该细胞类型中,实际为该细胞类型的比例)
  • 以Human immune cells进行fine-tune,预测多发性硬化症(MS)的细胞类型

    • 平均准确率Accuracy可以达到0.85左右
  • 使用6种肿瘤细胞类型作为fine-tune,预测3种其它肿瘤的细胞类型

    • 在肿瘤微环境细胞类型预测同样表现良好
  • 与TOSICA、scBERT模型进行了比较,均表现出一定的优势

image-20240926100118451

此外,作者还提出了Reference mapping注释方法:在Zero/Fine-tune模型时,获得标签细胞类型的Cell Embedding。然后再计算出Query cells的Embedding。最后据此,计算出每个Query cell最相近的Reference cells,从而注释其细胞类型(KNN)。

3.2 基因扰动预测

(1)Method

  • 对于Foundation model提取的每个细胞的Gene Embedding,将一部分基因的表达值掩码,从而进行Fine-tune微调训练,并以MSE作为损失函数;
image-20240928201056193
  • 在Perturbation prediction task:
    • Input: Gene expression为Control cell expression, Condition token标注相应位置Gene是否被Perturbate
    • Output: Post-perturbation expression
  • 输入数据前的预处理:
    • only select HVGs for training
    • Log1p expression instead of binned values (output相同)

(2)Result

  • 分别基于三个单细胞CRISPR数据集,比较了scGPT与Linear、GEARS模型的性能表现
    • Adamson: 87 one-gene perturbations;
    • Replogle: 2823 one-gene perturbations;
    • Norman: 131 two-gene and 105 one-gene perturbations
  • 对于每个数据集,使用一部分的扰动数据进行fine-tune,再对其余unseen gene的扰动数据作为test
  • 评价指标为计算预测与真实的post-perturbation expression的相关性 (基于全部基因,或者是前20个影响最显著的基因)
    • 结果发现scGPT模型相比于其它两种模型,提高了5–20%

image-20240928203804558

(3)In silicon reverse perturbation prediction

  • 简单理解,根据perturbation expression的结果反向预测是最有可能哪个基因被扰动;
  • 参考作者在Github issue (https://github.com/bowang-lab/scGPT/issues/87)的解答:
    • 其Fine-tune步骤其实与上面Perturbation prediction task基本一致;
    • 使用数据集的一部分进行微调后,再预测所有相关基因扰动的表达结果;
    • 最后使用真实的Query perturbation上述的预测结果进行KNN关联分析。

image-20240928205214389

The model was fine-tuned in the same way as in the “forward” perturbation prediction. It used a subset of the dataset as we illustrated in Figure 3F. The reverse perturbation task utilized the model in a different way. To summarize, the result cell states of all possible perturbations were predicted by the fine-tuned model, and then an actual sequenced cell state can query all the predicted cell states in a nearest neighbor search manner, so that the retrieved neighbors indicate the possible origin perturbations.

3.3 多批次/组学整合

(1)Method(Multi-batch)

  • 核心目标:对于多批次的scRNA-seq数据,优化不同批次中每个细胞的cell representation

    • correct batch effects while preserving biological variance
  • 输入数据的预处理

    • Common set between foundation model and the input data

    • Gene expression: normalization–log1–bin

    • All genes as input (include zero expression)

    • 如2.3中所述,在微调前,还需要在Foundation model的输出结果中,补充细胞的批次信息。

  • 具体在Fine-tune训练中,设计了对多个目标函数进行损失计算,以共同用于模型优化

1)GEP:基因表达预测,参考3.2;

2)GEPC, Gene expression prediction for cell modeling,即为了优化Cell Embedding(hc),而预测基因表达。

image-20240928212510674

3)ECS, Elastic cell similarity:基于一个mini-batch的两个细胞的相似度计算损失函数。使得高于某个阈值的两个细胞更相似,低于的则更远离。

image-20240928213604728

4)Domain adaptation via reverse back propagation:主要作用是使得生成的cell representation表示经过MLP分类器无法预测其正确的batch信息(类似于GAN 对抗神经网络)。本质上还是建立MLP预测Batch的分类器,但是计算梯度后,进行反向更新参数。

(2)Result

  • 基于三个数据集,比较了scGPT与scVI,Seurat,Harmony三个工具去批次的效果

    • COVID-19: 18 batches

    • PBMC-10K: 2 batches

    • Perirhinal cortex: 2 batches

  • 评价指标:基于三个指标(NMI/ARI/ASW)的AvgBIO,值越高表明去批次效果越好。

    • scGPT均优于其它三种工具的效果。

image-20240926125712463

(3)Multi-omics integration

Method

  • 目标是使得相同细胞类型的不同组学数据能够有相似的Cell Embedding,从而在聚类时比较接近。
  • scGPT主要考虑了3种组学,分别是scRNA-seq、scATAC-seq,Single-cell proteomics。对此,有两种数据形式:
    • paired setting:一群细胞同时测多种组学
    • mosaic setting:一群细胞测一种组学,另一群测另一组。
  • 对于scRNA-seq可以直接继承第2节训练的Foundation model,对于其它两组需要重头训练。
  • 在Pre-train之后,如2.3小节需要拼接表示多模态的token(如果涉及多批次,也要添加batch token)
  • Fine-tune 目标函数包括GEP,GEPC(如果涉及多批次,还需要添加DAR)

Result

  • 对于Multiome PBMC (paired)数据集,与scGLUE,Seurat进行了比较,如下图所示

    • RNA-seq、ATAC
  • 对于Bone marrow mononuclear cells (paired)数据集,与Seurat进行了比较

    • RNA-seq、Protein
    • 9w个细胞,12个donor(multiple batches),48种细胞类型
  • 对于ASAP human PBMC (mosaic)数据集,与scMoMat进行了比较

    • 4个批次,3种组学

image-20240926132021275

3.4 基因调控网络构建

(1)Method

  • 本质上是基于Foundation model或者Fine-tune model提取的Gene Embedding计算两两基因间的相关性,用来构建gene similarity network。

    • Zero-shot setting模式下,直接使用Foundation model输出的Gene Embedding计算;

    • Fine-tuned setting模式下,使用特定数据集微调后输出的Gene Embedding计算

      使用integration task进行微调,以学习特定数据集相关的Gene Embedding (下同)

  • 基于注意力机制的target gene鉴定:Attention map可以反映出基因之间的相互影响。其中,每一列(column),表示这个gene(列名)对所有query gene的influence。

  • 使用perturbation datasets可以帮助推测perturbating gene的target,如下图所示

    • 首先,基于control cell得到的attention map;
    • 然后,基于perturbation expression得到的attention map;
    • 计算上述二者对的差值,可以推测特定基因干扰前后,影响最大的gene

image-20240926134630079

(2)Result

  • 文章首先对zero-shot模式进行了探索,然后在对一个human immune dataset进行了Fine-tune后构建GRN,均发现了具有生物学意义的调控网络以及Gene cluster (Leiden),并进行了通路富集分析等。

image-20240929081715151

  • 接下来,作者基于Pre-train scGPT blood model,使用了Adamson CRISPR数据集(87 CRISPR inference on leukemia cells)进行了微调。
  • 例如下图,通过比较DDIT3扰动前后的difference attention score, 鉴定了其影响最大的的Top20/100个基因。

image-20240929082629291

4. 模型影响因素

  • 综上,scGPT Foundation model在处理下游分析任务时,相比于其它已有的单细胞工具均表现出明显的优势;
  • 最后,文章讨论了两个可能影响Pre-train model性能表现的因素。

4.1 训练样本量

  • 首先,文章研究了训练样本量对于预训练模型的影响;
  • 如下图,结果发现数据量最多时(从30K到33M),模型在多种下游任务中的表现越好;
  • As larger and more diverse datasets become available, we can anticipate further improvements in model performance, advancing our understanding of cellular processes.

image-20240929085432250

4.2 训练样本类型

  • 文章进一步研究了细胞来源对于模型的影响;
  • 在COVID-19数据集的多批次合并任务中,结果发现:
    • 使用全部33M数据与仅使用Blood数据的Foundation model的性能比较接近;
    • Lung来源的细胞量尽管只有2.1M,远少于Brain细胞量(13.2M),但前者性能明显表现更优。
  • This emphasizes the importance of aligning the cellular context in pretraining with the target dataset for superior results in downstream tasks.

image-20240929085123045

后续将参考scGPT工具的教程手册及源代码,学习其构建细节与数据处理方式等。