diff --git a/superset/charts/client_processing.py b/superset/charts/client_processing.py index 5a873ab8986df..66fb4d4b8384f 100644 --- a/superset/charts/client_processing.py +++ b/superset/charts/client_processing.py @@ -24,6 +24,7 @@ In order to do that, we reproduce the post-processing in Python for these chart types. """ +import logging from io import StringIO from typing import Any, Optional, TYPE_CHECKING, Union @@ -44,6 +45,9 @@ from superset.models.sql_lab import Query +logger = logging.getLogger(__name__) + + def get_column_key(label: tuple[str, ...], metrics: list[str]) -> tuple[Any, ...]: """ Sort columns when combining metrics. @@ -178,11 +182,20 @@ def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-s if rows and show_columns_total: # add subtotal for each group and overall total; we start from the # overall group, and iterate deeper into subgroups - groups = df.index + groups = df.index.copy() for level in range(df.index.nlevels): subgroups = {group[:level] for group in groups} for subgroup in subgroups: - slice_ = df.index.get_loc(subgroup) + try: + slice_ = groups.get_loc(subgroup) + except Exception: # pylint: disable=broad-except + logger.exception( + "Error getting location for subgroup %s from %s", + subgroup, + groups, + ) + raise + subtotal = pivot_v2_aggfunc_map[aggfunc]( df.iloc[slice_, :].apply(pd.to_numeric, errors="coerce"), axis=0 ) diff --git a/tests/unit_tests/charts/test_client_processing.py b/tests/unit_tests/charts/test_client_processing.py index e35229554b503..30905be063d2f 100644 --- a/tests/unit_tests/charts/test_client_processing.py +++ b/tests/unit_tests/charts/test_client_processing.py @@ -2499,3 +2499,62 @@ def test_apply_client_processing_verbose_map(session: Session): } ] } + + +def test_pivot_multi_level_index(): + """ + Pivot table with multi-level indexing. + """ + arrays = [ + ["Region1", "Region1", "Region1", "Region2", "Region2", "Region2"], + ["State1", "State1", "State2", "State3", "State3", "State4"], + ["City1", "City2", "City3", "City4", "City5", "City6"], + ] + index = pd.MultiIndex.from_tuples( + list(zip(*arrays)), + names=["Region", "State", "City"], + ) + + data = { + "Metric1": [10, 20, 30, 40, 50, 60], + "Metric2": [5, 10, 15, 20, 25, 30], + "Metric3": [None, None, None, None, None, None], + } + df = pd.DataFrame(data, index=index) + + pivoted = pivot_df( + df, + rows=["Region", "State", "City"], + columns=[], + metrics=["Metric1", "Metric2", "Metric3"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=False, + show_columns_total=True, + apply_metrics_on_rows=False, + ) + + # Sort the pivoted DataFrame to ensure deterministic output + pivoted_sorted = pivoted.sort_index() + + assert ( + pivoted_sorted.to_markdown() + == """ +| | ('Metric1',) | ('Metric2',) | ('Metric3',) | +|:----------------------------------|---------------:|---------------:|---------------:| +| ('Region1', 'State1', 'City1') | 10 | 5 | nan | +| ('Region1', 'State1', 'City2') | 20 | 10 | nan | +| ('Region1', 'State1', 'Subtotal') | 30 | 15 | 0 | +| ('Region1', 'State2', 'City3') | 30 | 15 | nan | +| ('Region1', 'State2', 'Subtotal') | 30 | 15 | 0 | +| ('Region1', 'Subtotal', '') | 60 | 30 | 0 | +| ('Region2', 'State3', 'City4') | 40 | 20 | nan | +| ('Region2', 'State3', 'City5') | 50 | 25 | nan | +| ('Region2', 'State3', 'Subtotal') | 100 | 50 | 0 | +| ('Region2', 'State4', 'City6') | 60 | 30 | nan | +| ('Region2', 'State4', 'Subtotal') | 40 | 20 | 0 | +| ('Region2', 'Subtotal', '') | 150 | 75 | 0 | +| ('Total (Sum)', '', '') | 210 | 105 | 0 | + """.strip() + )