-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
evaluate_pbc_fast()
is slow
#441
Comments
Hi, and thanks for sharing this, this sounds very promising! I recall there being a tradeoff between using KDTree and scipy.stats.gaussian_kde in earlier versions of the code, depending on whether there were more grid nodes or atoms. If I remember correctly, the older code allowed switching between these methods, possibly guided by heuristics. Given the performance improvements and since your patch doesn’t change the default behavior unless jax is installed, I agree it’s worth integrating. That said, I’m a little concerned about your remark:
Ideally, the results should be identical regardless of implementation. If there are discrepancies, it’s worth investigating whether the issue lies in this implementation or the existing one. Could you please help me looking into this to ensure consistency? Looking forward to reviewing your PR! |
So, I did some test with your version, modifying a bit the code to be able to switch between the two implementations and this is what I get with the ...: import MDAnalysis as mda
...:
...: from pytim.datafiles import WATER_GRO
...: from pytim.gaussian_kde_pbc import gaussian_kde_pbc
...: u = mda.Universe(WATER_GRO)
...: mesh = 2
...: ngrid, spacing = pytim.utilities.compute_compatible_mesh_params(mesh, box=u.dimensions[:3])
...: print('ngrid=',ngrid,'spacing=',spacing)
...: grid = pytim.utilities.generate_grid_in_box(box=u.dimensions[:3], npoints=ngrid, order='xyz')
...: print('grid computed')
...:
...: kernel = gaussian_kde_pbc(u.select_atoms('name OW').positions, box=u.dimensions[:3], sigma=2.0, use_jax=False)
...: print('kernel inited')
...:
...: %timeit -n 1 -r 1 density_field = kernel.evaluate(grid)
...:
...: kernelJax = gaussian_kde_pbc(u.select_atoms('name OW').positions, box=u.dimensions[:3], sigma=2.0, use_jax=True)
...: print('Jax kernel inited')
...: %timeit -n 1 -r 1 density_fieldJax = kernelJax.evaluate(grid)
...: #
ngrid= [25 25 75] spacing= [2. 2. 2.]
grid computed
kernel inited
evaluating using pytim implementation
129 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Jax kernel inited
evaluating using jax
10.7 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each) |
Can you post a configuration of one of the systems you are using, as well as the parameters that you are using to initialise the WillardChandler class? The result above is obtained with the code of the branch https://github.com/Marcello-Sega/pytim/tree/faster-kde |
JAX uses JIT compilation, so the initial run takes a relatively long time to complete. I think the speed improvements for my code must have come from the fact that I was analysing a whole trajectory, so there were many function calls. As for the difference in the location of the surface, I'm confident that the issue is with the method I'm implementing here. To implement the minimum image convention, you need to change the way that the kernel is evaluated, which you have done in the original code. But to take advantage of JAX's fast
I think this is a useful, alternative calculation for some use cases, like my own. To analyse all of my trajectories, the original code would have taken 8 hours, but this variation took less than 1 hour. |
I've been using
WillardChandler
a lot recently, but it recently got to the point that my analysis code would take several hours to run for a series of long simulations. I've attached the cProfile output for the original code (running on small slices of my trajectories). You can see that each execution ofevaluate_pbc_fast()
takes an average of 2.8 seconds.JAX has a very fast
gaussian_kde()
implementation. To make full use of theirevaluate()
code, instead of using the minimum image convention, we can generate a supercell (3x3x3) of a few periodic images and use these as the dataset.With this change, each call to
gaussian_kde()
takes around a millisecond. The code runtime goes from 384s to 15s. Using the default parameters, this results in a slightly larger interface being drawn, so you might need to change the density cutoff.I'll submit a PR with this change so you can see the difference.
profiles.zip
The text was updated successfully, but these errors were encountered: