Adaptive gradient clipping for PyTorch, TensorFlow, and JAX.
SmartClip keeps training stable with adaptive, per-step clipping you can enable in one line of code.
See the full documentation for details of the algorithms, framework usage examples, and logging metrics.
-
AutoClip — Seetharaman et al., 2020 (MLSP). Adaptive percentile-based clipping of gradient norms.
-
Adaptive Gradient Clipping (AGC, NFNets-style) — Brock, De, Smith, 2021. Threshold scales with parameter norm.
-
Z-Score clipping (EMA mean/std) — standard z-score thresholding using streaming mean/variance
zmaxcontrols how aggressive clipping is: threshold ismean + zmax * stdover recent norms. Higherzmaxclips less (more tolerant), lower clips more (more aggressive). Start atzmax=3.0; try2.0–2.5if you see instability from spikes, or3.5–4.0if training seems over‑clipped.
pip install smartclipOptional extras provide helpers for specific frameworks (install framework wheels first per vendor docs):
pip install "smartclip[torch]" # PyTorch + Lightning/Transformers helpers
pip install "smartclip[tf]" # TensorFlow/Keras helpers
pip install "smartclip[jax]" # JAX/Flax/Optax helpersimport torch
import smartclip as sc
model = MyModel().to("cpu")
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
with sc.clip_context(model, opt): # Defaults to AutoClip
for x, y in loader:
opt.zero_grad(set_to_none=True)
loss = model(x).loss_fn(y)
loss.backward()
opt.step() # clipped automaticallyimport tensorflow as tf
import smartclip as sc
model = MyModel()
opt = tf.keras.optimizers.Adam(3e-4)
with sc.clip_context(model, opt, clipper=sc.ZScoreClip(zmax=3.0)): # Use the zscore algorithm
model.fit(ds, epochs=5)import jax
import optax
from flax import linen as nn
import smartclip as sc
model = MyModel() # Flax Module
tx = optax.adam(3e-4)
with sc.clip_context(model, tx): # wraps tx.update
grads = jax.grad(loss_fn)(params, batch)
updates, opt_state = tx.update(grads, opt_state, params) # clipped automatically
params = optax.apply_updates(params, updates)See documentation for full guides for TensorFlow, JAX, Lightning, Keras, and HF Trainer.
We welcome issues and pull requests. See contribute.md for developer setup, testing, docs, and release workflows.
MIT