pytorch_geometric(pyg) reproduction T-GCN

Posted by GoncaloF on Wed, 12 Jan 2022 04:08:40 +0100


The last article introduced pyg from the basic tools provided by pyg. However, we generally use the tripartite library as a building block to build a relatively large model and use it in our own data set, rather than being satisfied with the simple model and standard data set in the running demo. Therefore, this paper will reproduce T-GCN from( The paper and official source code are shown here )It describes how to build a GNN-RNN model using pyg, including the construction of data sets and models.

At the beginning of reproduction, I stepped on many pits. Some pits were stepped on because I was not familiar with pyg, and some pits were stepped on because the model in the author's paper was inconsistent with that in the source code. This is a reproduction, but it is subject to the model in the author's source code. Then you may ask, there is the source code. What "reproduction" do I quarrel here? Because the source code is not written in pyg, but uses the original GCN calculation method and some matrix multiplication.

Pre preparation

Although I'm not part of the author's team... I think it's necessary to introduce the model and data set of this article

T-GCN introduction

The full name of T-GCN introduced here is traffic GCN. Students may see this abbreviation elsewhere, but it does not necessarily refer to this model.

The core model structure of T-GCN is the combination of GCN and GRU. Firstly, GCN is used to obtain richer node characteristics, and then the characteristics of each node are sent to GRU for calculation. It is equivalent to using GCN to aggregate spatial features and GRU to aggregate temporal features. The specific calculation formula is as follows. It should be noted that each input is the current time feature and the hidden layer feature of GRU. After splicing, they are used as the input of T-GCN cell (which can be considered as a layer of convolution within T-GCN).

C o n v t ( X t , h t ) = L ⋅ c o n c a t ( X t , h t ) G R U ( X t , h t ) = u t = σ ( W u C o n v t ( X t , h t ) + b u ) r t = σ ( W r C o n v t ( X t , h t ) + b r ) c t = t a n h ( W c C o n v t ( X t , h t ) + b c ) h t + 1 = u t ∗ h t + ( 1 − u t ) ∗ c t Conv_t(X_{t},h_{t})=L\cdot concat(X_{t},h_{t})\\ GRU(X_{t},h_{t})= \begin{aligned} u_t & = \sigma(W_uConv_t(X_{t},h_{t})+b_u) \\ r_t & = \sigma(W_rConv_t(X_{t},h_{t})+b_r) \\ c_t & = tanh(W_cConv_t(X_{t},h_{t})+b_c)\\ h_{t+1} & = u_t*h_t+(1-u_t)*c_t \end{aligned}\\ Convt​(Xt​,ht​)=L⋅concat(Xt​,ht​)GRU(Xt​,ht​)=ut​rt​ct​ht+1​​=σ(Wu​Convt​(Xt​,ht​)+bu​)=σ(Wr​Convt​(Xt​,ht​)+br​)=tanh(Wc​Convt​(Xt​,ht​)+bc​)=ut​∗ht​+(1−ut​)∗ct​​

Data set introduction

Only the "shenzhen" dataset is used here. The data set is the traffic flow data collected from 156 roads in shenzhen, with a collection interval of 5 minutes and a dimension of 1. In addition, an adjacency matrix between roads is attached. That is, in the modeling process, the author takes the road as a node and whether the road is connected as the drawing standard.

Start recurrence

The process of using pyg to reproduce T-GCN is more painful, because we must cut our feet to fit our shoes. Personally, it seems that pyg's support for timing Data is not so friendly. Of course, it may also be because I haven't found a DataLoader and Data object suitable for timing Data.

What is the problem with GNN data of time series?

As mentioned in the previous pyg introduction, DataLoader will treat each Data object as a graph. When forming a mini batch, it will package the Data objects in a batch into a large graph. There is no problem with non temporal samples, but for temporal Data, each sample contains multiple graphs. At this time, the objects packaged by DataLoader may not meet our wishes.

Static graph

