1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
|
def train_batch(model, graph,
feature_name, train_mask, test_mask,
num_epochs, learning_rate, weight_decay, patience, batch_size, verbose=True):
optimizer = torch.optim.Adam(model.parameters(),
lr=learning_rate,
weight_decay=weight_decay)
node_feats = graph.ndata[feature_name]
val_loss_best = 100000
trigger_times = -1
train_eids = graph.edges(form="all")[2][train_mask]
test_eids = graph.edges(form="all")[2][test_mask]
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
sampler = dgl.dataloading.as_edge_prediction_sampler(
sampler, negative_sampler=dgl.dataloading.negative_sampler.Uniform(1))
train_dataloader = dgl.dataloading.DataLoader(
graph, train_eids, sampler,
batch_size=batch_size,
shuffle=True)
all_train_dataloader = dgl.dataloading.DataLoader(
graph, train_eids, sampler,
batch_size=len(train_eids),
shuffle=False)
all_test_dataloader = dgl.dataloading.DataLoader(
graph, test_eids, sampler,
batch_size=len(test_eids),
shuffle=False)
for epoch in range(num_epochs):
for it, (input_nodes, positive_graph, negative_graph, blocks) in enumerate(train_dataloader):
loss, auc = loss_evaluate_batch(model, blocks, positive_graph, negative_graph, input_nodes, node_feats, mode="train")
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 该epoch的训练集loss与auc
input_nodes, positive_graph, negative_graph, blocks = next(iter(all_train_dataloader))
loss, auc = loss_evaluate_batch(model, blocks, positive_graph, negative_graph, input_nodes, node_feats, mode="eval")
# 该epoch的测试集loss与auc
input_nodes, positive_graph, negative_graph, blocks = next(iter(all_test_dataloader))
test_loss, test_auc = loss_evaluate_batch(model, blocks, positive_graph, negative_graph, input_nodes, node_feats, mode="eval")
if verbose :
print("Epoch {:03d} | Loss {:.4f} | Auc {:.4f} | Test Loss {:.4f} | Test Auc {:.4f} ".format(
epoch, loss.item(), auc, test_loss.item(), test_auc))
if test_loss.item() > val_loss_best:
trigger_times += 1
if trigger_times >= patience:
break
else:
trigger_times = 0
val_loss_best = test_loss.item()
return loss.item(), auc, test_loss.item(), test_auc
|