preface
The last article introduced pyg from the basic tools provided by pyg. However, we generally use the tripartite library as a building block to build a relatively large model and use it in our own data set, rather than being satisfied with the simple model and standard data set in the running demo. Therefore, this paper will reproduce T-GCN from( The paper and official source code are shown here )It describes how to build a GNN-RNN model using pyg, including the construction of data sets and models.
At the beginning of reproduction, I stepped on many pits. Some pits were stepped on because I was not familiar with pyg, and some pits were stepped on because the model in the author's paper was inconsistent with that in the source code. This is a reproduction, but it is subject to the model in the author's source code. Then you may ask, there is the source code. What "reproduction" do I quarrel here? Because the source code is not written in pyg, but uses the original GCN calculation method and some matrix multiplication.
Pre preparation
Although I'm not part of the author's team... I think it's necessary to introduce the model and data set of this article
T-GCN introduction
The full name of T-GCN introduced here is traffic GCN. Students may see this abbreviation elsewhere, but it does not necessarily refer to this model.
The core model structure of T-GCN is the combination of GCN and GRU. Firstly, GCN is used to obtain richer node characteristics, and then the characteristics of each node are sent to GRU for calculation. It is equivalent to using GCN to aggregate spatial features and GRU to aggregate temporal features. The specific calculation formula is as follows. It should be noted that each input is the current time feature and the hidden layer feature of GRU. After splicing, they are used as the input of T-GCN cell (which can be considered as a layer of convolution within T-GCN).
C o n v t ( X t , h t ) = L ⋅ c o n c a t ( X t , h t ) G R U ( X t , h t ) = u t = σ ( W u C o n v t ( X t , h t ) + b u ) r t = σ ( W r C o n v t ( X t , h t ) + b r ) c t = t a n h ( W c C o n v t ( X t , h t ) + b c ) h t + 1 = u t ∗ h t + ( 1 − u t ) ∗ c t Conv_t(X_{t},h_{t})=L\cdot concat(X_{t},h_{t})\\ GRU(X_{t},h_{t})= \begin{aligned} u_t & = \sigma(W_uConv_t(X_{t},h_{t})+b_u) \\ r_t & = \sigma(W_rConv_t(X_{t},h_{t})+b_r) \\ c_t & = tanh(W_cConv_t(X_{t},h_{t})+b_c)\\ h_{t+1} & = u_t*h_t+(1-u_t)*c_t \end{aligned}\\ Convt(Xt,ht)=L⋅concat(Xt,ht)GRU(Xt,ht)=utrtctht+1=σ(WuConvt(Xt,ht)+bu)=σ(WrConvt(Xt,ht)+br)=tanh(WcConvt(Xt,ht)+bc)=ut∗ht+(1−ut)∗ct
Data set introduction
Only the "shenzhen" dataset is used here. The data set is the traffic flow data collected from 156 roads in shenzhen, with a collection interval of 5 minutes and a dimension of 1. In addition, an adjacency matrix between roads is attached. That is, in the modeling process, the author takes the road as a node and whether the road is connected as the drawing standard.
Start recurrence
The process of using pyg to reproduce T-GCN is more painful, because we must cut our feet to fit our shoes. Personally, it seems that pyg's support for timing Data is not so friendly. Of course, it may also be because I haven't found a DataLoader and Data object suitable for timing Data.
What is the problem with GNN data of time series?
As mentioned in the previous pyg introduction, DataLoader will treat each Data object as a graph. When forming a mini batch, it will package the Data objects in a batch into a large graph. There is no problem with non temporal samples, but for temporal Data, each sample contains multiple graphs. At this time, the objects packaged by DataLoader may not meet our wishes.
Static graph
Let's first assume that in the simplest case, the graph structure and connection relationship of each sample are exactly the same (which is also called "static graph"), so we give each Data object the same adjacency matrix. In order to pursue parallelization, we usually extract all the Data at a time point and calculate them together. Assuming that the initial implicit state is zero vector, the approximate pseudo code of T-GCN is shown in the figure below.
# x.shape is [num_nodes, seq_len, num_features] h = zeros() for i in range(seq_len): h = gru(gcn(concat(x[:, i, :], h), edge_index))
At first glance, it seems that there is no problem. In fact, there is no problem. edge_ The index is spliced into a large picture in the form of mini batch; After x splicing, a matrix of [batch_size*num_nodes, seq_len, num_features] is formed. One time point is taken for operation each time, and the input logic is very correct.
Dynamic graph
Then let's look again. If the graph structure relationship in the sample (only considering the change of edges here) can change dynamically with time, an independent graph is required for each time point. At this time, the edge specified by the Data object_ The index structure does not meet our requirements, and the matching DataLoader will splice the wrong Mini batch.
In this case, I thought about it before and after, tossed and turned, and came up with a quite suitable method, which is to transform the DataLoader to generate the mini batch function, so that it can splice the adjacency matrix at each time point in the sample, and then generate a List with the dimension of [seq_len, 2, num_edges]. At the same time, note that the num of each graph_ The edges may be different, so such a data can not be packaged into a Tensor, so you can only save the batch map at all time points with a List
Fortunately, T-GCN is a static diagram, which is not so troublesome. This situation is only encountered when you are making salt. If you have a better method, you are welcome to discuss it.
Build a model that fits your feet
DataSet
Then we build the DataSet object
from typing import List, Union, Tuple import numpy as np import torch from torch_geometric.data import InMemoryDataset, Dataset, Data from utils.utils import dataset_path from constant import DATASET_NAME_TRAFFIC import pandas as pd class TrafficDataSet(InMemoryDataset): # One point is 15 minutes seq_len = 4 predict_len = 1 DATASET_TYPE = 'sz' PROCESSED_DATASET_FILENAME = '%s_seq%d_pre%d' % (DATASET_TYPE, seq_len, predict_len) speed_name = DATASET_TYPE + '_speed.csv' adj_name = DATASET_TYPE + '_adj.csv' def __init__(self): super().__init__(root=dataset_path(DATASET_NAME_TRAFFIC)) self.data, self.slices, self.max_speed, self.num_nodes, self.seq_len, self.pre_len = torch.load(self.processed_paths[0]) @property def raw_file_names(self) -> Union[str, List[str], Tuple]: return [TrafficDataSet.speed_name, TrafficDataSet.adj_name] @property def processed_file_names(self) -> Union[str, List[str], Tuple]: return TrafficDataSet.PROCESSED_DATASET_FILENAME + '.pt' def download(self): pass def process(self): # There is a header in one document and none speed = pd.read_csv(self.raw_paths[0]).values adj = pd.read_csv(self.raw_paths[1], header=None).values num_nodes = len(adj) adj = process_adj(adj) # Normalize the output of the sample. The normalization parameters need to be recorded and used when calculating the MSE of the test set max_speed = np.max(speed) speed = speed / max_speed speed = torch.tensor(speed, dtype=torch.float32) adj = torch.tensor(adj, dtype=torch.int64) time_len = speed.shape[0] seq_len = TrafficDataSet.seq_len pre_len = TrafficDataSet.predict_len data_list = [] for i in range(time_len - seq_len - pre_len): # speed = [time_len, num_nodes] # x = [num_nodes, seq_len, num_features=1] x = speed[i: i + seq_len].transpose(0,1).reshape([num_nodes, seq_len, 1]) # y = [pre_len, num_nodes] -> [num_nodes, pre_len] y = speed[i + seq_len: i + seq_len + pre_len].transpose(0, 1) pyg_data = Data(x, edge_index=adj, y=y) data_list.append(pyg_data) data, slices = self.collate(data_list) torch.save((data, slices, max_speed, num_nodes, seq_len, pre_len), self.processed_paths[0]) # The data set gives the adjacency matrix, which needs to be converted into the form of sparse matrix accepted by pyg def process_adj(adj): node_cnt = len(adj) pyg_adj = [[],[]] for i in range(node_cnt): for j in range(node_cnt): if adj[i][j] == 1: pyg_adj[0].append(i) pyg_adj[1].append(j) return np.array(pyg_adj)
T-GCN model
The model doesn't sound complicated, so start building it directly against the source code
import torch from torch_geometric.nn.conv import GCNConv import torch.nn.functional as F class TGCN_Conv_Module(torch.nn.Module): def __init__(self, args): super(TGCN_Conv_Module, self).__init__() self.args = args self.num_features = args.c_in self.nhid = args.c_out # The convolution layer will be input with the hidden of GRU_ State is spliced together as input and hidden is output_ Characteristics of size self.conv1 = GCNConv(self.num_features+self.nhid, self.nhid) def forward(self, x, edge_index): # In fact, only one layer of GCN convolution is used in the author's source code, while two layers are used in the paper x = F.relu(self.conv1(x, edge_index)) x = torch.sigmoid(x) return x class TGCNCell(torch.nn.Module): def __init__(self, args): super(TGCNCell, self).__init__() self.args = args self.num_features = args.c_in self.nhid = args.c_out self.seq_len = args.seq_len self.num_nodes = args.num_nodes # This is modeled on the author's source code. In fact, these are two GCN s, and their output will be split in half in the forward function self.graph_conv1 = GCNConv(self.nhid+self.num_features, self.nhid * 2) self.graph_conv2 = GCNConv(self.nhid+self.num_features, self.nhid) self.reset_parameters() def reset_parameters(self): torch.nn.init.constant_(self.graph_conv1.bias, 1.0) def forward(self, x, edge_index, hidden_state): ru_input = torch.concat([x, hidden_state], dim=1) # Here, the output of a GCN is split into two parts. If you are familiar with its matrix writing, you actually use two GCNs # However, the splitting function here is also modeled on the source code. Personally, I think the splitting dimension is wrong, but the accuracy of this writing is high ru = torch.sigmoid(self.graph_conv1(ru_input, edge_index)) r, u = torch.chunk(ru.reshape([-1, self.num_nodes * 2 * self.nhid]), chunks=2, dim=1) r = r.reshape([-1, self.nhid]) u = u.reshape([-1, self.nhid]) c_input = torch.concat([x, r * hidden_state], dim=1) c = torch.tanh(self.graph_conv2(c_input, edge_index)) new_hidden_state = u * hidden_state + (1.0 - u) * c return new_hidden_state # First perform graph level aggregation, and then perform sequence modeling class RNNProcessHelper(torch.nn.Module): def __init__(self, args, rnn_cell): super(RNNProcessHelper, self).__init__() self.args = args self.num_features = args.c_in self.nhid = args.c_out self.out_dim = args.out_dim self.seq_len = args.seq_len self.num_nodes = args.num_nodes self.rnn_cell = rnn_cell def forward(self, data, hidden_state=None): x, edge_index = data.x, data.edge_index if type(edge_index) is torch.Tensor: is_seq_edge_index = False elif type(edge_index) is list: is_seq_edge_index = True else: raise 'No edge connection information!' if not hidden_state: hidden_state = torch.zeros([x.shape[0], self.nhid]).to(self.args.device) hidden_state_list = [] for i in range(self.seq_len): # return gru_output.shape = [batch_size*num_nodes, hidden_size] if is_seq_edge_index: hidden_state = self.rnn_cell(x[:, i, :], edge_index[i], hidden_state) else: hidden_state = self.rnn_cell(x[:, i, :], edge_index, hidden_state) hidden_state_list.append(hidden_state) return hidden_state_list # Regression task class TGCN_Reg_Net(torch.nn.Module): def __init__(self, args): super(TGCN_Reg_Net, self).__init__() self.args = args self.num_features = args.c_in self.nhid = args.c_out self.out_dim = args.out_dim self.seq_len = args.seq_len self.num_nodes = args.num_nodes # self.tgcn_cell = TGCN_Cell(args) tgcn_cell = TGCNCell(args) self.seq_process_helper = RNNProcessHelper(args, tgcn_cell) # The final hidden of each node_ State - > speed of the node in the next 3 hours self.lin_out = torch.nn.Linear(self.nhid, self.out_dim) def forward(self, data): hidden_state_list = self.seq_process_helper(data) # Select the last output for prediction hidden_state_last = hidden_state_list[-1] out = self.lin_out(hidden_state_last) # According to the data set construction method, [batch*num_nodes, out_dim] return out @staticmethod def test(model, loader, args) -> float: import math model.eval() loss = 0.0 max_speed = args.max_speed # Because batch=1, the mse of one sample is calculated at a time for data in loader: data = data.to(args.device) out = model(data) loss += F.mse_loss(out, data.y).item() mse_loss = loss / len(loader.dataset) rmse_loss = math.sqrt(mse_loss) * max_speed # print("val RMSE loss:{}".format(rmse_loss)) return rmse_loss @staticmethod def get_loss_function(): from utils.loss_utils import mse_loss return mse_loss
Main function
import math import torch from torch_geometric.loader import DataLoader from utils.dataset_utils import split_dataset_by_ratio from classfiers.tgcn import TGCN_Reg_Net from datasets.traffic import TrafficDataSet from utils.args_utils import get_args_pred from utils.task_utils import train if __name__ == '__main__': dataset = TrafficDataSet() train_set, test_set = split_dataset_by_ratio(dataset) args = get_args_pred(dataset) train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False) test_loader = DataLoader(test_set, batch_size=1, shuffle=False) model = TGCN_Reg_Net(args).to(args.device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) train(model, train_loader, test_loader, optimizer, args)
The general construction process of the whole model is like this. Some tool functions are not given, but you can roughly know what they do by looking at the name
Other problems encountered in reproduction
As like as two peas in the code, I found that the model was built exactly the same as the source code, but the final indicators were much worse. Then I tried to look at the super parameter provisions in the source code and found that there is a very subtle parameter, weight_ Instead of adding it to the constructor of the optimizer Adam (that is, making Adam's weight_decaly = 0), the author adds l2 regularization loss of model parameters on the basis of mse_loss when calculating loss. The calculation method is as follows
def regular_loss(model, lamda=0): reg_loss = 0.0 for param in model.parameters(): reg_loss += torch.sum(param ** 2) return lamda * reg_loss def mse_loss(out, label, model, reg_weight=0): classify_loss = F.mse_loss(out.squeeze(), label.squeeze()) reg_loss = regular_loss(model, reg_weight) return classify_loss + reg_loss
Later, after consulting some materials, it was found that this is because Adam's punishment on the model will be adjusted adaptively with the training of the model. Using AdamW can solve this problem, but it is of little use in practice. Therefore, for my half bucket of water, I'd better be honest and practical, and use the parameters adjusted by the author
Postscript
This paper briefly introduces the process of reproducing T-GCN, and posts the construction process of model and DataSet. At present, I have not sorted out the source code that can be run directly for you to download, because the author has opened the source code, and I just re implemented pyg, which is not very useful for learning pyg itself and T-GCN.