Let's first assume that in the simplest case, the graph structure and connection relationship of each sample are exactly the same (which is also called "static graph"), so we give each Data object the same adjacency matrix. In order to pursue parallelization, we usually extract all the Data at a time point and calculate them together. Assuming that the initial implicit state is zero vector, the approximate pseudo code of T-GCN is shown in the figure below.

# x.shape is [num_nodes, seq_len, num_features]
h = zeros()
for i in range(seq_len):
    h = gru(gcn(concat(x[:, i, :], h), edge_index))

At first glance, it seems that there is no problem. In fact, there is no problem. edge_ The index is spliced into a large picture in the form of mini batch; After x splicing, a matrix of [batch_size*num_nodes, seq_len, num_features] is formed. One time point is taken for operation each time, and the input logic is very correct.

Dynamic graph

Then let's look again. If the graph structure relationship in the sample (only considering the change of edges here) can change dynamically with time, an independent graph is required for each time point. At this time, the edge specified by the Data object_ The index structure does not meet our requirements, and the matching DataLoader will splice the wrong Mini batch.

In this case, I thought about it before and after, tossed and turned, and came up with a quite suitable method, which is to transform the DataLoader to generate the mini batch function, so that it can splice the adjacency matrix at each time point in the sample, and then generate a List with the dimension of [seq_len, 2, num_edges]. At the same time, note that the num of each graph_ The edges may be different, so such a data can not be packaged into a Tensor, so you can only save the batch map at all time points with a List

Fortunately, T-GCN is a static diagram, which is not so troublesome. This situation is only encountered when you are making salt. If you have a better method, you are welcome to discuss it.

Build a model that fits your feet


Then we build the DataSet object

from typing import List, Union, Tuple

import numpy as np
import torch

from import InMemoryDataset, Dataset, Data
from utils.utils import dataset_path
from constant import DATASET_NAME_TRAFFIC
import pandas as pd

class TrafficDataSet(InMemoryDataset):
    # One point is 15 minutes
    seq_len = 4
    predict_len = 1
    DATASET_TYPE = 'sz'
    PROCESSED_DATASET_FILENAME = '%s_seq%d_pre%d' % (DATASET_TYPE, seq_len, predict_len)
    speed_name = DATASET_TYPE + '_speed.csv'
    adj_name = DATASET_TYPE + '_adj.csv'

    def __init__(self):
        super().__init__(root=dataset_path(DATASET_NAME_TRAFFIC)), self.slices, self.max_speed, self.num_nodes, self.seq_len, self.pre_len = torch.load(self.processed_paths[0])

    def raw_file_names(self) -> Union[str, List[str], Tuple]:
        return [TrafficDataSet.speed_name, TrafficDataSet.adj_name]

    def processed_file_names(self) -> Union[str, List[str], Tuple]:
        return TrafficDataSet.PROCESSED_DATASET_FILENAME + '.pt'

    def download(self):

    def process(self):
        # There is a header in one document and none
        speed = pd.read_csv(self.raw_paths[0]).values
        adj = pd.read_csv(self.raw_paths[1], header=None).values

        num_nodes = len(adj)

        adj = process_adj(adj)
        # Normalize the output of the sample. The normalization parameters need to be recorded and used when calculating the MSE of the test set
        max_speed = np.max(speed)
        speed = speed / max_speed

        speed = torch.tensor(speed, dtype=torch.float32)
        adj = torch.tensor(adj, dtype=torch.int64)

        time_len = speed.shape[0]
        seq_len = TrafficDataSet.seq_len
        pre_len = TrafficDataSet.predict_len
        data_list = []
        for i in range(time_len - seq_len - pre_len):
            # speed = [time_len, num_nodes]
            # x = [num_nodes, seq_len, num_features=1]
            x = speed[i: i + seq_len].transpose(0,1).reshape([num_nodes, seq_len, 1])

            # y = [pre_len, num_nodes] -> [num_nodes, pre_len]
            y = speed[i + seq_len: i + seq_len + pre_len].transpose(0, 1)

            pyg_data = Data(x, edge_index=adj, y=y)

        data, slices = self.collate(data_list), slices, max_speed, num_nodes, seq_len, pre_len), self.processed_paths[0])

