Skip to content

stefangordon/smartclip

Repository files navigation

smartclip

PyPI version CI Docs Build Docs Python Versions License: MIT

Adaptive gradient clipping for PyTorch, TensorFlow, and JAX.

SmartClip keeps training stable with adaptive, per-step clipping you can enable in one line of code.

See the full documentation for details of the algorithms, framework usage examples, and logging metrics.

Supported Algorithms

Install

pip install smartclip

Optional extras provide helpers for specific frameworks (install framework wheels first per vendor docs):

pip install "smartclip[torch]"    # PyTorch + Lightning/Transformers helpers
pip install "smartclip[tf]"       # TensorFlow/Keras helpers
pip install "smartclip[jax]"      # JAX/Flax/Optax helpers

Quickstart

PyTorch

import torch
import smartclip as sc

model = MyModel().to("cpu")
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)

with sc.clip_context(model, opt):  # Defaults to AutoClip
    for x, y in loader:
        opt.zero_grad(set_to_none=True)
        loss = model(x).loss_fn(y)
        loss.backward()
        opt.step()  # clipped automatically

TensorFlow/Keras

import tensorflow as tf
import smartclip as sc

model = MyModel()
opt = tf.keras.optimizers.Adam(3e-4)

with sc.clip_context(model, opt, clipper=sc.ZScoreClip(zmax=3.0)):  # Use the zscore algorithm
    model.fit(ds, epochs=5)

JAX/Optax

import jax
import optax
from flax import linen as nn
import smartclip as sc

model = MyModel()  # Flax Module
tx = optax.adam(3e-4)

with sc.clip_context(model, tx):  # wraps tx.update
    grads = jax.grad(loss_fn)(params, batch)
    updates, opt_state = tx.update(grads, opt_state, params)  # clipped automatically
    params = optax.apply_updates(params, updates)

See documentation for full guides for TensorFlow, JAX, Lightning, Keras, and HF Trainer.

Contributing

We welcome issues and pull requests. See contribute.md for developer setup, testing, docs, and release workflows.

License

MIT

About

Adaptive gradient clipping for PyTorch, TensorFlow, and JAX.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages