diff --git a/torax/config/config_args.py b/torax/config/config_args.py index c2278d68..f237feaf 100644 --- a/torax/config/config_args.py +++ b/torax/config/config_args.py @@ -73,25 +73,6 @@ def _check(ft): return _check(field_type) # pytype: disable=bad-return-type -def _is_bool( - interp_input: interpolated_param.InterpolatedVarSingleAxisInput, -) -> bool: - if isinstance(interp_input, dict): - if not interp_input: - raise ValueError('InterpolatedVarSingleAxisInput must include values.') - value = list(interp_input.values())[0] - return isinstance(value, bool) - return isinstance(interp_input, bool) - - -def _convert_value_to_floats( - interp_input: interpolated_param.InterpolatedVarSingleAxisInput, -) -> interpolated_param.InterpolatedVarSingleAxisInput: - if isinstance(interp_input, dict): - return {key: float(value) for key, value in interp_input.items()} - return float(interp_input) - - def get_interpolated_var_single_axis( interpolated_var_single_axis_input: interpolated_param.InterpolatedVarSingleAxisInput, ) -> interpolated_param.InterpolatedVarSingleAxis: @@ -100,40 +81,17 @@ def get_interpolated_var_single_axis( Args: interpolated_var_single_axis_input: Input that can be used to construct a `interpolated_param.InterpolatedVarSingleAxis` object. Can be either: - Python primitives, an xr.DataArray, a tuple(axis_array, values_array). - See torax.readthedocs.io/en/latest/configuration.html#time-varying-scalars - for more information on the supported inputs. + Python primitives, an xr.DataArray, a tuple(axis_array, values_array). See + torax.readthedocs.io/en/latest/configuration.html#time-varying-scalars for + more information on the supported inputs. Returns: A constructed interpolated var. """ - interpolation_mode = interpolated_param.InterpolationMode.PIECEWISE_LINEAR - # The param is a InterpolatedVarSingleAxisInput, so we need to convert it to - # an InterpolatedVarSingleAxis first. - if isinstance(interpolated_var_single_axis_input, tuple): - if len(interpolated_var_single_axis_input) != 2: - raise ValueError( - 'Single axis interpolated var tuple length must be 2. The first ' - 'element are the values and the second element is the ' - 'interpolation mode or both values should be arrays to be directly ' - f'interpolated. Given: {interpolated_var_single_axis_input}.' + xs, ys, interpolation_mode, is_bool_param = ( + interpolated_param.convert_input_to_xs_ys( + interpolated_var_single_axis_input ) - if isinstance(interpolated_var_single_axis_input[1], str): - interpolation_mode = interpolated_param.InterpolationMode[ - interpolated_var_single_axis_input[1].upper() - ] - interpolated_var_single_axis_input = interpolated_var_single_axis_input[0] - - if _is_bool(interpolated_var_single_axis_input): - interpolated_var_single_axis_input = _convert_value_to_floats( - interpolated_var_single_axis_input - ) - is_bool_param = True - else: - is_bool_param = False - - xs, ys = interpolated_param.convert_input_to_xs_ys( - interpolated_var_single_axis_input ) interpolated_var_single_axis = interpolated_param.InterpolatedVarSingleAxis( @@ -205,11 +163,12 @@ def _load_from_primitives( if not primitive_values: raise ValueError('Values mapping must not be empty.') - primitive_values = { - t: interpolated_param.convert_input_to_xs_ys(v) - for t, v in primitive_values.items() - } - return primitive_values + loaded_values = {} + for t, v in primitive_values.items(): + x, y, _, _ = interpolated_param.convert_input_to_xs_ys(v) + loaded_values[t] = (x, y) + + return loaded_values def _load_from_xr_array( diff --git a/torax/interpolated_param.py b/torax/interpolated_param.py index dae4971e..da5d1fad 100644 --- a/torax/interpolated_param.py +++ b/torax/interpolated_param.py @@ -246,11 +246,63 @@ def rhonorm1_defined_in_timerhoinput( return True +def _is_bool( + interp_input: InterpolatedVarSingleAxisInput, +) -> bool: + if isinstance(interp_input, dict): + if not interp_input: + raise ValueError('InterpolatedVarSingleAxisInput must include values.') + value = list(interp_input.values())[0] + return isinstance(value, bool) + return isinstance(interp_input, bool) + + +def _convert_value_to_floats( + interp_input: InterpolatedVarSingleAxisInput, +) -> InterpolatedVarSingleAxisInput: + if isinstance(interp_input, dict): + return {key: float(value) for key, value in interp_input.items()} + return float(interp_input) + + def convert_input_to_xs_ys( interp_input: InterpolatedVarSingleAxisInput, -) -> tuple[chex.Array, chex.Array]: - """Converts config inputs into inputs suitable for constructors.""" +) -> tuple[chex.Array, chex.Array, InterpolationMode, bool]: + """Converts config inputs into inputs suitable for constructors. + + Args: + interp_input: The input to convert. + + Returns: + A tuple of (xs, ys, interpolation_mode, is_bool_param) where xs and ys are + the arrays to be used in the constructor, interpolation_mode is the + interpolation mode to be used, and is_bool_param is True if the input is a + bool and False otherwise. + """ # This function does NOT need to be jittable. + interpolation_mode = InterpolationMode.PIECEWISE_LINEAR + # The param is a InterpolatedVarSingleAxisInput, so we need to convert it to + # an InterpolatedVarSingleAxis first. + if isinstance(interp_input, tuple): + if len(interp_input) != 2: + raise ValueError( + 'Single axis interpolated var tuple length must be 2. The first ' + 'element are the values and the second element is the ' + 'interpolation mode or both values should be arrays to be directly ' + f'interpolated. Given: {interp_input}.' + ) + if isinstance(interp_input[1], str): + interpolation_mode = InterpolationMode[interp_input[1].upper()] + interp_input = interp_input[0] + + if _is_bool(interp_input): + interp_input = _convert_value_to_floats( + interp_input + ) + is_bool_param = True + else: + is_bool_param = False + if isinstance(interp_input, xr.DataArray): if len(interp_input.coords) != 1: raise ValueError( @@ -265,6 +317,8 @@ def convert_input_to_xs_ys( return ( interp_input[index].data, interp_input.values, + interpolation_mode, + is_bool_param, ) if isinstance(interp_input, tuple): if len(interp_input) != 2: @@ -276,33 +330,26 @@ def convert_input_to_xs_ys( sort_order = np.argsort(xs) xs = xs[sort_order] ys = ys[sort_order] - return np.asarray(xs), np.asarray(ys) + return np.asarray(xs), np.asarray(ys), interpolation_mode, is_bool_param if isinstance(interp_input, dict): if not interp_input: raise ValueError('InterpolatedVarSingleAxisInput must include values.') sorted_keys = sorted(interp_input.keys()) values = [interp_input[key] for key in sorted_keys] - return np.array(sorted_keys), np.array(values) + return ( + np.array(sorted_keys), + np.array(values), + interpolation_mode, + is_bool_param, + ) else: # The input is a single value. - return np.array([0]), np.array([interp_input]) - - -def _is_bool(interp_input: InterpolatedVarSingleAxisInput) -> bool: - if isinstance(interp_input, dict): - if not interp_input: - raise ValueError('InterpolatedVarSingleAxisInput must include values.') - value = list(interp_input.values())[0] - return isinstance(value, bool) - return isinstance(interp_input, bool) - - -def convert_value_to_floats( - interp_input: InterpolatedVarSingleAxisInput, -) -> InterpolatedVarSingleAxisInput: - if isinstance(interp_input, dict): - return {key: float(value) for key, value in interp_input.items()} - return float(interp_input) + return ( + np.array([0]), + np.array([interp_input]), + interpolation_mode, + is_bool_param, + ) class InterpolatedVarSingleAxis(InterpolatedParamBase): diff --git a/torax/tests/interpolated_param.py b/torax/tests/interpolated_param.py index 67a65289..5c975559 100644 --- a/torax/tests/interpolated_param.py +++ b/torax/tests/interpolated_param.py @@ -450,9 +450,9 @@ def test_interpolated_var_time_rho(self, values, x, y, expected_output): ) def test_convert_input_to_xs_ys(self, values, expected_output): """Test input conversion to numpy arrays.""" - output = interpolated_param.convert_input_to_xs_ys(values) - np.testing.assert_allclose(output[0], expected_output[0]) - np.testing.assert_allclose(output[1], expected_output[1]) + _, _, x, y = interpolated_param.convert_input_to_xs_ys(values) + np.testing.assert_allclose(x, expected_output[0]) + np.testing.assert_allclose(y, expected_output[1]) @parameterized.parameters( interpolated_param.InterpolationMode.PIECEWISE_LINEAR,