代码拉取完成,页面将自动刷新
import pickle
from conf import beijing_des_hparams
from utils import *
from des_model import *
import torch
from torch import optim
import numpy as np
import random
import os
import torch.nn.functional as F
def test_des_pred(model, test_loc_set, hparams):
right = 0
sum_num = 0
count = 0
right_5 = 0
for batch in test_loc_set:
batch = np.array(batch)
if len(batch[0]) > 10 or len(batch[0]) < 6:
continue
input_tra = batch[:, :5]
label = batch[:, -1]
if count == 0:
count += 1
input_tra = torch.tensor(input_tra, dtype=torch.long, device = hparams.device)
label_tensor = torch.tensor(label, dtype=torch.long, device = hparams.device)
pred = model(input_tra)
pred = pred.view(pred.shape[1], pred.shape[0], pred.shape[2])
pred_loc = torch.argmax(pred, 2).tolist()
pred_loc = np.array(pred_loc)[:, -1]
pred_5 = (-np.array(pred.tolist())).argsort()[:, :, :5]
for item1, item2, item3 in zip(pred_loc.tolist(), label.tolist(), pred_5[:, -1, :].tolist()):
if item1 == item2:
right += 1
if item2 in item3:
right_5 += 1
sum_num += 1
print("des prediction @acc:", float(right)/sum_num)
print("des prediction @acc5", float(right_5)/sum_num)
def train_des_pred():
hparams = dict_to_object(beijing_des_hparams)
os.environ["CUDA_VISIBLE_DEVICES"] = str(hparams.device)
hparams.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
adj, features, struct_assign, fnc_assign, train_loc_set = load_loc_pred_data(hparams)
ce_criterion = torch.nn.CrossEntropyLoss()
train_loc_set = pickle.load(open("data/des_train_set", "rb"))
test_loc_set = pickle.load(open("data/des_test_set", "rb"))
adj_indices = torch.tensor(np.concatenate([adj.row[:, np.newaxis], adj.col[:, np.newaxis]], 1), dtype=torch.long).t()
adj_values = torch.tensor(adj.data, dtype=torch.float)
adj_shape = adj.shape
adj_tensor = torch.sparse.FloatTensor(adj_indices, adj_values, adj_shape)
features = features.astype(np.int)
lane_feature = torch.tensor(features[:, 0], dtype=torch.long, device = hparams.device)
type_feature = torch.tensor(features[:, 1], dtype=torch.long, device = hparams.device)
length_feature = torch.tensor(features[:, 2], dtype=torch.long, device = hparams.device)
node_feature = torch.tensor(features[:, 3], dtype=torch.long, device = hparams.device)
struct_assign = torch.tensor(struct_assign, dtype=torch.float, device = hparams.device)
fnc_assign = torch.tensor(fnc_assign, dtype=torch.float, device = hparams.device)
lp_model = LocPredModel(hparams, lane_feature, type_feature, length_feature, node_feature, adj_tensor, struct_assign, fnc_assign).to(hparams.device)
model_optimizer = optim.Adam(lp_model.parameters(), lr=hparams.lp_learning_rate)
for i in range(hparams.gae_epoch):
print("epoch", i)
count = 0
for batch in train_loc_set:
model_optimizer.zero_grad()
if len(batch[0]) > 12 or len(batch[0]) < 6:
continue
input_tra = torch.tensor(np.array(batch)[:, :5], dtype=torch.long, device = hparams.device)
pred_label = torch.tensor(np.array(batch)[:, -1], dtype=torch.long, device = hparams.device)
pred = lp_model(input_tra)
pred = pred.view(pred.shape[1], pred.shape[0], pred.shape[2])
pred = pred[:, -1, :]
loss = ce_criterion(pred.view(-1, hparams.node_num), pred_label.view(-1))
loss.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(lp_model.parameters(), hparams.lp_clip)
model_optimizer.step()
# print("grad:", g2s_model.linear.weight.grad)
if count % 200 == 0:
# test_loc_pred(lp_model, test_loc_set, hparams, 5)
test_des_pred(lp_model, test_loc_set, hparams)
print("step ", str(count))
print(loss.item())
count += 1
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
setup_seed(42)
train_des_pred() # three stage model for des prediction
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。