335 Star 1.5K Fork 858

MindSpore / docs

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
spatio_temporal_graph_training_STGCN.md 6.70 KB
一键复制 编辑 原始数据 按行查看 历史
TingWang 提交于 2023-09-18 10:08 . update link logo

时空图训练网络

查看源文件   

概述

在本例中将展示如何利用时空图卷积网络进行交通预测。

时空图卷积网络(STGCN)可以解决交通域的时间序列预测问题。实验表明,STGCN通过建模多尺度交通网络,有效地捕获了综合的时空相关性。

METR-LA是一个大规模数据集,从洛杉矶乡村公路网的1500个交通环路探测器收集。此数据集包括速度、道路容量和占用数据,覆盖约3,420英里。将路网构建成图,输入到STGCN网络中,根据历史数据来预测下个时间段的路网信息。

一般图的节点特征形状为(节点数量, 特征维度),时空图中输入的特征形状通常至少为三维(节点数量, 特征维度, 时间步),邻居节点的特征融合处理会更加复杂。并且由于时间维度上进行卷积,时间步也会发生变化,计算loss时,需要提前计算好输出时间长度。

下载完整的样例STGCN代码。

STGCN原理

论文链接: A deep learning framework for traffic forecasting

图拉普拉斯归一化

将图的自环删除,对图进行归一化,得到新的边索引与边权重。 mindspore_gl.graph提供norm的API可以被用于拉普拉斯归一化。边缘索引和边缘权重归一化的代码如下所示:

mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]
edge_attr = edge_attr[mask]

edge_index = ms.Tensor(edge_index, ms.int32)
edge_attr = ms.Tensor(edge_attr, ms.float32)
edge_index, edge_weight = norm(edge_index, node_num, edge_attr, args.normalization)

关于拉普拉斯归一化的更多细节,可以看mindspore_gl.graph.norm的API.

定义网络结构

mindspore_gl.nn提供了STConv的API可以直接调用。与一般的图卷积层不同,STConv的输入特征为四维,即(批次内图数量, 时间步, 节点数量, 特征维度)。输出特征的时间步需要根据1D卷积核尺寸、卷积次数来计算。

使用STConv实现一个两层的STGCN网络代码如下:

class STGcnNet(GNNCell):
    """ STGCN Net """
    def __init__(self,
                 num_nodes: int,
                 in_channels: int,
                 hidden_channels_1st: int,
                 out_channels_1st: int,
                 hidden_channels_2nd: int,
                 out_channels_2nd: int,
                 out_channels: int,
                 kernel_size: int,
                 k: int,
                 bias: bool = True):
        super().__init__()
        self.layer0 = STConv(num_nodes, in_channels,
                             hidden_channels_1st,
                             out_channels_1st,
                             kernel_size,
                             k, bias)
        self.layer1 = STConv(num_nodes, out_channels_1st,
                             hidden_channels_2nd,
                             out_channels_2nd,
                             kernel_size,
                             k, bias)
        self.relu = ms.nn.ReLU()
        self.fc = ms.nn.Dense(out_channels_2nd, out_channels)

    def construct(self, x, edge_weight, g: Graph):
        x = self.layer0(x, edge_weight, g)
        x = self.layer1(x, edge_weight, g)
        x = self.relu(x)
        x = self.fc(x)
        return x

STConv执行的更多细节可以看mindspore_gl.nn.temporal.STConv的API代码。

定义loss函数

由于本次任务为回归任务,可以采用最小均方差来作为损失函数。这里调用mindspore.nn.MSELoss实现最小均方差loss。

class LossNet(GNNCell):
    """ LossNet definition """
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.loss_fn = nn.loss.MSELoss()

    def construct(self, feat, edges, target, g: Graph):
        """STGCN Net with loss function"""
        predict = self.net(feat, edges, g)
        predict = ops.Squeeze()(predict)
        loss = self.loss_fn(predict, target)
        return ms.ops.ReduceMean()(loss)

构造数据集

输入特征为(批次内图数量, 时间步, 节点数量, 特征维度)。在时序卷积上时间序列的长度将会发生变化。因此,从数据集获取特征和标签时,输入和输出时间步有相应规范,否则会出现预测值与标签值形状不一致。

限制规范可以参考代码注释。

from mindspore_gl.dataset import MetrLa
metr = MetrLa(args.data_path)
# out_timestep setting
# out_timestep = in_timestep - ((kernel_size - 1) * 2 * layer_nums)
# such as: layer_nums = 2, kernel_size = 3, in_timestep = 12,
# out_timestep = 4
features, labels = metr.get_data(args.in_timestep, args.out_timestep)

其中MetrLa数据下载后,解压路径即为args.data_path。

网络训练和验证

设置环境变量

环境变量设置方法可以参考GCN

定义训练网络

实例化模型主体以及LossNet和优化器。 实现方法可以参考GCN

网络训练及验证

实现方法可以参考GCN

执行并查看结果

运行过程

运行程序后,翻译代码并开始训练。

执行结果

执行脚本trainval_metr.py启动训练。

cd model_zoo/stgcn
python trainval_metr.py --data-path={path} --fuse=True

其中{path}为数据集存放路径。

可以看到训练的结果如下:

...
Iteration/Epoch: 600:199 loss: 0.21488506
Iteration/Epoch: 700:199 loss: 0.21441595
Iteration/Epoch: 800:199 loss: 0.21243602
Time 13.162885904312134 Epoch loss 0.21053028
eval MSE: 0.2060675

MetrLa的MSE: 0.206

1
https://gitee.com/mindspore/docs.git
git@gitee.com:mindspore/docs.git
mindspore
docs
docs
master

搜索帮助