Skip to content

Commit

Permalink
Add gather support for parallel dir structure (#1044)
Browse files Browse the repository at this point in the history
* add docstrings

* Revert "add docstrings"

This reverts commit 6ae5d1d.

* adding fixture

* whitespace fixes

* added (failing) test for parallel results data

* add back all report types into testing

* include serial and paralell in paramterize

* adding support for raw output type

* expose uncertainty and estimate functions to gather

* pass dGs through - works but isn't pretty

* remove todo

* updating changelog

* updating changelog

* Update openfe/protocols/openmm_rfe/equil_rfe_methods.py

Co-authored-by: Hannah Baumann <43765638+hannahbaumann@users.noreply.github.com>

* fix type hint

---------

Co-authored-by: Hannah Baumann <43765638+hannahbaumann@users.noreply.github.com>
  • Loading branch information
atravitz and hannahbaumann authored Dec 20, 2024
1 parent bb90b07 commit 66476fd
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 45 deletions.
23 changes: 23 additions & 0 deletions news/support_gather_parallel.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* ``openfe gather`` now supports replicates that have been submitted in parallel across separate directories.

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
22 changes: 16 additions & 6 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ def __init__(self, **data):
if any(len(pur_list) > 2 for pur_list in self.data.values()):
raise NotImplementedError("Can't stitch together results yet")

@staticmethod
def compute_mean_estimate(dGs:list[unit.Quantity]):
u = dGs[0].u
# convert all values to units of the first value, then take average of magnitude
# this would avoid a screwy case where each value was in different units
vals = [dG.to(u).m for dG in dGs]

return np.average(vals) * u

def get_estimate(self) -> unit.Quantity:
"""Average free energy difference of this transformation
Expand All @@ -267,24 +276,25 @@ def get_estimate(self) -> unit.Quantity:
"""
# TODO: Check this holds up completely for SAMS.
dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()]
return self.compute_mean_estimate(dGs)

@staticmethod
def compute_uncertainty(dGs:list[unit.Quantity]):
u = dGs[0].u
# convert all values to units of the first value, then take average of magnitude
# this would avoid a screwy case where each value was in different units
vals = [dG.to(u).m for dG in dGs]

return np.average(vals) * u
return np.std(vals) * u

def get_uncertainty(self) -> unit.Quantity:
"""The uncertainty/error in the dG value: The std of the estimates of
each independent repeat
"""

dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()]
u = dGs[0].u
# convert all values to units of the first value, then take average of magnitude
# this would avoid a screwy case where each value was in different units
vals = [dG.to(u).m for dG in dGs]
return self.compute_uncertainty(dGs)

return np.std(vals) * u

