Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] minimization with complex arguments #71

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

Randl
Copy link
Contributor

@Randl Randl commented Aug 7, 2024

Following up on #61 (comment)
This doesn't introduce fixes yet, just adds a simple test that fails to highlight the issue

@Randl
Copy link
Contributor Author

Randl commented Aug 8, 2024

Locally, with patrick-kidger/lineax#103 this fails three least square tests (two are related to some pytree structure, and one should be some complex-related bug since the wrong answer is returned) and most of the minimize_jvp tests.

@Randl
Copy link
Contributor Author

Randl commented Aug 11, 2024

Unfortunately, jvp is frustratingly hard to debug; for example, the call stack is useless due to all the wrappers around the function calls. @patrick-kidger, do you have any advice on how I can approach that?

Specifically, I see that the LU solver gets 0 instead of 2 as a matrix, but I have no idea where it is called from and what this value is.

We can, of course, go the old way and merge the working stuff (minimization as for now and hopefully least squares soon) and work out the jvp later.

@patrick-kidger
Copy link
Owner

So the stack goes function -> jaxpr -> runtime. I think if the trace-time callstack doesn't help you, then I do have some tricks up my sleeve to help with the jaxpr and runtime approaches.

To help with intercepting things at the jaxpr level then I like to use eqx.debug.announce_transform. Passing an appropriate announce function then you can e.g. insert a breakpoint into the point at which a JVP transform is performed. (Even if this happens after a function is parsed into a jaxpr.)

As for the runtime level then lots of jax.debug.prints / jax.debug.breakpoints would probably be my usual thing here -- bisect through the program.

@Randl
Copy link
Contributor Author

Randl commented Aug 12, 2024

Ok, it turns out many tests passed since I initialized the minimization at the real line. I've changed that, and now there are more fails, but that helped me to find where the problem is.

@Randl Randl mentioned this pull request Aug 17, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants