Skip to content

Commit

Permalink
Simplify checks of the derived fields
Browse files Browse the repository at this point in the history
  • Loading branch information
ealerskans committed Dec 13, 2024
1 parent 2856c6b commit 98673ee
Showing 1 changed file with 5 additions and 37 deletions.
42 changes: 5 additions & 37 deletions mllam_data_prep/derived_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,17 @@ def derive_variables(ds, derived_variables, chunking):
# Calculate the derived variable
derived_field = func(**kwargs)

# Check the derived field(s)
derived_field = _check_field(derived_field, derived_variable_attributes)

# Add the derived field(s) to the dataset
# Check the derived field(s) and add it to the dataset
if isinstance(derived_field, xr.DataArray):
derived_field = _check_attributes(
derived_field, derived_variable_attributes
)
ds_derived_vars[derived_field.name] = derived_field
elif isinstance(derived_field, tuple) and all(
isinstance(field, xr.DataArray) for field in derived_field
):
for field in derived_field:
field = _check_attributes(field, derived_variable_attributes)
ds_derived_vars[field.name] = field
else:
raise TypeError(
Expand Down Expand Up @@ -201,39 +202,6 @@ def _get_derived_variable_function(function_namespace):
return function


def _check_field(derived_field, derived_field_attributes):
"""
Check the derived field.
Parameters
----------
derived_field: Union[xr.DataArray, Tuple[xr.DataArray]]
The derived variable
derived_field_attributes: Dict[str, str]
Dictionary with attributes for the derived variables.
Defined in the config file.
Returns
-------
derived_field: Union[xr.DataArray, Tuple[xr.DataArray]]
The derived field
"""
if isinstance(derived_field, xr.DataArray):
derived_field = _check_attributes(derived_field, derived_field_attributes)
elif isinstance(derived_field, tuple) and all(
isinstance(field, xr.DataArray) for field in derived_field
):
for field in derived_field:
field = _check_attributes(field, derived_field_attributes)
else:
raise TypeError(
"Expected an instance of xr.DataArray or tuple(xr.DataArray),"
f" but got {type(derived_field)}."
)

return derived_field


def _check_attributes(field, field_attributes):
"""
Check the attributes of the derived variable.
Expand Down

0 comments on commit 98673ee

Please sign in to comment.