Skip to content

Commit

Permalink
feat: support for loose init model from run (#1371)
Browse files Browse the repository at this point in the history
  • Loading branch information
LennartPurucker authored Oct 17, 2024
1 parent 26ae499 commit c30cd14
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
6 changes: 4 additions & 2 deletions openml/runs/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def get_run_trace(run_id: int) -> OpenMLRunTrace:
return OpenMLRunTrace.trace_from_xml(trace_xml)


def initialize_model_from_run(run_id: int) -> Any:
def initialize_model_from_run(run_id: int, *, strict_version: bool = True) -> Any:
"""
Initialized a model based on a run_id (i.e., using the exact
same parameter settings)
Expand All @@ -373,6 +373,8 @@ def initialize_model_from_run(run_id: int) -> Any:
----------
run_id : int
The Openml run_id
strict_version: bool (default=True)
See `flow_to_model` strict_version.
Returns
-------
Expand All @@ -382,7 +384,7 @@ def initialize_model_from_run(run_id: int) -> Any:
# TODO(eddiebergman): I imagine this is None if it's not published,
# might need to raise an explicit error for that
assert run.setup_id is not None
return initialize_model(run.setup_id)
return initialize_model(setup_id=run.setup_id, strict_version=strict_version)


def initialize_model_from_trace(
Expand Down
6 changes: 4 additions & 2 deletions openml/setups/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def __list_setups(
return setups


def initialize_model(setup_id: int) -> Any:
def initialize_model(setup_id: int, *, strict_version: bool = True) -> Any:
"""
Initialized a model based on a setup_id (i.e., using the exact
same parameter settings)
Expand All @@ -274,6 +274,8 @@ def initialize_model(setup_id: int) -> Any:
----------
setup_id : int
The Openml setup_id
strict_version: bool (default=True)
See `flow_to_model` strict_version.
Returns
-------
Expand All @@ -294,7 +296,7 @@ def initialize_model(setup_id: int) -> Any:
subflow = flow
subflow.parameters[hyperparameter.parameter_name] = hyperparameter.value

return flow.extension.flow_to_model(flow)
return flow.extension.flow_to_model(flow, strict_version=strict_version)


def _to_dict(
Expand Down
10 changes: 10 additions & 0 deletions tests/test_runs/test_run_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,16 @@ def test_delete_run(self):
_run_id = run.run_id
assert delete_run(_run_id)

@unittest.skipIf(
Version(sklearn.__version__) < Version("0.20"),
reason="SimpleImputer doesn't handle mixed type DataFrame as input",
)
def test_initialize_model_from_run_nonstrict(self):
# We cannot guarantee that a run with an older version exists on the server.
# Thus, we test it simply with a run that we know exists that might not be loose.
# This tests all lines of code for OpenML but not the initialization, which we do not want to guarantee anyhow.
_ = openml.runs.initialize_model_from_run(run_id=1, strict_version=False)


@mock.patch.object(requests.Session, "delete")
def test_delete_run_not_owned(mock_delete, test_files_directory, test_api_key):
Expand Down

0 comments on commit c30cd14

Please sign in to comment.