Skip to content

Commit

Permalink
Fix wildcard aggregation by skipping (#448)
Browse files Browse the repository at this point in the history
* Add external repo cleanup

* Return variable name string directly

* Update usage of vars_default_args

* Add a simple test

* Implement skipping region-aggregation for wildcard variables

* Require that wildcard-variables have explicit skip-region-aggregation

* Make ruff

* Add region aggregation for wildcard aggregation test

* Fix pydantic validator

---------

Co-authored-by: Philip Hackstock <20710924+phackstock@users.noreply.github.com>
  • Loading branch information
danielhuppmann and phackstock authored Dec 20, 2024
1 parent fb124b2 commit bfa80ff
Show file tree
Hide file tree
Showing 13 changed files with 124 additions and 21 deletions.
24 changes: 19 additions & 5 deletions nomenclature/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
from keyword import iskeyword
from pathlib import Path
from typing import Any

from pyam.utils import to_list
from pydantic import (
field_validator,
field_serializer,
ConfigDict,
BaseModel,
ConfigDict,
Field,
ValidationInfo,
field_serializer,
field_validator,
model_validator,
)
from typing_extensions import Self

from nomenclature.error import ErrorCollector

from pyam.utils import to_list

from .countries import countries


Expand Down Expand Up @@ -208,6 +210,14 @@ def deserialize_json(cls, v):
def convert_none_to_empty_string(cls, v):
return v if v is not None else ""

@model_validator(mode="after")
def wildcard_must_skip_region_aggregation(self) -> Self:
if self.is_wildcard and self.skip_region_aggregation is False:
raise ValueError(
f"Wildcard variable '{self.name}' must skip region aggregation"
)
return self

@field_validator("components", mode="before")
def cast_variable_components_args(cls, v):
"""Cast "components" list of dicts to a codelist"""
Expand All @@ -224,6 +234,10 @@ def cast_variable_components_args(cls, v):
def convert_str_to_none_for_writing(self, v):
return v if v != "" else None

@property
def is_wildcard(self) -> bool:
return "*" in self.name

@property
def units(self) -> list[str]:
return self.unit if isinstance(self.unit, list) else [self.unit]
Expand Down
15 changes: 9 additions & 6 deletions nomenclature/codelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,13 +595,15 @@ def check_weight_in_vars(cls, v):
)
return v

def vars_default_args(self, variables: list[str]) -> list[VariableCode]:
def vars_default_args(self, variables: list[str]) -> list[str]:
"""return subset of variables which does not feature any special pyam
aggregation arguments and where skip_region_aggregation is False"""
return [
self[var]
var
for var in variables
if not self[var].agg_kwargs and not self[var].skip_region_aggregation
if var in self.keys()
and not self[var].agg_kwargs
and not self[var].skip_region_aggregation
]

def vars_kwargs(self, variables: list[str]) -> list[VariableCode]:
Expand All @@ -610,7 +612,9 @@ def vars_kwargs(self, variables: list[str]) -> list[VariableCode]:
return [
self[var]
for var in variables
if self[var].agg_kwargs and not self[var].skip_region_aggregation
if var in self.keys()
and self[var].agg_kwargs
and not self[var].skip_region_aggregation
]

