Possible regression in jaxlib 0.4.25+ causing training deadlocks on GPU #25453
Replies: 1 comment 3 replies
-
Unfortunately it's impossible to say what's going wrong with only this information. It looks like a deadlock, but we don't know how or why without knowing more. I can think of two things that might help: b) a reproducer that we could run would also help. |
Beta Was this translation helpful? Give feedback.
-
We have a model training script that began to experience deadlocks during GPU computation since upgrading from jax 0.4.13 --> 0.4.25+. In particular, these issues emerge with jaxlib 0.4.25, disappear with 0.4.26, and are then present from jaxlib 0.4.27 onwards. We'd appreciate any insights into how we can further understand what's going on. We're attempting to create an MRE in the meantime, but our training code is quite complicated and we're working on bisecting the issue. These issues are present with both single and multi-GPU training runs, and during testing the single-GPU case, removing all sharding-related code does not resolve the issue.
Regression description:
At some point during training we call into a jitted
single_step
function (computing loss and gradients) and this function never exits (nor does it crash), as evidenced by apy-spy
trace. This happens non-deterministically minutes to hours into training runs. We're using weight and biases for logging, and from system resource logs we can see that at the time of the deadlock our GPU power usage decreases to a nontrivial amount and stays at that level with extremely low variation (image below), and looking at the python process we can see that it's waiting for control to return. To reiterate, this appears to be a regression. Our training seems to run just fine on jaxlib 0.4.24 and below.Here's what the GPU power usage looks like, with the hang occuring at around ~26k on the x-axis:
For reference, our H100s idle at ~100W, so something is happening.
Attempt at diagnostics:
When I exec into the training pod after the hang has occurred, I see that the training python process is alive (PID 1)
but it's waiting on a FUTEX
I obtained a backtrace from
gdb
I can provide the rest of the
bt
if anyone thinks it would be helpful.Obligatory environment dump:
Any insights into what's happening or suggestions for debugging this issue would be massively appreciated!
Beta Was this translation helpful? Give feedback.
All reactions