# The data set gives the adjacency matrix, which needs to be converted into the form of sparse matrix accepted by pyg
def process_adj(adj):
    node_cnt = len(adj)
    pyg_adj = [[],[]]
    for i in range(node_cnt):
        for j in range(node_cnt):
            if adj[i][j] == 1:
    return np.array(pyg_adj)

T-GCN model

The model doesn't sound complicated, so start building it directly against the source code

import torch
from torch_geometric.nn.conv import GCNConv

import torch.nn.functional as F
class TGCN_Conv_Module(torch.nn.Module):
    def __init__(self, args):
        super(TGCN_Conv_Module, self).__init__()
        self.args = args

        self.num_features = args.c_in
        self.nhid = args.c_out

        # The convolution layer will be input with the hidden of GRU_ State is spliced together as input and hidden is output_ Characteristics of size
        self.conv1 = GCNConv(self.num_features+self.nhid, self.nhid)

    def forward(self, x, edge_index):
        # In fact, only one layer of GCN convolution is used in the author's source code, while two layers are used in the paper
        x = F.relu(self.conv1(x, edge_index))
        x = torch.sigmoid(x)

        return x

class TGCNCell(torch.nn.Module):
    def __init__(self, args):
        super(TGCNCell, self).__init__()
        self.args = args
        self.num_features = args.c_in
        self.nhid = args.c_out
        self.seq_len = args.seq_len
        self.num_nodes = args.num_nodes

        # This is modeled on the author's source code. In fact, these are two GCN s, and their output will be split in half in the forward function
        self.graph_conv1 = GCNConv(self.nhid+self.num_features, self.nhid * 2)
        self.graph_conv2 = GCNConv(self.nhid+self.num_features, self.nhid)


    def reset_parameters(self):
        torch.nn.init.constant_(self.graph_conv1.bias, 1.0)

    def forward(self, x, edge_index, hidden_state):

        ru_input = torch.concat([x, hidden_state], dim=1)

        # Here, the output of a GCN is split into two parts. If you are familiar with its matrix writing, you actually use two GCNs
        # However, the splitting function here is also modeled on the source code. Personally, I think the splitting dimension is wrong, but the accuracy of this writing is high
        ru = torch.sigmoid(self.graph_conv1(ru_input, edge_index))
        r, u = torch.chunk(ru.reshape([-1, self.num_nodes * 2 * self.nhid]), chunks=2, dim=1)
        r = r.reshape([-1, self.nhid])
        u = u.reshape([-1, self.nhid])

        c_input = torch.concat([x, r * hidden_state], dim=1)
        c = torch.tanh(self.graph_conv2(c_input, edge_index))

        new_hidden_state = u * hidden_state + (1.0 - u) * c
        return new_hidden_state

# First perform graph level aggregation, and then perform sequence modeling
class RNNProcessHelper(torch.nn.Module):
    def __init__(self, args, rnn_cell):
        super(RNNProcessHelper, self).__init__()
        self.args = args
        self.num_features = args.c_in
        self.nhid = args.c_out
        self.out_dim = args.out_dim
        self.seq_len = args.seq_len
        self.num_nodes = args.num_nodes

        self.rnn_cell = rnn_cell

    def forward(self, data, hidden_state=None):
        x, edge_index = data.x, data.edge_index
        if type(edge_index) is torch.Tensor:
            is_seq_edge_index = False
        elif type(edge_index) is list:
            is_seq_edge_index = True
            raise 'No edge connection information!'

        if not hidden_state:
            hidden_state = torch.zeros([x.shape[0], self.nhid]).to(self.args.device)

        hidden_state_list = []
        for i in range(self.seq_len):
            # return gru_output.shape = [batch_size*num_nodes, hidden_size]
            if is_seq_edge_index:
                hidden_state = self.rnn_cell(x[:, i, :], edge_index[i], hidden_state)
                hidden_state = self.rnn_cell(x[:, i, :], edge_index, hidden_state)

        return hidden_state_list

