diff --git a/gufe/ligandnetwork.py b/gufe/ligandnetwork.py index 3390f377..4db451da 100644 --- a/gufe/ligandnetwork.py +++ b/gufe/ligandnetwork.py @@ -197,10 +197,10 @@ def _to_rfe_alchemical_network( components: dict[str, :class:`.Component`] non-alchemical components (components that will be on both sides of a transformation) - leg_label: dict[str, list[str]] + leg_labels: dict[str, list[str]] mapping of the names for legs (the keys of this dict) to a list - of the component names. The componnent names must be the same as - as used in the ``componentns`` dict. + of the component names. The component names must be the same as + used in the ``components`` dict. protocol: :class:`.Protocol` the protocol to apply alchemical_label: str @@ -237,12 +237,9 @@ def sys_from_dict(component): else: name = "" - mapping: dict[str, gufe.ComponentMapping] = { - alchemical_label: edge, - } - transformation = gufe.Transformation(sysA, sysB, protocol, - mapping, name) + mapping=edge, + name=name) transformations.append(transformation) diff --git a/gufe/protocols/protocol.py b/gufe/protocols/protocol.py index 0064c0b3..4bf7f455 100644 --- a/gufe/protocols/protocol.py +++ b/gufe/protocols/protocol.py @@ -135,15 +135,16 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[dict[str, ComponentMapping]] = None, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], extends: Optional[ProtocolDAGResult] = None, ) -> list[ProtocolUnit]: """Method to override in custom :class:`Protocol` subclasses. - This method should take two `ChemicalSystem`s, and optionally a - dict mapping string to ``ComponentMapping``, and prepare a collection of ``ProtocolUnit`` instances - that when executed in order give sufficient information to estimate the - free energy difference between those two `ChemicalSystem`s. + This method should take two `ChemicalSystem`s, and optionally one or + more ``ComponentMapping`` objects, and prepare a collection of + ``ProtocolUnit`` instances that when executed in order give sufficient + information to estimate the free energy difference between those two + `ChemicalSystem`s. This method should return a list of `ProtocolUnit` instances. For an instance in which another `ProtocolUnit` is given as a parameter @@ -170,7 +171,7 @@ def create( *, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Union[dict[str, ComponentMapping], None], + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], extends: Optional[ProtocolDAGResult] = None, name: Optional[str] = None, transformation_key: Optional[GufeKey] = None @@ -192,9 +193,9 @@ def create( The starting `ChemicalSystem` for the transformation. stateB : ChemicalSystem The ending `ChemicalSystem` for the transformation. - mapping : Optional[dict[str, ComponentMapping]] + mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] Mappings of e.g. atoms between a labelled component in the - stateA and stateB `ChemicalSystem` . + stateA and stateB `ChemicalSystem` . extends : Optional[ProtocolDAGResult] If provided, then the `ProtocolDAG` produced will start from the end state of the given `ProtocolDAGResult`. This allows for diff --git a/gufe/tests/test_ligand_network.py b/gufe/tests/test_ligand_network.py index f6f5972a..a81a0888 100644 --- a/gufe/tests/test_ligand_network.py +++ b/gufe/tests/test_ligand_network.py @@ -343,8 +343,8 @@ def test_to_rbfe_alchemical_network( assert compsA.get('protein') == compsB.get('protein') assert compsA.get('cofactor') == compsB.get('cofactor') - assert list(edge.mapping) == ['ligand'] - assert edge.mapping['ligand'] in real_molecules_network.edges + assert isinstance(edge.mapping, gufe.ComponentMapping) + assert edge.mapping in real_molecules_network.edges def test_to_rbfe_alchemical_network_autoname_false( self, diff --git a/gufe/tests/test_protocol.py b/gufe/tests/test_protocol.py index f700fa45..a3ac0a44 100644 --- a/gufe/tests/test_protocol.py +++ b/gufe/tests/test_protocol.py @@ -112,7 +112,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[dict[str, ComponentMapping]] = None, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]]=None, extends: Optional[ProtocolDAGResult] = None, ) -> List[ProtocolUnit]: @@ -172,7 +172,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[dict[str, ComponentMapping]] = None, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]]=None, extends: Optional[ProtocolDAGResult] = None, ) -> list[ProtocolUnit]: @@ -513,7 +513,7 @@ def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[dict[str, ComponentMapping]] = None, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None, extends: Optional[ProtocolDAGResult] = None, ) -> List[ProtocolUnit]: return [NoDepUnit(settings=self.settings, diff --git a/gufe/tests/test_protocoldag.py b/gufe/tests/test_protocoldag.py index a4a08a7b..9454eda0 100644 --- a/gufe/tests/test_protocoldag.py +++ b/gufe/tests/test_protocoldag.py @@ -53,9 +53,9 @@ def _default_settings(cls): def _defaults(cls): return {} - def _create(self, stateA, stateB, mapping=None, extends=None) -> list[gufe.ProtocolUnit]: + def _create(self, stateA, stateB, mapping, extends=None) -> list[gufe.ProtocolUnit]: return [ - WriterUnit(identity=i) for i in range(self.settings.n_repeats) # type: ignore + WriterUnit(identity=i) for i in range(self.settings.n_repeats) # type: ignore ] def _gather(self, results): @@ -69,7 +69,7 @@ def writefile_dag(): p = WriterProtocol(settings=WriterProtocol.default_settings()) - return p.create(stateA=s1, stateB=s2, mapping={}) + return p.create(stateA=s1, stateB=s2, mapping=[]) @pytest.mark.parametrize('keep_shared', [False, True]) diff --git a/gufe/transformations/transformation.py b/gufe/transformations/transformation.py index f067faa1..bb8a6c08 100644 --- a/gufe/transformations/transformation.py +++ b/gufe/transformations/transformation.py @@ -1,7 +1,7 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe -from typing import Optional, Iterable +from typing import Optional, Iterable, Union import json from ..tokenization import GufeTokenizable, JSON_HANDLER @@ -17,16 +17,20 @@ class Transformation(GufeTokenizable): Connects two :class:`ChemicalSystem` objects, with directionality. """ + _stateA: ChemicalSystem + _stateB: ChemicalSystem + _name: Optional[str] + _mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] + _protocol: Protocol def __init__( self, stateA: ChemicalSystem, stateB: ChemicalSystem, protocol: Protocol, - mapping: Optional[dict[str, ComponentMapping]] = None, + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None, name: Optional[str] = None, ): - self._stateA = stateA self._stateB = stateB self._mapping = mapping @@ -64,10 +68,8 @@ def protocol(self) -> Protocol: return self._protocol @property - def mapping(self) -> Optional[dict[str, ComponentMapping]]: - """ - Mapping of e.g. atoms between ``stateA`` and ``stateB``. - """ + def mapping(self) -> Optional[Union[ComponentMapping, list[ComponentMapping]]]: + """The mappings relevant for this Transformation""" return self._mapping @property