def validate_units(
Expand All @@ -621,8 +625,7 @@ def validate_units(
if invalid_units := [
(variable, unit, self.mapping[variable].unit)
for variable, unit in unit_mapping.items()
if variable in self.variables
and unit not in self.mapping[variable].units
if variable in self.variables and unit not in self.mapping[variable].units
]:
lst = [
f"'{v}' - expected: {'one of ' if isinstance(e, list) else ''}"
Expand Down
3 changes: 3 additions & 0 deletions nomenclature/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def check_aggregate(self, df: IamDataFrame, **kwargs) -> None:

with adjust_log_level(level="WARNING"):
for code in df.variable:
if code not in self.variable.mapping:
continue

attr = self.variable.mapping[code]
if attr.check_aggregate:
components = attr.components
Expand Down
2 changes: 1 addition & 1 deletion nomenclature/processor/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def _apply_region_processing(

# first, perform 'simple' aggregation (no arguments)
simple_vars = [
var.name
var
for var in self.variable_codelist.vars_default_args(
model_df.variable
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
model: m_a
model: model_a
native_regions:
- region_A
- region_B
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- common:
- World
- model_native:
- region_A
- region_B
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- Capital Cost|Electricity|*:
definition: Capital cost of electricity generation for a specific technology
unit: USD_2010/kW
skip-region-aggregation: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model: model_a
native_regions:
- region_A
- region_B
common_regions:
- World:
- region_A
- region_B
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model: model_b
native_regions:
- region_A
- region_b: region_B
common_regions:
- World:
- region_A
- region_b
6 changes: 6 additions & 0 deletions tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def test_variable_multiple_units():
assert var.unit == ["unit1", "unit2"]


def test_variable_wildcard_skip_region_aggregation_required():
"""Test that a VariableCode with wildcard must have skip_region_aggregation: True"""
with raises(ValueError, match="Wildcard variable 'Var1\*' must skip region"):
VariableCode.from_dict({"Var1*": {"unit": "unit1"}})


@pytest.mark.parametrize("unit", ["Test unit", ["Test unit 1", "Test unit 2"]])
def test_set_attributes_with_json(unit):
var = VariableCode(
Expand Down
13 changes: 10 additions & 3 deletions tests/test_codelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,16 @@ def test_illegal_char_ignores_external_repo():
"""Check that external repos are excluded from this check."""
# the config includes illegal characters known to be in common-definitions
# the test will not raise errors as the check is skipped for external repos
DataStructureDefinition(
MODULE_TEST_DATA_DIR / "illegal_chars" / "char_in_external_repo" / "definitions"
)

try:
dsd = DataStructureDefinition(
MODULE_TEST_DATA_DIR
/ "illegal_chars"
/ "char_in_external_repo"
/ "definitions"
)
finally:
clean_up_external_repos(dsd.config.repositories)


def test_end_whitespace_fails():
Expand Down
49 changes: 47 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ def test_region_processing_weighted_aggregation(folder, exp_df, args, caplog):
)
def test_region_processing_skip_aggregation(model_name, region_names):
# Testing two cases:
# * model "m_a" renames native regions and the world region is skipped
# * model "m_b" renames single constituent common regions
# * model "model_a" renames native regions and the world region is skipped
# * model "model_b" renames single constituent common regions

test_df = IamDataFrame(
pd.DataFrame(
Expand Down Expand Up @@ -296,6 +296,51 @@ def test_region_processing_skip_aggregation(model_name, region_names):
assert_iamframe_equal(obs, exp)


@pytest.mark.parametrize(
"model_name, region_names",
[("model_a", ("region_A", "region_B")), ("model_b", ("region_A", "region_b"))],
)
def test_region_processing_wildcard_skip_aggregation(model_name, region_names):
# Testing two cases:
# * model "model_a" keeps native regions as they are
# * model "model_b" renames one native region

variable = "Capital Cost|Electricity|Solar PV"
unit = "USD_2010/kW"
test_df = IamDataFrame(
pd.DataFrame(
[
[model_name, "s_a", region_names[0], variable, unit, 1, 2],
[model_name, "s_a", region_names[1], variable, unit, 3, 4],
],
columns=IAMC_IDX + [2005, 2010],
)
)
add_meta(test_df)

exp = IamDataFrame(
pd.DataFrame(
[
[model_name, "s_a", "region_A", variable, unit, 1, 2],
[model_name, "s_a", "region_B", variable, unit, 3, 4],
],
columns=IAMC_IDX + [2005, 2010],
)
)
add_meta(exp)

obs = process(
test_df,
dsd := DataStructureDefinition(
TEST_DATA_DIR / "region_processing/wildcard_skip_aggregation/dsd"
),
processor=RegionProcessor.from_directory(
TEST_DATA_DIR / "region_processing/wildcard_skip_aggregation/mappings", dsd
),
)
assert_iamframe_equal(obs, exp)


@pytest.mark.parametrize(
"input_data, exp_data, warning",
[
Expand Down
6 changes: 3 additions & 3 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def test_validation_fails_region(simple_definition, simple_df, caplog):
def test_validation_multiple_units(extras_definition, simple_df):
"""Validating against a VariableCode with multiple units works as expected"""
extras_definition.validate(
simple_df
.filter(variable="Primary Energy|Coal")
.rename(unit={"EJ/yr": "GWh/yr"})
simple_df.filter(variable="Primary Energy|Coal").rename(
unit={"EJ/yr": "GWh/yr"}
)
)


Expand Down

0 comments on commit bfa80ff

Please sign in to comment.