-
-
Notifications
You must be signed in to change notification settings - Fork 141
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Now using strict shape/dtype promotion rules.
This means that: 1. Tests now pass using `JAX_NUMPY_DTYPE_PROMOTION=strict` and `JAX_NUMPY_RANK_PROMOTION=raise`, and these are enabled in tests by default. 2. The values passed to `diffeqsolve` now more carefully determine the dtype used in the integration (previously things were mostly just left to behave in ad-hoc fashion; whatever the various interacting arrays promoted their dtypes to): a. The dtype of timelike values is the `jnp.result_type` of `t0`, `t1`, `dt0`, and `SaveAt(ts=...)`. If any of these are complex an error is raised. If these are all integers we use the default floating-point dtype. b. The `jnp.result_type` of the time dtype, and each leaf of `y0`, is the dtype of that leaf. 3. Of course, `diffeqsolve` accepts user-specified functions (e.g. the vector field of an `ODETerm`), and these could potentially return arrays with dtypes that do not match the ones we have selected above, which might cause further upcasting. For the sake of backward compatibility we don't try to prohibit this -- a user who feels strongly about this should enable `JAX_NUMPY_DTYPE_PROMOTION=strict` and fix their vector fields appropriately. (And can then be assured that the dtypes of these quantities are exactly as specified by the rules above.) So the key thing this commit enables is that using this flag to enforce this is now possible, without any false positives from Diffrax itself!
- Loading branch information
1 parent
7965e89
commit 0ee47c9
Showing
13 changed files
with
184 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.