通讯作者:Christina Theodoris (2021 Geneformer的一作,现在应该是独立的PI了)

https://gladstone.org/people/christina-theodoris

image-20250118123335727

背景知识

算法角度

  • Multi-task learning 多任务学习
    • 设置多个任务,每个任务会计算一个独立的损失函数,总损失函数则是所有任务损失的加权和。
    • 训练过程中,共享预训练模型参数,微调模型参数各不相同。
    • multi-task learning enables context-specific disease modeling that can yield contextual predictions of candidate therapeutic targets for human disease.
  • Continual learning 持续学习

性能角度- QLoRA: Quantization | Low-Rank Adapters

1. 第一次预训练

  • 数据量:~103M human scRNA-seq

    • 经质控筛选(e.g. 去除肿瘤细胞)后,保留95M
    • 词汇表:20275 (1) 20271 protein coding genes; (2) 4 special tokens (PAD, MASK, CLS, EOS )
  • 模型框架(gf-12L-95M-i4096)

    • 最大序列长度:4096,此外在有效序列的前后增加了CLS与EOS
    • 注意力层:12 | Embedding size: 512 | Heads: 8
    • Masked ratio: 15% | 最大学习率:5e-4 (warmup steps 5000) | Batch: 1 (梯度累计:4)
    • 44 hours, 8 H100 GPUs
    • 还有4个版本:GF-6L-30M-I2048, GF-12L-30M-I2048, GF-12L-95M-I2048, GF-20L95M-I4096
  • Zero-shot性能:基于上述预训练模型获得cell/gene的zero-shot Embedding(不更新模型参数),相比Geneformer v1在多个微调任务中表现具有优势。

    • 值得一提的是:这里的cell embedding是来自倒数第二层的CLS Embedding,后面的MTL以及ISP,都是来自最后一层的CLS Embedding
    • 在Geneformer V1版本中, Cell Embedding是所有基因Embedding的均值。
    • 微调性能结果是基于25次超参数寻参得到的(例如不同的学习率等)

image-20250118090649761

  • 模型量化:使用量化模型(e.g. 4-bit, rank16)可以使得微调任务性能影响不大的情况下,显著降低了成本(时间,内存)。
    • full fine-tuning:超参数包括被冻结的层数,top 0/7/14 layers
    • quantized fine-tuning:同上
    • zero-shot (指上面的结果)

image-20250118100823168

2. 多任务微调学习

  • 背景:疾病的发生过程可能与多个角度相关(例如cell type, tissue, developmental stage..)。基于单任务学习(例如Geneformer v1基于细胞类型)得到的cell/gene表示,缺少全面性,尤其是应用在ISP
  • 数据量:~43M annotated cells from CELLxGENE,
    • 71 cell types,38 tissues,68 diseases, 3 developmental stages
  • MTL设计:5个细胞分类微调任务的损失之和作为总损失更新模型
    • cell types/tissue/disease types/disease vs normal/developmental stage
    • 在每个任务的类别存在不平衡时,采用多种采样策略【详见method】
    • 最终发现在5个任务中均表现良好,此时得到的cell Embedding更具有代表性(用于后续的ISP)

image-20250118102441154

  • 基于微调后的MTL模型,采用量化技术得到的cell embedding信息没有丢失,同时推理成本较低。
    • 在ISP方面,量化模型的基因模拟敲除所发生的shift趋势相同,同样成本降低

3. 第二次预训练及对应微调

  • Domain-specific continual learning:肿瘤单细胞特征
  • 数据量: ~14M cells from cancer studies
    • 包括1%的第一阶段训练数据,以避免 “catastrophic forgetting”
  • 训练方式与第一阶段相同,采用15% masked gene的自监督学习。其中包括三种学习率的设置方式。
    • 如下图b,左侧绿线表示第一次预训练的学习率变;而右侧三条线便是3种学习率设置方案的continual learning
    • 结合图c,发现其中最优的一种方案(是与第一次的预训练最大学习率相同)
image-20250118105722186
  • 接下来继续对上述预训练模型,设计一个MTL细胞分类微调任务:
    • 在结直肠癌单细胞数据下的3个任务:MMR status/cell type/cell subtype
    • 发现在cell subtype存在上皮细胞的亚型注释时,二次预训练模型表现更好

4. 基于TME的ISP实验

基于上述的微调模型,进行了两次结直肠癌相关的ISP实验

TASK-1:

  • In silico overexpression
  • from epithelial cancer cells to normal epithelial (下图的e-f)

TASK-2:

  • In silico overexpression
  • from quiescent T cell to activated T cell state
  • 与已报道的CRISPR结果相一致

image-20250118110850234