ml-pytorch-geometric
npx machina-cli add skill nishide-dev/claude-code-ml-research/ml-pytorch-geometric --openclawPyTorch Geometric for Graph Neural Networks
Overview
PyTorch Geometric (PyG) is the standard library for geometric deep learning with PyTorch. It provides optimized implementations of Graph Neural Networks (GNNs), efficient data structures for graph data, and scalable solutions for large-scale graph learning.
Key Capabilities:
- Efficient sparse tensor operations for graphs
- 100+ pre-implemented GNN layers (GCN, GAT, GraphSAGE, etc.)
- Scalable data loaders with neighbor sampling
- Large-scale distributed graph learning
- Heterogeneous and temporal graph support
- Integration with PyTorch Lightning
Resources:
- Official docs: https://pytorch-geometric.readthedocs.io/
- Stanford CS224W: http://web.stanford.edu/class/cs224w/
- GitHub: https://github.com/pyg-team/pytorch_geometric
Core Concepts
1. Graph Data Representation
PyG uses a tensor-centric approach. Each graph is represented by a Data object containing node features, edge indices, and optional attributes.
Basic Data object:
import torch
from torch_geometric.data import Data
# Create a simple graph: 0 -> 1 -> 2
# ^ |
# |____|
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 2, 0, 1]], dtype=torch.long)
x = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float) # Node features
y = torch.tensor([0], dtype=torch.long) # Graph label
data = Data(x=x, edge_index=edge_index, y=y)
print(data)
# Data(x=[3, 1], edge_index=[2, 4], y=[1])
# Access attributes
print(f"Nodes: {data.num_nodes}") # 3
print(f"Edges: {data.num_edges}") # 4
print(f"Features: {data.num_node_features}") # 1
print(f"Directed: {data.is_directed()}") # True
Common Data attributes:
| Attribute | Shape | Type | Description |
|---|---|---|---|
data.x | [num_nodes, num_node_features] | float | Node feature matrix |
data.edge_index | [2, num_edges] | long | Graph connectivity in COO format |
data.edge_attr | [num_edges, num_edge_features] | float | Edge feature matrix |
data.y | Varies | Any | Target labels (node/graph level) |
data.pos | [num_nodes, num_dimensions] | float | Node positions (for point clouds) |
data.batch | [num_nodes] | long | Batch assignment vector |
2. Message Passing Framework
GNNs work through message passing: nodes aggregate information from neighbors to update their representations.
MessagePassing base class:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch.nn.functional as F
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "add", "mean", "max"
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# Add self-loops to adjacency matrix
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Linear transformation
x = self.lin(x)
# Compute normalization
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Start propagating messages
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# Normalize node features
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
# Update node embeddings
return aggr_out
Message passing steps:
- message(): Constructs messages from source nodes to target nodes
- aggregate(): Aggregates messages (sum, mean, max)
- update(): Updates node embeddings based on aggregated messages
3. Common GNN Layers
PyG provides 100+ pre-implemented layers. Here are the most important:
GCN (Graph Convolutional Network):
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
GAT (Graph Attention Networks):
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
GraphSAGE (inductive learning):
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
Graph-Level Tasks
Graph Classification
Aggregate node representations to classify entire graphs (molecules, social networks, etc.).
Complete example:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
# Define model
class GraphClassifier(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_node_features, 64)
self.conv2 = GCNConv(64, 64)
self.conv3 = GCNConv(64, 64)
self.lin = torch.nn.Linear(64, num_classes)
def forward(self, x, edge_index, batch):
# Node embedding
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = self.conv3(x, edge_index)
# Graph-level readout (pooling)
x = global_mean_pool(x, batch) # [num_graphs, hidden_channels]
# Classifier
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return F.log_softmax(x, dim=1)
# Load dataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
# Split dataset
train_dataset = dataset[:540]
test_dataset = dataset[540:]
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# Train model
model = GraphClassifier(dataset.num_node_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in range(200):
for data in train_loader:
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
Graph pooling layers:
global_mean_pool: Average node featuresglobal_max_pool: Max pooling over nodesglobal_add_pool: Sum node featuresglobal_sort_pool: Sort pooling (SortPool)SAGPooling: Self-attention graph pooling
Node-Level Tasks
Node Classification
Classify nodes within a single large graph.
Node classification example:
from torch_geometric.datasets import Planetoid
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
# Load Cora dataset (citation network)
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0] # Single graph
# Model
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# Training
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
# Only compute loss on training nodes
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# Evaluation
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
Neighbor Sampling for Large Graphs
For massive graphs, load only k-hop neighborhoods instead of the entire graph.
NeighborLoader:
from torch_geometric.loader import NeighborLoader
# Sample 2-hop neighborhoods
loader = NeighborLoader(
data,
num_neighbors=[25, 10], # Sample 25 neighbors in 1st hop, 10 in 2nd hop
batch_size=128,
input_nodes=data.train_mask, # Only create batches from training nodes
num_workers=4,
shuffle=True,
)
# Training loop with mini-batches
model.train()
for batch in loader:
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)[:batch.batch_size] # Only predict for center nodes
loss = F.nll_loss(out, batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
Integration with PyTorch Lightning
PyG provides Lightning-compatible wrappers for seamless integration.
LightningDataset (Graph Classification)
import lightning as L
from torch_geometric.datasets import TUDataset
from torch_geometric.data import LightningDataset
# Create Lightning wrapper
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
datamodule = LightningDataset(
train_dataset=dataset[:540],
val_dataset=dataset[540:600],
test_dataset=dataset[600:],
batch_size=32,
num_workers=4,
)
# LightningModule
class LitGNN(L.LightningModule):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
x = global_mean_pool(x, batch)
return x
def training_step(self, batch, batch_idx):
out = self(batch.x, batch.edge_index, batch.batch)
loss = F.cross_entropy(out, batch.y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
out = self(batch.x, batch.edge_index, batch.batch)
loss = F.cross_entropy(out, batch.y)
acc = (out.argmax(dim=1) == batch.y).float().mean()
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.01)
# Train
model = LitGNN(dataset.num_node_features, 64, dataset.num_classes)
trainer = L.Trainer(max_epochs=100, accelerator='gpu', devices=1)
trainer.fit(model, datamodule)
LightningNodeData (Node Classification)
from torch_geometric.data import LightningNodeData
from torch_geometric.datasets import Planetoid
# Load dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# Create Lightning wrapper with neighbor sampling
datamodule = LightningNodeData(
data,
input_train_nodes=data.train_mask,
input_val_nodes=data.val_mask,
input_test_nodes=data.test_mask,
loader='neighbor', # Use NeighborLoader
num_neighbors=[25, 10], # 2-hop sampling
batch_size=128,
num_workers=4,
)
# Training works the same as before
model = LitGNN(dataset.num_node_features, 64, dataset.num_classes)
trainer = L.Trainer(max_epochs=100)
trainer.fit(model, datamodule)
Large-Scale Graph Learning
Distributed Training
PyG 2.5+ supports distributed training for billion-scale graphs using DDP + RPC.
Architecture:
- Graph Partitioning: Split graph using METIS to minimize edge cuts
- DDP: Replicate model across GPUs with gradient synchronization
- RPC: Fetch features/structure from remote partitions
Partition graph:
# Use METIS to partition graph
python -m torch_geometric.distributed.partition \
--dataset ogbn-products \
--num_parts 4 \
--output_dir ./partitions
Distributed training script:
import torch.distributed as dist
from torch_geometric.distributed import DistNeighborLoader
from torch_geometric.nn import GraphSAGE
def run_distributed(rank, world_size):
# Initialize process group
dist.init_process_group('nccl', rank=rank, world_size=world_size)
# Load partition
data = torch.load(f'./partitions/part_{rank}.pt')
# Create distributed loader
loader = DistNeighborLoader(
data,
num_neighbors=[25, 10],
batch_size=1024,
shuffle=True,
)
# Model
model = GraphSAGE(in_channels, hidden_channels, out_channels)
model = torch.nn.parallel.DistributedDataParallel(model)
# Training loop
for epoch in range(100):
for batch in loader:
# Training logic
pass
if __name__ == '__main__':
world_size = 4
torch.multiprocessing.spawn(run_distributed, args=(world_size,), nprocs=world_size)
Remote Backend (Out-of-Core Learning)
Use FeatureStore and GraphStore for data that doesn't fit in memory.
Custom FeatureStore:
from torch_geometric.data import FeatureStore, GraphStore
import rocksdb
class RocksDBFeatureStore(FeatureStore):
def __init__(self, path):
self.db = rocksdb.DB(path, rocksdb.Options(create_if_missing=True))
def get_tensor(self, key):
# Fetch features from disk
data = self.db.get(key.encode())
return torch.frombuffer(data, dtype=torch.float32)
def set_tensor(self, key, tensor):
# Store features on disk
self.db.put(key.encode(), tensor.numpy().tobytes())
# Use with NeighborLoader
feature_store = RocksDBFeatureStore('./features_db')
loader = NeighborLoader(
data=(feature_store, graph_store),
num_neighbors=[15, 10],
batch_size=128,
)
Heterogeneous Graphs
Handle graphs with multiple node and edge types (e.g., knowledge graphs).
HeteroData:
from torch_geometric.data import HeteroData
data = HeteroData()
# Add node types
data['paper'].x = torch.randn(1000, 128) # 1000 papers with 128 features
data['author'].x = torch.randn(500, 64) # 500 authors with 64 features
# Add edge types
data['author', 'writes', 'paper'].edge_index = torch.randint(0, 500, (2, 2000))
data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 1000, (2, 5000))
print(data)
# HeteroData(
# paper={ x=[1000, 128] },
# author={ x=[500, 64] },
# (author, writes, paper)={ edge_index=[2, 2000] },
# (paper, cites, paper)={ edge_index=[2, 5000] }
# )
Heterogeneous GNN:
from torch_geometric.nn import HeteroConv, GCNConv, Linear
class HeteroGNN(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.conv1 = HeteroConv({
('author', 'writes', 'paper'): GCNConv(-1, hidden_channels),
('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
}, aggr='sum')
self.conv2 = HeteroConv({
('author', 'writes', 'paper'): GCNConv(-1, hidden_channels),
('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
}, aggr='sum')
self.lin = Linear(hidden_channels, num_classes)
def forward(self, x_dict, edge_index_dict):
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return self.lin(x_dict['paper'])
Advanced Techniques
1. Graph Augmentation
from torch_geometric.transforms import RandomNodeSplit, AddSelfLoops, NormalizeFeatures
# Compose transforms
transform = T.Compose([
AddSelfLoops(),
NormalizeFeatures(),
RandomNodeSplit(split='train_rest', num_val=0.1, num_test=0.1),
])
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
2. Custom Datasets
from torch_geometric.data import Dataset, download_url
class MyDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super().__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['data.csv']
@property
def processed_file_names(self):
return ['data_0.pt', 'data_1.pt', ...]
def download(self):
# Download raw data
download_url('https://example.com/data.csv', self.raw_dir)
def process(self):
# Process raw data into Data objects
for idx, raw_data in enumerate(raw_data_list):
data = Data(...)
torch.save(data, f'{self.processed_dir}/data_{idx}.pt')
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(f'{self.processed_dir}/data_{idx}.pt')
return data
3. Explainability
from torch_geometric.explain import Explainer, GNNExplainer
model = GCN(...)
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
)
# Explain prediction for node 0
explanation = explainer(data.x, data.edge_index, index=0)
print(f'Node mask: {explanation.node_mask}')
print(f'Edge mask: {explanation.edge_mask}')
Best Practices
✅ DO
- Use neighbor sampling for large graphs: Don't load entire graph into memory
- Leverage pre-implemented layers: PyG has 100+ optimized GNN layers
- Normalize features: Use
NormalizeFeatures()transform for stable training - Add self-loops: Many GNN layers require self-loops (use
add_self_loops()) - Use dropout: GNNs overfit easily; add dropout between layers
- Monitor gradient flow: Check for vanishing/exploding gradients in deep GNNs
- Profile memory usage: Graph data can be memory-intensive
- Use Lightning for distributed training: Simplifies multi-GPU setup
❌ DON'T
- Don't ignore edge direction: Undirected graphs need bidirectional edges in
edge_index - Don't use too many layers: Deep GNNs suffer from over-smoothing (3-5 layers typical)
- Don't forget to set model.eval(): Dropout/BatchNorm behavior differs in eval mode
- Don't use dense adjacency matrices: Always use sparse COO format
- Don't mix node/graph-level tasks: Be clear about prediction granularity
- Don't skip data validation: Check for isolated nodes, NaN features, etc.
Essential Resources
Official Documentation
- PyG Docs: https://pytorch-geometric.readthedocs.io/
- API Reference: https://pytorch-geometric.readthedocs.io/en/latest/modules/root.html
- Examples: https://github.com/pyg-team/pytorch_geometric/tree/master/examples
Learning Resources
- Stanford CS224W: http://web.stanford.edu/class/cs224w/ (Machine Learning with Graphs)
- UvA Deep Learning: https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/06-graph-neural-networks.html
- Distill.pub GNN Article: https://distill.pub/2021/gnn-intro/
Advanced Topics
- Distributed Training: https://pytorch-geometric.readthedocs.io/en/latest/tutorial/distributed_pyg.html
- Remote Backends: https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html
- Heterogeneous Graphs: https://pytorch-geometric.readthedocs.io/en/latest/tutorial/heterogeneous.html
Summary
PyTorch Geometric provides:
- Efficient sparse operations: Optimized for graph data structures
- Rich model library: 100+ GNN layers, datasets, and transforms
- Scalability: Neighbor sampling, distributed training, out-of-core learning
- Flexibility: Custom message passing, heterogeneous graphs, temporal data
- Integration: Seamless PyTorch Lightning integration for production
Combined with Lightning, PyG enables research-to-production graph deep learning at any scale.
Source
git clone https://github.com/nishide-dev/claude-code-ml-research/blob/main/skills/ml-pytorch-geometric/SKILL.mdView on GitHub Overview
PyTorch Geometric (PyG) is the standard library for geometric deep learning with PyTorch. It offers efficient sparse graph operations, 100+ pre-implemented GNN layers, scalable neighbor sampling, large-scale distributed training, and support for heterogeneous and temporal graphs, plus PyTorch Lightning integration.
How This Skill Works
PyG uses a Data object to represent graphs and the MessagePassing framework to implement GNN layers. Nodes exchange and aggregate messages via an edge_index structure, with scalable components like neighbor sampling and optimized sparse tensors, while layers like GCNConv or GATConv define the update rules. Training can be streamlined with PyTorch Lightning for robust, scalable pipelines.
When to Use It
- Prototype node/graph-level models on graphs with scalable neighbor sampling for medium to large datasets
- Work with heterogeneous or temporal graphs that require specialized data representations and layers
- Build end-to-end pipelines using PyTorch Lightning for cleaner training loops and checkpointing
- Scale training across multiple machines with large-scale distributed graph learning
- Experiment with a variety of GNN layers (GCN, GAT, GraphSAGE) and message passing strategies
Quick Start
- Step 1: Install PyG and dependencies (PyTorch, torch-geometric packages) and verify the environment
- Step 2: Create a simple Data object with node features x, edge_index, and label y
- Step 3: Define a basic GCN/GAT layer, run a forward pass, and set up a simple training loop (or use PyTorch Lightning)
Best Practices
- Start with a simple Data object (x, edge_index, y) to validate your pipeline before adding complex attributes
- Use neighbor sampling when training on large graphs to control memory and compute
- Choose an appropriate aggregation and normalization in MessagePassing (add/mean/max, degree-based norm)
- Leverage PyTorch Lightning to organize training, validation, and checkpointing for reproducibility
- If using heterogeneous/temporal graphs, utilize PyG's specialized support to model multiple node/edge types
Example Use Cases
- Node classification on a social network using GCN/GAT layers to predict user attributes
- Molecular property prediction with graph representations and GNN layers
- Large-scale product recommendations via distributed graph learning on a commerce graph
- Traffic forecasting on temporal graphs using time-aware message passing
- Knowledge graph link prediction with heterogeneous graph modeling