jax
npx machina-cli add skill G1Joshi/Agent-Skills/jax --openclawJAX
JAX is "NumPy on steroids". It combines Autograd (automatic differentiation) with XLA (compilation). 2025 sees Flax NNX (PyTorch-style OOP) becoming standard.
When to Use
- TPU Training: JAX runs natively on Google TPUs.
- Research: If you need to compute 10th order derivatives or strange math.
- Massive Scale: DeepMind and OpenAI use JAX for training frontier models.
Core Concepts
Functional Transformations
grad(), jit(), vmap(), pmap().
Flax (NNX)
Neural network library. NNX introduces mutable state (OOP) to make JAX feel like PyTorch.
Statelessness
(Legacy Flax) parameters are stored separately from the model.
Best Practices (2025)
Do:
- Use
jit: Always compile your functions. - Use Flax NNX: Avoid the complexity of legacy immutable Flax/Haiku.
- Use
shard_map: For distributed training across devices.
Don't:
- Don't use side effects:
print()inside ajitfunction only runs once (during tracing).
References
Source
git clone https://github.com/G1Joshi/Agent-Skills/blob/main/skills/ai-ml/jax/SKILL.mdView on GitHub Overview
JAX is a high-performance numerical computing library that combines automatic differentiation with XLA compilation, delivering fast, NumPy-like code for ML research. It enables differentiable programming at scale with transforms like grad, jit, vmap, and pmap, and is often used with Flax NNX for neural networks. This makes it ideal for experiments requiring high-order derivatives and hardware acceleration on TPUs and GPUs.
How This Skill Works
JAX exposes functional transformations such as grad, jit, vmap, and pmap that turn Python functions into optimized XLA-compiled kernels. It emphasizes stateless design, with model parameters stored separately from the code, while Flax NNX provides mutable state when needed. JAX runs natively on TPUs and scales across devices with device-compatible parallelism.
When to Use It
- TPU training natively on Google TPUs
- Research requiring high-order derivatives or advanced math
- Training frontier-scale models across multiple devices
- Leveraging functional transforms (grad, jit, vmap, pmap) for speed
- Using Flax NNX for PyTorch-style mutable state
Quick Start
- Step 1: Install JAX and Flax NNX in your environment
- Step 2: Implement a function and apply @jit and grad to compute derivatives
- Step 3: Parallelize across devices with vmap/pmap and shard_map for distributed training
Best Practices
- Always compile functions with jit to generate fast XLA-compiled kernels
- Prefer Flax NNX over legacy Flax/Haiku to simplify mutable state
- Use shard_map for distributed training across devices
- Avoid side effects inside jit functions (e.g., prints during tracing)
- Leverage grad, vmap, and pmap to build scalable differentiable pipelines
Example Use Cases
- Training a frontier-model on Google TPUs with JAX and pmap
- Exploring 10th-order derivatives for a research project
- Migrating NumPy code to JAX to gain speed and scalability
- Building neural networks with Flax NNX while keeping stateless parameters
- Scaling experiments across multiple devices using shard_map