def get_individual_estimates(self) -> list[tuple[unit.Quantity, unit.Quantity]]:
"""Return a list of tuples containing the individual free energy
Expand Down
40 changes: 24 additions & 16 deletions openfecli/commands/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Callable, Literal
import warnings

from openfe.protocols.openmm_rfe.equil_rfe_methods import RelativeHybridTopologyProtocolResult as rfe_result
from openfe.protocols import openmm_rfe
from openfecli import OFECommandPlugin
from openfecli.clicktypes import HyphenAwareChoice

Expand Down Expand Up @@ -200,7 +202,6 @@ def _parse_raw_units(results: dict) -> list[tuple]:
pu[0]['outputs']['unit_estimate_error'])
for pu in list_of_pur]


def _get_ddgs(legs:dict, error_on_missing=True):
import numpy as np
DDGs = []
Expand All @@ -215,16 +216,20 @@ def _get_ddgs(legs:dict, error_on_missing=True):
do_rhfe = (len(set_vals & {'vacuum', 'solvent'}) == 2)

if do_rbfe:
DG1_mag, DG1_unc = vals['complex']
DG2_mag, DG2_unc = vals['solvent']
DG1_mag = rfe_result.compute_mean_estimate(vals['complex'])
DG1_unc = rfe_result.compute_uncertainty(vals['complex'])
DG2_mag = rfe_result.compute_mean_estimate(vals['solvent'])
DG2_unc = rfe_result.compute_uncertainty(vals['solvent'])
if not ((DG1_mag is None) or (DG2_mag is None)):
# DDG(2,1)bind = DG(1->2)complex - DG(1->2)solvent
DDGbind = (DG1_mag - DG2_mag).m
bind_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))

if do_rhfe:
DG1_mag, DG1_unc = vals['solvent']
DG2_mag, DG2_unc = vals['vacuum']
DG1_mag = rfe_result.compute_mean_estimate(vals['solvent'])
DG1_unc = rfe_result.compute_uncertainty(vals['solvent'])
DG2_mag = rfe_result.compute_mean_estimate(vals['vacuum'])
DG2_unc = rfe_result.compute_uncertainty(vals['vacuum'])
if not ((DG1_mag is None) or (DG2_mag is None)):
DDGhyd = (DG1_mag - DG2_mag).m
hyd_unc = np.sqrt(np.sum(np.square([DG1_unc.m, DG2_unc.m])))
Expand Down Expand Up @@ -258,14 +263,15 @@ def _write_raw(legs:dict, writer:Callable, allow_partial=True):
writer.writerow(["leg", "ligand_i", "ligand_j",
"DG(i->j) (kcal/mol)", "MBAR uncertainty (kcal/mol)"])

for ligpair, vals in sorted(legs.items()):
for simtype, repeats in sorted(vals.items()):
for m, u in repeats:
if m is None:
m, u = 'NaN', 'NaN'
else:
m, u = format_estimate_uncertainty(m.m, u.m)
writer.writerow([simtype, *ligpair, m, u])
for ligpair, results in sorted(legs.items()):
for simtype, repeats in sorted(results.items()):
for repeat in repeats:
for m, u in repeat:
if m is None:
m, u = 'NaN', 'NaN'
else:
m, u = format_estimate_uncertainty(m.m, u.m)
writer.writerow([simtype, *ligpair, m, u])


def _write_dg_raw(legs:dict, writer:Callable, allow_partial): # pragma: no-cover
Expand Down Expand Up @@ -400,7 +406,7 @@ def gather(rootdir:os.PathLike|str,
result_fns = filter(is_results_json, json_fns)

# 3) pair legs of simulations together into dict of dicts
legs = defaultdict(dict)
legs = defaultdict(lambda: defaultdict(list))

for result_fn in result_fns:
result = load_results(result_fn)
Expand All @@ -420,9 +426,11 @@ def gather(rootdir:os.PathLike|str,
simtype = legacy_get_type(result_fn)

if report.lower() == 'raw':
legs[names][simtype] = _parse_raw_units(result)
legs[names][simtype].append(_parse_raw_units(result))
else:
legs[names][simtype] = result['estimate'], result['uncertainty']
dGs = [v[0]['outputs']['unit_estimate'] for v in result['protocol_result']['data'].values()]
## for jobs run in parallel, we need to compute these values
legs[names][simtype].extend(dGs)

writer = csv.writer(
output,
Expand Down
46 changes: 23 additions & 23 deletions openfecli/tests/commands/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,6 @@ def test_format_estimate_uncertainty(est, unc, unc_prec, est_str, unc_str):
def test_get_column(val, col):
assert _get_column(val) == col


@pytest.fixture
def results_dir_serial(tmpdir):
"""Example output data, with replicates run in serial (3 replicates per results JSON)."""
with tmpdir.as_cwd():
with resources.files('openfecli.tests.data') as d:
t = tarfile.open(d / 'rbfe_results.tar.gz', mode='r')
t.extractall('.')

yield

@pytest.fixture
def results_dir_parallel(tmpdir):
"""Identical data to results_dir_serial(), with replicates run in parallel (1 replicate per results JSON)."""
with tmpdir.as_cwd():
with resources.files('openfecli.tests.data') as d:
t = tarfile.open(d / 'results_parallel.tar.gz', mode='r')
t.extractall('.')

yield

_EXPECTED_DG = b"""
ligand DG(MLE) (kcal/mol) uncertainty (kcal/mol)
lig_ejm_31 -0.09 0.05
Expand Down Expand Up @@ -155,9 +134,29 @@ def results_dir_parallel(tmpdir):
solvent lig_ejm_46 lig_jmc_28 23.4 0.8
"""

@pytest.fixture()
def results_dir_serial(tmpdir):
"""Example output data, with replicates run in serial (3 replicates per results JSON)."""
with tmpdir.as_cwd():
with resources.files('openfecli.tests.data') as d:
t = tarfile.open(d / 'rbfe_results.tar.gz', mode='r')
t.extractall('.')

return os.path.abspath(t.getnames()[0])

@pytest.fixture()
def results_dir_parallel(tmpdir):
"""Example output data, with replicates run in serial (3 replicates per results JSON)."""
with tmpdir.as_cwd():
with resources.files('openfecli.tests.data') as d:
t = tarfile.open(d / 'rbfe_results_parallel.tar.gz', mode='r')
t.extractall('.')

return os.path.abspath(t.getnames()[0])

@pytest.mark.parametrize('data_fixture', ['results_dir_serial', 'results_dir_parallel'])
@pytest.mark.parametrize('report', ["", "dg", "ddg", "raw"])
def test_gather(results_dir_serial, report):
def test_gather(request, data_fixture, report):
expected = {
"": _EXPECTED_DG,
"dg": _EXPECTED_DG,
Expand All @@ -171,7 +170,8 @@ def test_gather(results_dir_serial, report):
else:
args = []

result = runner.invoke(gather, ['results'] + args + ['-o', '-'])
results_dir = request.getfixturevalue(data_fixture)
result = runner.invoke(gather, [results_dir] + args + ['-o', '-'])

assert result.exit_code == 0

Expand Down

0 comments on commit 66476fd

Please sign in to comment.