# Regression task
class TGCN_Reg_Net(torch.nn.Module):
    def __init__(self, args):
        super(TGCN_Reg_Net, self).__init__()
        self.args = args
        self.num_features = args.c_in
        self.nhid = args.c_out
        self.out_dim = args.out_dim
        self.seq_len = args.seq_len
        self.num_nodes = args.num_nodes

        # self.tgcn_cell = TGCN_Cell(args)
        tgcn_cell = TGCNCell(args)
        self.seq_process_helper = RNNProcessHelper(args, tgcn_cell)

        # The final hidden of each node_ State - > speed of the node in the next 3 hours
        self.lin_out = torch.nn.Linear(self.nhid, self.out_dim)

    def forward(self, data):
        hidden_state_list = self.seq_process_helper(data)

        # Select the last output for prediction
        hidden_state_last = hidden_state_list[-1]
        out = self.lin_out(hidden_state_last)

        # According to the data set construction method, [batch*num_nodes, out_dim]
        return out

    def test(model, loader, args) -> float:
        import math
        loss = 0.0
        max_speed = args.max_speed
        # Because batch=1, the mse of one sample is calculated at a time
        for data in loader:
            data =
            out = model(data)
            loss += F.mse_loss(out, data.y).item()

        mse_loss = loss / len(loader.dataset)
        rmse_loss = math.sqrt(mse_loss) * max_speed
        # print("val RMSE loss:{}".format(rmse_loss))

        return rmse_loss

    def get_loss_function():
        from utils.loss_utils import mse_loss
        return mse_loss

Main function

import math

import torch
from torch_geometric.loader import DataLoader

from utils.dataset_utils import split_dataset_by_ratio
from classfiers.tgcn import TGCN_Reg_Net
from datasets.traffic import TrafficDataSet
from utils.args_utils import get_args_pred
from utils.task_utils import train

if __name__ == '__main__':

    dataset = TrafficDataSet()
    train_set, test_set = split_dataset_by_ratio(dataset)

    args = get_args_pred(dataset)

    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

    model = TGCN_Reg_Net(args).to(args.device)

    optimizer = torch.optim.Adam(model.parameters(),, weight_decay=args.weight_decay)

    train(model, train_loader, test_loader, optimizer, args)

The general construction process of the whole model is like this. Some tool functions are not given, but you can roughly know what they do by looking at the name

Other problems encountered in reproduction

As like as two peas in the code, I found that the model was built exactly the same as the source code, but the final indicators were much worse. Then I tried to look at the super parameter provisions in the source code and found that there is a very subtle parameter, weight_ Instead of adding it to the constructor of the optimizer Adam (that is, making Adam's weight_decaly = 0), the author adds l2 regularization loss of model parameters on the basis of mse_loss when calculating loss. The calculation method is as follows

def regular_loss(model, lamda=0):
    reg_loss = 0.0
    for param in model.parameters():
        reg_loss += torch.sum(param ** 2)
    return lamda * reg_loss

def mse_loss(out, label, model, reg_weight=0):
    classify_loss = F.mse_loss(out.squeeze(), label.squeeze())
    reg_loss = regular_loss(model, reg_weight)
    return classify_loss + reg_loss

Later, after consulting some materials, it was found that this is because Adam's punishment on the model will be adjusted adaptively with the training of the model. Using AdamW can solve this problem, but it is of little use in practice. Therefore, for my half bucket of water, I'd better be honest and practical, and use the parameters adjusted by the author


This paper briefly introduces the process of reproducing T-GCN, and posts the construction process of model and DataSet. At present, I have not sorted out the source code that can be run directly for you to download, because the author has opened the source code, and I just re implemented pyg, which is not very useful for learning pyg itself and T-GCN.

Topics: Pytorch Deep Learning GNN pyg