1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
|
def train_batch(model, graph, feature_name, label_name, train_mask, test_mask,
num_epochs, learning_rate, weight_decay, patience,batch_size, verbose=True):
#与之前的区别之一是多了一个batch_size参数
optimizer = torch.optim.Adam(model.parameters(),
lr=learning_rate,
weight_decay=weight_decay)
val_loss_best = 100000
trigger_times = -1
node_features = graph.ndata[feature_name]
node_labels = graph.ndata[label_name]
node_ids = graph.nodes()
# 根据后续需要,设置不同的节点采样器
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
train_dataloader = dgl.dataloading.DataLoader(
graph, node_ids[train_mask], sampler,
batch_size=batch_size,
shuffle=True)
all_train_dataloader = dgl.dataloading.DataLoader(
graph, node_ids[train_mask], sampler,
batch_size=len(node_ids[train_mask]),
shuffle=False)
all_test_dataloader = dgl.dataloading.DataLoader(
graph, node_ids[test_mask], sampler,
batch_size=len(node_ids[test_mask]),
shuffle=False)
for epoch in range(num_epochs):
model.train()
for it, (_, _, blocks) in enumerate(train_dataloader):
x = blocks[0].srcdata[feature_name] #第一层
y = blocks[-1].dstdata[label_name] #最后一层
y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
_, _, blocks = next(iter(all_train_dataloader))
x = blocks[0].srcdata[feature_name] #第一层
y = blocks[-1].dstdata[label_name] #最后一层
y_hat = model(blocks, x)
epoch_loss = F.cross_entropy(y_hat, y)
train_acc = evaluate(y_hat, y)
_, _, blocks = next(iter(all_test_dataloader))
x = blocks[0].srcdata[feature_name] #第一层
y = blocks[-1].dstdata[label_name] #最后一层
y_hat = model(blocks, x)
test_loss = F.cross_entropy(y_hat, y)
test_acc = evaluate(y_hat, y)
if verbose:
print("Epoch {:03d} | Loss {:.4f} | Train Acc {:.4f} | Test Loss {:.4f} | Test Acc {:.4f} ".format(
epoch, loss.item(),train_acc,test_loss.item(),test_acc))
if test_loss.item() > val_loss_best:
trigger_times += 1
if trigger_times >= patience:
break
else:
trigger_times = 0
val_loss_best = test_loss.item()
return loss.item(), val_loss_best, train_acc, test_acc
|