Md Shamim Hussain, Mohammed J. Zaki, Dharmashankar Subramanian
We propose an extension to the transformer neural network architecture for general-purpose graph learning by adding a dedicated pathway for pairwise structural information, called edge channels. The resultant framework - which we call Edge-augmented Graph Transformer (EGT) - can directly accept, process and output structural information of arbitrary form, which is important for effective learning on graph-structured data. Our model exclusively uses global self-attention as an aggregation mechanism rather than static localized convolutional aggregation. This allows for unconstrained long-range dynamic interactions between nodes. Moreover, the edge channels allow the structural information to evolve from layer to layer, and prediction tasks on edges/links can be performed directly from the output embeddings of these channels. We verify the performance of EGT in a wide range of graph-learning experiments on benchmark datasets, in which it outperforms Convolutional/Message-Passing Graph Neural Networks. EGT sets a new state-of-the-art for the quantum-chemical regression task on the OGB-LSC PCQM4Mv2 dataset containing 3.8 million molecular graphs. Our findings indicate that global self-attention based aggregation can serve as a flexible, adaptive and effective replacement of graph convolution for general-purpose graph learning. Therefore, convolutional local neighborhood aggregation is not an essential inductive bias.
| Task | Dataset | Metric | Value | Model |
|---|---|---|---|---|
| Link Prediction | TSP/HCP Benchmark set | F1 | 0.853 | EGT |
| Graph Regression | PCQM4Mv2-LSC | Test MAE | 0.0683 | EGT + Triangular Attention |
| Graph Regression | PCQM4Mv2-LSC | Validation MAE | 0.0671 | EGT + Triangular Attention |
| Graph Regression | PCQM4Mv2-LSC | Test MAE | 0.0862 | EGT |
| Graph Regression | PCQM4Mv2-LSC | Validation MAE | 0.0857 | EGT |
| Graph Regression | ZINC-500k | MAE | 0.108 | EGT |
| Graph Regression | ZINC 100k | MAE | 0.143 | EGT |
| Graph Regression | PCQM4M-LSC | Validation MAE | 0.1224 | EGT |
| Graph Classification | MNIST | Accuracy | 98.173 | EGT |
| Graph Classification | CIFAR10 100k | Accuracy (%) | 68.702 | EGT |
| Node Classification | PATTERN | Accuracy | 86.821 | EGT |
| Node Classification | CLUSTER | Accuracy | 79.232 | EGT |
| Node Classification | PATTERN 100k | Accuracy (%) | 86.816 | EGT |
| Classification | MNIST | Accuracy | 98.173 | EGT |
| Classification | CIFAR10 100k | Accuracy (%) | 68.702 | EGT |