7 min read
On this page

Graph Neural Networks

Foundations

Graphs are the natural representation for relational data: social networks, molecules, knowledge bases, program structure, physical systems. Graph neural networks (GNNs) learn representations of nodes, edges, and graphs by leveraging graph topology and feature information. Unlike images (regular grids) or text (sequences), graphs have irregular, permutation-invariant structure, requiring specialized architectures.

Notation

A graph G = (V, E) with node feature matrix X ∈ ℝ^{|V|×d}, edge features (optional), adjacency matrix A. Node v has feature vector x_v ∈ ℝ^d. The neighbor set of v is N(v). The degree matrix D is diagonal with D_ii = Σ_j A_ij.

Message Passing Neural Networks (MPNN)

Framework

Gilmer et al. (2017) unified GNN architectures under the message passing framework. Each layer performs:

  1. Message: m_v^{(l)} = AGG({M(h_v^{(l)}, h_u^{(l)}, e_{vu}) : u ∈ N(v)})
  2. Update: h_v^{(l+1)} = U(h_v^{(l)}, m_v^{(l)})

where M is the message function, AGG is a permutation-invariant aggregation (sum, mean, max), U is the update function, and h_v^{(l)} is node v's representation at layer l.

After L layers, node v's representation captures information from its L-hop neighborhood. Readout: For graph-level tasks, aggregate all node representations: h_G = READOUT({h_v^{(L)} : v ∈ V}).

Expressiveness: Weisfeiler-Leman Hierarchy

The 1-WL (Weisfeiler-Leman) graph isomorphism test iteratively refines node colors by hashing the multiset of neighbor colors. Xu et al. (2019) proved that standard MPNNs are at most as powerful as 1-WL — they cannot distinguish graphs that 1-WL cannot distinguish.

Implications: MPNNs cannot count triangles, detect certain substructures, or distinguish some non-isomorphic regular graphs. Higher-order GNNs (k-WL) gain expressiveness at exponential cost.

Spectral Methods

Spectral Graph Convolution

Spectral approaches define convolution via the graph Fourier transform. The graph Laplacian L = D - A (or normalized: L_norm = I - D^{-1/2}AD^{-1/2}) has eigendecomposition L = UΛU^T. The graph Fourier transform of signal x is x̂ = U^T x. Spectral convolution:

g *_G x = U g_θ(Λ) U^T x

where g_θ(Λ) = diag(g_θ(λ_1), ..., g_θ(λ_n)) is a spectral filter.

Limitation: Computing the full eigendecomposition is O(n³). The filter is defined in the spectral domain, not localizable in space without constraints. The filter is graph-specific (different graphs have different eigenbases), hindering transfer.

ChebNet

Defferrard et al. (2016) approximate spectral filters with Chebyshev polynomials: g_θ(Λ) ≈ Σ_{k=0}^{K} θ_k T_k(Λ̃), where Λ̃ is the rescaled eigenvalue matrix and T_k are Chebyshev polynomials. Key insight: T_k(L) can be computed via recurrence without eigendecomposition, and the filter is K-localized (only depends on K-hop neighborhood).

Complexity: O(K|E|) per filter application — linear in edges, polynomial in filter order. Overcomes spectral method limitations.

Spatial Methods

GCN (Graph Convolutional Network)

Kipf and Welling (2017) simplify ChebNet with a first-order approximation (K=1):

H^{(l+1)} = σ(D̃^{-1/2} Ã D̃^{-1/2} H^{(l)} W^{(l)})

where à = A + I (self-loop), D̃ is its degree matrix, W^{(l)} is a learnable weight matrix. This is equivalent to mean aggregation of neighbor features followed by a linear transformation.

Renormalization trick: Adding I to A (self-loops) and symmetric normalization D̃^{-1/2}ÃD̃^{-1/2} prevents numerical instabilities and ensures feature scale preservation.

GraphSAGE

Hamilton et al. (2017) introduced inductive graph learning via sampling and aggregation:

  1. Sample: Uniformly sample a fixed-size set of neighbors for each node (rather than using all neighbors).
  2. Aggregate: Apply an aggregation function (mean, LSTM, pool) to sampled neighbor features.
  3. Combine: Concatenate aggregated neighbor features with the node's own feature, apply a linear layer.

Inductive capability: GraphSAGE learns an aggregation function, not per-node embeddings. Can generalize to unseen nodes and graphs — essential for evolving graphs and transfer learning.

Mini-batch training: Neighborhood sampling enables mini-batch SGD on large graphs. Each mini-batch includes target nodes and their sampled neighborhoods (up to L hops). This avoids full-graph computation.

GAT (Graph Attention Network)

Velickovic et al. (2018) replace fixed aggregation weights with learned attention:

α_{ij} = softmax_j(LeakyReLU(a^T [Wh_i || Wh_j]))

h_i' = σ(Σ_{j∈N(i)} α_{ij} W h_j)

Attention allows the model to weight neighbors differently based on their features. Multi-head attention: K independent attention heads, concatenated or averaged.

GATv2 (Brody et al., 2022): Fixes an expressiveness limitation of GAT. Original GAT computes attention with a static ranking of neighbors (same for all query nodes in certain cases). GATv2 applies the non-linearity after concatenation: α_{ij} = a^T LeakyReLU(W[h_i || h_j]), enabling dynamic attention that depends on both source and target.

