1 Star 0 Fork 0

最近......的Ender / TeMP

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README

TeMP: Temporal Message Passing Network for Temporal Knowledge Graph Completion

PyTorch implementation of TeMP: Temporal Message Passing Network for Temporal Knowledge Graph Completion (EMNLP 2020)

Update: you can now download the available trained model here. The hyperparameter configuration are both suggested by the folder name, and detailed in the config.json in each checkpoint folder.

Installation

Create a conda virtual environment first, you can name your_env_name yourself:

conda create --name <your_env_name> python=3.6.10
conda activate <your_env_name>

Assuming that you are using cuda 10.1, the package installation process is as follows:

conda install pytorch=1.3.0 cudatoolkit=10.1 -c pytorch && conda install -c dglteam dgl-cuda10.1==0.4.1 && python -m pip install -U matplotlib && pip install -r requirements.txt

Training a model

The config files are stored in the grid folder. The structure of the folder looks like this:

grid
├── icews14
├── icews15
└── gdelt      

Each subfolder contains the following effective config files:

icews14
├── config_bigrrgcn.json # bidirectional GRU + RGCN
├-- config_bisargcn.json # bidirectional Transformer + RGCN
├-- config_grrgcn.json   # one-directional GRU + RGCN
├-- config_sargcn.json   # one-directional Transformer + RGCN
└-- config_srgcn.json    # RGCN only

The following command trains a model using the bidirectional GRU + RGCN model with frequency based gating. The config file following -c provide a set of parameters that overwrites the default parameters.

python -u main.py -c configs/grid/icews15/config_bisargcn.json --rec-only-last-layer --use-time-embedding --post-ensemble

--n-gpu: index of the gpus for usage, e.g. --n-gpu 0 1 2 for using GPU indexed 0, 1 and 2.

--module: model architecture:

  • baselines:Static for static KG embedding, SRGCN for static RGCN; DE, Hyte.
  • GRRGCN: GRU + RGCN; BiGRRGCN: BiGRU + RGCN
  • SARGCN: Transformer + RGCN; BiSARGCN: BiTransformer + RGCN

--dataset or -d: name of the dataset, icews14, icews05-15 or gdelt

--config: name of the config file.

--score-function: decoding function. Choose among TransE, distmult and complex. Default: complex

--negative-rate: number of negative samples per training instance. Note that for both object and subject we sample this amount of negative entities.

--max-nb-epochs: maximum number of training epoches

--patience: stop training after waiting for this number of epochs after model achieving the best performance on validation set

--n_bases: number of blocks in each block-diagonal relation matrix. Used for RGCN representation

--num_pos_facts: number of sampled facts to construct the training graph at each time step

--train-seq-len: number of time steps preceding each time step t, from which historical facts are sampled. For single directional models, the model uses this number of snapshots preceeding the current time step. For bidirectional model (BiGRRGCN or BiSARGCN), the model uses this number of time steps both before and after the current time step.

--test-seq-len: same as --train-seq-len, except that it is used at the test time.

Flag arguments:

--post-ensemble: use frequency based gating (see paper)

--impute: use imputation (see paper)

--learnable-lambda: learn the temperature lambda as a learnable parameter, as described in the paper

--rec-only-last-layer: use recurrence only in the last RGCN layer. We find this to be the most effective hence include it in the paper.

--random-dropout: randomly drop half of edges in each historical and/or future time step

--debug: only train the model using 0.1 percent of the data for the sanity check purpose

--fast_dev_run: runs full iteration over everything to find bugs

--type1: use type 1 GRU cell defined by the wikipedia page implemented by ourselves

Testing and analysis

To test a model on the corresponding test set, run the following:

python -u test.py --checkpoint-path ${path-to-your-model-checkpoint}

To perform various link prediction analysis:

python link_prediction_analysis.py --checkpoint-path ${path-to-your-model-checkpoint}

To get the prediction of the TED classifier, run python greedy_classifier.py with the desired parameters.

Trained model checkpoints and computation specs

You can directly download the trained models here and use them for inference. If you would like to perform training on a cluster, please refer to the following bash files for the specified computation resource requirements:

launcher_baseline.sh: script for embedding models and SRGCN models

launcher_14.sh: script for all models on the ICEWS14 dataset

launcher_15.sh: script for all models on the ICEWS05-15 dataset (except SARGCN models)

launcher_15_sargcn.sh: script for SARGCN models on the ICEWS05-15

launcher_2gpu.sh: script for all models on the GDELT dataset

空文件

简介

暂无描述 展开 收起
Python 等 2 种语言
取消

发行版

暂无发行版

贡献者

全部

近期动态

加载更多
不能加载更多了
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/enderbyendera/TeMP.git
git@gitee.com:enderbyendera/TeMP.git
enderbyendera
TeMP
TeMP
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891