Runtime scaling of JIT compilation #3478
-
How does jit-compile time scale with size of computation? If my computation can be represented by a computational graph of length L and width W, can I predict the scaling of jit-compilation time? Seems like it should definitely depend on L. Not sure about W, if the operations are vectorized. As we know with loops, you split n jit or use lax.scan, and that speeds up compilation of the loop immensely. Iirc, when I naively jitted an inner for loop with L steps rather than scan, jit-compilation seemed to go like O(L^2). On my current project, a linear computational graph of fixed L but variable W, compile time depends on W (matrix sizes) very strangely, there doesn't seem to be any discernible pattern (sometimes grows with W, sometimes shrinks with W). Is there a way to reason about this? I'm trying to understand when it makes sense to split your big computation into smaller functions and jit those, versus jitting a single bigger computation. |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 1 reply
-
Congratulations on starting our first GitHub Discussions thread! I have to agree: my mental model is that XLA compilation time scales in the worst case approximately If you have interesting and well-motivated examples of programs with particularly slow compile times, we could raise them with the XLA developers. |
Beta Was this translation helpful? Give feedback.
-
XLA TPU lead here, it is important to distinguish between the time it takes to process the the HLO graph and the time it takes to compile each op in the final HLO graph. For HLO passes, typically the dominant complexity is linear in number of ops. Passes like fusion have quadratic behavior in the number of parameters in a fusion node every time a fusion is created. In theory this could be cubic complexity of everything in a graph fuses and adds an operand. Many hlo passes produce a ReachabilityMap which is quadratic complexity with the largest computation. In practice this is generally quite fast due to a small constant factor. As a result, hlo passes are definitely quadratic. That said, typically more time is spent generating machine code from hlo ops which is definitely linear in terms of number of ops, but each of those ops can exhibit behaviors correlated to dimension sizes or how amenable the generated code is to be optimized. The reliance on LLVM for vectorization and loop unrolling on xla CPU can create pretty nasty compile time problems. For XLA GPU, the ptx to cubin process is a relatively slow black box. The XLA TPU backend is completely custom and as a result was co-designed for how the code generation process works. |
Beta Was this translation helpful? Give feedback.
-
I was wondering if there's any place to find out general info (like @blakehechtman gave above) about the flow of a jit compilation. What kind of optimizations does XLA do and when do they happen? For example, if I have a scan that is really two disconnected scans (two sets of in/out/carry that don't interact in the body), does XLA recognize this and break them apart if doing so is useful? This seems like a contrived example but if the answer is "while loops are hard and we don't currently do anything like that" then that's helpful to me. I have a big huge graph that I've been struggling with for over a year now, usually running out of host memory at jit time. I usually put one big jit around my entire computation, but I've also tried to jit subgraphs first which seemed counterproductive. Over the course of many moons I've made several changes that I imagine (with my Theano upbringing) ought to cut some connections and vastly simplify the "effective" graph, but I have no way of telling (except by observing OOMs) whether JAX (at graph construction time) / XLA (at compilation time) are able to exploit these things, and if they are able to do so before they do the things that scale poorly. |
Beta Was this translation helpful? Give feedback.
-
@cheshire is looking into optimizing while loops for XLA GPU, though possibly not the exact cases you're seeing. |
Beta Was this translation helpful? Give feedback.
XLA TPU lead here, it is important to distinguish between the time it takes to process the the HLO graph and the time it takes to compile each op in the final HLO graph.
For HLO passes, typically the dominant complexity is linear in number of ops. Passes like fusion have quadratic behavior in the number of parameters in a fusion node every time a fusion is created. In theory this could be cubic complexity of everything in a graph fuses and adds an operand.
Many hlo passes produce a ReachabilityMap which is quadratic complexity with the largest computation. In practice this is generally quite fast due to a small constant factor. As a result, hlo passes are definitely quadratic.
That said, …