GIN (Graph Isomorphism Network)

Xu et al. (2019) designed GIN to be maximally powerful among MPNNs (equivalent to 1-WL):

h_v^{(l+1)} = MLP((1 + ε^{(l)}) · h_v^{(l)} + Σ_{u∈N(v)} h_u^{(l)})

Key design choices: Use sum aggregation (not mean or max — sum is the only injective multiset function among standard aggregations). Use MLP (not a single linear layer) for sufficient approximation power. ε is a learnable or fixed scalar.

Theoretical result: GIN with sum aggregation and sufficient MLP capacity is as powerful as 1-WL. Mean and max aggregations are strictly weaker.

Graph Pooling

Graph-level tasks (classification, regression) require collapsing node representations into a fixed-size graph representation.

Flat pooling: Sum, mean, or max over all node representations. Simple but loses structural information.

Hierarchical pooling: Coarsen the graph progressively:

  • DiffPool (Ying et al., 2018): Learn a soft assignment matrix S ∈ ℝ^{n×k} mapping n nodes to k clusters. The coarsened adjacency: A' = S^T A S. Learns to hierarchically cluster nodes. Quadratic memory in n.
  • Top-K pooling: Score each node, keep top-k nodes, induce subgraph. Sparse and memory-efficient.
  • SAGPool: Use GNN-computed attention scores for Top-K selection.
  • MinCut pooling: Optimize a differentiable relaxation of the normalized min-cut objective.

Graph Transformers

Standard Transformers treat input tokens as a fully-connected graph with learned attention. Graph Transformers adapt this to graph-structured data.

Positional encodings: Graphs lack inherent position. Approaches:

  • Laplacian eigenvectors: Use the k smallest non-trivial eigenvectors of the graph Laplacian as positional features. Sign ambiguity addressed by random sign flipping during training (SignNet) or learning sign-invariant functions.
  • Random walk positional encoding (RWPE): Landing probabilities of random walks from each node.

Graphormer (Ying et al., 2021): Adds centrality encoding (degree-based), spatial encoding (shortest-path distance in attention bias), and edge encoding to a standard Transformer. Won the OGB-LSC molecular prediction challenge.

GPS (General, Powerful, Scalable): Combines local MPNN layers (for structural bias) with global Transformer attention (for long-range interactions). Alternates or interleaves MPNN and Transformer blocks.

Over-Smoothing

As GNN depth increases, node representations converge to indistinguishable vectors — the over-smoothing problem. After L message-passing layers, each node's representation is a function of its L-hop neighborhood. When L approaches the graph diameter, all nodes aggregate the same global information.

Quantification: Representation similarity increases exponentially with depth. For GCN on regular graphs, representations converge to the dominant eigenvector of the normalized adjacency.

Mitigations:

  • Residual connections: Skip connections (JKNet — jumping knowledge: concatenate representations from all layers).
  • Normalization: PairNorm, NodeNorm, GraphNorm prevent representation collapse.
  • DropEdge: Randomly remove edges during training, slowing information propagation.
  • Deep GNNs: GCNII adds initial residual connections and identity mapping, enabling 64+ layer GCNs.

Knowledge Graph Embeddings

Knowledge graphs store facts as (head, relation, tail) triples. Embedding methods learn vector representations of entities and relations.

  • TransE: h + r ≈ t (translation in embedding space). Simple but cannot model symmetric or 1-to-N relations well.
  • DistMult: Score = h^T diag(r) t (bilinear). Symmetric by construction.
  • ComplEx: DistMult in complex space. Handles asymmetric relations via the imaginary component.
  • RotatE: r is a rotation in complex space: t = h ⊙ r (element-wise). Models symmetry, antisymmetry, inversion, composition.
  • R-GCN (Relational GCN): Apply GCN with relation-specific weight matrices. Message passing on the knowledge graph.

Link prediction: Given (h, r, ?), rank all possible tail entities by score. Evaluation: MRR, Hits@k.

Graph Generation

Generate graphs with desired properties (molecules, social networks, circuits).

  • GraphRNN (You et al., 2018): Autoregressive generation — generate nodes and edges sequentially using RNNs.
  • GraphVAE: VAE with graph-structured decoder. Generate adjacency matrix and node features.
  • GDSS / DiGress: Diffusion models on graphs. DiGress (Vignac et al., 2023) performs discrete diffusion on node types and edge types, with graph transformer denoising network.
  • Molecular generation: Junction tree VAE (generating molecular graphs by assembling substructure vocabularies), MolGAN, equivariant diffusion (EDM) for 3D molecular coordinates.

Scalability

Large graphs (millions of nodes) pose computational challenges for GNNs.

  • Neighbor sampling: GraphSAGE-style sampling limits the neighborhood per node. Variance reduction: historical embeddings (VR-GCN), importance sampling.
  • Cluster-GCN: Partition the graph into clusters (METIS). Mini-batches are subgraphs induced by sampled clusters. Preserves within-cluster structure.
  • GraphSAINT: Random node/edge/walk samplers with normalization to reduce bias.
  • Precomputation approaches: SGC (Simplified GCN) removes non-linearities between layers, enabling precomputation of the multi-hop aggregation: Y = softmax(Ã^K X W). Linear cost.