Skip to content

Commit

Permalink
Initialize context from backend (#156)
Browse files Browse the repository at this point in the history
* add all

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dacheng Xu <dx2227@columbia.edu>
  • Loading branch information
3 people authored Apr 6, 2024
1 parent a5a631f commit b9ecb17
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
48 changes: 33 additions & 15 deletions appletree/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import appletree as apt
from appletree import randgen
from appletree import Parameter
from appletree.utils import load_json
from appletree.utils import load_json, get_file_path
from appletree.share import _cached_configs, set_global_config

os.environ["OMP_NUM_THREADS"] = "1"
Expand Down Expand Up @@ -55,6 +55,17 @@ def __init__(self, instruct, par_config=None):

self.register_all_likelihood(instruct)

@classmethod
def from_backend(cls, backend_h5_file_name):
"""Initialize context from a backend_h5 file."""
with h5py.File(get_file_path(backend_h5_file_name)) as file:
instruct = eval(file["mcmc"].attrs["instruct"])
nwalkers = file["mcmc"].attrs["nwalkers"]
batch_size = file["mcmc"].attrs["batch_size"]
tree = cls(instruct)
tree.pre_fitting(nwalkers, batch_size=batch_size)
return tree

def __getitem__(self, keys):
"""Get likelihood in context."""
return self.likelihoods[keys]
Expand Down Expand Up @@ -174,19 +185,19 @@ def log_posterior(self, parameters, batch_size=1_000_000):
def _ndim(self):
return len(self.par_manager.parameter_fit_array)

def _set_backend(self, nwalkers=100, read_only=True):
def _set_backend(self, nwalkers=100, read_only=True, reset=False):
if self.backend_h5 is None:
self._backend = None
print("With no backend")
else:
self._backend = emcee.backends.HDFBackend(self.backend_h5, read_only=read_only)
if not read_only:
if reset:
self._backend.reset(nwalkers, self._ndim)
print(f"With h5 backend {self.backend_h5}")

def pre_fitting(self, nwalkers=100, read_only=True, batch_size=1_000_000):
def pre_fitting(self, nwalkers=100, read_only=True, reset=False, batch_size=1_000_000):
"""Prepare for fitting, initialize backend and sampler."""
self._set_backend(nwalkers, read_only=read_only)
self._set_backend(nwalkers, read_only=read_only, reset=reset)
self.sampler = emcee.EnsembleSampler(
nwalkers,
self._ndim,
Expand All @@ -212,7 +223,7 @@ def fitting(self, nwalkers=200, iteration=500, batch_size=1_000_000):
self.par_manager.sample_init()
p0.append(self.par_manager.parameter_fit_array)

self.pre_fitting(nwalkers=nwalkers, read_only=False, batch_size=batch_size)
self.pre_fitting(nwalkers=nwalkers, read_only=False, reset=True, batch_size=batch_size)

result = self.sampler.run_mcmc(
p0,
Expand All @@ -221,25 +232,30 @@ def fitting(self, nwalkers=200, iteration=500, batch_size=1_000_000):
progress=True,
)

self._dump_meta()
self._dump_meta(batch_size=batch_size)
return result

def continue_fitting(self, context, iteration=500, batch_size=1_000_000):
def continue_fitting(self, context=None, iteration=500, batch_size=1_000_000):
"""Continue a fitting of another context.
Args:
context: appletree context.
iteration: int, number of steps to generate.
"""
# Final iteration
final_iteration = context.sampler.get_chain()[-1, :, :]
p0 = final_iteration.tolist()
# If context is None, use self, i.e. continue the fitting defined in self
if context is None:
context = self
p0 = None
else:
# Final iteration
final_iteration = context.sampler.get_chain()[-1, :, :]
p0 = final_iteration.tolist()

nwalkers = len(p0)
nwalkers = context.sampler.get_chain().shape[1]

# Init sampler for current context
self.pre_fitting(nwalkers=nwalkers, read_only=False, batch_size=batch_size)
self.pre_fitting(nwalkers=nwalkers, read_only=False, reset=False, batch_size=batch_size)

result = self.sampler.run_mcmc(
p0,
Expand All @@ -249,7 +265,7 @@ def continue_fitting(self, context, iteration=500, batch_size=1_000_000):
skip_initial_state_check=True,
)

self._dump_meta()
self._dump_meta(batch_size=batch_size)
return result

def get_post_parameters(self):
Expand All @@ -276,7 +292,7 @@ def dump_post_parameters(self, file_name):
with open(file_name, "w") as fp:
json.dump(parameters, fp)

def _dump_meta(self, metadata=None):
def _dump_meta(self, batch_size, metadata=None):
"""Save parameters name as attributes."""
if metadata is None:
metadata = {
Expand All @@ -299,6 +315,8 @@ def _dump_meta(self, metadata=None):
opt[name].attrs["config"] = json.dumps(self.config)
# configurations, maybe users will manually add some maps
opt[name].attrs["_cached_configs"] = json.dumps(_cached_configs)
# batch size
opt[name].attrs["batch_size"] = batch_size

def get_template(
self,
Expand Down
4 changes: 2 additions & 2 deletions appletree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def get_file_path(fname):
* can be downloaded from MongoDB, download and return cached path
"""
# 1. From absolute path
# 1. From absolute path if file exists
# Usually Config.default is a absolute path
if fname.startswith("/"):
if os.path.isfile(fname):
return fname

# 2. From local folder
Expand Down
16 changes: 16 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,19 @@ def test_literature_context():
parameters = context.get_post_parameters()
context.get_num_events_accepted(parameters, batch_size=batch_size)
check_unused_configs()


def test_backend():
"""Test backend, initialize from backend and continue fitting."""
_cached_functions.clear()
_cached_configs.clear()
instruct = apt.utils.load_json("rn220.json")
instruct["backend_h5"] = "test_backend.h5"
context = apt.Context(instruct)
context.fitting(nwalkers=100, iteration=2, batch_size=int(1e4))

_cached_functions.clear()
_cached_configs.clear()
context = apt.Context.from_backend("test_backend.h5")
context.continue_fitting(iteration=2, batch_size=int(1e4))
assert context.sampler.get_chain().shape[0] == 4

0 comments on commit b9ecb17

Please sign in to comment.