ray-train
Scannednpx machina-cli add skill Orchestra-Research/AI-Research-SKILLs/ray-train --openclawRay Train - Distributed Training Orchestration
Quick start
Ray Train scales machine learning training from single GPU to multi-node clusters with minimal code changes.
Installation:
pip install -U "ray[train]"
Basic PyTorch training (single node):
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
import torch
import torch.nn as nn
# Define training function
def train_func(config):
# Your normal PyTorch code
model = nn.Linear(10, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Prepare for distributed (Ray handles device placement)
model = train.torch.prepare_model(model)
for epoch in range(10):
# Your training loop
output = model(torch.randn(32, 10))
loss = output.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Report metrics (logged automatically)
train.report({"loss": loss.item(), "epoch": epoch})
# Run distributed training
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=4, # 4 GPUs/workers
use_gpu=True
)
)
result = trainer.fit()
print(f"Final loss: {result.metrics['loss']}")
That's it! Ray handles:
- Distributed coordination
- GPU allocation
- Fault tolerance
- Checkpointing
- Metric aggregation
Common workflows
Workflow 1: Scale existing PyTorch code
Original single-GPU code:
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(epochs):
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
Ray Train version (scales to multi-GPU/multi-node):
from ray.train.torch import TorchTrainer
from ray import train
def train_func(config):
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
# Prepare for distributed (automatic device placement)
model = train.torch.prepare_model(model)
dataloader = train.torch.prepare_data_loader(dataloader)
for epoch in range(epochs):
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
# Report metrics
train.report({"loss": loss.item()})
# Scale to 8 GPUs
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=8, use_gpu=True)
)
trainer.fit()
Benefits: Same code runs on 1 GPU or 1000 GPUs
Workflow 2: HuggingFace Transformers integration
from ray.train.huggingface import TransformersTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
def train_func(config):
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Training arguments (HuggingFace API)
training_args = TrainingArguments(
output_dir="./output",
num_train_epochs=3,
per_device_train_batch_size=8,
learning_rate=2e-5,
)
# Ray automatically handles distributed training
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()
# Scale to multi-node (2 nodes × 8 GPUs = 16 workers)
trainer = TransformersTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=16,
use_gpu=True,
resources_per_worker={"GPU": 1}
)
)
result = trainer.fit()
Workflow 3: Hyperparameter tuning with Ray Tune
from ray import tune
from ray.train.torch import TorchTrainer
from ray.tune.schedulers import ASHAScheduler
def train_func(config):
# Use hyperparameters from config
lr = config["lr"]
batch_size = config["batch_size"]
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model = train.torch.prepare_model(model)
for epoch in range(10):
# Training loop
loss = train_epoch(model, optimizer, batch_size)
train.report({"loss": loss, "epoch": epoch})
# Define search space
param_space = {
"lr": tune.loguniform(1e-5, 1e-2),
"batch_size": tune.choice([16, 32, 64, 128])
}
# Run 20 trials with early stopping
tuner = tune.Tuner(
TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
),
param_space=param_space,
tune_config=tune.TuneConfig(
num_samples=20,
scheduler=ASHAScheduler(metric="loss", mode="min")
)
)
results = tuner.fit()
best = results.get_best_result(metric="loss", mode="min")
print(f"Best hyperparameters: {best.config}")
Result: Distributed hyperparameter search across cluster
Workflow 4: Checkpointing and fault tolerance
from ray import train
from ray.train import Checkpoint
def train_func(config):
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
# Try to resume from checkpoint
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
state = torch.load(f"{checkpoint_dir}/model.pt")
model.load_state_dict(state["model"])
optimizer.load_state_dict(state["optimizer"])
start_epoch = state["epoch"]
else:
start_epoch = 0
model = train.torch.prepare_model(model)
for epoch in range(start_epoch, 100):
loss = train_epoch(model, optimizer)
# Save checkpoint every 10 epochs
if epoch % 10 == 0:
checkpoint = Checkpoint.from_directory(
train.get_context().get_trial_dir()
)
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch
}, checkpoint.path / "model.pt")
train.report({"loss": loss}, checkpoint=checkpoint)
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(num_workers=8, use_gpu=True)
)
# Automatically resumes from checkpoint if training fails
result = trainer.fit()
Workflow 5: Multi-node training
from ray.train import ScalingConfig
# Connect to Ray cluster
ray.init(address="auto") # Or ray.init("ray://head-node:10001")
# Train across 4 nodes × 8 GPUs = 32 workers
trainer = TorchTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=32,
use_gpu=True,
resources_per_worker={"GPU": 1, "CPU": 4},
placement_strategy="SPREAD" # Spread across nodes
)
)
result = trainer.fit()
Launch Ray cluster:
# On head node
ray start --head --port=6379
# On worker nodes
ray start --address=<head-node-ip>:6379
When to use vs alternatives
Use Ray Train when:
- Training across multiple machines (multi-node)
- Need hyperparameter tuning at scale
- Want fault tolerance (auto-restart failed workers)
- Elastic scaling (add/remove nodes during training)
- Unified framework (same code for PyTorch/TF/HF)
Key advantages:
- Multi-node orchestration: Easiest multi-node setup
- Ray Tune integration: Best-in-class hyperparameter tuning
- Fault tolerance: Automatic recovery from failures
- Elastic: Add/remove nodes without restarting
- Framework agnostic: PyTorch, TensorFlow, HuggingFace, XGBoost
Use alternatives instead:
- Accelerate: Single-node multi-GPU, simpler
- PyTorch Lightning: High-level abstractions, callbacks
- DeepSpeed: Maximum performance, complex setup
- Raw DDP: Maximum control, minimal overhead
Common issues
Issue: Ray cluster not connecting
Check ray status:
ray status
# Should show:
# - Nodes: 4
# - GPUs: 32
# - Workers: Ready
If not connected:
# Restart head node
ray stop
ray start --head --port=6379 --dashboard-host=0.0.0.0
# Restart worker nodes
ray stop
ray start --address=<head-ip>:6379
Issue: Out of memory
Reduce workers or use gradient accumulation:
scaling_config=ScalingConfig(
num_workers=4, # Reduce from 8
use_gpu=True
)
# In train_func, accumulate gradients
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Issue: Slow training
Check if data loading is bottleneck:
import time
def train_func(config):
for epoch in range(epochs):
start = time.time()
for batch in dataloader:
data_time = time.time() - start
# Train...
start = time.time()
print(f"Data loading: {data_time:.3f}s")
If data loading is slow, increase workers:
dataloader = DataLoader(dataset, num_workers=8)
Advanced topics
Multi-node setup: See references/multi-node.md for Ray cluster deployment on AWS, GCP, Kubernetes, and SLURM.
Hyperparameter tuning: See references/hyperparameter-tuning.md for Ray Tune integration, search algorithms (Optuna, HyperOpt), and population-based training.
Custom training loops: See references/custom-loops.md for advanced Ray Train usage, custom backends, and integration with other frameworks.
Hardware requirements
- Single node: 1+ GPUs (or CPUs)
- Multi-node: 2+ machines with network connectivity
- Cloud: AWS, GCP, Azure (Ray autoscaling)
- On-prem: Kubernetes, SLURM clusters
Supported accelerators:
- NVIDIA GPUs (CUDA)
- AMD GPUs (ROCm)
- TPUs (Google Cloud)
- CPUs
Resources
- Docs: https://docs.ray.io/en/latest/train/train.html
- GitHub: https://github.com/ray-project/ray ⭐ 36,000+
- Version: 2.40.0+
- Examples: https://docs.ray.io/en/latest/train/examples.html
- Slack: https://forms.gle/9TSdDYUgxYs8SA9e8
- Used by: OpenAI, Uber, Spotify, Shopify, Instacart
Source
git clone https://github.com/Orchestra-Research/AI-Research-SKILLs/blob/main/08-distributed-training/ray-train/SKILL.mdView on GitHub Overview
Ray Train provides distributed training orchestration across clusters for PyTorch, TensorFlow, and HuggingFace models. It scales workloads from a laptop to thousands of nodes and includes built-in hyperparameter tuning with Ray Tune, fault tolerance, and elastic scaling. This makes it practical to train massive models or run distributed hyperparameter sweeps with minimal code changes.
How This Skill Works
You write a standard training function and Ray handles distribution, device placement, and coordination. Wrappers like TorchTrainer and TransformersTrainer enable scalable training with a ScalingConfig that controls num_workers and GPUs. Ray also handles checkpointing, metric aggregation, fault tolerance, and automatic data/model preparation via helpers like train.torch.prepare_model and train.torch.prepare_data_loader.
When to Use It
- Training massive models across multiple machines
- Running distributed hyperparameter sweeps with Ray Tune
- Scaling existing PyTorch code from 1 GPU to multi-GPU/multi-node
- Training HuggingFace Transformers models on multi-node clusters
- Need fault-tolerant and elastic scaling as cluster availability changes
Quick Start
- Step 1: Install: pip install -U "ray[train]"
- Step 2: Define train_func and wrap your model with train.torch.prepare_model, e.g., TorchTrainer(train_func, ...)
- Step 3: Create a TorchTrainer with a ScalingConfig (num_workers and use_gpu) and call trainer.fit()
Best Practices
- Configure ScalingConfig with accurate num_workers and use_gpu to match your hardware
- Wrap models with train.torch.prepare_model and data loaders with train.torch.prepare_data_loader
- Leverage Ray Tune for organized hyperparameter sweeps and early stopping
- Enable checkpointing and monitor metrics via train.report
- Test on a small multi-node subset before scaling to thousands of nodes
Example Use Cases
- Scale an existing PyTorch single-GPU script to multi-GPU/multi-node using TorchTrainer
- Scale to 8 GPUs across 2 nodes (workflow demonstrated in the quick start)
- Integrate HuggingFace Transformers with TransformersTrainer for distributed training
- Achieve elastic scaling by adjusting the number of workers during a job while maintaining progress
- Run distributed hyperparameter sweeps with Ray Tune across a cluster to optimize model performance
Frequently Asked Questions
Related Skills
tensorboard
Orchestra-Research/AI-Research-SKILLs
Visualize training metrics, debug models with histograms, compare experiments, visualize model graphs, and profile performance with TensorBoard - Google's ML visualization toolkit
huggingface-accelerate
Orchestra-Research/AI-Research-SKILLs
Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
optimizing-attention-flash
Orchestra-Research/AI-Research-SKILLs
Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
crewai-multi-agent
Orchestra-Research/AI-Research-SKILLs
Multi-agent orchestration framework for autonomous AI collaboration. Use when building teams of specialized agents working together on complex tasks, when you need role-based agent collaboration with memory, or for production workflows requiring sequential/hierarchical execution. Built without LangChain dependencies for lean, fast execution.
pytorch-fsdp2
Orchestra-Research/AI-Research-SKILLs
Adds PyTorch FSDP2 (fully_shard) to training scripts with correct init, sharding, mixed precision/offload config, and distributed checkpointing. Use when models exceed single-GPU memory or when you need DTensor-based sharding with DeviceMesh.
ray-data
Orchestra-Research/AI-Research-SKILLs
Scalable data processing for ML workloads. Streaming execution across CPU/GPU, supports Parquet/CSV/JSON/images. Integrates with Ray Train, PyTorch, TensorFlow. Scales from single machine to 100s of nodes. Use for batch inference, data preprocessing, multi-modal data loading, or distributed ETL pipelines.