From 65883549894e57d6b12bf9d409e4c6443b079c4c Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Sun, 6 Oct 2024 14:19:06 -0700 Subject: [PATCH 01/21] Created the ConnectivityStrategy class and added tests Initial implementation of _propose, using the mean connectivity of the edges' (transformations) nodes as unnormalized weights. The weights have yet to be normalized and the tests are just checking that the means are correct. --- src/stratocaster/strategies/__init__.py | 3 ++ src/stratocaster/strategies/connectivity.py | 31 ++++++++++++++++ .../tests/test_connectivity_strategy.py | 37 +++++++++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 src/stratocaster/strategies/__init__.py create mode 100644 src/stratocaster/strategies/connectivity.py create mode 100644 src/stratocaster/tests/test_connectivity_strategy.py diff --git a/src/stratocaster/strategies/__init__.py b/src/stratocaster/strategies/__init__.py new file mode 100644 index 0000000..592b5e5 --- /dev/null +++ b/src/stratocaster/strategies/__init__.py @@ -0,0 +1,3 @@ +from stratocaster.strategies.connectivity import ConnectivityStrategy + +__all__ = ["ConnectivityStrategy"] diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py new file mode 100644 index 0000000..87be326 --- /dev/null +++ b/src/stratocaster/strategies/connectivity.py @@ -0,0 +1,31 @@ +from gufe import AlchemicalNetwork, ProtocolResult + +from stratocaster.base import Strategy, StrategyResult +from stratocaster.base.models import StrategySettings + + +class ConnectivityStrategy(Strategy): + + def _propose( + self, + alchemical_network: AlchemicalNetwork, + protocol_results: list[ProtocolResult], + ) -> StrategyResult: + + alchemical_network_mdg = alchemical_network.graph + weights = {} + + for state_a, state_b in alchemical_network_mdg.edges(): + num_neighbors_a = len(list(alchemical_network_mdg.neighbors(state_a))) + num_neighbors_b = len(list(alchemical_network_mdg.neighbors(state_b))) + transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[ + 0 + ]["object"].key + weights[transformation_key] = (num_neighbors_a + num_neighbors_b) / 2 + + results = StrategyResult(weights=weights) + return results + + @classmethod + def _default_settings(cls) -> StrategySettings: + return StrategySettings() diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py new file mode 100644 index 0000000..33135b4 --- /dev/null +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -0,0 +1,37 @@ +from gufe import AlchemicalNetwork +import pytest + +from stratocaster.strategies import ConnectivityStrategy +from stratocaster.base.strategy import StrategyResult +from stratocaster.base.models import StrategySettings + + +from gufe.tests.test_protocol import DummyProtocol +from gufe.tests.conftest import ( + benzene_variants_star_map, + benzene, + benzene_modifications, + toluene, + phenol, + benzonitrile, + anisole, + benzaldehyde, + styrene, + prot_comp, + solv_comp, + PDB_181L_path, +) + + +@pytest.fixture +def default_strategy(): + _settings = ConnectivityStrategy._default_settings() + return ConnectivityStrategy(_settings) + + +def test_propose( + default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork +): + proposal: StrategyResult = default_strategy.propose(benzene_variants_star_map, []) + + assert all([weight == 3 for weight in proposal._weights.values()]) From 975211bbe0eca36f6571e9beb92261225beb1fdb Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 16 Oct 2024 10:51:39 -0700 Subject: [PATCH 02/21] Added documentation base --- .gitignore | 5 +-- devtools/conda-envs/docs.yaml | 11 +++++++ devtools/conda-envs/{test.yml => test.yaml} | 0 docs/Makefile | 20 ++++++++++++ docs/api.rst | 3 ++ docs/conf.py | 27 ++++++++++++++++ docs/developer_guide.rst | 2 ++ docs/getting_started.rst | 2 ++ docs/index.rst | 16 ++++++++++ docs/installation.rst | 2 ++ docs/make.bat | 35 +++++++++++++++++++++ docs/user_guide.rst | 2 ++ 12 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 devtools/conda-envs/docs.yaml rename devtools/conda-envs/{test.yml => test.yaml} (100%) create mode 100644 docs/Makefile create mode 100644 docs/api.rst create mode 100644 docs/conf.py create mode 100644 docs/developer_guide.rst create mode 100644 docs/getting_started.rst create mode 100644 docs/index.rst create mode 100644 docs/installation.rst create mode 100644 docs/make.bat create mode 100644 docs/user_guide.rst diff --git a/.gitignore b/.gitignore index 46bbf13..1d40d00 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ -/stratocaster.egg-info/ -/stratocaster/__pycache__/ +/src/stratocaster/stratocaster.egg-info/ +/src/stratocaster/__pycache__/ +docs/_build/ diff --git a/devtools/conda-envs/docs.yaml b/devtools/conda-envs/docs.yaml new file mode 100644 index 0000000..7e4fcd9 --- /dev/null +++ b/devtools/conda-envs/docs.yaml @@ -0,0 +1,11 @@ +name: stratocaster-docs +channels: + - conda-forge + +dependencies: + - python>=3.9 + - gufe>=1.0.0 + - sphinx + + - pip: + - git+https://github.com/OpenFreeEnergy/ofe-sphinx-theme@a45f3edd5bc3e973c1a01b577c71efa1b62a65d6 diff --git a/devtools/conda-envs/test.yml b/devtools/conda-envs/test.yaml similarity index 100% rename from devtools/conda-envs/test.yml rename to devtools/conda-envs/test.yaml diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d4bb2cb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..9c16f1b --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,3 @@ +API Reference +============= + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..1e54a51 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,27 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'stratocaster' +copyright = '2024, Ian Kenney' +author = 'Ian Kenney' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [] + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'ofe_sphinx_theme' +html_static_path = ['_static'] diff --git a/docs/developer_guide.rst b/docs/developer_guide.rst new file mode 100644 index 0000000..cf18c40 --- /dev/null +++ b/docs/developer_guide.rst @@ -0,0 +1,2 @@ +Developer Guide +=============== diff --git a/docs/getting_started.rst b/docs/getting_started.rst new file mode 100644 index 0000000..3ce1d78 --- /dev/null +++ b/docs/getting_started.rst @@ -0,0 +1,2 @@ +Getting Started +=============== diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..aa24fb3 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,16 @@ +stratocaster +============ + +The stratocaster package is complimentary to gufe and provides suggestions, via Strategies, for optimally executing Transformation Protocols defined in AlchemicalNetworks. + +This library includes a set of Strategy implementations as well as base classes to facilitate the creation of custom Strategy implementations. + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + installation + getting_started + user_guide + developer_guide + api diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100644 index 0000000..11e4437 --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,2 @@ +Installation +============ diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..32bb245 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/user_guide.rst b/docs/user_guide.rst new file mode 100644 index 0000000..415d843 --- /dev/null +++ b/docs/user_guide.rst @@ -0,0 +1,2 @@ +User guide +========== From 4620ed60f99a309e50c9675da31b72fbf7c038d3 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 29 Oct 2024 14:47:03 -0700 Subject: [PATCH 03/21] Implemented ConnectivityStrategySettings and weight penalty --- src/stratocaster/base/models.py | 4 +- src/stratocaster/base/strategy.py | 9 ++- src/stratocaster/strategies/connectivity.py | 74 +++++++++++++++++++-- 3 files changed, 74 insertions(+), 13 deletions(-) diff --git a/src/stratocaster/base/models.py b/src/stratocaster/base/models.py index 1718c93..0e9c0a7 100644 --- a/src/stratocaster/base/models.py +++ b/src/stratocaster/base/models.py @@ -2,6 +2,4 @@ class StrategySettings(SettingsBaseModel): - - def __init__(self): - normalize_weights: bool = True + pass diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index b64e533..e69f154 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -1,9 +1,8 @@ import abc from typing import Self -from gufe.tokenization import GufeTokenizable -from gufe import AlchemicalNetwork -from gufe.protocols import ProtocolResult +from gufe.tokenization import GufeTokenizable, GufeKey +from gufe import AlchemicalNetwork, ProtocolResult from .models import StrategySettings @@ -52,13 +51,13 @@ def _default_settings(cls) -> StrategySettings: def _propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: list[ProtocolResult], + protocol_results: dict[GufeKey, ProtocolResult], ) -> StrategyResult: raise NotImplementedError def propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: list[ProtocolResult], + protocol_results: dict[GufeKey, ProtocolResult], ) -> StrategyResult: return self._propose(alchemical_network, protocol_results) diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py index 87be326..c463ba7 100644 --- a/src/stratocaster/strategies/connectivity.py +++ b/src/stratocaster/strategies/connectivity.py @@ -1,27 +1,91 @@ +from typing import Optional from gufe import AlchemicalNetwork, ProtocolResult +from gufe.tokenization import GufeKey from stratocaster.base import Strategy, StrategyResult from stratocaster.base.models import StrategySettings +from pydantic import field_validator + + +class ConnectivityStrategySettings(StrategySettings): + + decay_rate: float = 0.5 + cutoff: Optional[float] = None + max_runs: Optional[int] = None + + @field_validator("cutoff", "decay_rate") + def validate_cutoff(cls, value): + if not (0 < value < 1): + raise ValueError("value must be between 0 and 1") + return value + class ConnectivityStrategy(Strategy): + def _exponential_decay_scaling(self, number_of_results: int, decay_rate: float): + return decay_rate**number_of_results + def _propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: list[ProtocolResult], + protocol_results: dict[GufeKey, ProtocolResult], ) -> StrategyResult: + """Propose `Transformation` weight recommendations based on high connectivity nodes. + + Parameters + ---------- + alchemical_network: AlchemicalNetwork + protocol_results: dict[GufeKey, ProtocolResult] + A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork` + and whose values are the `ProtocolResult`s for those `Transformation`s. + + Returns + ------- + StrategyResult + A `StrategyResult` containing the proposed `Transformation` weights. + + """ + + settings = self._settings alchemical_network_mdg = alchemical_network.graph - weights = {} + weights: dict[GufeKey, Optional[float]] = {} for state_a, state_b in alchemical_network_mdg.edges(): - num_neighbors_a = len(list(alchemical_network_mdg.neighbors(state_a))) - num_neighbors_b = len(list(alchemical_network_mdg.neighbors(state_b))) + num_neighbors_a = alchemical_network_mdg.degree(state_a) + num_neighbors_b = alchemical_network_mdg.degree(state_b) + + # linter-satisfying assertion + assert isinstance(num_neighbors_a, int) and isinstance(num_neighbors_b, int) + transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[ 0 ]["object"].key - weights[transformation_key] = (num_neighbors_a + num_neighbors_b) / 2 + + match (protocol_results.get(transformation_key)): + case None: + transformation_num_components = 0 + case pr: + transformation_num_components = pr.num_components + + scaling_factor = self._exponential_decay_scaling( + transformation_num_components, settings["decay_rate"] + ) + weight = scaling_factor * (num_neighbors_a + num_neighbors_b) / 2 + + match (settings.get("max_runs", None), settings.get("cutoff")): + case (None, cutoff): + if weight < cutoff: + weight = None + case (max_runs, None): + if transformation_num_components >= max_runs: + weight = None + case (max_runs, cutoff): + if weight < cutoff or transformation_num_components >= max_runs: + weight = None + + weights[transformation_key] = weight results = StrategyResult(weights=weights) return results From 1e1dbaa1bf7757a812ce5f80b7924213324e97c0 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 30 Oct 2024 13:24:40 -0700 Subject: [PATCH 04/21] Added more model validation --- src/stratocaster/strategies/connectivity.py | 46 +++++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py index c463ba7..71cadcd 100644 --- a/src/stratocaster/strategies/connectivity.py +++ b/src/stratocaster/strategies/connectivity.py @@ -1,25 +1,55 @@ -from typing import Optional from gufe import AlchemicalNetwork, ProtocolResult from gufe.tokenization import GufeKey from stratocaster.base import Strategy, StrategyResult from stratocaster.base.models import StrategySettings -from pydantic import field_validator +from pydantic import field_validator, Field, model_validator class ConnectivityStrategySettings(StrategySettings): - decay_rate: float = 0.5 - cutoff: Optional[float] = None - max_runs: Optional[int] = None - - @field_validator("cutoff", "decay_rate") + decay_rate: float = Field( + 0.5, description="decay rate of the exponential decay penalty factor" + ) + cutoff: float | None = Field( + default=None, + description="unnormalized weight cutoff used for termination condition", + ) + max_runs: int | None = Field( + default=None, + description="the upper limit of protocol DAG results needed before a transformation is no longer weighed", + ) + + @field_validator("cutoff") def validate_cutoff(cls, value): + if value is not None: + if not (0 < value < 1): + raise ValueError("`cutoff` must be between 0 and 1") + return value + + @field_validator("decay_rate") + def validate_decay_rate(cls, value): if not (0 < value < 1): - raise ValueError("value must be between 0 and 1") + raise ValueError("`decay_rate` must be between 0 and 1") return value + + @field_validator("max_runs") + def validate_max_runs(cls, value): + if value is not None: + if not value >= 1: + raise ValueError("`max_runs` must be greater than or equal to 1") + return value + + @model_validator(mode="before") + def check_cutoff_or_max_runs(cls, values): + max_runs, cutoff = values.get("max_runs"), values.get("cutoff") + + if max_runs is None and cutoff is None: + raise ValueError("At least one of `max_runs` or `cutoff` must be set") + return values + class ConnectivityStrategy(Strategy): From 58f0cc4155e53749af5b3d49844262c6cb712dff Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 30 Oct 2024 13:26:26 -0700 Subject: [PATCH 05/21] Change property name from num_components to n_protocol_dag_results --- src/stratocaster/strategies/connectivity.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py index 71cadcd..3e0106c 100644 --- a/src/stratocaster/strategies/connectivity.py +++ b/src/stratocaster/strategies/connectivity.py @@ -33,7 +33,7 @@ def validate_decay_rate(cls, value): if not (0 < value < 1): raise ValueError("`decay_rate` must be between 0 and 1") return value - + @field_validator("max_runs") def validate_max_runs(cls, value): if value is not None: @@ -49,7 +49,7 @@ def check_cutoff_or_max_runs(cls, values): raise ValueError("At least one of `max_runs` or `cutoff` must be set") return values - + class ConnectivityStrategy(Strategy): @@ -80,7 +80,7 @@ def _propose( settings = self._settings alchemical_network_mdg = alchemical_network.graph - weights: dict[GufeKey, Optional[float]] = {} + weights: dict[GufeKey, float | None] = {} for state_a, state_b in alchemical_network_mdg.edges(): num_neighbors_a = alchemical_network_mdg.degree(state_a) @@ -95,12 +95,12 @@ def _propose( match (protocol_results.get(transformation_key)): case None: - transformation_num_components = 0 + transformation_n_protcol_dag_results = 0 case pr: - transformation_num_components = pr.num_components + transformation_n_protcol_dag_results = pr.n_protocol_dag_results scaling_factor = self._exponential_decay_scaling( - transformation_num_components, settings["decay_rate"] + transformation_n_protcol_dag_results, settings["decay_rate"] ) weight = scaling_factor * (num_neighbors_a + num_neighbors_b) / 2 @@ -109,10 +109,13 @@ def _propose( if weight < cutoff: weight = None case (max_runs, None): - if transformation_num_components >= max_runs: + if transformation_n_protcol_dag_results >= max_runs: weight = None case (max_runs, cutoff): - if weight < cutoff or transformation_num_components >= max_runs: + if ( + weight < cutoff + or transformation_n_protcol_dag_results >= max_runs + ): weight = None weights[transformation_key] = weight From a66cb5558f5e23de9ad063d0bcb2a8bd372ed7d3 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 30 Oct 2024 13:44:48 -0700 Subject: [PATCH 06/21] Added normalization to StrategyResult In order to preserve unnormalized weights, which might be useful for debugging/have a more intuitive meaning, the StrategyResult now holds onto these values directly and normalization is handled by calling the new `StrategyResult.resolve` method. --- src/stratocaster/base/strategy.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index e69f154..a828ed6 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -9,7 +9,7 @@ class StrategyResult(GufeTokenizable): - def __init__(self, weights): + def __init__(self, weights: dict[GufeKey, float | None]): self._weights = weights @classmethod @@ -23,6 +23,21 @@ def _to_dict(self) -> dict: def _from_dict(cls, dct: dict) -> Self: return cls(**dct) + @property + def weights(self) -> dict[GufeKey, float | None]: + return self._weights + + def resolve(self) -> dict[GufeKey, float | None]: + weights = self.weights + weight_sum = sum([weight for weight in weights.values() if weight is not None]) + modified_weights = { + key: weight / weight_sum + for key, weight in weights.items() + if weight is not None + } + weights.update(modified_weights) + return weights + # TODO: docstrings class Strategy(GufeTokenizable): From a18dd825188b3587f48a9fafe17e00eb5e593931 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 5 Nov 2024 15:02:28 -0700 Subject: [PATCH 07/21] Enforce correct settings for each strategy * Strategy initialization now checks that the settings provided to the constructor are the correct type specified by the _settings_cls class attribute of any Strategy subclass. This has the effect that a Strategy author must explicitly specify what settings they are will to accept. This deviates slightly from the gufe Protocol class, which doesn't perform such a check. Tests were added to `test_strategy_abstraction.py` showing only specified StrategySettings will be allowed. * Added TypeVar for ProtocolResults in the `Strategy` base class. Any downstream use of the `propose` method will then check that it's being fed a dictionary with subclasses of ProtocolResults. I don't believe there is an easier way to do this since the ProtocolResult is embedded in a dictionary. * gufe doesn't work with pydantic v2, so I now use the pydantic v1 api. * The match statement for capturing the max_runs / cutoff settings was updated and properly catches which of these were set to None --- devtools/conda-envs/test.yml | 2 +- src/stratocaster/base/strategy.py | 20 ++++++- src/stratocaster/strategies/connectivity.py | 33 +++++----- .../tests/test_connectivity_strategy.py | 27 +++++++-- .../tests/test_strategy_abstraction.py | 60 +++++++++++++++++++ 5 files changed, 120 insertions(+), 22 deletions(-) create mode 100644 src/stratocaster/tests/test_strategy_abstraction.py diff --git a/devtools/conda-envs/test.yml b/devtools/conda-envs/test.yml index 2e895a4..cf8084c 100644 --- a/devtools/conda-envs/test.yml +++ b/devtools/conda-envs/test.yml @@ -10,4 +10,4 @@ dependencies: - pytest - pytest-xdist - pytest-cov - - coverage \ No newline at end of file + - coverage diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index a828ed6..1704faa 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -1,11 +1,13 @@ import abc -from typing import Self +from typing import Self, TypeVar from gufe.tokenization import GufeTokenizable, GufeKey from gufe import AlchemicalNetwork, ProtocolResult from .models import StrategySettings +_ProtocolResult = TypeVar("_ProtocolResult", bound=ProtocolResult) + class StrategyResult(GufeTokenizable): @@ -43,8 +45,20 @@ def resolve(self) -> dict[GufeKey, float | None]: class Strategy(GufeTokenizable): """An object that proposes the relative urgency of computing transformations within an AlchemicalNetwork.""" + _settings_cls: type[StrategySettings] + def __init__(self, settings: StrategySettings): + + # TODO better error error message + if not isinstance(settings, self._settings_cls): + raise ValueError() + self._settings = settings + super().__init__() + + @property + def settings(self) -> StrategySettings: + return self._settings @classmethod def _defaults(cls): @@ -66,13 +80,13 @@ def _default_settings(cls) -> StrategySettings: def _propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, ProtocolResult], + protocol_results: dict[GufeKey, _ProtocolResult], ) -> StrategyResult: raise NotImplementedError def propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, ProtocolResult], + protocol_results: dict[GufeKey, _ProtocolResult], ) -> StrategyResult: return self._propose(alchemical_network, protocol_results) diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py index 3e0106c..c15bedc 100644 --- a/src/stratocaster/strategies/connectivity.py +++ b/src/stratocaster/strategies/connectivity.py @@ -4,13 +4,13 @@ from stratocaster.base import Strategy, StrategyResult from stratocaster.base.models import StrategySettings -from pydantic import field_validator, Field, model_validator +from pydantic import validator, Field, root_validator class ConnectivityStrategySettings(StrategySettings): decay_rate: float = Field( - 0.5, description="decay rate of the exponential decay penalty factor" + default=0.5, description="decay rate of the exponential decay penalty factor" ) cutoff: float | None = Field( default=None, @@ -21,27 +21,27 @@ class ConnectivityStrategySettings(StrategySettings): description="the upper limit of protocol DAG results needed before a transformation is no longer weighed", ) - @field_validator("cutoff") + @validator("cutoff") def validate_cutoff(cls, value): if value is not None: if not (0 < value < 1): raise ValueError("`cutoff` must be between 0 and 1") return value - @field_validator("decay_rate") + @validator("decay_rate") def validate_decay_rate(cls, value): if not (0 < value < 1): raise ValueError("`decay_rate` must be between 0 and 1") return value - @field_validator("max_runs") + @validator("max_runs") def validate_max_runs(cls, value): if value is not None: if not value >= 1: raise ValueError("`max_runs` must be greater than or equal to 1") return value - @model_validator(mode="before") + @root_validator def check_cutoff_or_max_runs(cls, values): max_runs, cutoff = values.get("max_runs"), values.get("cutoff") @@ -53,6 +53,8 @@ def check_cutoff_or_max_runs(cls, values): class ConnectivityStrategy(Strategy): + _settings_cls = ConnectivityStrategySettings + def _exponential_decay_scaling(self, number_of_results: int, decay_rate: float): return decay_rate**number_of_results @@ -74,10 +76,12 @@ def _propose( ------- StrategyResult A `StrategyResult` containing the proposed `Transformation` weights. - """ - settings = self._settings + settings = self.settings + + # keep the type checker happy + assert isinstance(settings, ConnectivityStrategySettings) alchemical_network_mdg = alchemical_network.graph weights: dict[GufeKey, float | None] = {} @@ -97,21 +101,22 @@ def _propose( case None: transformation_n_protcol_dag_results = 0 case pr: + assert isinstance(pr, ProtocolResult) transformation_n_protcol_dag_results = pr.n_protocol_dag_results scaling_factor = self._exponential_decay_scaling( - transformation_n_protcol_dag_results, settings["decay_rate"] + transformation_n_protcol_dag_results, settings.decay_rate ) weight = scaling_factor * (num_neighbors_a + num_neighbors_b) / 2 - match (settings.get("max_runs", None), settings.get("cutoff")): - case (None, cutoff): + match (settings.max_runs, settings.cutoff): + case (None, cutoff) if cutoff is not None: if weight < cutoff: weight = None - case (max_runs, None): + case (max_runs, None) if max_runs is not None: if transformation_n_protcol_dag_results >= max_runs: weight = None - case (max_runs, cutoff): + case (max_runs, cutoff) if max_runs is not None and cutoff is not None: if ( weight < cutoff or transformation_n_protcol_dag_results >= max_runs @@ -125,4 +130,4 @@ def _propose( @classmethod def _default_settings(cls) -> StrategySettings: - return StrategySettings() + return ConnectivityStrategySettings(max_runs=3) diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 33135b4..2f62aa4 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -6,7 +6,7 @@ from stratocaster.base.models import StrategySettings -from gufe.tests.test_protocol import DummyProtocol +from gufe.tests.test_protocol import DummyProtocol, DummyProtocolResult from gufe.tests.conftest import ( benzene_variants_star_map, benzene, @@ -21,6 +21,7 @@ solv_comp, PDB_181L_path, ) +from gufe.tokenization import GufeKey @pytest.fixture @@ -29,9 +30,27 @@ def default_strategy(): return ConnectivityStrategy(_settings) -def test_propose( +def test_propose_no_results( default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork ): - proposal: StrategyResult = default_strategy.propose(benzene_variants_star_map, []) + proposal: StrategyResult = default_strategy.propose(benzene_variants_star_map, {}) - assert all([weight == 3 for weight in proposal._weights.values()]) + assert all([weight == 3.5 for weight in proposal._weights.values()]) + assert 1 == sum( + weight for weight in proposal.resolve().values() if weight is not None + ) + + +def test_propose_previous_results( + default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork +): + + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key + result = DummyProtocolResult( + n_protocol_dag_results=1, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + + default_strategy.propose(benzene_variants_star_map, result_data) diff --git a/src/stratocaster/tests/test_strategy_abstraction.py b/src/stratocaster/tests/test_strategy_abstraction.py new file mode 100644 index 0000000..a9b347b --- /dev/null +++ b/src/stratocaster/tests/test_strategy_abstraction.py @@ -0,0 +1,60 @@ +import pytest + +from stratocaster.base import StrategySettings, Strategy +from stratocaster.base.strategy import StrategyResult + +from gufe import AlchemicalNetwork, ProtocolResult +from gufe.tokenization import GufeKey + + +class StrategyASettings(StrategySettings): + pass + + +class StrategyBSettings(StrategySettings): + pass + + +class StrategyNoSettings(Strategy): + + @classmethod + def _default_settings(cls) -> StrategySettings: + return cls._settings_cls() + + def _propose( + self, + alchemical_network: AlchemicalNetwork, + protocol_results: dict[GufeKey, ProtocolResult], + ) -> StrategyResult: + return StrategyResult({}) + + +class StrategyA(StrategyNoSettings): + _settings_cls = StrategyASettings + + +class StrategyB(StrategyNoSettings): + _settings_cls = StrategyBSettings + + +@pytest.mark.parametrize( + ("strategy", "settings"), + ((StrategyA, StrategyBSettings), (StrategyB, StrategyASettings)), +) +def test_incorrect_strategy_settings_passed(strategy, settings): + try: + strategy(settings()) + assert False + except ValueError as e: + pass + + +@pytest.mark.parametrize( + ("strategy", "settings"), + ((StrategyA, StrategyASettings), (StrategyB, StrategyBSettings)), +) +def test_correct_strategy_settings_passed(strategy, settings): + strat_settings = settings() + strat = strategy(strat_settings) + + assert strat._settings_cls == settings From 6964024d830624f6110fb4072a6132c01379ee93 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 13 Nov 2024 06:57:27 -0700 Subject: [PATCH 08/21] Improved error messages around invalid settings arguments --- src/stratocaster/base/strategy.py | 10 ++++++++-- src/stratocaster/tests/test_strategy_abstraction.py | 11 +++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index 1704faa..542ea07 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -49,9 +49,15 @@ class Strategy(GufeTokenizable): def __init__(self, settings: StrategySettings): - # TODO better error error message + if not hasattr(self.__class__, "_settings_cls"): + raise NotImplementedError( + f"class `{self.__class__.__qualname__}` must implement the `_settings_cls` attribute." + ) + if not isinstance(settings, self._settings_cls): - raise ValueError() + raise ValueError( + f"`{self.__class__.__qualname__}` expected a `{self._settings_cls.__qualname__}` instance" + ) self._settings = settings super().__init__() diff --git a/src/stratocaster/tests/test_strategy_abstraction.py b/src/stratocaster/tests/test_strategy_abstraction.py index a9b347b..85bb8c2 100644 --- a/src/stratocaster/tests/test_strategy_abstraction.py +++ b/src/stratocaster/tests/test_strategy_abstraction.py @@ -42,11 +42,8 @@ class StrategyB(StrategyNoSettings): ((StrategyA, StrategyBSettings), (StrategyB, StrategyASettings)), ) def test_incorrect_strategy_settings_passed(strategy, settings): - try: + with pytest.raises(ValueError): strategy(settings()) - assert False - except ValueError as e: - pass @pytest.mark.parametrize( @@ -58,3 +55,9 @@ def test_correct_strategy_settings_passed(strategy, settings): strat = strategy(strat_settings) assert strat._settings_cls == settings + + +def test_no_settings_implemented(): + + with pytest.raises(NotImplementedError): + StrategyNoSettings(StrategyASettings()) From 9572c38abc602340b3787124d85306c53ecaf0a2 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 13 Nov 2024 14:21:17 -0700 Subject: [PATCH 09/21] Connectivity cutoff now applies to unnormalized weights --- src/stratocaster/base/models.py | 1 + src/stratocaster/base/strategy.py | 1 + src/stratocaster/strategies/connectivity.py | 6 +- .../tests/test_connectivity_strategy.py | 62 ++++++++++++++++++- 4 files changed, 65 insertions(+), 5 deletions(-) diff --git a/src/stratocaster/base/models.py b/src/stratocaster/base/models.py index 0e9c0a7..e443b75 100644 --- a/src/stratocaster/base/models.py +++ b/src/stratocaster/base/models.py @@ -1,5 +1,6 @@ from gufe.settings.models import SettingsBaseModel +# TODO: docstrings class StrategySettings(SettingsBaseModel): pass diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index 542ea07..1e41506 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -9,6 +9,7 @@ _ProtocolResult = TypeVar("_ProtocolResult", bound=ProtocolResult) +# TODO: docstrings class StrategyResult(GufeTokenizable): def __init__(self, weights: dict[GufeKey, float | None]): diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py index c15bedc..d5fe343 100644 --- a/src/stratocaster/strategies/connectivity.py +++ b/src/stratocaster/strategies/connectivity.py @@ -7,6 +7,7 @@ from pydantic import validator, Field, root_validator +# TODO: docstrings class ConnectivityStrategySettings(StrategySettings): decay_rate: float = Field( @@ -24,8 +25,8 @@ class ConnectivityStrategySettings(StrategySettings): @validator("cutoff") def validate_cutoff(cls, value): if value is not None: - if not (0 < value < 1): - raise ValueError("`cutoff` must be between 0 and 1") + if not (0 < value): + raise ValueError("`cutoff` must be greater than 0") return value @validator("decay_rate") @@ -51,6 +52,7 @@ def check_cutoff_or_max_runs(cls, values): return values +# TODO: docstrings class ConnectivityStrategy(Strategy): _settings_cls = ConnectivityStrategySettings diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 2f62aa4..918b0fa 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -1,7 +1,12 @@ +import math + from gufe import AlchemicalNetwork import pytest -from stratocaster.strategies import ConnectivityStrategy +from stratocaster.strategies.connectivity import ( + ConnectivityStrategy, + ConnectivityStrategySettings, +) from stratocaster.base.strategy import StrategyResult from stratocaster.base.models import StrategySettings @@ -49,8 +54,59 @@ def test_propose_previous_results( for transformation in benzene_variants_star_map.edges: transformation_key = transformation.key result = DummyProtocolResult( - n_protocol_dag_results=1, info=f"key: {transformation_key}" + n_protocol_dag_results=2, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + + results = default_strategy.propose(benzene_variants_star_map, result_data) + results_no_data = default_strategy.propose(benzene_variants_star_map, {}) + + # the raw weights should no longer be the same + assert results.weights != results_no_data.weights + # since each transformation had the same number of previous results, resolve + # should give back the same normalized weights + assert results.resolve() == results_no_data.resolve() + + +def test_propose_max_runs_termination( + default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork +): + + max_runs = default_strategy.settings.max_runs + + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key + result = DummyProtocolResult( + n_protocol_dag_results=max_runs, info=f"key: {transformation_key}" ) result_data[transformation_key] = result - default_strategy.propose(benzene_variants_star_map, result_data) + results = default_strategy.propose(benzene_variants_star_map, result_data) + + # since the default strategy has a max_runs of 3, we expect all Nones + assert not [weight for weight in results.resolve().values() if weight is not None] + + +def test_propose_cutoff(benzene_variants_star_map): + + settings = ConnectivityStrategySettings(cutoff=2, decay_rate=0.5) + strategy = ConnectivityStrategy(settings) + + assert isinstance(settings.cutoff, float) + num_runs = math.floor( + math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate) + ) + print(num_runs) + + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key + result = DummyProtocolResult( + n_protocol_dag_results=num_runs + 1, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + + results = strategy.propose(benzene_variants_star_map, result_data) + + assert not [weight for weight in results.weights.values() if weight is not None] From 5f9ff78f657a7975c59990a8e1dcd6baf784cac5 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 13 Nov 2024 14:29:53 -0700 Subject: [PATCH 10/21] Removed debugging print statement --- src/stratocaster/tests/test_connectivity_strategy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 918b0fa..e7e7592 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -97,7 +97,6 @@ def test_propose_cutoff(benzene_variants_star_map): num_runs = math.floor( math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate) ) - print(num_runs) result_data: dict[GufeKey, DummyProtocolResult] = {} for transformation in benzene_variants_star_map.edges: From 924256f27fcd9a2b66d321c950d25ae3fdf3b275 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 20 Nov 2024 12:41:24 -0700 Subject: [PATCH 11/21] Increase testing coverage * Provide a variety of parameters to ConnectivityStrategySettings Termination logic within _propose depends on the combination of parameters given to ConnectivityStrategySettings. A preset batch of valid settings are provided through SETTINGS_VALID. * Test ConnectivityStrategySettings validators Add a test for sets of ConnectivityStrategySettings including both valid and invalid settings. If an Exception is expected, check that it is raised, otherwise check that the settings could be instantiated. * Simulate the ConnectivityStrategy Test that the ConnectivityStrategy terminates after a set number of iterations. The max number of iterations is set to 100. * Remove abstract method bodies from coverage reports --- .coveragerc | 4 + pyproject.toml | 1 + .../tests/test_connectivity_strategy.py | 92 ++++++++++++++++++- src/stratocaster/tests/test_strategy_base.py | 47 ++++++++++ 4 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 .coveragerc create mode 100644 src/stratocaster/tests/test_strategy_base.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..d663fad --- /dev/null +++ b/.coveragerc @@ -0,0 +1,4 @@ +[report] +exclude_lines = + @abstractmethod + @abc.abstractmethod \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7a0f407..1a07645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ Issues = "https://github.com/OpenFreeEnergy/stratocaster/issues" [project.optional-dependencies] test = [ "pytest", + "pytest-cov", ] dev = [ "stratocaster[test]", diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index e7e7592..9290559 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -1,4 +1,5 @@ import math +from random import shuffle from gufe import AlchemicalNetwork import pytest @@ -28,6 +29,32 @@ ) from gufe.tokenization import GufeKey +SETTINGS_VALID = [(0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)] + + +@pytest.mark.parametrize( + ["decay_rate", "cutoff", "max_runs", "raises"], + [ + (0, None, None, ValueError), + (1, None, None, ValueError), + (0.5, 0, None, ValueError), + (0.5, None, 0, ValueError), + ] + + [(*vals, None) for vals in SETTINGS_VALID], # include all valid settings +) +def test_connectivity_strategy_settings(decay_rate, cutoff, max_runs, raises): + + def instantiate_settings(): + ConnectivityStrategySettings( + decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs + ) + + if raises: + with pytest.raises(raises): + instantiate_settings() + else: + instantiate_settings() + @pytest.fixture def default_strategy(): @@ -71,8 +98,9 @@ def test_propose_previous_results( def test_propose_max_runs_termination( default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork ): - + assert isinstance(default_strategy.settings, ConnectivityStrategySettings) max_runs = default_strategy.settings.max_runs + assert isinstance(max_runs, int) result_data: dict[GufeKey, DummyProtocolResult] = {} for transformation in benzene_variants_star_map.edges: @@ -88,12 +116,17 @@ def test_propose_max_runs_termination( assert not [weight for weight in results.resolve().values() if weight is not None] -def test_propose_cutoff(benzene_variants_star_map): +def test_propose_cutoff_num_runs_predictioned_termination(benzene_variants_star_map): + """We can predict the number of runs needed to terminate with a given cutoff. + + Each edge in benzene_variants_star_map has a base weight of 3.5. + """ settings = ConnectivityStrategySettings(cutoff=2, decay_rate=0.5) strategy = ConnectivityStrategy(settings) assert isinstance(settings.cutoff, float) + num_runs = math.floor( math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate) ) @@ -109,3 +142,58 @@ def test_propose_cutoff(benzene_variants_star_map): results = strategy.propose(benzene_variants_star_map, result_data) assert not [weight for weight in results.weights.values() if weight is not None] + + +@pytest.mark.parametrize(["decay_rate", "cutoff", "max_runs"], SETTINGS_VALID) +def test_simulated_termination( + default_strategy, benzene_variants_star_map, decay_rate, cutoff, max_runs +): + + settings = ConnectivityStrategySettings( + decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs + ) + default_strategy = ConnectivityStrategy(settings) + + def counts_to_result_data(counts_dict): + result_data = {} + for transformation_key, count in counts_dict.items(): + result = DummyProtocolResult( + n_protocol_dag_results=count, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + return result_data + + def shuffle_take_n(keys_list, n): + shuffle(keys_list) + return keys_list[:n] + + # initial transforms + transformation_counts = { + transformation.key: 0 for transformation in benzene_variants_star_map.edges + } + + max_iterations = 100 + current_iteration = 0 + while current_iteration <= max_iterations: + + if current_iteration == max_iterations: + raise RuntimeError( + f"Strategy did not terminate in {max_iterations} iterations " + ) + + result_data = counts_to_result_data(transformation_counts) + proposal = default_strategy.propose(benzene_variants_star_map, result_data) + + # get random transformations from those with a non-None weight + resolved_keys = shuffle_take_n( + [key for key, weight in proposal.resolve().items() if weight is not None], 5 + ) + + if resolved_keys: + # pretend we ran each of the randomly selected protocols + for key in resolved_keys: + transformation_counts[key] += 1 + # if we got an empty list back, there are not more protocols to run + else: + break + current_iteration += 1 diff --git a/src/stratocaster/tests/test_strategy_base.py b/src/stratocaster/tests/test_strategy_base.py new file mode 100644 index 0000000..118c622 --- /dev/null +++ b/src/stratocaster/tests/test_strategy_base.py @@ -0,0 +1,47 @@ +from stratocaster.base import Strategy, StrategySettings, StrategyResult +from gufe import AlchemicalNetwork, ProtocolResult +from gufe.tokenization import GufeKey + + +class TestStrategyResult: + + result = StrategyResult( + { + GufeKey("MyTransformation-ABC123"): 1, + GufeKey("MyTransformation-321CBA"): None, + GufeKey("MyOtherTransformation-789xyz"): 10, + } + ) + + def test_dict_roundtrip(self): + assert StrategyResult.from_dict(self.result.to_dict()) == self.result + + +class DummyStrategySettings(StrategySettings): + pass + + +class DummyStrategy(Strategy): + + _settings_cls = DummyStrategySettings + + @classmethod + def _default_settings(cls) -> DummyStrategySettings: + return DummyStrategySettings() + + def _propose( + self, + alchemical_network: AlchemicalNetwork, + protocol_results: dict[GufeKey, ProtocolResult], + ): + assert alchemical_network, protocol_results + return StrategyResult({}) + + +class TestStrategy: + + strategy = DummyStrategy(DummyStrategySettings()) + + def test_dict_roundtrip(self): + strategy_dict_form = self.strategy.to_dict() + assert DummyStrategy.from_dict(strategy_dict_form) == self.strategy From f32656859807c902c28b609fb0e78ef8068a7d58 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 20 Nov 2024 13:55:41 -0700 Subject: [PATCH 12/21] Test determinism for connectivity strategies Add a test for ConnectivityStrategy demonstrating deterministc proposals. Given random ProtocolResults, show that StrategyResults are the same after multiple runs. --- .../tests/test_connectivity_strategy.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 9290559..122643e 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -1,5 +1,5 @@ import math -from random import shuffle +from random import shuffle, randint from gufe import AlchemicalNetwork import pytest @@ -197,3 +197,35 @@ def shuffle_take_n(keys_list, n): else: break current_iteration += 1 + + +def test_deterministic( + default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork +): + + settings = default_strategy.settings + assert isinstance(settings, ConnectivityStrategySettings) + + max_runs = settings.max_runs + assert isinstance(max_runs, int) + + def random_runs(): + """Generate random randomized inputs for propose.""" + return { + transformation.key: DummyProtocolResult( + n_protocol_dag_results=randint(0, max_runs), + info=f"key: {transformation.key}", + ) + for transformation in benzene_variants_star_map.edges + } + + for _ in range(10): + random_protocol_results = random_runs() + proposal = default_strategy.propose( + benzene_variants_star_map, protocol_results=random_protocol_results + ) + for _ in range(3): + _proposal = default_strategy.propose( + benzene_variants_star_map, protocol_results=random_protocol_results + ) + assert _proposal == proposal From 1c821e65be11ab270313d4bf2d8bbaa85bc16edd Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 7 Jan 2025 11:23:12 -0700 Subject: [PATCH 13/21] Add gufe submodule for testing until gufev1.2.0 release Since features in the planned gufev1.2.0 release are needed by stratocaster, we need to accurately recreate an environment that gufe will provide upon installation. By using a submodule, we can install the base testing environment from the gufe environment.yml file. This will be reverted once gufev1.2.0 is available on conda-forge. This also makes it easier to react to last minute changes to gufe prior to its v1.2.0 release. --- .gitmodules | 3 +++ devtools/conda-envs/test.yml | 7 +------ lib/gufe | 1 + src/stratocaster/tests/test_connectivity_strategy.py | 1 + 4 files changed, 6 insertions(+), 6 deletions(-) create mode 100644 .gitmodules create mode 160000 lib/gufe diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..d55cb12 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "lib/gufe"] + path = lib/gufe + url = git@github.com:OpenFreeEnergy/gufe diff --git a/devtools/conda-envs/test.yml b/devtools/conda-envs/test.yml index cf8084c..0b3f125 100644 --- a/devtools/conda-envs/test.yml +++ b/devtools/conda-envs/test.yml @@ -5,9 +5,4 @@ channels: dependencies: - python>=3.9 - - gufe>=1.0.0 - - - pytest - - pytest-xdist - - pytest-cov - - coverage + - gufe>=1.2.0 \ No newline at end of file diff --git a/lib/gufe b/lib/gufe new file mode 160000 index 0000000..71a9c66 --- /dev/null +++ b/lib/gufe @@ -0,0 +1 @@ +Subproject commit 71a9c6610a9e13c8f7d588bd8309150557f104a5 diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 122643e..cb37d20 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -15,6 +15,7 @@ from gufe.tests.test_protocol import DummyProtocol, DummyProtocolResult from gufe.tests.conftest import ( benzene_variants_star_map, + benzene_variants_star_map_transformations, benzene, benzene_modifications, toluene, From 85e2b9147e12aa8a112f2ea30a869bd0f089d1ed Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 7 Jan 2025 15:24:41 -0700 Subject: [PATCH 14/21] Update docs and environment information --- devtools/conda-envs/docs.yaml | 7 +------ docs/getting_started.rst | 29 +++++++++++++++++++++++++++++ docs/index.rst | 3 ++- docs/installation.rst | 26 ++++++++++++++++++++++++++ pyproject.toml | 4 ++++ 5 files changed, 62 insertions(+), 7 deletions(-) diff --git a/devtools/conda-envs/docs.yaml b/devtools/conda-envs/docs.yaml index 7e4fcd9..e604c69 100644 --- a/devtools/conda-envs/docs.yaml +++ b/devtools/conda-envs/docs.yaml @@ -3,9 +3,4 @@ channels: - conda-forge dependencies: - - python>=3.9 - - gufe>=1.0.0 - - sphinx - - - pip: - - git+https://github.com/OpenFreeEnergy/ofe-sphinx-theme@a45f3edd5bc3e973c1a01b577c71efa1b62a65d6 + - python>=3.12 diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 3ce1d78..2069331 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -1,2 +1,31 @@ Getting Started =============== + +This guide will help you quickly get started using stratocaster. + +1. Installation +~~~~~~~~~~~~~~~ + +For installation instructions, refer to the :ref:`installation page`. + +2. Verify the installation +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Verify the installation was successful in a Python interpreter + +.. code:: python + + import statocaster + print(stratocaster.__version__) + +3. Quick-start example +~~~~~~~~~~~~~~~~~~~~~~ + +TODO + +Other resources +~~~~~~~~~~~~~~~ + +- `Source code repository `_ +- `GitHub issue tracker `_ + diff --git a/docs/index.rst b/docs/index.rst index aa24fb3..eba5348 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,13 +1,14 @@ stratocaster ============ -The stratocaster package is complimentary to gufe and provides suggestions, via Strategies, for optimally executing Transformation Protocols defined in AlchemicalNetworks. +The stratocaster library is complimentary to gufe and provides suggestions, via Strategies, for optimally executing Transformation Protocols defined in AlchemicalNetworks. This library includes a set of Strategy implementations as well as base classes to facilitate the creation of custom Strategy implementations. .. toctree:: :maxdepth: 2 :caption: Contents: + :hidden: installation getting_started diff --git a/docs/installation.rst b/docs/installation.rst index 11e4437..24bfdca 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,2 +1,28 @@ +.. _installation-label: + Installation ============ + +The only requirement for installing statocaster is a working installation of gufe with a version 1.2.0 or higher. +For general use, we recommend installing from the conda-forge channel, which will also install gufe in the process. + +conda-forge channel +~~~~~~~~~~~~~~~~~~~ + +If you use conda, stratocaster can be installed through the conda-forge channel. + +.. code:: + + conda create -n statocaster-env + conda activate stratocaster-env + conda install -c conda-forge stratocaster + +Development version +~~~~~~~~~~~~~~~~~~~ + +If you want to install the latest development version of stratocaster, you can do so using pip, provided that you have a working installation of gufe (version >=1.2.0) in your environment. + +.. code:: + + pip install git+https://github.com/OpenFreeEnergy/stratocaster.git@main + diff --git a/pyproject.toml b/pyproject.toml index 1a07645..6290519 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,10 @@ dev = [ "stratocaster[test]", "black", ] +docs = [ + "sphinx", + "ofe-sphinx-theme @ git+https://github.com/OpenFreeEnergy/ofe-sphinx-theme.git@main", +] [build-system] requires = [ From 895525e3686767b05be930da713eba56c54dcc12 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 13 Jan 2025 09:02:35 -0700 Subject: [PATCH 15/21] Extract test fixtures from gufe The data structures returned by gufe fixtures have use outside of the package testing (think example systems in docs). I've extracted the network fixture so it is returned by a regular function in the stratocaster.tests.networks module. --- src/stratocaster/tests/networks.py | 117 ++++++++++++++++++ .../tests/test_connectivity_strategy.py | 25 ++-- 2 files changed, 126 insertions(+), 16 deletions(-) create mode 100644 src/stratocaster/tests/networks.py diff --git a/src/stratocaster/tests/networks.py b/src/stratocaster/tests/networks.py new file mode 100644 index 0000000..a6aae9e --- /dev/null +++ b/src/stratocaster/tests/networks.py @@ -0,0 +1,117 @@ +""" +This file contains a modified version of code originally from the +gufe test module (conftest.py). The modifications were made to allow +easier construction of AlchemicalNetworks for development testing and +user examples. + +Original Commit SHA: 71a9c6610a9e13c8f7d588bd8309150557f104a5 +""" + +import importlib + +import gufe +from gufe.tests.test_protocol import DummyProtocol +from openff.units import unit +from rdkit import Chem + + +class BenzeneModifications: + + @staticmethod + def load_benzene_modifications(): + path = ( + importlib.resources.files("gufe.tests.data") / "benzene_modifications.sdf" + ) + supp = Chem.SDMolSupplier(str(path), removeHs=False) + return {m.GetProp("_Name"): m for m in list(supp)} + + _mod = load_benzene_modifications() + + def __class_getitem__(cls, key): + return gufe.SmallMoleculeComponent(cls._mod[key]) + + +def PDB_181L_path(): + path = importlib.resources.files("gufe.tests.data") / "181l.pdb" + return str(path) + + +def benzene_variants_star_map_transformations(): + + benzene = BenzeneModifications["benzene"] + + variants = tuple( + map( + lambda x: BenzeneModifications[x], + [ + "toluene", + "phenol", + "benzonitrile", + "anisole", + "benzaldehyde", + "styrene", + ], + ) + ) + + solv_comp = gufe.SolventComponent( + positive_ion="K", negative_ion="Cl", ion_concentration=0.0 * unit.molar + ) + prot_comp = gufe.ProteinComponent.from_pdb_file(PDB_181L_path()) + + # define the solvent chemical systems and transformations between + # benzene and the others + solvated_ligands = {} + solvated_ligand_transformations = {} + + solvated_ligands["benzene"] = gufe.ChemicalSystem( + {"solvent": solv_comp, "ligand": benzene}, name="benzene-solvent" + ) + + for ligand in variants: + solvated_ligands[ligand.name] = gufe.ChemicalSystem( + {"solvent": solv_comp, "ligand": ligand}, name=f"{ligand.name}-solvnet" + ) + solvated_ligand_transformations[("benzene", ligand.name)] = gufe.Transformation( + solvated_ligands["benzene"], + solvated_ligands[ligand.name], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + + # define the complex chemical systems and transformations between + # benzene and the others + solvated_complexes = {} + solvated_complex_transformations = {} + + solvated_complexes["benzene"] = gufe.ChemicalSystem( + {"protein": prot_comp, "solvent": solv_comp, "ligand": benzene}, + name="benzene-complex", + ) + + for ligand in variants: + solvated_complexes[ligand.name] = gufe.ChemicalSystem( + {"protein": prot_comp, "solvent": solv_comp, "ligand": ligand}, + name=f"{ligand.name}-complex", + ) + solvated_complex_transformations[("benzene", ligand.name)] = ( + gufe.Transformation( + solvated_complexes["benzene"], + solvated_complexes[ligand.name], + protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + mapping=None, + ) + ) + + return list(solvated_ligand_transformations.values()), list( + solvated_complex_transformations.values() + ) + + +def benzene_variants_star_map(): + solvated_ligand_transformations, solvated_complex_transformations = ( + benzene_variants_star_map_transformations() + ) + return gufe.AlchemicalNetwork( + solvated_ligand_transformations + solvated_complex_transformations + ) diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index cb37d20..552207c 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -3,6 +3,7 @@ from gufe import AlchemicalNetwork import pytest +from gufe.tests.test_protocol import DummyProtocol, DummyProtocolResult from stratocaster.strategies.connectivity import ( ConnectivityStrategy, @@ -10,24 +11,16 @@ ) from stratocaster.base.strategy import StrategyResult from stratocaster.base.models import StrategySettings +from stratocaster.tests.networks import ( + benzene_variants_star_map as _benzene_variants_star_map, +) + + +@pytest.fixture(scope="module") +def benzene_variants_star_map(): + return _benzene_variants_star_map() -from gufe.tests.test_protocol import DummyProtocol, DummyProtocolResult -from gufe.tests.conftest import ( - benzene_variants_star_map, - benzene_variants_star_map_transformations, - benzene, - benzene_modifications, - toluene, - phenol, - benzonitrile, - anisole, - benzaldehyde, - styrene, - prot_comp, - solv_comp, - PDB_181L_path, -) from gufe.tokenization import GufeKey SETTINGS_VALID = [(0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)] From b3a42db3d8ec48e5c2bec689538081ba2f8d8360 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 13 Jan 2025 09:08:54 -0700 Subject: [PATCH 16/21] Use isort for import formatting --- pyproject.toml | 4 ++++ src/stratocaster/base/__init__.py | 2 +- src/stratocaster/base/strategy.py | 2 +- src/stratocaster/strategies/connectivity.py | 3 +-- src/stratocaster/tests/test_connectivity_strategy.py | 8 ++++---- src/stratocaster/tests/test_strategy_abstraction.py | 7 +++---- src/stratocaster/tests/test_strategy_base.py | 3 ++- 7 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6290519..3163e41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ test = [ dev = [ "stratocaster[test]", "black", + "isort", ] docs = [ "sphinx", @@ -42,6 +43,9 @@ requires = [ ] build-backend = "setuptools.build_meta" +[tool.isort] +profile = "black" + [tool.versioningit] default-version = "1+unknown" diff --git a/src/stratocaster/base/__init__.py b/src/stratocaster/base/__init__.py index b0b88d2..dbf17b6 100644 --- a/src/stratocaster/base/__init__.py +++ b/src/stratocaster/base/__init__.py @@ -1,2 +1,2 @@ -from .strategy import Strategy, StrategyResult from .models import StrategySettings +from .strategy import Strategy, StrategyResult diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index 1e41506..c146a41 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -1,8 +1,8 @@ import abc from typing import Self, TypeVar -from gufe.tokenization import GufeTokenizable, GufeKey from gufe import AlchemicalNetwork, ProtocolResult +from gufe.tokenization import GufeKey, GufeTokenizable from .models import StrategySettings diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py index d5fe343..f2fe58d 100644 --- a/src/stratocaster/strategies/connectivity.py +++ b/src/stratocaster/strategies/connectivity.py @@ -1,11 +1,10 @@ from gufe import AlchemicalNetwork, ProtocolResult from gufe.tokenization import GufeKey +from pydantic import Field, root_validator, validator from stratocaster.base import Strategy, StrategyResult from stratocaster.base.models import StrategySettings -from pydantic import validator, Field, root_validator - # TODO: docstrings class ConnectivityStrategySettings(StrategySettings): diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 552207c..236c511 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -1,16 +1,16 @@ import math -from random import shuffle, randint +from random import randint, shuffle -from gufe import AlchemicalNetwork import pytest +from gufe import AlchemicalNetwork from gufe.tests.test_protocol import DummyProtocol, DummyProtocolResult +from stratocaster.base.models import StrategySettings +from stratocaster.base.strategy import StrategyResult from stratocaster.strategies.connectivity import ( ConnectivityStrategy, ConnectivityStrategySettings, ) -from stratocaster.base.strategy import StrategyResult -from stratocaster.base.models import StrategySettings from stratocaster.tests.networks import ( benzene_variants_star_map as _benzene_variants_star_map, ) diff --git a/src/stratocaster/tests/test_strategy_abstraction.py b/src/stratocaster/tests/test_strategy_abstraction.py index 85bb8c2..2d50150 100644 --- a/src/stratocaster/tests/test_strategy_abstraction.py +++ b/src/stratocaster/tests/test_strategy_abstraction.py @@ -1,11 +1,10 @@ import pytest - -from stratocaster.base import StrategySettings, Strategy -from stratocaster.base.strategy import StrategyResult - from gufe import AlchemicalNetwork, ProtocolResult from gufe.tokenization import GufeKey +from stratocaster.base import Strategy, StrategySettings +from stratocaster.base.strategy import StrategyResult + class StrategyASettings(StrategySettings): pass diff --git a/src/stratocaster/tests/test_strategy_base.py b/src/stratocaster/tests/test_strategy_base.py index 118c622..6b20058 100644 --- a/src/stratocaster/tests/test_strategy_base.py +++ b/src/stratocaster/tests/test_strategy_base.py @@ -1,7 +1,8 @@ -from stratocaster.base import Strategy, StrategySettings, StrategyResult from gufe import AlchemicalNetwork, ProtocolResult from gufe.tokenization import GufeKey +from stratocaster.base import Strategy, StrategyResult, StrategySettings + class TestStrategyResult: From dfe01c3a5cd1eba19a9d5ed17458263e3d02b81b Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 14 Jan 2025 08:31:13 -0700 Subject: [PATCH 17/21] Remove gufe submodule and fix pydantic import gufe v1.2.0 is available on conda-forge, no longer need the submodule for testing. pydantic imports now match the style used in gufe to address version issue. --- .gitmodules | 3 --- lib/gufe | 1 - src/stratocaster/strategies/connectivity.py | 12 +++++++++++- 3 files changed, 11 insertions(+), 5 deletions(-) delete mode 160000 lib/gufe diff --git a/.gitmodules b/.gitmodules index d55cb12..e69de29 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "lib/gufe"] - path = lib/gufe - url = git@github.com:OpenFreeEnergy/gufe diff --git a/lib/gufe b/lib/gufe deleted file mode 160000 index 71a9c66..0000000 --- a/lib/gufe +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 71a9c6610a9e13c8f7d588bd8309150557f104a5 diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py index f2fe58d..92c4367 100644 --- a/src/stratocaster/strategies/connectivity.py +++ b/src/stratocaster/strategies/connectivity.py @@ -1,10 +1,20 @@ from gufe import AlchemicalNetwork, ProtocolResult from gufe.tokenization import GufeKey -from pydantic import Field, root_validator, validator from stratocaster.base import Strategy, StrategyResult from stratocaster.base.models import StrategySettings +try: + from pydantic.v1 import Field, root_validator, validator +except ImportError: + from pydantic import ( + Field, + root_validator, + validator, + ) + +import pydantic + # TODO: docstrings class ConnectivityStrategySettings(StrategySettings): From 26a23143ba236a8c1317e710e35b0dba9b86e0a1 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 14 Jan 2025 09:11:27 -0700 Subject: [PATCH 18/21] Add tests GitHub action --- .github/workflows/tests.yaml | 59 ++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 .github/workflows/tests.yaml diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml new file mode 100644 index 0000000..836ee8e --- /dev/null +++ b/.github/workflows/tests.yaml @@ -0,0 +1,59 @@ +name: "CI" +on: + # on PR to main + pull_request: + branches: + - main + # on push to main + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +defaults: + run: + shell: bash -l {0} + +jobs: + tests: + runs-on: ${{ matrix.os }} + name: "${{ matrix.os }} python-${{ matrix.python-version }}" + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python-version: + - 3.10 + - 3.11 + - 3.12 + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v2 + with: + environment-name: stratocaster-test + init-shell: bash + cache-environment: true + # Since gufe the single dependency, install it without env file + create-args: >- + python=${{ matrix.python-version }} + gufe + + - name: Install stratocaster + # install test dependencies + run: python -m pip install -e ".[test]" + + - name: Environment information + run: | + micromamba info + micromamba list + + - name: Run tests + run: | + pytest -v src/stratocaster/tests/ From 6ad97941354ce809fe7c5e97befad82f1293f90a Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 14 Jan 2025 09:22:12 -0700 Subject: [PATCH 19/21] Use quotes for python versions in workflow --- .github/workflows/tests.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 836ee8e..1d728da 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -26,9 +26,9 @@ jobs: matrix: os: [ubuntu-latest, macos-latest] python-version: - - 3.10 - - 3.11 - - 3.12 + - "3.10" + - "3.11" + - "3.12" steps: - uses: actions/checkout@v4 with: From d1d38d630ac94aa835a1ae3f15f5901983d7e2ca Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 14 Jan 2025 10:22:58 -0700 Subject: [PATCH 20/21] Remove typing.Self for Python 3.10 compatibility --- devtools/conda-envs/test.yaml | 4 ++-- src/stratocaster/base/strategy.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/devtools/conda-envs/test.yaml b/devtools/conda-envs/test.yaml index 0b3f125..c5ef553 100644 --- a/devtools/conda-envs/test.yaml +++ b/devtools/conda-envs/test.yaml @@ -3,6 +3,6 @@ channels: - conda-forge dependencies: - - python>=3.9 + - python>=3.10 - - gufe>=1.2.0 \ No newline at end of file + - gufe>=1.2.0 diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index c146a41..3acf336 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -1,13 +1,12 @@ import abc -from typing import Self, TypeVar +from typing import TypeVar from gufe import AlchemicalNetwork, ProtocolResult from gufe.tokenization import GufeKey, GufeTokenizable from .models import StrategySettings -_ProtocolResult = TypeVar("_ProtocolResult", bound=ProtocolResult) - +TProtocolResult = TypeVar("TProtocolResult", bound=ProtocolResult) # TODO: docstrings class StrategyResult(GufeTokenizable): @@ -22,8 +21,9 @@ def _defaults(cls): def _to_dict(self) -> dict: return {"weights": self._weights} + # TODO: Return type from typing.Self when Python 3.10 is no longer supported @classmethod - def _from_dict(cls, dct: dict) -> Self: + def _from_dict(cls, dct: dict): return cls(**dct) @property @@ -74,8 +74,9 @@ def _defaults(cls): def _to_dict(self) -> dict: return {"settings": self._settings} + # TODO: Return type from typing.Self when Python 3.10 is no longer supported @classmethod - def _from_dict(cls, dct: dict) -> Self: + def _from_dict(cls, dct: dict): return cls(**dct) @classmethod @@ -87,13 +88,13 @@ def _default_settings(cls) -> StrategySettings: def _propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, _ProtocolResult], + protocol_results: dict[GufeKey, TProtocolResult], ) -> StrategyResult: raise NotImplementedError def propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, _ProtocolResult], + protocol_results: dict[GufeKey, TProtocolResult], ) -> StrategyResult: return self._propose(alchemical_network, protocol_results) From 75e66dccca56e2431336463e8e91c0815e676b51 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 14 Jan 2025 10:31:51 -0700 Subject: [PATCH 21/21] Add test badge to README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 441ab6a..65ad925 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # stratocaster +[![CI](https://github.com/OpenFreeEnergy/stratocaster/actions/workflows/tests.yaml/badge.svg)](https://github.com/OpenFreeEnergy/stratocaster/actions/workflows/tests.yaml) + A library for proposing a prioritization of Transformations within AlchemicalNetworks. ## Installation