Get the FREE Ultimate OpenClaw Setup Guide →

jax

npx machina-cli add skill G1Joshi/Agent-Skills/jax --openclaw
Files (1)
SKILL.md
1.1 KB

JAX

JAX is "NumPy on steroids". It combines Autograd (automatic differentiation) with XLA (compilation). 2025 sees Flax NNX (PyTorch-style OOP) becoming standard.

When to Use

  • TPU Training: JAX runs natively on Google TPUs.
  • Research: If you need to compute 10th order derivatives or strange math.
  • Massive Scale: DeepMind and OpenAI use JAX for training frontier models.

Core Concepts

Functional Transformations

grad(), jit(), vmap(), pmap().

Flax (NNX)

Neural network library. NNX introduces mutable state (OOP) to make JAX feel like PyTorch.

Statelessness

(Legacy Flax) parameters are stored separately from the model.

Best Practices (2025)

Do:

  • Use jit: Always compile your functions.
  • Use Flax NNX: Avoid the complexity of legacy immutable Flax/Haiku.
  • Use shard_map: For distributed training across devices.

Don't:

  • Don't use side effects: print() inside a jit function only runs once (during tracing).

References

Source

git clone https://github.com/G1Joshi/Agent-Skills/blob/main/skills/ai-ml/jax/SKILL.mdView on GitHub

Overview

JAX is a high-performance numerical computing library that combines automatic differentiation with XLA compilation, delivering fast, NumPy-like code for ML research. It enables differentiable programming at scale with transforms like grad, jit, vmap, and pmap, and is often used with Flax NNX for neural networks. This makes it ideal for experiments requiring high-order derivatives and hardware acceleration on TPUs and GPUs.

How This Skill Works

JAX exposes functional transformations such as grad, jit, vmap, and pmap that turn Python functions into optimized XLA-compiled kernels. It emphasizes stateless design, with model parameters stored separately from the code, while Flax NNX provides mutable state when needed. JAX runs natively on TPUs and scales across devices with device-compatible parallelism.

When to Use It

  • TPU training natively on Google TPUs
  • Research requiring high-order derivatives or advanced math
  • Training frontier-scale models across multiple devices
  • Leveraging functional transforms (grad, jit, vmap, pmap) for speed
  • Using Flax NNX for PyTorch-style mutable state

Quick Start

  1. Step 1: Install JAX and Flax NNX in your environment
  2. Step 2: Implement a function and apply @jit and grad to compute derivatives
  3. Step 3: Parallelize across devices with vmap/pmap and shard_map for distributed training

Best Practices

  • Always compile functions with jit to generate fast XLA-compiled kernels
  • Prefer Flax NNX over legacy Flax/Haiku to simplify mutable state
  • Use shard_map for distributed training across devices
  • Avoid side effects inside jit functions (e.g., prints during tracing)
  • Leverage grad, vmap, and pmap to build scalable differentiable pipelines

Example Use Cases

  • Training a frontier-model on Google TPUs with JAX and pmap
  • Exploring 10th-order derivatives for a research project
  • Migrating NumPy code to JAX to gain speed and scalability
  • Building neural networks with Flax NNX while keeping stateless parameters
  • Scaling experiments across multiple devices using shard_map

Frequently Asked Questions

Add this skill to your agents
Sponsor this space

Reach thousands of developers