1 Star 2 Fork 0

solaris / HRNR

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_label_pred.py 4.64 KB
一键复制 编辑 原始数据 按行查看 历史
solaris 提交于 2020-09-03 17:07 . first commit
import pickle
from conf import beijing_label_hparams
from utils import *
from label_model import *
import torch
from torch import optim
import numpy as np
import random
import os
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
import argparse
def test_label_pred(model, test_set, test_label, hparams):
right = 0
sum_num = 0
test_set = torch.tensor(test_set, dtype=torch.long, device = hparams.device)
pred = model(test_set)
pred_prob = F.softmax(pred, -1)
pred_scores = pred_prob[:, 1]
auc = roc_auc_score(np.array(test_label), np.array(pred_scores.tolist()))
print("auc:", auc)
# print("pred:", pred, "test_label:", test_label)
pred_loc = torch.argmax(pred, 1).tolist()
right_pos = 0
right_neg = 0
wrong_pos = 0
wrong_neg = 0
for item1, item2 in zip(pred_loc, test_label):
if item1 == item2:
right += 1
if item2 == 1:
right_pos += 1
else:
right_neg += 1
else:
if item2 == 1:
wrong_pos += 1
else:
wrong_neg += 1
sum_num += 1
recall_sum = right_pos + wrong_pos
precision_sum = wrong_neg + right_pos
if recall_sum == 0:
recall_sum += 1
if precision_sum == 0:
precision_sum += 1
recall = float(right_pos)/(recall_sum)
precision = float(right_pos)/(precision_sum)
if recall == 0 or precision == 0:
print("p/r/f:0/0/0")
return 0.0, 0.0, 0.0, 0.0
f1 = 2*recall*precision/(precision + recall)
print("label prediction @acc @p/r/f:", float(right)/sum_num, precision, recall, f1)
return precision, recall, f1, auc
def train_label_pred(gpu_id):
hparams = dict_to_object(beijing_label_hparams)
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
hparams.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
adj, features, struct_assign, fnc_assign = load_label_pred_data(hparams)
ce_criterion = torch.nn.CrossEntropyLoss()
adj_indices = torch.tensor(np.concatenate([adj.row[:, np.newaxis], adj.col[:, np.newaxis]], 1), dtype=torch.long).t()
adj_values = torch.tensor(adj.data, dtype=torch.float)
adj_shape = adj.shape
adj_tensor = torch.sparse.FloatTensor(adj_indices, adj_values, adj_shape)
features = features.astype(np.int)
lane_feature = torch.tensor(features[:, 0], dtype=torch.long, device = hparams.device)
type_feature = torch.tensor(features[:, 1], dtype=torch.long, device = hparams.device)
length_feature = torch.tensor(features[:, 2], dtype=torch.long, device = hparams.device)
node_feature = torch.tensor(features[:, 3], dtype=torch.long, device = hparams.device)
struct_assign = torch.tensor(struct_assign, dtype=torch.float, device = hparams.device)
fnc_assign = torch.tensor(fnc_assign, dtype=torch.float, device = hparams.device)
lp_model = LabelPredModel(hparams, lane_feature, type_feature, length_feature, node_feature, adj_tensor, struct_assign, fnc_assign).to(hparams.device)
model_optimizer = optim.Adam(lp_model.parameters(), lr=hparams.lp_learning_rate)
max_f1 = 0
max_auc = 0
for i in range(hparams.label_epoch):
print("epoch", i)
count = 0
train_set, train_label, test_set, test_label = get_label_train_data(hparams)
model_optimizer.zero_grad()
train_set = torch.tensor(train_set, dtype=torch.long, device = hparams.device)
train_label = torch.tensor(train_label, dtype=torch.long, device = hparams.device)
pred = lp_model(train_set)
print(pred.shape, train_label.shape, train_set.shape)
loss = ce_criterion(pred, train_label)
loss.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(lp_model.parameters(), hparams.lp_clip)
model_optimizer.step()
if count % 20 == 0:
precision, recall, f1, auc = test_label_pred(lp_model, test_set, test_label, hparams)
if auc > max_auc:
max_auc = auc
if f1 > max_f1:
max_f1 = f1
print("max_auc:", max_auc)
print("max_f1:", max_f1)
print("step ", str(count))
print(loss.item())
# torch.save(lp_model.state_dict(), "/data/wuning/RN-GNN/beijing/model/label_pred.model_" + str(i))
count += 1
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
setup_seed(1)
parser = argparse.ArgumentParser()
parser.add_argument(
"--gpu_id",
default=None,
type=str,
required=True,
help="gpu id",
)
args = parser.parse_args()
train_label_pred(args.gpu_id) # three stage model for loc prediction
1
https://gitee.com/solaris_wn/HRNR.git
git@gitee.com:solaris_wn/HRNR.git
solaris_wn
HRNR
HRNR
master

搜索帮助