keras
npx machina-cli add skill G1Joshi/Agent-Skills/keras --openclawKeras
Keras 3 is a game changer: it is now multi-backend. You can write Keras code and run it on top of JAX, PyTorch, or TensorFlow.
When to Use
- Portability: Write once, run on any framework.
- Simplicity:
model.fit()is still the cleanest API in the industry. - XLA: Keras 3 enables XLA compilation on all backends by default.
Core Concepts
Backend Agnostic
The Model is just a blueprint. You choose the engine at runtime.
os.environ["KERAS_BACKEND"] = "jax"
Functional API
Defining models as a graph of layers: x = Dense()(inputs).
Keras Core (keras.ops)
A numpy-like API that works across all frameworks (differentiable numpy).
Best Practices (2025)
Do:
- Use Keras 3: Migrate from
tf.keras. - Use JAX backend: For fastest training on TPUs/GPUs.
- Use PyTorch backend: If you need to integrate into a larger PyTorch codebase.
Don't:
- Don't mix
tf.*ops: Usekeras.ops.*to remain framework-agnostic.
References
Source
git clone https://github.com/G1Joshi/Agent-Skills/blob/main/skills/ai-ml/keras/SKILL.mdView on GitHub Overview
Keras 3 is a multi-backend deep learning API that runs on JAX, PyTorch, or TensorFlow. It keeps a simple, model.fit()-driven workflow and enables XLA compilation by default across backends. This makes it easy to port models between engines and optimize performance on GPUs/TPUs.
How This Skill Works
You define models as a graph of layers using the Functional API and choose the engine at runtime. The backend is selected via runtime configuration (e.g., os.environ['KERAS_BACKEND']). Keras Core (keras.ops) provides a numpy-like API that works across frameworks for differentiable operations.
When to Use It
- Need portability: write once, run on JAX, PyTorch, or TF
- Prefer a clean API with model.fit()
- Want XLA-enabled graphs by default across backends
- Need fast training on TPUs/GPUs (use JAX backend)
- Integrate Keras models into a larger PyTorch codebase
Quick Start
- Step 1: Install Keras 3 and ensure a backend is available (e.g., JAX, PyTorch, or TF)
- Step 2: Set the backend and define a simple model with the Functional API (e.g., os.environ['KERAS_BACKEND']='jax' and a small model)
- Step 3: Compile the model and run model.fit() to start training
Best Practices
- Use Keras 3 and migrate from tf.keras
- Prefer JAX backend for fastest training on TPUs/GPUs
- Use PyTorch backend when integrating with PyTorch codebases
- Avoid mixing tf.* ops; use keras.ops for framework-agnostic ops
- Test models across backends to ensure portability
Example Use Cases
- Port an existing TF-Keras model to JAX for TPU acceleration
- Train a CNN with a JAX backend and compare with PyTorch backend
- Embed a Keras model within a larger PyTorch project
- Leverage XLA across backends for optimized graphs
- Experiment with backend portability by swapping frameworks without code changes