Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatible with jax v0.5.0 #930

Open
rg936672 opened this issue Jan 20, 2025 · 1 comment
Open

Incompatible with jax v0.5.0 #930

rg936672 opened this issue Jan 20, 2025 · 1 comment
Labels
new Something yet to be discussed by development team

Comments

@rg936672
Copy link
Contributor

What's the issue?

Updating to jax==0.5.0 causes a number of test failures. Mostly these seem to be precision issues (arrays being equal to only four-ish decimal places, which doesn't pass the more demanding tests), but also some of the kernel gradients seem to be being flipped.

Remove the !=0.5..0 specifier from the jax dependency in pyproject.toml once done.

@rg936672 rg936672 added the new Something yet to be discussed by development team label Jan 20, 2025
@rg936672
Copy link
Contributor Author

Partially resolved by #936 - but we might still want to investigate why the gradient was previously evaluating as 0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new Something yet to be discussed by development team
Projects
None yet
Development

No branches or pull requests

1 participant