From 40f5679f2c1773dfa5a06859e7e44535190aca00 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Fri, 17 Jan 2025 17:53:44 +0000 Subject: [PATCH 1/2] fix: unify integer/float filter approach to avoid shape mismatch Co-Authored-By: Myles Scolnick --- .../_impl/dataframes/transforms/handlers.py | 122 +++++++++++------- .../_impl/dataframes/transforms/print_code.py | 22 ++-- 2 files changed, 88 insertions(+), 56 deletions(-) diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py b/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py index 821557f0c77..43e78456c03 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py @@ -73,59 +73,89 @@ def handle_filter_rows( clauses: List[pd.Series[Any]] = [] for condition in transform.where: - try: - value = _coerce_value( - df[condition.column_id].dtype, condition.value + # Get column and its type + column = df[condition.column_id] + dtype = column.dtype + + # Handle string operations + if condition.operator in ["contains", "regex", "starts_with", "ends_with"]: + df_filter = column.str.contains( + str(condition.value or ""), + regex=condition.operator == "regex", + na=False + ) if condition.operator in ["contains", "regex"] else ( + column.str.startswith(str(condition.value or ""), na=False) + if condition.operator == "starts_with" + else column.str.endswith(str(condition.value or ""), na=False) ) - except Exception: - value = condition.value or "" - if condition.operator == "==": - df_filter = df[condition.column_id] == value - elif condition.operator == "!=": - df_filter = df[condition.column_id] != value - elif condition.operator == ">": - df_filter = df[condition.column_id] > value - elif condition.operator == "<": - df_filter = df[condition.column_id] < value - elif condition.operator == ">=": - df_filter = df[condition.column_id] >= value - elif condition.operator == "<=": - df_filter = df[condition.column_id] <= value - elif condition.operator == "is_true": - df_filter = df[condition.column_id].eq(True) - elif condition.operator == "is_false": - df_filter = df[condition.column_id].eq(False) + # Handle numeric comparisons + elif condition.operator in ["==", "!=", ">", "<", ">=", "<="]: + # Skip coercion for integer columns with float values + if dtype.kind == "i" and isinstance(condition.value, float): + value = condition.value + else: + try: + value = _coerce_value(dtype, condition.value) + except Exception: + value = condition.value + + if condition.operator == "==": + df_filter = column == value + elif condition.operator == "!=": + df_filter = column != value + else: + df_filter = eval(f"column {condition.operator} value") + # Handle list operations + elif condition.operator == "in": + value = condition.value if isinstance(condition.value, (list, tuple)) else [] + df_filter = column.isin(value) + # Handle boolean operations + elif condition.operator in ["is_true", "is_false"]: + df_filter = column.eq(True if condition.operator == "is_true" else False) + # Handle null checks elif condition.operator == "is_nan": - df_filter = df[condition.column_id].isna() + df_filter = column.isna() elif condition.operator == "is_not_nan": - df_filter = df[condition.column_id].notna() - elif condition.operator == "equals": - df_filter = df[condition.column_id].eq(value) - elif condition.operator == "does_not_equal": - df_filter = df[condition.column_id].ne(value) - elif condition.operator == "contains": - df_filter = df[condition.column_id].str.contains( - value, regex=False, na=False - ) - elif condition.operator == "regex": - df_filter = df[condition.column_id].str.contains( - value, regex=True, na=False - ) - elif condition.operator == "starts_with": - df_filter = df[condition.column_id].str.startswith( - value, na=False - ) - elif condition.operator == "ends_with": - df_filter = df[condition.column_id].str.endswith( - value, na=False - ) - elif condition.operator == "in": - df_filter = df[condition.column_id].isin(value) + df_filter = column.notna() + # Handle equality operations with proper type handling + elif condition.operator in ["equals", "does_not_equal"]: + # For numeric types, handle direct comparison + if dtype.kind in ["i", "f"] and isinstance(condition.value, (int, float)): + # For numeric comparisons, we can use the value directly + scalar_value = condition.value + else: + # For other types, try coercion with fallback + try: + scalar_value = _coerce_value(dtype, condition.value) + except Exception: + # Use the original value as a last resort + scalar_value = condition.value if condition.value is not None else None + + # Apply the comparison with proper scalar value + if scalar_value is not None: + df_filter = column.eq(scalar_value) if condition.operator == "equals" else column.ne(scalar_value) + else: + # Handle None values specially + df_filter = column.isna() if condition.operator == "equals" else column.notna() else: - assert_never(condition.operator) + # All valid operators should be handled above + # This is just to satisfy the type checker + from typing import get_args + + from marimo._plugins.ui._impl.dataframes.transforms.types import ( + Operator, + ) + + # Get all possible values of the Operator type + operator_values = get_args(Operator) + if condition.operator not in operator_values: + raise ValueError(f"Invalid operator: {condition.operator}") + # This line should never be reached + raise AssertionError("Unhandled operator case") clauses.append(df_filter) + if transform.operation == "keep_rows": df = df[pd.concat(clauses, axis=1).all(axis=1)] elif transform.operation == "remove_rows": diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py b/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py index a1e1e040eff..ed27e6293a2 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py @@ -38,17 +38,25 @@ def generate_where_clause(df_name: str, where: Condition) -> str: where.value, ) - if operator == "==": + # For numeric comparisons, check if we're comparing an integer column with a float + if operator in ["==", "!=", ">", "<", ">=", "<="]: + # Add dtype check in the generated code return ( - f"{df_name}[{_as_literal(column_id)}] == {_as_literal(value)}" + f"(lambda col: col {operator} {_as_literal(value)} " + f"if col.dtype.kind == 'i' and isinstance({_as_literal(value)}, float) " + f"else col {operator} {_as_literal(value)})({df_name}[{_as_literal(column_id)}])" ) elif operator == "equals": return ( - f"{df_name}[{_as_literal(column_id)}].eq({_as_literal(value)})" + f"(lambda col: col.eq({_as_literal(value)}) " + f"if col.dtype.kind == 'i' and isinstance({_as_literal(value)}, float) " + f"else col.eq({_as_literal(value)}))({df_name}[{_as_literal(column_id)}])" ) elif operator == "does_not_equal": return ( - f"{df_name}[{_as_literal(column_id)}].ne({_as_literal(value)})" + f"(lambda col: col.ne({_as_literal(value)}) " + f"if col.dtype.kind == 'i' and isinstance({_as_literal(value)}, float) " + f"else col.ne({_as_literal(value)}))({df_name}[{_as_literal(column_id)}])" ) elif operator == "contains": return f"{df_name}[{_as_literal(column_id)}].str.contains({_as_literal(value)})" # noqa: E501 @@ -60,12 +68,6 @@ def generate_where_clause(df_name: str, where: Condition) -> str: return f"{df_name}[{_as_literal(column_id)}].str.endswith({_as_literal(value)})" # noqa: E501 elif operator == "in": return f"{df_name}[{_as_literal(column_id)}].isin({_list_of_strings(value)})" # noqa: E501 - elif operator == "!=": - return ( - f"{df_name}[{_as_literal(column_id)}].ne({_as_literal(value)})" - ) - elif operator in [">", ">=", "<", "<="]: - return f"{df_name}[{_as_literal(column_id)}] {operator} {_as_literal(value)}" # noqa: E501 elif operator == "is_nan": return f"{df_name}[{_as_literal(column_id)}].isna()" elif operator == "is_not_nan": From 38e7c1c08883bb572154b6f6a6a3188af6389757 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Jan 2025 17:54:25 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../_impl/dataframes/transforms/handlers.py | 64 ++++++++++++++----- 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py b/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py index 43e78456c03..2b831a37549 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py @@ -78,15 +78,28 @@ def handle_filter_rows( dtype = column.dtype # Handle string operations - if condition.operator in ["contains", "regex", "starts_with", "ends_with"]: - df_filter = column.str.contains( - str(condition.value or ""), - regex=condition.operator == "regex", - na=False - ) if condition.operator in ["contains", "regex"] else ( - column.str.startswith(str(condition.value or ""), na=False) - if condition.operator == "starts_with" - else column.str.endswith(str(condition.value or ""), na=False) + if condition.operator in [ + "contains", + "regex", + "starts_with", + "ends_with", + ]: + df_filter = ( + column.str.contains( + str(condition.value or ""), + regex=condition.operator == "regex", + na=False, + ) + if condition.operator in ["contains", "regex"] + else ( + column.str.startswith( + str(condition.value or ""), na=False + ) + if condition.operator == "starts_with" + else column.str.endswith( + str(condition.value or ""), na=False + ) + ) ) # Handle numeric comparisons elif condition.operator in ["==", "!=", ">", "<", ">=", "<="]: @@ -107,11 +120,17 @@ def handle_filter_rows( df_filter = eval(f"column {condition.operator} value") # Handle list operations elif condition.operator == "in": - value = condition.value if isinstance(condition.value, (list, tuple)) else [] + value = ( + condition.value + if isinstance(condition.value, (list, tuple)) + else [] + ) df_filter = column.isin(value) # Handle boolean operations elif condition.operator in ["is_true", "is_false"]: - df_filter = column.eq(True if condition.operator == "is_true" else False) + df_filter = column.eq( + True if condition.operator == "is_true" else False + ) # Handle null checks elif condition.operator == "is_nan": df_filter = column.isna() @@ -120,7 +139,9 @@ def handle_filter_rows( # Handle equality operations with proper type handling elif condition.operator in ["equals", "does_not_equal"]: # For numeric types, handle direct comparison - if dtype.kind in ["i", "f"] and isinstance(condition.value, (int, float)): + if dtype.kind in ["i", "f"] and isinstance( + condition.value, (int, float) + ): # For numeric comparisons, we can use the value directly scalar_value = condition.value else: @@ -129,14 +150,26 @@ def handle_filter_rows( scalar_value = _coerce_value(dtype, condition.value) except Exception: # Use the original value as a last resort - scalar_value = condition.value if condition.value is not None else None + scalar_value = ( + condition.value + if condition.value is not None + else None + ) # Apply the comparison with proper scalar value if scalar_value is not None: - df_filter = column.eq(scalar_value) if condition.operator == "equals" else column.ne(scalar_value) + df_filter = ( + column.eq(scalar_value) + if condition.operator == "equals" + else column.ne(scalar_value) + ) else: # Handle None values specially - df_filter = column.isna() if condition.operator == "equals" else column.notna() + df_filter = ( + column.isna() + if condition.operator == "equals" + else column.notna() + ) else: # All valid operators should be handled above # This is just to satisfy the type checker @@ -155,7 +188,6 @@ def handle_filter_rows( clauses.append(df_filter) - if transform.operation == "keep_rows": df = df[pd.concat(clauses, axis=1).all(axis=1)] elif transform.operation == "remove_rows":