Open
Description
Hi, I know this is probably more of an issue on the jax side and has been discussed there, e.g. jax-ml/jax#743, jax-ml/jax#1539 and jax-ml/jax#6790, although I'm still wondering if you know how to limit the # of threads for jax. Below is a simple snippet showing that currently, jax does not observe the threadpool limits.
import jax.numpy as jnp
from threadpoolctl import threadpool_limits
ja = jnp.ones((1000, 1000))
with threadpool_limits(5):
for _ in range(100):
foo = ja @ ja
Metadata
Metadata
Assignees
Labels
No labels