diff --git a/README.md b/README.md index 9a918331ebb1b8e42431a9ecbb472ba2eee442f1..c55d3b85cf6ffeda14b582f84c36377be1557bf8 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,41 @@ for i in range(adult_dataset.get_m()): print('WGDRO', i, loss) ``` -## 课题二:滑动窗口上的最优矩阵略图算子 +## 课题二 + +### 2.1 高吞吐率图流的动态图神经网络 + +该算法主要包含`数据集`,`DSFD算子实现`,`性能测试`三部分,相关代码参见目录: +- StreamLearn/Dataset/DTDGsDataset.py +- StreamLearn/Algorithm/DGNN.py +- StreamLearn/tests/test_DecoupledDGNN.py + +算子应用实例见 StreamLearn/tests/test_DecoupledDGNN.py 文件,使用方法为: +```bash +python StreamLearn/tests/test_DecoupledDGNN.py --data NAME_OF_DATA --checkpoint PATH_TO_CHECKPOINT +``` + +首先,获取算法所需数据集:Dataset文件夹中包含了数据下载和预处理的代码,用户可以根据需要修改其中的数据存储路径。 +```python +from StreamLearn.Dataset.DTDGsDataset import DTDGsDataset +dataset = DTDGsDataset(args.data) +``` + +其次,按照以下方式获取动态图节点的时序表示: +```python +from StreamLearn.Algorithm.DecoupledDGNN.graph_embs import GraphEmbs +gen_embs = GraphEmbs(dataset.path, args.data, args.rmax, args.alpha) +gen_embs.load_and_process_data +``` + +最后,调用DGNN算法进行训练并测试: +```python +from StreamLearn.Algorithm.DecoupledDGNN.DGNN import DGNN +execute = DGNN(args) +execute.train() +``` + +### 2.2 滑动窗口上的最优矩阵略图算子 该算子主要包含`数据集`,`DSFD算子实现`,`性能测试`三部分,相关代码参见目录: - StreamLearn/Dataset/FDDataset.py @@ -134,6 +168,10 @@ swfd.fit(data) predict = swfd.predict(direction) ``` +### 2.3 基于采样的分布式环境下的元素估计算法 + +实现了基于采样的、分布式流数据环境下、亚线性通信量的元素估计(NDV)算法。详见[https://gitee.com/yinhanyan/ndv_-estimation_in_distributed_environment](https://gitee.com/yinhanyan/ndv_-estimation_in_distributed_environment)。 + ## 课题三 ### 3.1 流数据分布自适应学习算法 diff --git a/StreamLearn/Algorithm/DecoupledDGNN/DGNN.py b/StreamLearn/Algorithm/DecoupledDGNN/DGNN.py new file mode 100755 index 0000000000000000000000000000000000000000..c4475ec60d0b1f23bdd0c2557fa042f5ece40577 --- /dev/null +++ b/StreamLearn/Algorithm/DecoupledDGNN/DGNN.py @@ -0,0 +1,378 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + +from sklearn.metrics import roc_auc_score +from sklearn.metrics import accuracy_score + +from torch.utils.data import Dataset +from torch.utils.data import DataLoader + +from src import parameter_parser + +import pickle +import math +import random + +from StreamLearn.Base.SemiEstimator import StreamEstimator +from StreamLearn.Algorithm.DecoupledDGNN.utils import ( + get_data_transductive, + EdgeHelper, +) +from StreamLearn.Algorithm.DecoupledDGNN.model import ( + linkPredictor, +) + +def edge_index_difference(edge_all, edge_except, num_nodes): + """Set difference operator, return edges in edge_all but not + in edge_except. + """ + idx_all = edge_all[0] * num_nodes + edge_all[1] + idx_except = edge_except[0] * num_nodes + edge_except[1] + mask=np.isin(idx_all, idx_except) + idx_kept = idx_all[~mask] + ii = idx_kept // num_nodes + jj = idx_kept % num_nodes + return np.vstack((ii,jj)).astype(np.int64) + +def gen_negative_edges(sources, destinations, num_nodes, num_neg_per_node): + """Generates a fixed number of negative edges for each node. + + Args: + sources: (E) array of positive edges' sources. + destinations: (E) array of positive edges' destinations. + num_nodes: total number of nodes. + num_neg_per_node: approximate number of negative edges generated for + each source node in edge_index. + """ + src_lst = np.unique(sources) # get unique senders. + pos_edge_index = np.vstack((sources, destinations)) + num_neg_per_node = int(1.5 * num_neg_per_node) # add some redundancy. + ii = src_lst.repeat(num_neg_per_node) + jj = np.random.choice(num_nodes, len(ii), replace=True) + candidates = np.vstack((ii, jj)).astype(np.int64) + neg_edge_index = edge_index_difference(candidates, pos_edge_index, num_nodes) + return neg_edge_index[0], neg_edge_index[1] + +class DGNN(StreamEstimator): + def __init__(self, args): + self.args = args + GPU = args.gpu + # Set device + device_string = 'cuda:{}'.format(GPU) if torch.cuda.is_available() else 'cpu' + self.device = torch.device(device_string) + + self.batch_size = args.batch_size + self.data = args.data + self.use_neg = args.use_neg + if self.use_neg: + args.emb_size *= 2 + + if args.seq_model in ['lstm', 'gru']: + self.model = linkPredictor(args, self.device) + elif args.seq_model == 'transformer': + self.model = linkPredictor(self.device, args.emb_size * 2, args.hidden_dim, 4, 4, args.dropout) + + self.model = self.model.to(self.device) + self.window_size = args.window_size + self.inductive = args.inductive + + self.shuffle = args.shuffle + self.checkpt_path = args.checkpt_path + self.checkpt_file = self.checkpt_path + '/' + self.data + '_ws'+str(self.window_size)+'_best.pt' + self.patience = args.patience + + def train(self): + full_data, train_data, val_data, test_data, _, _ = get_data_transductive(self.data, shuffle=self.shuffle, disperse=True) + print('window_size: ', self.window_size) + self.edge_helper = EdgeHelper(self.data,randomize=True, disperse=True, use_neg=self.use_neg, inductive=self.inductive, split=False) + + criterion = nn.BCELoss() + optimizer = optim.RMSprop(self.model.parameters(), lr=args.learning_rate) + best_ap, best_epoch, best_loss = 0, 0, 10000. + bad_counter = 0 + for epoch in range(args.epochs): + self.model.train() + correct = 0 + num_samples = 0 + all_loss = [] + + for time in train_data.unique_times: + if self.edge_helper.time_edge_dict[time]['idx'] < self.window_size: + continue + edges_snap = self.edge_helper.time_edge_dict[time]['edges'] + num_instance = edges_snap.shape[0] + num_batch = math.ceil(num_instance / self.batch_size) + for batch_idx in range(num_batch): + start_idx = batch_idx * self.batch_size + end_idx = min(num_instance, start_idx + self.batch_size) + sources_batch, destinations_batch = edges_snap[start_idx:end_idx, 0], edges_snap[start_idx:end_idx, 1] + size = len(sources_batch) + timestamps_batch = np.repeat(time, size) + + if args.seq_model in ['lstm', 'gru']: + ### get pos features (sources_batch, destinations_batch, timestamps_batch) + src_features, dst_features = self.edge_helper.get_edges_feats(sources_batch, destinations_batch, timestamps_batch, window_size=self.window_size, concat=False) + pos_preds = self.model.get_edges_embedding(src_features.to(self.device), dst_features.to(self.device)) + elif args.seq_model == 'transformer': + pos_features = self.edge_helper.get_edges_feats(sources_batch, destinations_batch, timestamps_batch, window_size=self.window_size, concat=True) + pos_preds = self.model(pos_features.to(self.device)) + pos_labels = torch.ones(pos_preds.shape[0], dtype=torch.float, device=self.device) + pos_loss = criterion(pos_preds.squeeze(dim=1), pos_labels) ##pos_loss = -torch.log(pos_preds[:, 1]) + + ### get negtive sample and features + neg_destinations_batch = np.random.randint(0, self.edge_helper.node_num, size) + if args.seq_model in ['lstm', 'gru']: + neg_src_features, neg_dst_features = self.edge_helper.get_edges_feats(sources_batch, neg_destinations_batch, timestamps_batch, window_size=self.window_size, concat=False) + neg_preds = self.model.get_edges_embedding(neg_src_features.to(self.device), neg_dst_features.to(self.device)) + elif args.seq_model == 'transformer': + neg_features = self.edge_helper.get_edges_feats(sources_batch, neg_destinations_batch, timestamps_batch, window_size=self.window_size, concat=True) + neg_preds = self.model(neg_features.to(self.device)) + + neg_labels = torch.zeros(neg_preds.shape[0], dtype=torch.float, device=self.device) + neg_loss = criterion(neg_preds.squeeze(dim=1), neg_labels) ##neg_loss = -torch.log(neg_preds[:, 0]) + + # backward + loss = pos_loss + neg_loss ##loss = torch.mean(pos_loss + neg_loss) + all_loss.append(loss.item()) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + TP = torch.sum(pos_preds>=0.5) + TN = torch.sum(neg_preds<0.5) + correct += (TP + TN).item() + num_samples += (pos_preds.shape[0] + neg_preds.shape[0]) + + train_loss = np.mean(all_loss) + train_acc = correct / num_samples + valid_loss, valid_acc, valid_auc, valid_ap = self.valid(val_data) + print("Epoch: %d, loss: %.5f, Train accuracy: %.5f, Valid loss: %.5f, Valid accuracy: %.5f, Valid AUC: %.5f, Valid AP: %.5f" % (epoch+1, train_loss, train_acc, valid_loss, valid_acc, valid_auc, valid_ap)) + + if valid_ap > best_ap: + best_ap = valid_ap + best_epoch = epoch + torch.save(self.model.state_dict(), self.checkpt_file) + bad_counter = 0 + else: + bad_counter += 1 + if bad_counter == self.patience: + break + + print('begin test...') + print(best_epoch) + test_loss, test_acc, test_auc, test_ap, test_mrr = self.test(test_data) + print("Test loss: %.5f, Test accuracy: %.5f, Test AUC: %.5f, Test AP: %.5f, MRR: %.5f" % (test_loss, test_acc, test_auc, test_ap, test_mrr)) + + def valid(self, val_data, cal_mrr=False): + criterion = nn.BCELoss() + self.model.eval() + correct = 0 + num_samples = 0 + all_loss = [] + + val_batchsize = 2*self.batch_size + y_true, y_pred = [], [] + mrr_hist = [] + for time in val_data.unique_times: + edges_snap = self.edge_helper.time_edge_dict[time]['edges'] + num_instance = edges_snap.shape[0] + num_batch = math.ceil(num_instance / val_batchsize) + for batch_idx in range(num_batch): + start_idx = batch_idx * self.batch_size + end_idx = min(num_instance, start_idx + self.batch_size) + sources_batch, destinations_batch = edges_snap[start_idx:end_idx, 0], edges_snap[start_idx:end_idx, 1] + size = len(sources_batch) + timestamps_batch = np.repeat(time, size) + + if args.seq_model in ['lstm', 'gru']: + src_features, dst_features = self.edge_helper.get_edges_feats(sources_batch, destinations_batch, timestamps_batch, window_size=self.window_size, concat=False) + pos_preds = self.model.get_edges_embedding(src_features.to(self.device), dst_features.to(self.device)) + elif args.seq_model == 'transformer': + pos_features = self.edge_helper.get_edges_feats(sources_batch, destinations_batch, timestamps_batch, window_size=self.window_size, concat=True) + pos_preds = self.model(pos_features.to(self.device)) + pos_labels = torch.ones(pos_preds.shape[0], dtype=torch.float, device=self.device) + pos_loss = criterion(pos_preds.squeeze(dim=1), pos_labels) + + ### get negtive sample and features + neg_destinations_batch = np.random.randint(0, self.edge_helper.node_num, size) + if args.seq_model in ['lstm', 'gru']: + neg_src_features, neg_dst_features = self.edge_helper.get_edges_feats(sources_batch, neg_destinations_batch, timestamps_batch, window_size=self.window_size, concat=False) + neg_preds = self.model.get_edges_embedding(neg_src_features.to(self.device), neg_dst_features.to(self.device)) + elif args.seq_model == 'transformer': + neg_features = self.edge_helper.get_edges_feats(sources_batch, neg_destinations_batch, timestamps_batch, window_size=self.window_size, concat=True) + neg_preds = self.model(neg_features.to(self.device)) + + neg_labels = torch.zeros(neg_preds.shape[0], dtype=torch.float, device=self.device) + neg_loss = criterion(neg_preds.squeeze(dim=1), neg_labels) + + loss = pos_loss + neg_loss ##loss = torch.mean(pos_loss + neg_loss) + all_loss.append(loss.item()) + + TP = torch.sum(pos_preds>=0.5) + TN = torch.sum(neg_preds<0.5) + + correct += (TP + TN).item() + num_samples += (pos_preds.shape[0] + neg_preds.shape[0]) + for i in range(pos_preds.shape[0]): + y_pred.append(pos_preds[i].item()) + y_true.append(1) + for i in range(neg_preds.shape[0]): + y_pred.append(neg_preds[i].item()) + y_true.append(0) + + if cal_mrr: + # calculate MRR for each snap + mrr, recall_at = self.eval_mrr_and_recall(edges_snap, np.repeat(time, edges_snap.shape[0]), self.edge_helper.node_num) + mrr_hist.append(mrr) + auc = self.auc_score(y_true, y_pred) + ap = self.ap_score(y_true, y_pred) + valid_loss = np.mean(all_loss) + valid_acc = correct / num_samples + if cal_mrr: + valid_mrr = np.mean(mrr_hist) + return valid_loss, valid_acc, auc, ap, valid_mrr + else: + return valid_loss, valid_acc, auc, ap + + def auc_score(self, y_true, y_score): + ''' use sklearn roc_auc_score API + y_true & y_score; array-like, shape = [n_samples] + ''' + from sklearn.metrics import roc_auc_score + roc = roc_auc_score(y_true=y_true, y_score=y_score) + return roc + + def ap_score(self, y_true, y_score): + ''' use sklearn roc_auc_score API + y_true & y_score; array-like, shape = [n_samples] + ''' + from sklearn.metrics import average_precision_score + ap = average_precision_score(y_true, y_score) + return ap + + def eval_mrr_and_recall(self, eval_edges, eval_timestamps, num_nodes, num_neg_per_node=1000): + from datetime import datetime + start = datetime.now() + eval_sources, eval_destinations = eval_edges[:, 0], eval_edges[:, 1] + + # A list of source nodes to consider. + src_lst = np.unique(eval_sources) # get unique senders. + num_users = len(src_lst) + + src_features, dst_features = self.edge_helper.get_edges_feats(eval_sources, eval_destinations, eval_timestamps, window_size=self.window_size, concat=False) + pos_preds = self.model.get_edges_embedding(src_features.to(self.device), dst_features.to(self.device)) + pos_preds = pos_preds.squeeze(dim=1) + pos_labels = torch.ones(pos_preds.shape[0], dtype=torch.float, device=self.device) + + # generate negtive samples + neg_sources, neg_destinations = gen_negative_edges(eval_sources, eval_destinations, num_nodes, num_neg_per_node) + neg_timestamps = np.resize(eval_timestamps, neg_sources.shape) + neg_src_features, neg_dst_features = self.edge_helper.get_edges_feats(neg_sources, neg_destinations, neg_timestamps, window_size=self.window_size, concat=False) + neg_preds = self.model.get_edges_embedding(neg_src_features.to(self.device), neg_dst_features.to(self.device)) + neg_preds = neg_preds.squeeze(dim=1) + neg_labels = torch.zeros(neg_preds.shape[0], dtype=torch.float, device=self.device) + + # The default setting, consider the rank of the most confident edge. + from torch_scatter import scatter_max + best_p_pos, _ = scatter_max(src=pos_preds, index=torch.from_numpy(eval_sources).to(self.device), dim_size=num_nodes) + # best_p_pos has shape (num_nodes), for nodes not in src_lst has value 0. + best_p_pos_by_user = best_p_pos[src_lst] + + uni, counts = np.unique(neg_sources,return_counts=True) + # find index of first occurrence of each src in neg_sources + first_occ_idx = np.cumsum(counts,axis=0) - counts + add = np.arange(num_neg_per_node) + # take the first $num_neg_per_node$ negative edges from each src. + score_idx = first_occ_idx.reshape(-1, 1) + add.reshape(1, -1) + score_idx = torch.from_numpy(score_idx).long() + p_neg_by_user = neg_preds[score_idx] # (num_users, num_neg_per_node) + + compare = (p_neg_by_user >= best_p_pos_by_user.reshape(num_users, 1)).float() + assert compare.shape == (num_users, num_neg_per_node) + # compare[i, j], for node i, the j-th negative edge's score > p_best. + + # counts 1 + how many negative edge from src has higher score than p_best. + # if there's no such negative edge, rank is 1. + rank_by_user = compare.sum(axis=1) + 1 # (num_users,) + assert rank_by_user.shape == (num_users,) + + mrr = float(torch.mean(1 / rank_by_user)) + print(f'MRR={mrr}, time taken: {(datetime.now() - start).seconds} s') + + recall_at = dict() + for k in [1, 3, 10]: + recall_at[k] = float((rank_by_user <= k).float().mean()) + return mrr, recall_at + + def test(self, test_data): + self.model.load_state_dict(torch.load(self.checkpt_file)) + self.model.eval() + criterion = nn.BCELoss() + correct = 0 + num_samples = 0 + all_loss = [] + test_batchsize = 2*self.batch_size + y_true, y_pred = [], [] + mrr_hist = [] + for time in test_data.unique_times: + edges_snap = self.edge_helper.time_edge_dict[time]['edges'] + num_instance = edges_snap.shape[0] + num_batch = math.ceil(num_instance / test_batchsize) + + for batch_idx in range(num_batch): + start_idx = batch_idx * self.batch_size + end_idx = min(num_instance, start_idx + self.batch_size) + sources_batch, destinations_batch = edges_snap[start_idx:end_idx, 0], edges_snap[start_idx:end_idx, 1] + size = len(sources_batch) + timestamps_batch = np.repeat(time, size) + if args.seq_model in ['lstm', 'gru']: + ### get pos features (sources_batch, destinations_batch, timestamps_batch) + src_features, dst_features = self.edge_helper.get_edges_feats(sources_batch, destinations_batch, timestamps_batch, window_size=self.window_size, concat=False) + pos_preds = self.model.get_edges_embedding(src_features.to(self.device), dst_features.to(self.device)) + elif args.seq_model == 'transformer': + pos_features = self.edge_helper.get_edges_feats(sources_batch, destinations_batch, timestamps_batch, window_size=self.window_size, concat=True) + pos_preds = self.model(pos_features.to(self.device)) + pos_labels = torch.ones(pos_preds.shape[0], dtype=torch.float, device=self.device) + pos_loss = criterion(pos_preds.squeeze(dim=1), pos_labels) ##pos_loss = -torch.log(pos_preds[:, 1]) + + ### get negtive sample and features + neg_destinations_batch = np.random.randint(0, self.edge_helper.node_num, size) + if args.seq_model in ['lstm', 'gru']: + neg_src_features, neg_dst_features = self.edge_helper.get_edges_feats(sources_batch, neg_destinations_batch, timestamps_batch, window_size=self.window_size, concat=False) + neg_preds = self.model.get_edges_embedding(neg_src_features.to(self.device), neg_dst_features.to(self.device)) + elif args.seq_model == 'transformer': + neg_features = self.edge_helper.get_edges_feats(sources_batch, neg_destinations_batch, timestamps_batch, window_size=self.window_size, concat=True) + neg_preds = self.model(neg_features.to(self.device)) + neg_labels = torch.zeros(neg_preds.shape[0], dtype=torch.float, device=self.device) + neg_loss = criterion(neg_preds.squeeze(dim=1), neg_labels) + loss = pos_loss + neg_loss ##loss = torch.mean(pos_loss + neg_loss) + all_loss.append(loss.item()) + + TP = torch.sum(pos_preds>=0.5) + TN = torch.sum(neg_preds<0.5) + + correct += (TP + TN).item() + num_samples += (pos_preds.shape[0] + neg_preds.shape[0]) + for i in range(pos_preds.shape[0]): + y_pred.append(pos_preds[i].item()) + y_true.append(1) + for i in range(neg_preds.shape[0]): + y_pred.append(neg_preds[i].item()) + y_true.append(0) + + # calculate MRR for each snap + mrr, recall_at = self.eval_mrr_and_recall(edges_snap, np.repeat(time, edges_snap.shape[0]), self.edge_helper.node_num) + mrr_hist.append(mrr) + + test_auc = self.auc_score(y_true, y_pred) + test_ap = self.ap_score(y_true, y_pred) + test_loss = np.mean(all_loss) + test_acc = correct / num_samples + test_mrr = np.mean(mrr_hist) + return test_loss, test_acc, test_auc, test_ap, test_mrr + diff --git a/StreamLearn/Algorithm/DecoupledDGNN/graph_embs.py b/StreamLearn/Algorithm/DecoupledDGNN/graph_embs.py new file mode 100755 index 0000000000000000000000000000000000000000..aa7fad2b3542be41ee06be2ce0e606763da633b9 --- /dev/null +++ b/StreamLearn/Algorithm/DecoupledDGNN/graph_embs.py @@ -0,0 +1,149 @@ +import torch +import gc +import numpy as np +from propagation import InstantGNN +import argparse +import pickle +import os +import copy +import scipy.sparse as sp +import uuid +from datetime import datetime + +np.random.seed(0) + +class GraphEmbs: + def __init__(self, path, data, rmax, alpha, randomize_features=True, disperse=True, undirect=True, split=False): + self.path = path + self.data = data + self.rmax = rmax + self.alpha = alpha + self.randomize_features = randomize_features + self.disperse = disperse + self.split = split + self.undirect = undirect + + def load_and_process_data(self): + features, py_alg, node_num = self.load_data_init(self.data+'_init', rmax=self.rmax, alpha=self.alpha, randomize_features=self.randomize_features) + + feat_dim = features.shape[1] + print('feat_dim:', feat_dim) + nodes_seq_lst = [] + + rand_str = '_randomize' if self.randomize_features else '' + disperse_str = '_disperse' if self.disperse else '' + + out_file = self.path + self.data + '_nodes_seq_lst' + rand_str + '_mul' + disperse_str + time_edge_dict_lst = [] + seq_len = 0 + + if self.split: + for ss in ['train', 'valid', 'test']: + time_edge_file = self.path+self.data+'_time_edge_map_' + ss + '.pkl' + with open(time_edge_file, 'rb') as f: + time_edge_dict = pickle.load(f) + time_edge_dict_lst.append(time_edge_dict) + seq_len += len(time_edge_dict) + else: + time_edge_file = self.path+self.data+'_time_edge_map'+ disperse_str +'.pkl' + if self.disperse: + time_edge_file = self.path+self.data+'_time_edge_map_disperse.pkl' + with open(time_edge_file, 'rb') as f: + time_edge_dict = pickle.load(f) + time_edge_dict_lst.append(time_edge_dict) + seq_len += len(time_edge_dict) + + nodes_seq_lst=[np.zeros((seq_len + 1, feat_dim)) for i in range(node_num)] + + + for node in range(node_num): + nodes_seq_lst[node][0] = features[node] + print('init feat append......') + + if self.split: + splits = ['train', 'valid', 'test'] + else: + splits = ['full'] + + count = 0 + history = 0 + tmp_file = 'tmp_'+self.data+'.txt' + for it, ss in enumerate(splits): + print('---- %s ----' % ss) + time_edge_dict = time_edge_dict_lst[it] + + for idx, time in enumerate(time_edge_dict): + old_feat = copy.deepcopy(features) + edges = time_edge_dict[time] + + ##reverse edges + if self.undirect: + ss, tt = edges[:,0], edges[:,1] + ss=ss.reshape(-1,1) + tt=tt.reshape(-1,1) + re_edges=np.concatenate([tt,ss], axis=1) + edges = np.concatenate([edges, re_edges]) + + np.savetxt(tmp_file, edges, fmt='%d', delimiter=' ') + + py_alg.snapshot_operation(tmp_file, self.rmax, self.alpha, features) + os.remove(tmp_file) + + delta_feat = features - old_feat + + affacted_nodes, pos = np.where(delta_feat!=0) + for cur_node, cur_pos in zip(affacted_nodes, pos): + nodes_seq_lst[cur_node][idx+1+history, cur_pos] = delta_feat[cur_node, cur_pos] + history += len(time_edge_dict) + + out_file += '.pkl' + for i in range(node_num): nodes_seq_lst[i] = sp.csr_matrix(nodes_seq_lst[i]) + ttf = open(out_file,'wb') + pickle.dump(nodes_seq_lst, ttf, pickle.HIGHEST_PROTOCOL) + ttf.close() + print('get embeddings finish..') + + return features, py_alg, node_num, nodes_seq_lst, out_file + + def load_data_init(datastr, rmax, alpha, randomize_features=False, neg=False): + if datastr == 'wikipedia_init': + m = 9227; n = 9227 + if datastr == 'reddit_init': + m = 10984; n = 10984 + if datastr == 'CollegeMsg_init': + m = 1899; n = 1899 + if datastr == 'bitcoinotc_init': + m = 5881; n = 5881 + if datastr == 'bitcoinalpha_init': + m = 3783; n = 3783 + if datastr == 'GDELT_init': + m = 16682; n = 16682 + if datastr == 'MAG_init': + m = 72508661; n = 72508661 + + print("Load %s!" % datastr) + + py_alg = InstantGNN() + if randomize_features: + if self.data in ['wikipedia', 'reddit']: + features = np.random.rand(n, 172) + else: + features = np.random.rand(n, 128) + else: + features = torch.load('{0}/{1}/{1}_node_features.pt'.format(self.path, self.data)) + if features.dtype == torch.bool: + features = features.type(torch.float64) + features = features.numpy() + + print('features:', features) + memory_dataset = py_alg.initial_operation(self.path, datastr, m, n, rmax, alpha, features) + return features, py_alg, n + + + + + + + + + diff --git a/StreamLearn/Algorithm/DecoupledDGNN/model.py b/StreamLearn/Algorithm/DecoupledDGNN/model.py new file mode 100755 index 0000000000000000000000000000000000000000..258a8a31fe9fc5cdd78a0043c09c6a4bded32da3 --- /dev/null +++ b/StreamLearn/Algorithm/DecoupledDGNN/model.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class LSTM_Emb(nn.ModuleList): + def __init__(self, batch_size, hidden_dim, lstm_layers, emb_size, dropout, device): + super(LSTM_Emb, self).__init__() + + self.batch_size = batch_size + self.hidden_dim = hidden_dim + self.LSTM_layers = lstm_layers + self.input_size = emb_size + + self.dropout = nn.Dropout(dropout) + self.relu = torch.nn.ReLU() + self.fc0 = nn.Linear(in_features=self.input_size, out_features=self.hidden_dim) + self.lstm = nn.LSTM(input_size=self.hidden_dim, hidden_size=self.hidden_dim, num_layers=self.LSTM_layers, batch_first=True) + self.device = device + + def forward(self, x): + h = torch.zeros((self.LSTM_layers, x.size(0), self.hidden_dim), device=x.device) + c = torch.zeros((self.LSTM_layers, x.size(0), self.hidden_dim), device=x.device) + + torch.nn.init.xavier_normal_(h) + torch.nn.init.xavier_normal_(c) + + out = torch.tanh(self.fc0(x)) + out = self.dropout(out) + out, (hidden, cell) = self.lstm(out, (h,c)) + return out + +class GRU_Emb(nn.ModuleList): + def __init__(self, batch_size, hidden_dim, lstm_layers, emb_size, dropout, device): + super(GRU_Emb, self).__init__() + + self.batch_size = batch_size + self.hidden_dim = hidden_dim + self.LSTM_layers = lstm_layers + self.input_size = emb_size + + self.dropout = nn.Dropout(dropout) + self.relu = torch.nn.ReLU() + self.fc0 = nn.Linear(in_features=self.input_size, out_features=self.hidden_dim) + self.gru = nn.GRU(input_size=self.hidden_dim, hidden_size=self.hidden_dim, num_layers=self.LSTM_layers, batch_first=True) + self.device = device + + def forward(self, x): + h = torch.zeros((self.LSTM_layers, x.size(0), self.hidden_dim), device=self.device) + torch.nn.init.xavier_normal_(h) + + out = torch.tanh(self.fc0(x)) + out = self.dropout(out) + out, hidden = self.gru(out, h) + return out + + +class MergeLayer(torch.nn.Module): + def __init__(self, dim1, dim2, dim3, dim4): + super().__init__() + self.fc1 = torch.nn.Linear(dim1 + dim2, dim3) + self.fc2 = torch.nn.Linear(dim3, dim4) + self.act = torch.nn.ReLU() + + torch.nn.init.xavier_normal_(self.fc1.weight) + torch.nn.init.xavier_normal_(self.fc2.weight) + + def forward(self, x1, x2): + x = torch.cat([x1, x2], dim=-1) + h = self.act(self.fc1(x[:,-1,:])) + out = torch.nn.Sigmoid()(self.fc2(h)) + return out + +class MLP(torch.nn.Module): + def __init__(self, dim, drop=0.3): + super().__init__() + self.fc_1 = torch.nn.Linear(dim, dim) + self.fc_2 = torch.nn.Linear(dim, 10) + self.fc_3 = torch.nn.Linear(10, 1) + self.act = torch.nn.ReLU() + self.dropout = torch.nn.Dropout(p=drop, inplace=False) + + def forward(self, x): + x = self.act(self.fc_1(x[:,-1,:])) + x = self.dropout(x) + x = self.act(self.fc_2(x)) + x = self.dropout(x) + #return self.fc_3(x).squeeze(dim=1) + return torch.nn.Sigmoid()(self.fc_3(x)) + +class TimeEncode(torch.nn.Module): + # Time Encoding proposed by TGAT + def __init__(self, dimension): + super(TimeEncode, self).__init__() + + self.dimension = dimension + self.w = torch.nn.Linear(1, dimension) + + self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension))) + .float().reshape(dimension, -1)) + self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float()) + + def forward(self, t): + t = t.unsqueeze(-1) + output = torch.cos(self.w(t)) + return output + +class TimeEncodeMixer(torch.nn.Module): + """ + # Time Encoding proposed by Mixer + out = linear(time_scatter): 1-->time_dims + out = cos(out) + """ + def __init__(self, dim): + super(TimeEncodeMixer, self).__init__() + self.dim = dim + self.w = torch.nn.Linear(1, dim) + self.reset_parameters() + + def reset_parameters(self, ): + self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, self.dim, dtype=np.float32))).reshape(self.dim, -1)) + self.w.bias = torch.nn.Parameter(torch.zeros(self.dim)) + + self.w.weight.requires_grad = False + self.w.bias.requires_grad = False + + @torch.no_grad() + def forward(self, t): + t = t.unsqueeze(-1) + output = torch.cos(self.w(t)) + return output + + +class linkPredictor(nn.ModuleList): + def __init__(self, args, device): + super(linkPredictor, self).__init__() + + self.batch_size = args.batch_size + self.hidden_dim = args.hidden_dim + self.LSTM_layers = args.lstm_layers + self.nodes_embedding_size = args.emb_size + self.dropout = args.dropout + self.embedding_module_type = args.seq_model + print('self.embedding_module_type: ', self.embedding_module_type) + if self.embedding_module_type=='lstm': + self.embedding_model = LSTM_Emb(self.batch_size, self.hidden_dim, self.LSTM_layers, self.nodes_embedding_size, self.dropout, device) + elif self.embedding_module_type=='gru': + self.embedding_model = GRU_Emb(self.batch_size, self.hidden_dim, self.LSTM_layers, self.nodes_embedding_size, self.dropout, device) + self.decoder = MergeLayer(self.hidden_dim, self.hidden_dim, self.hidden_dim, 1) + + def get_nodes_embedding(self, source_feats, destination_feats): + src_emb = self.embedding_model(source_feats) + dst_emb = self.embedding_model(destination_feats) + return src_emb, dst_emb + + def get_edges_embedding(self, source_feats, destination_feats): + src_emb, dst_emb = self.get_nodes_embedding(source_feats, destination_feats) + edge_prob = self.decoder(src_emb, dst_emb) + return edge_prob + diff --git a/StreamLearn/Algorithm/DecoupledDGNN/utils.py b/StreamLearn/Algorithm/DecoupledDGNN/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..f897dd215b71128a90b96df27c2d2ed61acee441 --- /dev/null +++ b/StreamLearn/Algorithm/DecoupledDGNN/utils.py @@ -0,0 +1,205 @@ +import numpy as np +import random +import pandas as pd +import pickle +import torch +import random +import os + +class Data: + def __init__(self, sources, destinations, timestamps, edge_idxs, labels, shuffle=False): + temp = list(zip(sources, destinations, timestamps, edge_idxs, labels)) + if shuffle: + random.shuffle(temp) + else: + temp.sort(key=lambda x: x[2], reverse=False) + sources[:], destinations[:], timestamps[:], edge_idxs[:], labels[:] = zip(*temp) + self.sources = sources + self.destinations = destinations + self.timestamps = timestamps + self.edge_idxs = edge_idxs + self.labels = labels + self.n_interactions = len(sources) + self.unique_nodes = np.concatenate((sources, destinations)) + self.unique_nodes = np.unique(self.unique_nodes) + self.n_unique_nodes = len(self.unique_nodes) + self.unique_times = np.unique(timestamps) + +class EdgeHelper(): + def __init__(self, path, dataset_name, randomize=True, disperse=False, use_neg=False, inductive=False, split=False): + self.path = path + self.dataset_name = dataset_name + self.split = split + self.use_neg = use_neg + self.time_edge_dict = dict() + self.nodes_seq_lst = [] + self.node_num = 0 + self.get_time_edges(disperse,inductive) + self.get_nodes_seq_lst(randomize,disperse,inductive) + + def get_time_edges(self,disperse=False, inductive=False): + disperse_str = '_disperse' if self.disperse else '' + if self.split: + history = 1 + for i, ss in enumerate(['train', 'valid', 'test']): + time_edge_file = os.path.join(self.path, self.dataset_name, f"{self.dataset_name}_time_edge_map_{ss}.pkl") + with open(time_edge_file, 'rb') as f: + time_edge_dict = pickle.load(f) + for idx, time in enumerate(time_edge_dict): + edges = time_edge_dict[time] + self.time_edge_dict[ss+str(time)] = {'idx': idx+history, 'edges': edges} + print('%s has %d time step'%(ss, len(time_edge_dict))) + history += len(time_edge_dict) + else: + if inductive: + time_edge_file = os.path.join(self.path, self.dataset_name, f"{self.dataset_name}_time_edge_map_inductive{disperse_str}.pkl") + else: + time_edge_file = os.path.join(self.path, self.dataset_name, f"{self.dataset_name}_time_edge_map{disperse_str}.pkl") + with open(time_edge_file, 'rb') as f: + time_edge_dict = pickle.load(f) + + self.time_lst = np.zeros(len(time_edge_dict) + 1) + self.time_lst[0] = -1 + for idx, time in enumerate(time_edge_dict): + edges = time_edge_dict[time] + self.time_edge_dict[time] = {'idx': idx+1, 'edges': edges} + self.time_lst[idx+1] = time + + def get_nodes_seq_lst(self, randomize=True, disperse=False, inductive=False): + rand_str = '_randomize' if self.randomize_features else '' + disperse_str = '_disperse' if self.disperse else '' + if self.dataset_name in ['CollegeMsg', 'bitcoinalpha', 'bitcoinotc'] and self.use_neg: + file1 = os.path.join(self.path, self.dataset_name, f"{self.dataset_name}_nodes_seq_lst{rand_str}_mul{disperse_str}.pkl") # 1-alpha + file2 = os.path.join(self.path, self.dataset_name, f"{self.dataset_name}_nodes_seq_lst{rand_str}_mul{disperse_str}_alpha-1.pkl") # alpha-1 + with open(file1, 'rb') as f: + aaalpha = pickle.load(f) + with open(file2, 'rb') as f: + alphaaa = pickle.load(f) + self.node_num = len(aaalpha) + self.nodes_seq_lst = [] + for i in range(len(aaalpha)): + self.nodes_seq_lst.append(scipy.sparse.hstack((aaalpha[i], alphaaa[i]),format='csr')) + else: + if inductive: + nodes_seq_lst_file = os.path.join(self.path, self.dataset_name, f"{self.dataset_name}_nodes_seq_lst{rand_str}_mul_inductive{disperse_str}.pkl") + else: + nodes_seq_lst_file = os.path.join(self.path, self.dataset_name, f"{self.dataset_name}_nodes_seq_lst{rand_str}_mul{disperse_str}.pkl") + with open(nodes_seq_lst_file, 'rb') as f: + self.nodes_seq_lst = pickle.load(f) + self.node_num = len(self.nodes_seq_lst) + + def get_edges_feats(self, sources, destinations, timestamps, window_size=5, concat=True): + src_feat_lst = dst_feat_lst = torch.zeros((len(sources), window_size, self.nodes_seq_lst[0].shape[1])) + src_dts_lst = dts_dts_lst = torch.zeros((len(sources), window_size)) + for i, (src, dst, ts) in enumerate(zip(sources, destinations, timestamps)): + ts = round(ts, 3) + ts_id = self.time_edge_dict[ts]['idx'] + + src_feat, src_prev_ts = self.get_node_feats(src, ts_id, window_size) + dst_feat, dst_prev_ts = self.get_node_feats(dst, ts_id, window_size) + src_feat_lst[i][-len(src_feat):] = src_feat + dst_feat_lst[i][-len(dst_feat):] = dst_feat + + src_dts = ts - src_prev_ts + dst_dts = ts - dst_prev_ts + src_dts_lst[i][-len(src_dts):] = torch.tensor(src_dts, dtype=torch.float32) + dts_dts_lst[i][-len(dst_dts):] = torch.tensor(dst_dts, dtype=torch.float32) + if concat: + edges_feats = torch.cat((src_feat_lst, dst_feat_lst), dim=-1) + return edges_feats, src_dts_lst, dts_dts_lst + else: + return src_feat_lst, dst_feat_lst, src_dts_lst, dts_dts_lst + + def get_node_feats(self, node, tsid, window_size): + mask = np.array(self.nodes_seq_lst[node].sum(-1)) != 0 + mask = mask.reshape(-1) + mask *= np.arange(self.nodes_seq_lst[node].shape[0]) < tsid + node_feat = self.nodes_seq_lst[node][mask] + node_feat = node_feat[-window_size:, :] + node_feat_tensor = torch.tensor(node_feat.toarray(), dtype=torch.float32) + node_time = self.time_lst[mask][-window_size:] + return node_feat_tensor, node_time + + +def get_data_transductive(path, dataset_name, shuffle=False,disperse=True): + disperse_str = '_disperse' if self.disperse else '' + ### Load data and train val test split + graph_df = pd.read_csv('{0}/{1}/ml_{1}{2}.csv'.format(path, dataset_name, disperse_str)) + + timestamps = graph_df.ts.values + if disperse: + timestamps = graph_df.ts_str.values + + if dataset_name in ['bitcoinotc', 'bitcoinalpha']: + splits = [0.70, 0.80] + ts_lst=np.unique(timestamps) + train_end = int(len(ts_lst) * splits[0]) + valid_end = int(len(ts_lst) * splits[1]) + val_time=ts_lst[train_end] + test_time=ts_lst[valid_end] + elif dataset_name in ['CollegeMsg']: + splits = [0.71, 0.80] + ts_lst=np.unique(timestamps) + train_end = int(len(ts_lst) * splits[0]) + valid_end = int(len(ts_lst) * splits[1]) + valid_end += 2 ## + val_time=ts_lst[train_end] + test_time=ts_lst[valid_end] + else: + val_time, test_time = list(np.quantile(timestamps, [0.70, 0.85])) + + sources = graph_df.u.values + destinations = graph_df.i.values + edge_idxs = graph_df.idx.values + if {'label'}.issubset(graph_df.columns): + labels = graph_df.label.values + else: + labels = np.ones(sources.shape) + + full_data = Data(sources, destinations, timestamps, edge_idxs, labels) + + random.seed(2020) + node_set = set(sources) | set(destinations) + n_total_unique_nodes = len(node_set) + + train_mask = timestamps <= val_time #transductive + + train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask], + edge_idxs[train_mask], labels[train_mask], shuffle=shuffle) + + # define the new nodes sets for testing inductiveness of the model + train_node_set = set(train_data.sources).union(train_data.destinations) + new_node_set = node_set - train_node_set + + val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time) + test_mask = timestamps > test_time + + edge_contains_new_node_mask = np.array([(a in new_node_set or b in new_node_set) for a, b in zip(sources, destinations)]) + new_node_val_mask = np.logical_and(val_mask, edge_contains_new_node_mask) + new_node_test_mask = np.logical_and(test_mask, edge_contains_new_node_mask) + + # validation and test with all edges + val_data = Data(sources[val_mask], destinations[val_mask], timestamps[val_mask], + edge_idxs[val_mask], labels[val_mask]) + + test_data = Data(sources[test_mask], destinations[test_mask], timestamps[test_mask], + edge_idxs[test_mask], labels[test_mask]) + + # validation and test with edges that at least has one new node (not in training set) + new_node_val_data = Data(sources[new_node_val_mask], destinations[new_node_val_mask], + timestamps[new_node_val_mask], + edge_idxs[new_node_val_mask], labels[new_node_val_mask]) + + new_node_test_data = Data(sources[new_node_test_mask], destinations[new_node_test_mask], + timestamps[new_node_test_mask], edge_idxs[new_node_test_mask], + labels[new_node_test_mask]) + + print("--------- Get {} data: Transductive ---------".format(dataset_name)) + print("The dataset has {} interactions, involving {} different nodes".format(full_data.n_interactions, full_data.n_unique_nodes)) + print("The training dataset has {} interactions, involving {} different nodes".format(train_data.n_interactions, train_data.n_unique_nodes)) + print("The validation dataset has {} interactions, involving {} different nodes".format(val_data.n_interactions, val_data.n_unique_nodes)) + print("The test dataset has {} interactions, involving {} different nodes".format(test_data.n_interactions, test_data.n_unique_nodes)) + return full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data + + + diff --git a/StreamLearn/Dataset/DTDGDataset.py b/StreamLearn/Dataset/DTDGDataset.py new file mode 100755 index 0000000000000000000000000000000000000000..0a2f7f79cfcf0fd3bd2b3db6be173ff7a263fe0c --- /dev/null +++ b/StreamLearn/Dataset/DTDGDataset.py @@ -0,0 +1,168 @@ +import json +import numpy as np +import pandas as pd +from pathlib import Path +import argparse +import pickle +import pdb +import time +import os +import subprocess + +from StreamLearn.Dataset.StreamDataset import StreamDataset + +def reindex(df): + new_df = df.copy() + + node_list = list(new_df.u.unique())+list(new_df.i.unique()) + node_list = list(set(node_list)) + node_list.sort() + if max(new_df.u.max(), new_df.i.max()) - 0 + 1 == len(node_list): ## The ids of u and i are consecutive and start from 0. + return new_df + min_id = min(new_df.u.min(),new_df.i.min()) + if min_id == 1 and max(new_df.u.max(), new_df.i.max()) - min_id + 1 == len(node_list): ## The ids of u and i are consecutive and start from 1. + new_df.u -= 1 + new_df.i -= 1 + return new_df + + node_id_map = {} + for idx, node in enumerate(node_list): + node_id_map[node] = idx + + from_list, to_list = [], [] + for u, i in zip(df.u, df.i): + from_id = node_id_map[u] + to_id = node_id_map[i] + from_list.append(from_id) + to_list.append(to_id) + new_df.u = from_list + new_df.i = to_list + + return new_df + +def download_and_process_files(data_name, root="../data/"): + # Ensure trailing slash for directory path + if not root.endswith('/'): + root += '/' + + # Define command lists for each dataset with root path included + commands_alpha = [ + 'wget "https://snap.stanford.edu/data/soc-sign-bitcoinalpha.csv.gz"', + 'gzip -d soc-sign-bitcoinalpha.csv.gz', + f'mv soc-sign-bitcoinalpha.csv {root}bitcoinalpha.csv' + ] + + commands_otc = [ + 'wget "https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz"', + 'gzip -d soc-sign-bitcoinotc.csv.gz', + f'mv soc-sign-bitcoinotc.csv {root}bitcoinotc.csv' + ] + + commands_college = [ + 'wget "http://snap.stanford.edu/data/CollegeMsg.txt.gz"', + 'gzip -d CollegeMsg.txt.gz', + f'mv CollegeMsg.txt {root}CollegeMsg.txt' + ] + + # Execute commands based on data_name + if data_name == 'bitcoinalpha': + commands = commands_alpha + elif data_name == 'bitcoinotc': + commands = commands_otc + elif data_name == 'CollegeMsg': + commands = commands_college + else: + raise ValueError("Invalid data name. Please choose from 'bitcoinalpha', 'bitcoinotc', or 'CollegeMsg'.") + + # Run the selected commands + for command in commands: + subprocess.run(command, shell=True, check=True) + print(f"Executed: {command}") + +def disperse_dataset(data, graph_df, snapshot_freq='S'): + new_df = graph_df.copy() + timestamps = graph_df.ts.values + if snapshot_freq in ['S']: + if data == 'CollegeMsg': + ts_str = timestamps // 190080 + snap_num = len(set(ts_str)) + print('snapshot_freq is 190080s, snap_num = ', snap_num) + if data in ['bitcoinotc', 'bitcoinalpha']: + ts_str = timestamps // 1200000 + snap_num = len(set(ts_str)) + print('snapshot_freq is 1200000s, snap_num = ', snap_num) + elif snapshot_freq in ['W']: + pass + new_df.insert(3, 'ts_str', list(ts_str)) + return new_df + +def generate_init_graph(file_name, node_num): + with open(file_name, 'w+') as f: + for node_id in range(node_num): + f.write('%d %d\n'%(node_id, node_id)) + print('init graph generate finish!!') + +def build_time_edge_map(graph_df, file_name, disperse=True): + sources = graph_df.u.values + destinations = graph_df.i.values + if disperse: + timestamps = graph_df.ts_str.values + else: + timestamps = graph_df.ts.values + time_set = list(set(timestamps)) + time_set.sort() + print('len of time_set: ', len(time_set)) + + time_edge_map = {} + for i,tts in enumerate(time_set): + # if i % 20 == 0: + # print(i, ', tts: ', tts) + from_nodes = sources[timestamps == tts] + to_nodes = destinations[timestamps == tts] + edges = np.array([from_nodes, to_nodes]) + + edges = edges.transpose(1,0) + if edges.shape[0]>=2: + time_edge_map[tts] = edges + + print(file_name) + with open(file_name, 'wb') as f: + pickle.dump(time_edge_map, f, pickle.HIGHEST_PROTOCOL) + print('save time_edge_map finish!!!') + +class DTDGsDataset(StreamDataset): + def __init__(self, data_name, snapshot_freq=True, disperse=True): + root = "../data/" + save_path = root + data_name + Path(save_path).mkdir(parents=True, exist_ok=True) + self.path = save_path + download_and_process_files(data_name, root) + + OUT_DF = os.path.join(save_path, 'ml_{}_disperse.csv'.format(data_name)) + OUT_TIME_EDGE_MAP = os.path.join(save_path, '{}_time_edge_map_disperse.pkl'.format(data_name)) + OUT_INIT_GRAPH = os.path.join(save_path, '{}_init.txt'.format(data_name)) + + if data_name =='CollegeMsg': + PATH = os.path.join(root, '{}.txt'.format(data_name)) #'../data/{}.txt'.format(data_name) + df = pd.read_csv(PATH, sep=' ', header=None, index_col=None) + df.columns = ['u', 'i', 'ts'] + else: + PATH = os.path.join(root, '{}.csv'.format(data_name)) #'../data/{}.csv'.format(data_name) + df = pd.read_csv(PATH, sep=',', header=None, index_col=None) + df.columns = ['u', 'i', 'rating', 'ts'] + new_df = reindex(df) + max_idx = max(new_df.u.max(), new_df.i.max()) + print('num of nodes: ', max_idx+1) + + newnew_df = disperse_dataset(data_name, new_df) ## split into snapshots + newnew_df.to_csv(OUT_DF) + + generate_init_graph(OUT_INIT_GRAPH, node_num = max_idx+1) # generate empty graph that contains only self-loop edges + build_time_edge_map(newnew_df,OUT_TIME_EDGE_MAP, disperse=disperse) + + + + + + + diff --git a/StreamLearn/tests/test_DecoupledDGNN.py b/StreamLearn/tests/test_DecoupledDGNN.py new file mode 100755 index 0000000000000000000000000000000000000000..e05408a9ab979b1e7c3f338e3a1287a3f83d34f7 --- /dev/null +++ b/StreamLearn/tests/test_DecoupledDGNN.py @@ -0,0 +1,49 @@ +import numpy as np +import argparse +import random +import torch + +from StreamLearn.Algorithm.DecoupledDGNN.DGNN import DGNN +from StreamLearn.Dataset.DTDGsDataset import DTDGsDataset +from StreamLearn.Algorithm.DecoupledDGNN.graph_embs import GraphEmbs + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + +def parameter_parser(): + parser = argparse.ArgumentParser(description="Tweet Classification") + + parser.add_argument("--data", default='bitcoinotc', help="Dataset name.") + parser.add_argument("--epochs", type=int, default=50, help="Number of gradient descent iterations. Default is 50.") + parser.add_argument("--learning_rate", type=float, default=0.01, help="Gradient descent learning rate. Default is 0.01.") + parser.add_argument("--hidden_dim", type=int, default=128, help="Number of neurons by hidden layer. Default is 128.") + parser.add_argument("--lstm_layers", type=int, default=2, help="Number of LSTM layers.") + parser.add_argument("--batch_size", type=int, default=64, help="Batch size.") + parser.add_argument("--dropout", type=float, default=0.1, help="Dropout rate. Default is 0.1.") + parser.add_argument("--max_len", type=int, default=20, help="Maximum sequence length per tweet.") + parser.add_argument("--seq_model", default='lstm', help="Sequence model type.") + parser.add_argument("--window_size", type=int, default=20, help="Window size.") + parser.add_argument("--emb_size", type=int, default=256, help="Embedding size.") + parser.add_argument("--seed", type=int, default=1024, help="Random seed.") + parser.add_argument("--gpu", type=int, default=0, help="ID number of GPU.") + parser.add_argument("--patience", type=int, default=3, help="Max number of bad count.") + parser.add_argument("--checkpt_path", default='./check_point', help="Checkpoint path.") + parser.add_argument('--inductive', action='store_true', help='Whether to use inductive setting.') + parser.add_argument('--shuffle', action='store_true', help='Whether to shuffle training data.') + parser.add_argument('--use_neg', action='store_true', help='Whether to use negative features.') + parser.add_argument("--alpha", type=float, default=0.98, help="Parameter for FocalLoss - alpha.") + + return parser.parse_args() + +if __name__ == "__main__": + args = parameter_parser() + + dataset = DTDGsDataset(args.data) + gen_embs = GraphEmbs(dataset.path, args.data, args.rmax, args.alpha) + gen_embs.load_and_process_data + + execute = DGNN(args) + execute.train()