1. 示例模型

  • 两层MLP的神经网络
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Example usage
input_size = 256   # Number of input features
hidden_size = 128   # Number of neurons in the hidden layer
output_size = 2   # Number of output classes

model = SimpleMLP(input_size, hidden_size, output_size)

2. 组成模块查询

  • 通过递归的方式遍历模型的所有层,包括嵌套在其他层内的子模块
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# torch.nn.Module类
model.modules

for module in model.modules():
    print(f"Module: {module}")

for name, module in model.named_modules():
    print(f"{name}: {module}")
# fc1: Linear(in_features=256, out_features=128, bias=True)
# fc2: Linear(in_features=128, out_features=2, bias=True)

3. 模型参数查询

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# torch.nn.Parameter类
for param in model.parameters():
    print(f"{param.shape}")

for name, param in model.named_parameters():
    print(f"{name}: {param.shape}")
# fc1.weight: torch.Size([128, 256])
# fc1.bias: torch.Size([128])
# fc2.weight: torch.Size([2, 128])
# fc2.bias: torch.Size([2])

for name, param in model.fc1.named_parameters():
    print(f"{name}: {param.shape}")
# weight: torch.Size([128, 256])
# bias: torch.Size([128])


# 模型总参数量
total_parameters = sum(p.numel() for p in model.parameters())

# 查看具体某一层的参数
param = next(iter(model.fc1.parameters()))

type(param)
# torch.nn.parameter.Parameter

param.shape
# torch.Size([128, 256])
param.numel()
# 32768
param.requires_grad
# True

# 参数冻结,即不更新该module参数
param.requires_grad=False

4. 模型(参数)保存与加载

1
2
3
4
5
6
7
type(model.state_dict())
# save
torch.save(model.state_dict(), 'model.pth')  # pt后缀也可

# load
model.load_state_dict(torch.load('model.pth'))
pretrained_params = torch.load(model_pt, map_location='cuda:2') 
  • 一个实际加载的示例函数
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# 
def load_pretrained(
    model: torch.nn.Module,
    pretrained_params: dict = None,
    strict: bool = False,
    prefix: list = None,
    use_flash_attn = True,
    verbose: bool = True,
) -> torch.nn.Module:

    # 修改特定参数的key name
    if not use_flash_attn:
        pretrained_params = {
            k.replace("Wqkv.", "in_proj_"): v for k, v in pretrained_params.items()
        }

    # 只加载特定keys的参数
    if prefix is not None and len(prefix) > 0:
        if isinstance(prefix, str):
            prefix = [prefix]
        pretrained_params = {
            k: v
            for k, v in pretrained_params.items()
            if any(k.startswith(p) for p in prefix)
        }

    model_dict = model.state_dict()
    # 严格加载:全部参数需要匹配
    if strict:
        if verbose:
            for k, v in pretrained_params.items():
                print(f"Loading parameter {k} with shape {v.shape}")
        model_dict.update(pretrained_params)
        model.load_state_dict(model_dict)
    # 部分加载:只加载部分能够匹配的参数(key name以及 value shape)
    else:
        if verbose:
            for k, v in pretrained_params.items():
                if k in model_dict and v.shape == model_dict[k].shape:
                    print(f"Loading parameter {k} with shape {v.shape}")
        pretrained_params = {
            k: v
            for k, v in pretrained_params.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        model_dict.update(pretrained_params)
        model.load_state_dict(model_dict)

    return model