Skip to content

Commit

Permalink
Simplify processing of multistart results.
Browse files Browse the repository at this point in the history
  • Loading branch information
janosg committed Nov 30, 2024
1 parent 9b73242 commit d6f894b
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 54 deletions.
9 changes: 9 additions & 0 deletions src/optimagic/optimization/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,15 @@ def solve_internal_problem(
step_id, {"status": str(StepStatus.COMPLETE.value)}
)

# make sure the start params provided in static_result_fields are the same as x0
extra_fields = problem.static_result_fields
x0_problem = problem.converter.params_to_internal(extra_fields.start_params)
if not np.allclose(x0_problem, x0):
start_params = problem.converter.params_from_internal(x0)
extra_fields = replace(
extra_fields, start_params=start_params, start_fun=None
)

res = raw_res.create_optimize_result(
converter=problem.converter,
solver_type=self.algo_info.solver_type,
Expand Down
18 changes: 9 additions & 9 deletions src/optimagic/optimization/multistart.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,19 @@ def single_optimization(x0, step_id):
logger.step_store.update(step, {"status": new_status})
break

multistart_info = {
"start_parameters": state["start_history"],
"local_optima": state["result_history"],
"exploration_sample": sorted_sample,
"exploration_results": exploration_res["sorted_values"],
}

raw_res = state["best_res"]

expl_sample = [
internal_problem.converter.params_from_internal(s) for s in sorted_sample
]
expl_res = list(exploration_res["sorted_values"])

res = process_multistart_result(
raw_res=raw_res,
converter=internal_problem.converter,
extra_fields=internal_problem.static_result_fields,
multistart_info=multistart_info,
local_optima=state["result_history"],
exploration_sample=expl_sample,
exploration_results=expl_res,
)

return res
Expand Down
4 changes: 2 additions & 2 deletions src/optimagic/optimization/optimize_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class OptimizeResult:

params: Any
fun: float
start_fun: float
start_fun: float | None
start_params: Any
algorithm: str
direction: str
Expand Down Expand Up @@ -78,7 +78,7 @@ def criterion(self) -> float:
return self.fun

@property
def start_criterion(self) -> float:
def start_criterion(self) -> float | None:
msg = (
"The start_criterion attribute is deprecated. Use the start_fun attribute "
"instead."
Expand Down
60 changes: 19 additions & 41 deletions src/optimagic/optimization/process_multistart_result.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
import copy
from typing import Any

import numpy as np
from numpy.typing import NDArray

from optimagic.optimization.convergence_report import get_convergence_report
from optimagic.optimization.optimize_result import MultistartInfo, OptimizeResult
from optimagic.parameters.conversion import Converter
from optimagic.typing import Direction, ExtraResultFields


def process_multistart_result(
raw_res: OptimizeResult,
converter: Converter,
extra_fields: ExtraResultFields,
multistart_info: dict[str, Any],
local_optima: list[OptimizeResult],
exploration_sample: list[NDArray[np.float64]],
exploration_results: list[float],
) -> OptimizeResult:
"""Process results of internal optimizers."""

if isinstance(raw_res, str):
res = _dummy_result_from_traceback(raw_res, extra_fields)

Check warning on line 19 in src/optimagic/optimization/process_multistart_result.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/optimization/process_multistart_result.py#L19

Added line #L19 was not covered by tests
else:
res = raw_res
info = _process_multistart_info(
multistart_info,
converter=converter,
extra_fields=extra_fields,
if extra_fields.direction == Direction.MAXIMIZE:
exploration_results = [-res for res in exploration_results]

info = MultistartInfo(
start_parameters=[opt.start_params for opt in local_optima],
local_optima=local_optima,
exploration_sample=exploration_sample,
exploration_results=exploration_results,
)

# ==============================================================================
Expand Down Expand Up @@ -53,42 +55,18 @@ def process_multistart_result(
return res


def _process_multistart_info(
info: dict[str, Any],
converter: Converter,
extra_fields: ExtraResultFields,
) -> MultistartInfo:
starts = [converter.params_from_internal(x) for x in info["start_parameters"]]

optima = []
for res, start in zip(info["local_optima"], starts, strict=False):
processed = copy.copy(res)
processed.start_params = start
processed.start_fun = None
optima.append(processed)

sample = [converter.params_from_internal(x) for x in info["exploration_sample"]]

if extra_fields.direction == Direction.MINIMIZE:
exploration_res = info["exploration_results"]
else:
exploration_res = [-res for res in info["exploration_results"]]

return MultistartInfo(
start_parameters=starts,
local_optima=optima,
exploration_sample=sample,
exploration_results=exploration_res,
)


def _dummy_result_from_traceback(
candidate: str, extra_fields: ExtraResultFields
) -> OptimizeResult:
if extra_fields.start_fun is None:
start_fun = np.inf

Check warning on line 62 in src/optimagic/optimization/process_multistart_result.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/optimization/process_multistart_result.py#L61-L62

Added lines #L61 - L62 were not covered by tests
else:
start_fun = extra_fields.start_fun

Check warning on line 64 in src/optimagic/optimization/process_multistart_result.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/optimization/process_multistart_result.py#L64

Added line #L64 was not covered by tests

out = OptimizeResult(

Check warning on line 66 in src/optimagic/optimization/process_multistart_result.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/optimization/process_multistart_result.py#L66

Added line #L66 was not covered by tests
params=extra_fields.start_params,
fun=extra_fields.start_fun,
start_fun=extra_fields.start_fun,
fun=start_fun,
start_fun=start_fun,
start_params=extra_fields.start_params,
algorithm=extra_fields.algorithm,
direction=extra_fields.direction.value,
Expand Down
2 changes: 1 addition & 1 deletion src/optimagic/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class MultiStartIterationHistory(TupleLikeAccess):
class ExtraResultFields:
"""Fields for OptimizeResult that are not part of InternalOptimizeResult."""

start_fun: float
start_fun: float | None
start_params: PyTree
algorithm: str
direction: Direction
Expand Down
2 changes: 1 addition & 1 deletion src/optimagic/visualization/history_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _extract_plotting_data_from_results_object(
res.multistart_info.exploration_sample[::-1] + stacked["params"]
)
stacked["criterion"] = (
res.multistart_info.exploration_results.tolist()[::-1]
list(res.multistart_info.exploration_results)[::-1]
+ stacked["criterion"]
)
else:
Expand Down

0 comments on commit d6f894b

Please sign in to comment.