# LPFormer **Repository Path**: yao9e/lpformer ## Basic Information - **Project Name**: LPFormer - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2024-07-31 - **Last Updated**: 2024-07-31 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # LPFormer Official Implementation of the KDD'24 paper - "LPFormer: An Adaptive Graph Transformer for Link Prediction" ![Framework](https://raw.githubusercontent.com/HarryShomer/LPFormer/master/LPFormer-Framework.png) ## Abstract Link prediction is a common task on graph-structured data that has seen applications in a variety of domains. Classically, hand-crafted heuristics were used for this task. Heuristic measures are chosen such that they correlate well with the underlying factors related to link formation. In recent years, a new class of methods has emerged that combines the advantages of message-passing neural networks (MPNN) and heuristics methods. These methods perform predictions by using the output of an MPNN in conjunction with a "pairwise encoding" that captures the relationship between nodes in the candidate link. They have been shown to achieve strong performance on numerous datasets. However, current pairwise encodings often contain a strong inductive bias, using the same underlying factors to classify all links. This limits the ability of existing methods to learn how to properly classify a variety of different links that may form from different factors. To address this limitation, we propose a new method, LPFormer, which attempts to adaptively learn the pairwise encodings for each link. LPFormer models the link factors via an attention module that learns the pairwise encoding that exists between nodes by modeling multiple factors integral to link prediction. Extensive experiments demonstrate that LPFormer can achieve SOTA performance on numerous datasets while maintaining efficiency. ## Requirements All experiments were run using python 3.9.13. The required python packages can be installed via the `requirements.txt` file. ``` pip install -r requirements.txt ``` ## Data The data for Cora, Citeseer, and Pubmed can be downloaded from [here](https://github.com/Juanhui28/HeaRT#download-data). The data should correspondingly be placed in a directory called `dataset` in the root project directory. The data for the OGB datasets are downloaded automatically from the `ogb` package. ## Reproduce Results ### Reproduce the Paper Results The commands for reproducing the results on the existing setting in the paper are in the `scripts/replicate_existing.sh` file. For the HeaRT setting, they are in `scripts/replicate_heart.sh`. Please note that for the ogbl-citation2 and ogbl-ddi, over 32GB of GPU memory is required to train the model. ### Running Yourself 1. To add a new dataset, you'll need to add a custom function in `src/util/read_datasets.py`. Then add the option to call that function in `run_model` in `src/run.py`. 2. When computing the PPR matrix, the parameter `--eps` controls the approximation accuracy. If you'd like a better (or worse) approximation of the PPR scores, please adjust `--eps` accordingly. Please note that for larger datasets, a very lower epsilon may take a very long time to run and will result in a large file saved to the disk. 3. The list of hyperparameters can be found by looking at one of the sample commands in either `scripts/replicate_existing.sh` or `scripts/replicate_heart.sh`. ## Cite ``` @article{shomer2024adaptive, title={LPFormer: An Adaptive Graph Transformer for Link Prediction}, author={Harry Shomer and Yao Ma and Haitao Mao and Juanhui Li and Bo Wu and Jiliang Tang}, booktitle={Proceedings of the 30th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, year={2024} } ``` 这段代码定义了一个 `sample` 方法,它是 `LocalSampler` 类或 `LocalSamplerNew` 类中的一部分。让我们逐行分析它的作用和功能: 1. **参数解析**: - `batch`:传入的参数,可以是一个张量 (`Tensor`) 或一个列表/数组。如果不是张量,则将其转换为张量。 ```python if not isinstance(batch, Tensor): batch = torch.tensor(batch) ``` 2. **初始化变量**: - `batch_size`:计算 `batch` 的长度,即批次的大小。 ```python batch_size: int = len(batch) ``` 3. **循环采样**: - `edge_index` 和 `edge_dist` 初始化为空列表。在循环中,对每个 `batch` 中的节点进行单独的采样操作,调用 `sample_one` 方法。 ```python edge_index, edge_dist = [], [] for i in range(len(batch)): out = self.sample_one(batch[i:i+1]) edge_index.append(out[0]) edge_dist.append(out[1]) ``` 4. **拼接结果**: - 将所有节点的采样结果拼接成一个张量 `edge_index` 和 `edge_dist`。这里假设 `sample_one` 方法返回的是两个张量,分别是边的索引和边的距离信息。 ```python edge_index = torch.cat(edge_index, dim=1) edge_dist = torch.cat(edge_dist, dim=1) ``` 5. **处理节点索引**: - `node_idx` 获取所有采样结果中的源节点,并确保包括目标节点。 - `node_idx_flag` 用于检查哪些节点不在 `batch` 中。 - 更新 `node_idx`,将 `batch` 和不在 `batch` 中的节点合并。 ```python node_idx = torch.unique(edge_index[0]) # 获取所有源节点,包括目标节点 node_idx_flag = torch.tensor([i not in batch for i in node_idx]) # 检查哪些节点不在 batch 中 node_idx = node_idx[node_idx_flag] # 筛选出不在 batch 中的节点 node_idx = torch.cat([batch, node_idx]) # 将 batch 和不在 batch 中的节点合并 ``` 6. **重新标记节点索引**: - 创建一个全零张量 `node_idx_all`,用于存储重新标记后的节点索引。 - 使用 `torch.arange(node_idx.size(0))` 将重新标记后的节点索引赋值给 `node_idx_all`。 ```python node_idx_all = torch.zeros(self.num_nodes, dtype=torch.long) # 创建全零张量,大小为 num_nodes node_idx_all[node_idx] = torch.arange(node_idx.size(0)) # 将重新标记后的节点索引赋值给 node_idx_all edge_index = node_idx_all[edge_index] # 根据重新标记后的节点索引更新 edge_index ``` 7. **返回结果**: - 返回经过处理后的 `edge_index` 和 `edge_dist`,以及更新后的 `node_idx` 和 `batch_size`。 ```python return torch.cat([edge_index, edge_dist], dim=0), node_idx, batch_size ``` ### 总结: 这段代码的主要作用是从输入的 `batch` 中每个节点进行单独的采样操作,获取边的索引和边的距离信息,并对节点索引进行重新标记,最终返回处理后的结果。这在图数据处理中非常常见,特别是在使用 PyTorch Geometric 进行图神经网络训练时,需要对图数据进行采样和处理。