Li's Bioinfo-Blog
  • |
  • 主页
  • 分类
  • 标签
  • 归档
  • 关于
  • 搜索
Home » 分类

📖 生信数据分析--分析流程,工具包等

torch模型组成模块参数查询、管理、保存

1. 示例模型 两层MLP的神经网络 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 import torch import torch.nn as nn import torch.nn.functional as F class SimpleMLP(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleMLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, output_size) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return x # Example usage input_size = 256 # Number of input features hidden_size = 128 # Number of neurons in the hidden layer output_size = 2 # Number of output classes model = SimpleMLP(input_size, hidden_size, output_size) 2. 组成模块查询 通过递归的方式遍历模型的所有层,包括嵌套在其他层内的子模块 1 2 3 4 5 6 7 8 9 10 # torch.nn.Module类 model.modules for module in model.modules(): print(f"Module: {module}") for name, module in model.named_modules(): print(f"{name}: {module}") # fc1: Linear(in_features=256, out_features=128, bias=True) # fc2: Linear(in_features=128, out_features=2, bias=True) 3. 模型参数查询 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 # torch.nn.Parameter类 for param in model.parameters(): print(f"{param.shape}") for name, param in model.named_parameters(): print(f"{name}: {param.shape}") # fc1.weight: torch.Size([128, 256]) # fc1.bias: torch.Size([128]) # fc2.weight: torch.Size([2, 128]) # fc2.bias: torch.Size([2]) for name, param in model.fc1.named_parameters(): print(f"{name}: {param.shape}") # weight: torch.Size([128, 256]) # bias: torch.Size([128]) # 模型总参数量 total_parameters = sum(p.numel() for p in model.parameters()) # 查看具体某一层的参数 param = next(iter(model.fc1.parameters())) type(param) # torch.nn.parameter.Parameter param.shape # torch.Size([128, 256]) param.numel() # 32768 param.requires_grad # True # 参数冻结,即不更新该module参数 param.requires_grad=False 4. 模型(参数)保存与加载 1 2 3 4 5 6 7 type(model.state_dict()) # save torch.save(model.state_dict(), 'model.pth') # pt后缀也可 # load model.load_state_dict(torch.load('model.pth')) pretrained_params = torch.load(model_pt, map_location='cuda:2') 一个实际加载的示例函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 # def load_pretrained( model: torch.nn.Module, pretrained_params: dict = None, strict: bool = False, prefix: list = None, use_flash_attn = True, verbose: bool = True, ) -> torch.nn.Module: # 修改特定参数的key name if not use_flash_attn: pretrained_params = { k.replace("Wqkv.", "in_proj_"): v for k, v in pretrained_params.items() } # 只加载特定keys的参数 if prefix is not None and len(prefix) > 0: if isinstance(prefix, str): prefix = [prefix] pretrained_params = { k: v for k, v in pretrained_params.items() if any(k.startswith(p) for p in prefix) } model_dict = model.state_dict() # 严格加载:全部参数需要匹配 if strict: if verbose: for k, v in pretrained_params.items(): print(f"Loading parameter {k} with shape {v.shape}") model_dict.update(pretrained_params) model.load_state_dict(model_dict) # 部分加载:只加载部分能够匹配的参数(key name以及 value shape) else: if verbose: for k, v in pretrained_params.items(): if k in model_dict and v.shape == model_dict[k].shape: print(f"Loading parameter {k} with shape {v.shape}") pretrained_params = { k: v for k, v in pretrained_params.items() if k in model_dict and v.shape == model_dict[k].shape } model_dict.update(pretrained_params) model.load_state_dict(model_dict) return model

Create:&nbsp;<span title='2024-12-14 00:00:00 +0000 UTC'>2024-12-14</span>&nbsp;|&nbsp;Update:&nbsp;2024-12-14&nbsp;|&nbsp;Words:&nbsp;670&nbsp;|&nbsp;2 min&nbsp;|&nbsp;Lishensuo

Hugging face(4) Bert模型及Collator与Trainer

1. Collator数据处理 目的:将dataset的初始数据进行规范化批量处理,用以后续的前向计算 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 # Start from dataset (Sequences could have diff lengths) Dataset({ features: ['input_ids'], num_rows: 5 }) # End to encoded batch input (BatchEncoding格式) {'input_ids': tensor([[350, 241, 345, 705, 695, 1, 427, 645, 99, 943, 0, 0, 0, 0], [196, 464, 546, 626, 413, 1, 973, 98, 824, 1, 410, 0, 0, 0], [475, 665, 1, 164, 306, 788, 53, 562, 232, 216, 252, 990, 0, 0], [ 1, 966, 734, 897, 171, 357, 217, 850, 529, 895, 728, 234, 799, 0], [713, 76, 1, 428, 913, 890, 143, 992, 832, 963, 555, 18, 354, 455]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[-100, -100, -100, -100, -100, 716, -100, -100, -100, -100, -100, -100, -100, -100], [-100, -100, -100, -100, -100, 665, -100, -100, -100, 686, -100, -100, -100, -100], [-100, -100, 56, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [ 619, 966, -100, -100, -100, 357, -100, -100, -100, -100, -100, -100, -100, -100], [-100, -100, 218, -100, -100, -100, -100, -100, -100, 963, -100, -100, -100, -100]])} 常见的关键字段包括: ...

