Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamped tuning #130

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5b47f0c
Initial version of `grid_search`
GardevoirX Nov 26, 2024
5936ccc
Remove error
GardevoirX Nov 26, 2024
6b8eba1
Allow a precomputed nl
GardevoirX Nov 26, 2024
78e2a30
Renamed examples, and added a tuning playground
ceriottm Nov 23, 2024
96e8552
Nelder mead (doesn't work because actual error is not a good target)
ceriottm Nov 24, 2024
886d8d8
Added a tuning class
ceriottm Nov 24, 2024
c5f0e66
I'm not a morning person it seems
ceriottm Nov 24, 2024
2047ca2
Examples
ceriottm Nov 24, 2024
4422082
Better plotting
ceriottm Nov 24, 2024
abca208
Fixes on `H` and `RMS_phi`
GardevoirX Nov 25, 2024
376c647
Some cleaning and test fix
GardevoirX Nov 25, 2024
adb25af
Further clean
GardevoirX Nov 26, 2024
fcd155e
Replace `loss` in tuning with `ErrorBounds` and draft for `Tuner`
GardevoirX Nov 27, 2024
964e427
Supress output
GardevoirX Nov 27, 2024
baf5fcc
Update `grid_search`
GardevoirX Nov 28, 2024
c05ef59
Return something when is cannot reach desired accuracy
GardevoirX Nov 28, 2024
565607d
Supress output
GardevoirX Nov 28, 2024
8a29f16
Repair some errors of the example
GardevoirX Nov 28, 2024
799cbe4
Add a warning for the case that no parameter can meet the accuracy re…
GardevoirX Dec 5, 2024
1889695
Update warning
GardevoirX Dec 5, 2024
104d9c1
Documentations and pytests update
GardevoirX Dec 18, 2024
14c10c1
Added a TIP4P example
ceriottm Dec 20, 2024
ab6fd50
Started to change the API to use full charges rather than the sum of …
ceriottm Dec 20, 2024
e0f7291
Move from `sum_squared_charges` to `charges`
GardevoirX Dec 28, 2024
567acdf
Refactor the tuning methods with a base class
GardevoirX Dec 28, 2024
9eaac38
Fix pytests and make linter happy
GardevoirX Dec 28, 2024
6bca873
Mini cleanups
ceriottm Dec 29, 2024
83532a3
Docs fix
GardevoirX Dec 29, 2024
22cfa3e
Separate timings calculator
ceriottm Dec 29, 2024
47c154d
Linting
ceriottm Dec 29, 2024
bc831d3
Try fix github action failures
GardevoirX Dec 29, 2024
0be1c88
Add tuning functions back
GardevoirX Jan 7, 2025
d7a3165
Allow doctests
GardevoirX Jan 7, 2025
848dfd1
Fix doctests and remove orphan functions
GardevoirX Jan 7, 2025
c813435
Fix ewald doctest again and remove unused members
GardevoirX Jan 7, 2025
949b94a
Formatting
GardevoirX Jan 7, 2025
9d43fb6
Draft for renovated tuning
GardevoirX Jan 13, 2025
2a2bf0f
For now move back to `CoulombPotential`
GardevoirX Jan 13, 2025
4da66b4
Rearange the tuning stuff
GardevoirX Jan 13, 2025
8021c35
Rearrange again
GardevoirX Jan 15, 2025
2ce197d
An initial version of refurnished documentation
GardevoirX Jan 15, 2025
6c7b593
Minor modification
GardevoirX Jan 16, 2025
d5b708a
ErrorBounds related updates
GardevoirX Jan 16, 2025
625c4f1
Update error formulas
GardevoirX Jan 16, 2025
7044024
Update tuning tests
GardevoirX Jan 16, 2025
0019092
Lint manually
GardevoirX Jan 16, 2025
0a83bbb
Update tuning doctests
GardevoirX Jan 16, 2025
4c22071
Update example
GardevoirX Jan 16, 2025
d55c9a0
some minor cleanup
PicoCentauri Jan 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/references/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ refer to the :ref:`userdoc-how-to` section.

potentials/index
calculators/index
tuning/index
metatensor
lib/index
utils/index
Expand Down
43 changes: 43 additions & 0 deletions docs/src/references/tuning/base_classes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
Base Classes
############

GardevoirX marked this conversation as resolved.
Show resolved Hide resolved
Current scheme behind all tuning functions is grid-searching based, focusing on the Fourier
space parameters like ``lr_wavelength``, ``mesh_spacing`` and ``interpolation_nodes``.
For real space parameter ``cutoff``, it is treated as a hyperparameter here, which
should be manually specified by the user. The parameter ``smearing`` is determined by
the real space error formula and is set to achieve a real space error of
``desired_accuracy / 4``.

The Fourier space parameters are all discrete, so it's convenient to do the grid-search.
Default searching-ranges are provided for those parameters. For ``lr_wavelength``, the
values are chosen to be with a minimum of 1 and a maximum of 13 mesh points in each
spatial direction ``(x, y, z)``. For ``mesh_spacing``, the values are set to have
minimally 2 and maximally 7 mesh points in each spatial direction, for both the P3M and
PME method. The values of ``interpolation_nodes`` are the same as those supported in
:class:`torchpme.lib.MeshInterpolator`.

In the grid-searching, all possible parameter combinations are evaluated. The error
associated with the parameter is estimated by the error formulas implemented in the
subclasses of :class:`torchpme.tuning.tuner.TuningErrorBounds`. Parameter with
the error within the desired accuracy are benchmarked for computational time by
:class:`torchpme.tuning.tuner.TuningTimings` The timing of the other parameters are
not tested and set to infinity.

The return of these tuning functions contains the ``smearing`` and a dictionary, in
which there is parameter for the Fourier space. The parameter is that of the desired
accuracy and the shortest timing. The parameter of the smallest error will be returned
in the case that no parameter can fulfill the accuracy requirement.


.. autoclass:: torchpme.tuning.tuner.TunerBase
:members:

.. autoclass:: torchpme.tuning.tuner.GridSearchTuner
:members:

.. autoclass:: torchpme.tuning.tuner.TuningTimings
:members:

.. autoclass:: torchpme.tuning.tuner.TuningErrorBounds
:members:

14 changes: 14 additions & 0 deletions docs/src/references/tuning/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Tuning
######

The choice of parameters like the neighborlist ``cutoff``, the ``smearing`` or the
``lr_wavelength``/``mesh_spacing`` has a large influence one the accuracy of the
calculation. To help find the parameters that meet the accuracy requirements, this
module offers tuning methods for the calculators.


.. toctree::
:maxdepth: 1
:glob:

./*
24 changes: 24 additions & 0 deletions docs/src/references/tuning/tune_ewald.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Tune Ewald
##########

GardevoirX marked this conversation as resolved.
Show resolved Hide resolved
The tuning is based on the following error formulas:

.. math::
\Delta F_\mathrm{real}
\approx \frac{Q^2}{\sqrt{N}}
\frac{2}{\sqrt{r_{\text{cutoff}} V}}
e^{-r_{\text{cutoff}}^2 / 2 \sigma^2}

.. math::
\Delta F_\mathrm{Fourier}^\mathrm{Ewald}
\approx \frac{Q^2}{\sqrt{N}}
\frac{\sqrt{2} / \sigma}{\pi\sqrt{2 V / h}} e^{-2\pi^2 \sigma^2 / h ^ 2}

where :math:`N` is the number of charges, :math:`Q^2 = \sum_{i = 1}^N q_i^2`, is the sum of squared
charges, :math:`r_{\text{cutoff}}` is the short-range cutoff, :math:`V` is the volume of the
simulation box and :math:`h^2` is the long range wavelength.

.. autofunction:: torchpme.tuning.ewald.tune_ewald

.. autoclass:: torchpme.tuning.ewald.EwaldErrorBounds
:members:
27 changes: 27 additions & 0 deletions docs/src/references/tuning/tune_p3m.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Tune P3M
#########

The tuning is based on the following error formulas:

.. math::
\Delta F_\mathrm{real}
\approx \frac{Q^2}{\sqrt{N}}
\frac{2}{\sqrt{r_{\text{cutoff}} V}}
e^{-r_{\text{cutoff}}^2 / 2 \sigma^2}

.. math::
\Delta F_\mathrm{Fourier}^\mathrm{P3M}
\approx \frac{Q^2}{L^2}(\frac{\sqrt{2}H}{\sigma})^p
\sqrt{\frac{\sqrt{2}L}{N\sigma}
\sqrt{2\pi}\sum_{m=0}^{p-1}a_m^{(p)}(\frac{\sqrt{2}H}{\sigma})^{2m}}

where :math:`N` is the number of charges, :math:`Q^2 = \sum_{i = 1}^N q_i^2`, is the sum of squared
charges, :math:`r_{\text{cutoff}}` is the short-range cutoff, :math:`V` is the volume of the
simulation box, :math:`p` is the order of the interpolation scheme, :math:`H` is the spacing of mesh
points and :math:`a_m^{(p)}` is an expansion coefficient.


.. autofunction:: torchpme.tuning.p3m.tune_p3m

.. autoclass:: torchpme.tuning.p3m.P3MErrorBounds
:members:
27 changes: 27 additions & 0 deletions docs/src/references/tuning/tune_pme.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Tune PME
#########

The tuning is based on the following error formulas:

.. math::
\Delta F_\mathrm{real}
\approx \frac{Q^2}{\sqrt{N}}
\frac{2}{\sqrt{r_{\text{cutoff}} V}}
e^{-r_{\text{cutoff}}^2 / 2 \sigma^2}

.. math::
\Delta F_\mathrm{Fourier}^\mathrm{PME}
\approx 2\pi^{1/4}\sqrt{\frac{3\sqrt{2} / \sigma}{N(2p+3)}}
\frac{Q^2}{L^2}\frac{(\sqrt{2}H/\sigma)^{p+1}}{(p+1)!} \times
\exp{\frac{(p+1)[\log{(p+1)} - \log 2 - 1]}{2}} \left< \phi_p^2 \right> ^{1/2}

where :math:`N` is the number of charges, :math:`Q^2 = \sum_{i = 1}^N q_i^2`, is the sum of squared
charges, :math:`r_{\text{cutoff}}` is the short-range cutoff, :math:`V` is the volume of the
simulation box, :math:`p` is the order of the interpolation scheme, :math:`H` is the spacing of mesh
points, and :math:`\phi_p^2 = H^{-(p+1)}\prod_{s\in S_H^{(p)}}(x - s)`, in which :math:`S_H^{(p)}` is
the :math:`p+1` mesh points closest to the point :math:`x`.

.. autofunction:: torchpme.tuning.pme.tune_pme

.. autoclass:: torchpme.tuning.pme.PMEErrorBounds
:members:
19 changes: 0 additions & 19 deletions docs/src/references/utils/tuning.rst

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@
from metatensor.torch.atomistic import NeighborListOptions, System

import torchpme
from torchpme.tuning import tune_pme

# %%
#
# Create the properties CsCl unit cell

symbols = ("Cs", "Cl")
types = torch.tensor([55, 17])
charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64)
positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=torch.float64)
cell = torch.eye(3, dtype=torch.float64)
pbc = torch.tensor([True, True, True])
Expand All @@ -55,8 +57,9 @@
# The ``sum_squared_charges`` is equal to ``2.0`` becaue each atom either has a charge
# of 1 or -1 in units of elementary charges.

smearing, pme_params, cutoff = torchpme.utils.tune_pme(
sum_squared_charges=2.0, cell=cell, positions=positions
cutoff = 4.4
smearing, pme_params = tune_pme(
charges=charges, cell=cell, positions=positions, cutoff=cutoff
)

# %%
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import vesin.torch

import torchpme
from torchpme.tuning import tune_pme

# %%
#
Expand Down Expand Up @@ -92,9 +93,9 @@
cell = torch.from_numpy(atoms.cell.array)

sum_squared_charges = float(torch.sum(charges**2))

smearing, pme_params, cutoff = torchpme.utils.tune_pme(
sum_squared_charges=sum_squared_charges, cell=cell, positions=positions
cutoff = 4.4
smearing, pme_params = tune_pme(
charges=charges, cell=cell, positions=positions, cutoff=cutoff
)

# %%
Expand Down
File renamed without changes.
File renamed without changes.
86 changes: 84 additions & 2 deletions examples/5-autograd-demo.py → examples/05-autograd-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
exercise to the reader.
"""

# %%

from time import time

import ase
Expand Down Expand Up @@ -477,10 +479,11 @@ def forward(self, positions, cell, charges):
)

# %%
# We can also time the difference in execution
# We can also evaluate the difference in execution
# time between the Pytorch and scripted versions of the
# module (depending on the system, the relative efficiency
# of the two evaluations could go either way!)
# of the two evaluations could go either way, as this is
# a too small system to make a difference!)

duration = 0.0
for _i in range(20):
Expand Down Expand Up @@ -515,3 +518,82 @@ def forward(self, positions, cell, charges):
print(f"Evaluation time:\nPytorch: {time_python}ms\nJitted: {time_jit}ms")

# %%
# Other auto-differentiation ideas
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO opinion I wouldn't put this example here - even though I think it is good to have it. The tutorial is already 500 lines and with this super long. I rather vote for smaller examples tackling one specific tasks. Finding solutions is much easier if they are shorter. See also the beloved matplotlib examples.

# --------------------------------
#
# There are many other ways the auto-differentiation engine of
# ``torch`` can be used to facilitate the evaluation of atomistic
# models.

# %%
# 4-site water models
# ~~~~~~~~~~~~~~~~~~~
#
# Several water models (starting from the venerable TIP4P model of
# `Abascal and C. Vega, JCP (2005) <http://doi.org/10.1063/1.2121687>`_)
# use a center of negative charge that is displaced from the O position.
# This is easily implemented, yielding the forces on the O and H positions
# generated by the displaced charge.

structure = ase.Atoms(
positions=[
[0, 0, 0],
[0, 1, 0],
[1, -0.2, 0],
],
cell=[6, 6, 6],
symbols="OHH",
)

cell = torch.from_numpy(structure.cell.array).to(device=device, dtype=dtype)
positions = torch.from_numpy(structure.positions).to(device=device, dtype=dtype)

# %%
# The key step is to create a "fourth site" based on the O positions
# and use it in the ``interpolate`` step.

charges = torch.tensor(
[[-1.0], [0.5], [0.5]],
dtype=dtype,
device=device,
)

positions.requires_grad_(True)
charges.requires_grad_(True)
cell.requires_grad_(True)

positions_4site = torch.vstack(
[
((positions[1::3] + positions[2::3]) * 0.5 + positions[0::3] * 3) / 4,
positions[1::3],
positions[2::3],
]
)

ns = torch.tensor([5, 5, 5])
interpolator = torchpme.lib.MeshInterpolator(
cell=cell, ns_mesh=ns, interpolation_nodes=3, method="Lagrange"
)
interpolator.compute_weights(positions_4site)
mesh = interpolator.points_to_mesh(charges)

value = (mesh**2).sum()

# %%
# The gradients can be computed by just running `backward` on the
# end result. Gradients are computed on the H and O positions.

value.backward()

print(
f"""
Position gradients:
{positions.grad.T}

Cell gradients:
{cell.grad}

Charges gradients:
{charges.grad.T}
"""
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading
Loading