Skip to content

Commit

Permalink
Move mode parsing logic for single axis interpolated variables into m…
Browse files Browse the repository at this point in the history
…ethod that parses values.

It's useful to keep these together for being able to sanity check inputs without constructing the `InterpolatedVarSingleAxis`.

PiperOrigin-RevId: 713679682
  • Loading branch information
Nush395 authored and Torax team committed Jan 9, 2025
1 parent 847c3b0 commit ca1b69a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 78 deletions.
65 changes: 12 additions & 53 deletions torax/config/config_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,6 @@ def _check(ft):
return _check(field_type) # pytype: disable=bad-return-type


def _is_bool(
interp_input: interpolated_param.InterpolatedVarSingleAxisInput,
) -> bool:
if isinstance(interp_input, dict):
if not interp_input:
raise ValueError('InterpolatedVarSingleAxisInput must include values.')
value = list(interp_input.values())[0]
return isinstance(value, bool)
return isinstance(interp_input, bool)


def _convert_value_to_floats(
interp_input: interpolated_param.InterpolatedVarSingleAxisInput,
) -> interpolated_param.InterpolatedVarSingleAxisInput:
if isinstance(interp_input, dict):
return {key: float(value) for key, value in interp_input.items()}
return float(interp_input)


def get_interpolated_var_single_axis(
interpolated_var_single_axis_input: interpolated_param.InterpolatedVarSingleAxisInput,
) -> interpolated_param.InterpolatedVarSingleAxis:
Expand All @@ -100,40 +81,17 @@ def get_interpolated_var_single_axis(
Args:
interpolated_var_single_axis_input: Input that can be used to construct a
`interpolated_param.InterpolatedVarSingleAxis` object. Can be either:
Python primitives, an xr.DataArray, a tuple(axis_array, values_array).
See torax.readthedocs.io/en/latest/configuration.html#time-varying-scalars
for more information on the supported inputs.
Python primitives, an xr.DataArray, a tuple(axis_array, values_array). See
torax.readthedocs.io/en/latest/configuration.html#time-varying-scalars for
more information on the supported inputs.
Returns:
A constructed interpolated var.
"""
interpolation_mode = interpolated_param.InterpolationMode.PIECEWISE_LINEAR
# The param is a InterpolatedVarSingleAxisInput, so we need to convert it to
# an InterpolatedVarSingleAxis first.
if isinstance(interpolated_var_single_axis_input, tuple):
if len(interpolated_var_single_axis_input) != 2:
raise ValueError(
'Single axis interpolated var tuple length must be 2. The first '
'element are the values and the second element is the '
'interpolation mode or both values should be arrays to be directly '
f'interpolated. Given: {interpolated_var_single_axis_input}.'
xs, ys, interpolation_mode, is_bool_param = (
interpolated_param.convert_input_to_xs_ys(
interpolated_var_single_axis_input
)
if isinstance(interpolated_var_single_axis_input[1], str):
interpolation_mode = interpolated_param.InterpolationMode[
interpolated_var_single_axis_input[1].upper()
]
interpolated_var_single_axis_input = interpolated_var_single_axis_input[0]

if _is_bool(interpolated_var_single_axis_input):
interpolated_var_single_axis_input = _convert_value_to_floats(
interpolated_var_single_axis_input
)
is_bool_param = True
else:
is_bool_param = False

xs, ys = interpolated_param.convert_input_to_xs_ys(
interpolated_var_single_axis_input
)

interpolated_var_single_axis = interpolated_param.InterpolatedVarSingleAxis(
Expand Down Expand Up @@ -205,11 +163,12 @@ def _load_from_primitives(
if not primitive_values:
raise ValueError('Values mapping must not be empty.')

primitive_values = {
t: interpolated_param.convert_input_to_xs_ys(v)
for t, v in primitive_values.items()
}
return primitive_values
loaded_values = {}
for t, v in primitive_values.items():
x, y, _, _ = interpolated_param.convert_input_to_xs_ys(v)
loaded_values[t] = (x, y)

return loaded_values


def _load_from_xr_array(
Expand Down
91 changes: 69 additions & 22 deletions torax/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,63 @@ def rhonorm1_defined_in_timerhoinput(
return True


def _is_bool(
interp_input: InterpolatedVarSingleAxisInput,
) -> bool:
if isinstance(interp_input, dict):
if not interp_input:
raise ValueError('InterpolatedVarSingleAxisInput must include values.')
value = list(interp_input.values())[0]
return isinstance(value, bool)
return isinstance(interp_input, bool)


def _convert_value_to_floats(
interp_input: InterpolatedVarSingleAxisInput,
) -> InterpolatedVarSingleAxisInput:
if isinstance(interp_input, dict):
return {key: float(value) for key, value in interp_input.items()}
return float(interp_input)


def convert_input_to_xs_ys(
interp_input: InterpolatedVarSingleAxisInput,
) -> tuple[chex.Array, chex.Array]:
"""Converts config inputs into inputs suitable for constructors."""
) -> tuple[chex.Array, chex.Array, InterpolationMode, bool]:
"""Converts config inputs into inputs suitable for constructors.
Args:
interp_input: The input to convert.
Returns:
A tuple of (xs, ys, interpolation_mode, is_bool_param) where xs and ys are
the arrays to be used in the constructor, interpolation_mode is the
interpolation mode to be used, and is_bool_param is True if the input is a
bool and False otherwise.
"""
# This function does NOT need to be jittable.
interpolation_mode = InterpolationMode.PIECEWISE_LINEAR
# The param is a InterpolatedVarSingleAxisInput, so we need to convert it to
# an InterpolatedVarSingleAxis first.
if isinstance(interp_input, tuple):
if len(interp_input) != 2:
raise ValueError(
'Single axis interpolated var tuple length must be 2. The first '
'element are the values and the second element is the '
'interpolation mode or both values should be arrays to be directly '
f'interpolated. Given: {interp_input}.'
)
if isinstance(interp_input[1], str):
interpolation_mode = InterpolationMode[interp_input[1].upper()]
interp_input = interp_input[0]

if _is_bool(interp_input):
interp_input = _convert_value_to_floats(
interp_input
)
is_bool_param = True
else:
is_bool_param = False

if isinstance(interp_input, xr.DataArray):
if len(interp_input.coords) != 1:
raise ValueError(
Expand All @@ -265,6 +317,8 @@ def convert_input_to_xs_ys(
return (
interp_input[index].data,
interp_input.values,
interpolation_mode,
is_bool_param,
)
if isinstance(interp_input, tuple):
if len(interp_input) != 2:
Expand All @@ -276,33 +330,26 @@ def convert_input_to_xs_ys(
sort_order = np.argsort(xs)
xs = xs[sort_order]
ys = ys[sort_order]
return np.asarray(xs), np.asarray(ys)
return np.asarray(xs), np.asarray(ys), interpolation_mode, is_bool_param
if isinstance(interp_input, dict):
if not interp_input:
raise ValueError('InterpolatedVarSingleAxisInput must include values.')
sorted_keys = sorted(interp_input.keys())
values = [interp_input[key] for key in sorted_keys]
return np.array(sorted_keys), np.array(values)
return (
np.array(sorted_keys),
np.array(values),
interpolation_mode,
is_bool_param,
)
else:
# The input is a single value.
return np.array([0]), np.array([interp_input])


def _is_bool(interp_input: InterpolatedVarSingleAxisInput) -> bool:
if isinstance(interp_input, dict):
if not interp_input:
raise ValueError('InterpolatedVarSingleAxisInput must include values.')
value = list(interp_input.values())[0]
return isinstance(value, bool)
return isinstance(interp_input, bool)


def convert_value_to_floats(
interp_input: InterpolatedVarSingleAxisInput,
) -> InterpolatedVarSingleAxisInput:
if isinstance(interp_input, dict):
return {key: float(value) for key, value in interp_input.items()}
return float(interp_input)
return (
np.array([0]),
np.array([interp_input]),
interpolation_mode,
is_bool_param,
)


class InterpolatedVarSingleAxis(InterpolatedParamBase):
Expand Down
6 changes: 3 additions & 3 deletions torax/tests/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,9 @@ def test_interpolated_var_time_rho(self, values, x, y, expected_output):
)
def test_convert_input_to_xs_ys(self, values, expected_output):
"""Test input conversion to numpy arrays."""
output = interpolated_param.convert_input_to_xs_ys(values)
np.testing.assert_allclose(output[0], expected_output[0])
np.testing.assert_allclose(output[1], expected_output[1])
_, _, x, y = interpolated_param.convert_input_to_xs_ys(values)
np.testing.assert_allclose(x, expected_output[0])
np.testing.assert_allclose(y, expected_output[1])

@parameterized.parameters(
interpolated_param.InterpolationMode.PIECEWISE_LINEAR,
Expand Down

0 comments on commit ca1b69a

Please sign in to comment.