-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathdownstream_task.py
82 lines (61 loc) · 3.3 KB
/
downstream_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
from prompt_graph.tasker import NodeTask, GraphTask
from prompt_graph.utils import seed_everything
from torchsummary import summary
from prompt_graph.utils import print_model_parameters
from prompt_graph.utils import get_args
from prompt_graph.data import load4node,load4graph, split_induced_graphs
import pickle
import random
import numpy as np
import os
import pandas as pd
def load_induced_graph(dataset_name, data, device):
folder_path = './Experiment/induced_graph/' + dataset_name
if not os.path.exists(folder_path):
os.makedirs(folder_path)
file_path = folder_path + '/induced_graph_min100_max300.pkl'
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
print('loading induced graph...')
graphs_list = pickle.load(f)
print('Done!!!')
else:
print('Begin split_induced_graphs.')
split_induced_graphs(data, folder_path, device, smallest_size=100, largest_size=300)
with open(file_path, 'rb') as f:
graphs_list = pickle.load(f)
graphs_list = [graph.to(device) for graph in graphs_list]
return graphs_list
def get_downstream_task_delegate(args:argparse.Namespace):
seed_everything(args.seed)
if args.downstream_task == 'NodeTask':
data, input_dim, output_dim = load4node(args.dataset_name)
data = data.to(args.device)
if args.prompt_type in ['Gprompt', 'All-in-one', 'GPF', 'GPF-plus']:
graphs_list = load_induced_graph(args.dataset_name, data, args.device)
else:
graphs_list = None
tasker = NodeTask(pre_train_model_path = args.pre_train_model_path,
dataset_name = args.dataset_name, num_layer = args.num_layer,
gnn_type = args.gnn_type, hid_dim = args.hid_dim, prompt_type = args.prompt_type,
epochs = args.epochs, shot_num = args.shot_num, device=args.device, lr = args.lr, wd = args.decay,
batch_size = args.batch_size, data = data, input_dim = input_dim, output_dim = output_dim, graphs_list = graphs_list)
elif args.downstream_task == 'GraphTask':
input_dim, output_dim, dataset = load4graph(args.dataset_name)
tasker = GraphTask(pre_train_model_path = args.pre_train_model_path,
dataset_name = args.dataset_name, num_layer = args.num_layer, gnn_type = args.gnn_type, hid_dim = args.hid_dim, prompt_type = args.prompt_type, epochs = args.epochs,
shot_num = args.shot_num, device=args.device, lr = args.lr, wd = args.decay,
batch_size = args.batch_size, dataset = dataset, input_dim = input_dim, output_dim = output_dim)
else:
raise ValueError(f"Unexpected args.downstream_task type {args.downstream_task}.")
return tasker
if __name__ == "__main__":
args = get_args()
print('dataset_name', args.dataset_name)
tasker = get_downstream_task_delegate(args=args)
_, test_acc, std_test_acc, f1, std_f1, roc, std_roc, _, _= tasker.run()
print("Final Accuracy {:.4f}±{:.4f}(std)".format(test_acc, std_test_acc))
print("Final F1 {:.4f}±{:.4f}(std)".format(f1,std_f1))
print("Final AUROC {:.4f}±{:.4f}(std)".format(roc, std_roc))
pre_train_type = tasker.pre_train_type