1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
|
def train_batch(model, graph, feature_name, label_name, train_mask, test_mask,
num_epochs, learning_rate, weight_decay, patience, batch_size, verbose=True):
optimizer = torch.optim.Adam(model.parameters(),
lr=learning_rate,
weight_decay=weight_decay)
node_feats = graph.ndata[feature_name]
val_loss_best = 100000
trigger_times = -1
E = int(graph.number_of_edges()/2)
reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
sampler = dgl.dataloading.as_edge_prediction_sampler(
sampler, exclude='reverse_id', reverse_eids=reverse_eids)
train_eid = graph.edges(form="all")[2][train_mask]
test_eid = graph.edges(form="all")[2][test_mask]
train_dataloader = dgl.dataloading.DataLoader(
graph, train_eid, sampler,
batch_size=batch_size,
shuffle=True)
all_train_dataloader = dgl.dataloading.DataLoader(
graph, train_eid, sampler,
batch_size=len(train_eid), # 全部训练集作为一个batch,计算每个epoch的train loss
shuffle=False)
all_test_dataloader = dgl.dataloading.DataLoader(
graph, test_eid, sampler,
batch_size=len(test_eid), # 全部测试集作为一个batch,计算每个epoch的test loss
shuffle=False)
for epoch in range(num_epochs):
model.train()
for it, (input_nodes, pair_graph, blocks) in enumerate(train_dataloader):
predict_labels = model(blocks, node_feats[input_nodes], pair_graph)
edge_labels = pair_graph.edata["label"]
loss = F.mse_loss(predict_labels, edge_labels.reshape((-1,1)))
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
input_nodes, pair_graph, blocks = next(iter(all_train_dataloader))
predict_labels = model(blocks, node_feats[input_nodes], pair_graph)
edge_labels = pair_graph.edata["label"]
epoch_loss = F.mse_loss(predict_labels, edge_labels.reshape((-1,1)))
model.eval()
input_nodes, pair_graph, blocks = next(iter(all_test_dataloader))
predict_labels = model(blocks, node_feats[input_nodes], pair_graph)
edge_labels = pair_graph.edata["label"]
test_loss = F.mse_loss(predict_labels, edge_labels.reshape((-1,1)))
if verbose :
print("Epoch {:03d} | Loss {:.4f} | Test Loss {:.4f} ".format(
epoch, loss.item(), test_loss.item()))
if test_loss.item() > val_loss_best:
trigger_times += 1
if trigger_times >= patience:
break
else:
trigger_times = 0
val_loss_best = test_loss.item()
return epoch_loss.item(), test_loss.item()
|