From da279acc76c38bbf2cc61b328506fe7a16c62b05 Mon Sep 17 00:00:00 2001 From: Tim Mensinger Date: Tue, 3 Dec 2024 17:53:46 +0100 Subject: [PATCH] Handle batch case --- src/optimagic/optimization/history.py | 91 +++++++++++++++++++- src/optimagic/timing.py | 4 +- tests/optimagic/optimization/test_history.py | 85 +++++++++++++++--- 3 files changed, 161 insertions(+), 19 deletions(-) diff --git a/src/optimagic/optimization/history.py b/src/optimagic/optimization/history.py index aa3af7709..dd9bad85b 100644 --- a/src/optimagic/optimization/history.py +++ b/src/optimagic/optimization/history.py @@ -1,7 +1,7 @@ import warnings from dataclasses import dataclass from functools import partial -from typing import Any, Literal +from typing import Any, Callable, Iterable, Literal import numpy as np import pandas as pd @@ -192,6 +192,16 @@ def flat_param_names(self) -> list[str]: def _get_time( self, cost_model: CostModel | Literal["wall_time"] ) -> NDArray[np.float64]: + """Return the cumulative time measure. + + Args: + cost_model: The cost model that is used to calculate the time measure. If + "wall_time", the wall time is returned. + + Returns: + np.ndarray: The time measure. + + """ if not isinstance(cost_model, CostModel) and cost_model != "wall_time": raise ValueError("cost_model must be a CostModel or 'wall_time'.") @@ -207,11 +217,31 @@ def _get_time( fun_and_jac_time = self._get_time_per_task( task=EvalTask.FUN_AND_JAC, cost_factor=cost_model.fun_and_jac ) - return fun_time + jac_time + fun_and_jac_time + + time = fun_time + jac_time + fun_and_jac_time + batch_time = _batch_apply( + data=time, + batch_ids=self.batches, + func=cost_model.aggregate_batch_time, + ) + return np.cumsum(batch_time) def _get_time_per_task( self, task: EvalTask, cost_factor: float | None ) -> NDArray[np.float64]: + """Return the time measure per task. + + Args: + task: The task for which the time is calculated. + cost_factor: The cost factor used to calculate the time. If None, the time + is the difference between the start and stop time, otherwise the time + is given by the cost factor. + + Returns: + np.ndarray: The time per task. For entries where the task is not the + requested task, the time is 0. + + """ dummy_task = np.array([1 if t == task else 0 for t in self.task]) if cost_factor is None: factor: float | NDArray[np.float64] = np.array( @@ -220,7 +250,7 @@ def _get_time_per_task( else: factor = cost_factor - return np.cumsum(factor * dummy_task) + return factor * dummy_task @property def start_time(self) -> list[float]: @@ -351,3 +381,58 @@ def _task_as_categorical(task: list[EvalTask]) -> pd.Categorical: return pd.Categorical( [t.value for t in task], categories=[t.value for t in EvalTask] ) + + +def _batch_apply( + data: NDArray[np.float64], + batch_ids: list[int], + func: Callable[[Iterable[float]], float], +) -> NDArray[np.float64]: + """Apply a reduction operator on batches of data. + + Args: + data: 1d array with data. + batch_ids: A list whose length is equal to the size of data. Values need to be + sorted and can be repeated. + func: A reduction function that takes an iterable of floats as input (e.g., a + numpy array or a list) and returns a scalar. + + Returns: + The transformed data. Has the same length as data. For each batch, the result of + the reduction operation is stored at the first index of that batch, and all + other values of that batch are set to zero. + + """ + batch_start = _get_batch_start(batch_ids) + batch_stop = [*batch_start, len(data)][1:] + + batch_result = [] + for batch, (start, stop) in zip( + batch_ids, zip(batch_start, batch_stop, strict=False), strict=False + ): + try: + batch_data = data[start:stop] + reduced = func(batch_data) + batch_result.append(reduced) + except Exception as e: + msg = ( + f"Calling function {func.__name__} on batch {batch} of the History " + f"History raised an Exception. Please verify that {func.__name__} is " + "properly defined." + ) + raise ValueError(msg) from e + + out = np.zeros_like(data) + out[batch_start] = batch_result + return out + + +def _get_batch_start(batch_ids: list[int]) -> list[int]: + """Get start indices of batch. + + This function assumes that batch_ids non-empty and sorted. + + """ + ids_arr = np.array(batch_ids, dtype=np.int64) + indices = np.where(ids_arr[:-1] != ids_arr[1:])[0] + 1 + return np.insert(indices, 0, 0).tolist() diff --git a/src/optimagic/timing.py b/src/optimagic/timing.py index 5814363f0..db83a76d2 100644 --- a/src/optimagic/timing.py +++ b/src/optimagic/timing.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Callable +from typing import Callable, Iterable @dataclass(frozen=True) @@ -8,7 +8,7 @@ class CostModel: jac: float | None fun_and_jac: float | None label: str - aggregate_batch_time: Callable[[list[float]], float] + aggregate_batch_time: Callable[[Iterable[float]], float] evaluation_time = CostModel( diff --git a/tests/optimagic/optimization/test_history.py b/tests/optimagic/optimization/test_history.py index 58b25b4ca..ab92b88a2 100644 --- a/tests/optimagic/optimization/test_history.py +++ b/tests/optimagic/optimization/test_history.py @@ -10,7 +10,9 @@ from optimagic.optimization.history import ( History, HistoryEntry, + _batch_apply, _calculate_monotone_sequence, + _get_batch_start, _get_flat_param_names, _get_flat_params, _is_1d_array, @@ -143,8 +145,8 @@ def params(): @pytest.fixture -def history(params): - data = { +def history_data(params): + return { "fun": [10, None, 9, None, 2, 5], "task": [ EvalTask.FUN, @@ -157,9 +159,19 @@ def history(params): "start_time": [0, 2, 5, 7, 10, 12], "stop_time": [1, 4, 6, 9, 11, 14], "params": params, - "batches": [0, 0, 1, 1, 2, 2], + "batches": [0, 1, 2, 3, 4, 5], } + +@pytest.fixture +def history(history_data): + return History(direction=Direction.MINIMIZE, **history_data) + + +@pytest.fixture +def history_with_batch_data(history_data): + data = history_data.copy() + data["batches"] = [0, 0, 1, 1, 2, 2] return History(direction=Direction.MINIMIZE, **data) @@ -211,9 +223,8 @@ def test_history_fun_data_with_fun_evaluations_cost_model_and_monotone(history): assert_frame_equal(got, exp, check_dtype=False, check_categorical=False) -@pytest.mark.xfail(reason="Must be fixed!") -def test_history_fun_data_with_fun_batches_cost_model(history): - got = history.fun_data( +def test_history_fun_data_with_fun_batches_cost_model(history_with_batch_data): + got = history_with_batch_data.fun_data( cost_model=om.timing.fun_batches, monotone=False, ) @@ -328,23 +339,23 @@ def test_flat_param_names(history): def test_get_time_per_task_fun(history): got = history._get_time_per_task(EvalTask.FUN, cost_factor=1) - exp = np.array([1, 1, 2, 2, 3, 3]) + exp = np.array([1, 0, 1, 0, 1, 0]) assert_array_equal(got, exp) -def test_get_time_per_task_jac(history): - got = history._get_time_per_task(EvalTask.JAC, cost_factor=1) - exp = np.array([0, 1, 1, 2, 2, 2]) +def test_get_time_per_task_jac_cost_factor_none(history): + got = history._get_time_per_task(EvalTask.JAC, cost_factor=None) + exp = np.array([0, 2, 0, 2, 0, 0]) assert_array_equal(got, exp) def test_get_time_per_task_fun_and_jac(history): - got = history._get_time_per_task(EvalTask.FUN_AND_JAC, cost_factor=1) - exp = np.array([0, 0, 0, 0, 0, 1]) + got = history._get_time_per_task(EvalTask.FUN_AND_JAC, cost_factor=-0.5) + exp = np.array([0, 0, 0, 0, 0, -0.5]) assert_array_equal(got, exp) -def test_get_time_cost_model(history): +def test_get_time_custom_cost_model(history): cost_model = om.timing.CostModel( fun=0.5, jac=1, fun_and_jac=2, label="test", aggregate_batch_time=sum ) @@ -362,6 +373,30 @@ def test_get_time_cost_model(history): assert_array_equal(got, exp) +def test_get_time_fun_evaluations(history): + got = history._get_time(cost_model=om.timing.fun_evaluations) + exp = np.array([1, 1, 2, 2, 3, 4]) + assert_array_equal(got, exp) + + +def test_get_time_fun_batches(history): + got = history._get_time(cost_model=om.timing.fun_batches) + exp = np.array([1, 1, 2, 2, 3, 4]) + assert_array_equal(got, exp) + + +def test_get_time_fun_batches_with_batch_data(history_with_batch_data): + got = history_with_batch_data._get_time(cost_model=om.timing.fun_batches) + exp = np.array([1, 1, 2, 2, 3, 3]) + assert_array_equal(got, exp) + + +def test_get_time_evaluation_time(history): + got = history._get_time(cost_model=om.timing.evaluation_time) + exp = np.array([1, 3, 4, 6, 7, 9]) + assert_array_equal(got, exp) + + def test_get_time_wall_time(history): got = history._get_time(cost_model="wall_time") exp = np.array([1, 4, 6, 9, 11, 14]) @@ -381,7 +416,7 @@ def test_stop_time_property(history): def test_batches_property(history): - assert history.batches == [0, 0, 1, 1, 2, 2] + assert history.batches == [0, 1, 2, 3, 4, 5] # Tasks @@ -466,3 +501,25 @@ def test_task_as_categorical(): got = _task_as_categorical(task) assert got.tolist() == ["fun", "jac", "fun_and_jac"] assert isinstance(got.dtype, pd.CategoricalDtype) + + +def test_get_batch_start(): + batches = [0, 0, 1, 1, 1, 2, 2, 3] + got = _get_batch_start(batches) + assert got == [0, 2, 5, 7] + + +def test_batch_apply_sum(): + data = np.array([0, 1, 2, 3, 4]) + batch_ids = [0, 0, 1, 1, 2] + exp = np.array([1, 0, 5, 0, 4]) + got = _batch_apply(data, batch_ids, sum) + assert_array_equal(exp, got) + + +def test_batch_apply_max(): + data = np.array([0, 1, 2, 3, 4]) + batch_ids = [0, 0, 1, 1, 2] + exp = np.array([1, 0, 3, 0, 4]) + got = _batch_apply(data, batch_ids, max) + assert_array_equal(exp, got)