From e6f9f88a23673108137e9327df3877ba57c0887e Mon Sep 17 00:00:00 2001 From: "Matthew W. Thompson" Date: Wed, 11 Dec 2024 12:37:31 -0600 Subject: [PATCH 1/2] PERF: Cache some charge increment calculations --- openff/interchange/smirnoff/_nonbonded.py | 28 ++++++++++------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/openff/interchange/smirnoff/_nonbonded.py b/openff/interchange/smirnoff/_nonbonded.py index 950cfe92..0295fa03 100644 --- a/openff/interchange/smirnoff/_nonbonded.py +++ b/openff/interchange/smirnoff/_nonbonded.py @@ -56,17 +56,13 @@ _ZERO_CHARGE = Quantity(0.0, unit.elementary_charge) -@unit.wraps( - ret=unit.elementary_charge, - args=(unit.elementary_charge, unit.elementary_charge), - strict=True, -) +@functools.lru_cache(None) def _add_charges( - charge1: "Quantity", - charge2: "Quantity", + charge1: float, + charge2: float, ) -> "Quantity": """Add two charges together.""" - return charge1 + charge2 + return Quantity(charge1 + charge2, "elementary_charge") def _upconvert_vdw_handler(vdw_handler: vdWHandler): @@ -358,8 +354,8 @@ def _get_charges( orientation_atom_index = topology_key.orientation_atom_indices[i] charges[orientation_atom_index] = _add_charges( - charges.get(orientation_atom_index, _ZERO_CHARGE), - increment, + charges.get(orientation_atom_index, _ZERO_CHARGE).m, + increment.m, ) elif parameter_key == "charge": @@ -374,8 +370,8 @@ def _get_charges( "ExternalSource", ): charges[atom_index] = _add_charges( - charges.get(atom_index, _ZERO_CHARGE), - parameter_value, + charges.get(atom_index, _ZERO_CHARGE).m, + parameter_value.m, ) elif potential_key.associated_handler in ( # type: ignore[operator] @@ -386,8 +382,8 @@ def _get_charges( # There should be a better way to do this. charges[atom_index] = _add_charges( - charges.get(atom_index, _ZERO_CHARGE), - parameter_value, + charges.get(atom_index, _ZERO_CHARGE).m, + parameter_value.m, ) else: @@ -401,8 +397,8 @@ def _get_charges( atom_index = topology_key.atom_indices[0] charges[atom_index] = _add_charges( - charges.get(atom_index, _ZERO_CHARGE), - parameter_value, + charges.get(atom_index, _ZERO_CHARGE).m, + parameter_value.m, ) logger.info( From f3037f0356cf7bfb9dd1fc04e04a3dbd1c326d74 Mon Sep 17 00:00:00 2001 From: "Matthew W. Thompson" Date: Wed, 15 Jan 2025 13:33:13 -0600 Subject: [PATCH 2/2] Fix --- openff/interchange/smirnoff/_nonbonded.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/openff/interchange/smirnoff/_nonbonded.py b/openff/interchange/smirnoff/_nonbonded.py index ea4cefc7..fd478025 100644 --- a/openff/interchange/smirnoff/_nonbonded.py +++ b/openff/interchange/smirnoff/_nonbonded.py @@ -53,7 +53,7 @@ LibraryChargeHandler, ] -_ZERO_CHARGE = Quantity(0.0, unit.elementary_charge) +_ZERO_CHARGE = Quantity(0.0, "elementary_charge") @functools.lru_cache(None) @@ -395,7 +395,10 @@ def _get_charges( atom_index = topology_key.atom_indices[0] - charges[atom_index] = charges.get(atom_index, 0.0) + parameter_value.m + charges[atom_index] = _add_charges( + charges.get(atom_index, _ZERO_CHARGE).m, + parameter_value.m, + ) logger.info( "Charge section ChargeIncrementModel, applying charge increment from atom " # type: ignore[union-attr]