Create:&nbsp;<span title='2025-01-24 00:00:00 +0000 UTC'>2025-01-24</span>&nbsp;|&nbsp;Update:&nbsp;2025-01-24&nbsp;|&nbsp;Words:&nbsp;3796&nbsp;|&nbsp;8 min&nbsp;|&nbsp;Lishensuo

Faiss向量数据库相似搜索

Faiss (Facebook AI Similarity Search) 是一个用于高效相似向量搜索的库,特别适合处理大规模向量数据。 它的核心功能是给定一个向量数据库,为当前的query vector(s) 寻找Top-K个相似向量。 此外也支持KNN聚类,PCA降维,数据量化等模块。 官方文档:https://github.com/facebookresearch/faiss/wiki 1 2 3 4 5 # 安装方式: 一个环境不能同时安装cpu与gpu两个版本 pip install faiss-cpu pip install faiss-gpu # 下面以cpu版本为例 1. 相似Top-K查找 Faiss通过构建一个Index(索引)对象,用于建立可实现快速相似搜索的数据结构。 ...

Create:&nbsp;<span title='2025-04-18 00:00:00 +0000 UTC'>2025-04-18</span>&nbsp;|&nbsp;Update:&nbsp;2025-04-18&nbsp;|&nbsp;Words:&nbsp;1420&nbsp;|&nbsp;3 min&nbsp;|&nbsp;Lishensuo

图神经网络DGL-01-DGL基础

homogeneous graph : 一种节点,一种边 heterogeneous graph: 多种节点,多种边 bipartite graph : 两种节点,一种边 1、DGLgraph结构 1.1 同构图 Cora论文关系图–同构图(homogeneous) ...

Create:&nbsp;<span title='2022-07-31 00:00:00 +0000 UTC'>2022-07-31</span>&nbsp;|&nbsp;Update:&nbsp;2022-07-31&nbsp;|&nbsp;Words:&nbsp;1335&nbsp;|&nbsp;3 min&nbsp;|&nbsp;Lishensuo

图神经网络DGL-02同构图_节点分类

1 2 3 4 5 6 7 8 import dgl import dgl.nn as dglnn import torch import torch.nn as nn import torch.nn.functional as F import numpy as np 0、预测任务与数据 预测论文属于哪一种类别,即为多分类问题 ...

Create:&nbsp;<span title='2022-08-28 00:00:00 +0000 UTC'>2022-08-28</span>&nbsp;|&nbsp;Update:&nbsp;2022-08-28&nbsp;|&nbsp;Words:&nbsp;1690&nbsp;|&nbsp;4 min&nbsp;|&nbsp;Lishensuo

图神经网络DGL-03同构图_边回归

(1)边回归问题,对图中已存在的边的定量(回归)/定性(分类),结合训练得到的边的embedding做GNN神经网络预测。 (2)边的embedding通常由两端节点计算而得,常采用点积或者拼接的方式。而节点embedding的更新同前。 ...

Create:&nbsp;<span title='2022-08-28 00:00:00 +0000 UTC'>2022-08-28</span>&nbsp;|&nbsp;Update:&nbsp;2022-08-28&nbsp;|&nbsp;Words:&nbsp;2076&nbsp;|&nbsp;5 min&nbsp;|&nbsp;Lishensuo

图神经网络DGL-03同构图_边预测

边的预测问题即预测两个节点间是否可能存在边,可以视为二分类问题。 在训练时,首先同样要先更新节点信息,然后计算边的特征— 将图中已知存在的边作为阳性边,标签为1 随机抽取图中不存在边的两节点组成的边作为阴性边,标签为0 0、预测数据与任务 假设100个药物两两之间已知存在1000个相互作用,药物节点有50个特征。 ...

Create:&nbsp;<span title='2022-08-28 00:00:00 +0000 UTC'>2022-08-28</span>&nbsp;|&nbsp;Update:&nbsp;2022-08-28&nbsp;|&nbsp;Words:&nbsp;2082&nbsp;|&nbsp;5 min&nbsp;|&nbsp;Lishensuo
« Prev Page
© 2025 Li's Bioinfo-Blog Powered by Hugo & PaperMod
您是本站第 位访问者,总浏览量为 次