How to use all available cores? #25716
Replies: 1 comment
-
Answering this one for myself: you should use |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Here's a short program that compares (unsharded)
jax.jit
andjax.shard_map
on a simple monolithic array operation:The numbers are basically arbitrary, but they serve to allow me to open
htop
and eyeball how many cores each operation is using:jax.jit
operation uses only 4 cores, no matter what. It reaches ~360% CPU utilization.shard_map
uses all available cores. It reaches ~1350% CPU utilization on my 12+4 perf+efficiency M3 macbook.Here are some other permutations of the XLA_FLAGS settings.
Trying to disable all threading, unsuccessfully:
jax.jit ~ 360%, shard_map ~ 360% (with mesh size of (1, 1, 1, 1))
Using 2 devices:
jax.jit ~ 360%, shard_map ~ 1300% (with mesh size of (2, 1, 1, 1))
What I'm trying to do
My goal is really just to get the best possible performance out of my machine. My real workload is a plasma simulation that is written with the
jax.numpy
API, and the core timestepping loop has an iteration time of say 300ms for a nontrivial problem.Based on the above experiments, I have to conclude that I have zero idea how to control the CPU cores that the code is running on--I can't even get to full utilization of the available CPUs, let alone implement more sophisticated ideas like thread pinning.
I think there are two possible routes:
jax.jit
decorators on myjax.numpy
code, and hope that the Eigen thread pool is sophisticated enough to make full use of the 40 cores on my production machines. What is the flag that will control the number of threads used by Eigen? Why is this apparently limited to 4 in my tests, no matter what I do?shard_map
. This is appealing in some sense because it is quite similar toMPI
, which is a standard tool in our toolkit. However, MPI is only reliably fast if data is local to a core, and there aren't rogue thread pools stomping all over each other. If there's no way to ensure that each shard will get one thread, then I can't see any reason to useshard_map
at all.Beta Was this translation helpful? Give feedback.
All reactions