torch模型组成模块参数查询、管理、保存
1. 示例模型 两层MLP的神经网络 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 import torch import torch.nn as nn import torch.nn.functional as F class SimpleMLP(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleMLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, output_size) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return x # Example usage input_size = 256 # Number of input features hidden_size = 128 # Number of neurons in the hidden layer output_size = 2 # Number of output classes model = SimpleMLP(input_size, hidden_size, output_size) 2. 组成模块查询 通过递归的方式遍历模型的所有层,包括嵌套在其他层内的子模块 1 2 3 4 5 6 7 8 9 10 # torch.nn.Module类 model.modules for module in model.modules(): print(f"Module: {module}") for name, module in model.named_modules(): print(f"{name}: {module}") # fc1: Linear(in_features=256, out_features=128, bias=True) # fc2: Linear(in_features=128, out_features=2, bias=True) 3. 模型参数查询 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 # torch.nn.Parameter类 for param in model.parameters(): print(f"{param.shape}") for name, param in model.named_parameters(): print(f"{name}: {param.shape}") # fc1.weight: torch.Size([128, 256]) # fc1.bias: torch.Size([128]) # fc2.weight: torch.Size([2, 128]) # fc2.bias: torch.Size([2]) for name, param in model.fc1.named_parameters(): print(f"{name}: {param.shape}") # weight: torch.Size([128, 256]) # bias: torch.Size([128]) # 模型总参数量 total_parameters = sum(p.numel() for p in model.parameters()) # 查看具体某一层的参数 param = next(iter(model.fc1.parameters())) type(param) # torch.nn.parameter.Parameter param.shape # torch.Size([128, 256]) param.numel() # 32768 param.requires_grad # True # 参数冻结,即不更新该module参数 param.requires_grad=False 4. 模型(参数)保存与加载 1 2 3 4 5 6 7 type(model.state_dict()) # save torch.save(model.state_dict(), 'model.pth') # pt后缀也可 # load model.load_state_dict(torch.load('model.pth')) pretrained_params = torch.load(model_pt, map_location='cuda:2') 一个实际加载的示例函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 # def load_pretrained( model: torch.nn.Module, pretrained_params: dict = None, strict: bool = False, prefix: list = None, use_flash_attn = True, verbose: bool = True, ) -> torch.nn.Module: # 修改特定参数的key name if not use_flash_attn: pretrained_params = { k.replace("Wqkv.", "in_proj_"): v for k, v in pretrained_params.items() } # 只加载特定keys的参数 if prefix is not None and len(prefix) > 0: if isinstance(prefix, str): prefix = [prefix] pretrained_params = { k: v for k, v in pretrained_params.items() if any(k.startswith(p) for p in prefix) } model_dict = model.state_dict() # 严格加载:全部参数需要匹配 if strict: if verbose: for k, v in pretrained_params.items(): print(f"Loading parameter {k} with shape {v.shape}") model_dict.update(pretrained_params) model.load_state_dict(model_dict) # 部分加载:只加载部分能够匹配的参数(key name以及 value shape) else: if verbose: for k, v in pretrained_params.items(): if k in model_dict and v.shape == model_dict[k].shape: print(f"Loading parameter {k} with shape {v.shape}") pretrained_params = { k: v for k, v in pretrained_params.items() if k in model_dict and v.shape == model_dict[k].shape } model_dict.update(pretrained_params) model.load_state_dict(model_dict) return model