From f22f7bad787204e5535b1102e7a75e413e651325 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Fri, 17 Jan 2025 18:21:22 +0000 Subject: [PATCH] fix: unify integer/float filter approach to avoid shape mismatch Co-Authored-By: Myles Scolnick --- .../_impl/dataframes/transforms/handlers.py | 90 ++++++++++--------- .../_impl/dataframes/transforms/print_code.py | 22 ++--- 2 files changed, 52 insertions(+), 60 deletions(-) diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py b/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py index 821557f0c77..989e9177047 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py @@ -38,10 +38,12 @@ class PandasTransformHandler(TransformHandler["pd.DataFrame"]): def handle_column_conversion( df: "pd.DataFrame", transform: ColumnConversionTransform ) -> "pd.DataFrame": + import numpy as np + # Use numpy dtype for type conversion df[transform.column_id] = df[transform.column_id].astype( - transform.data_type, - errors=transform.errors, - ) # type: ignore[call-overload] + np.dtype(transform.data_type), + errors=transform.errors + ) return df @staticmethod @@ -73,56 +75,59 @@ def handle_filter_rows( clauses: List[pd.Series[Any]] = [] for condition in transform.where: - try: - value = _coerce_value( - df[condition.column_id].dtype, condition.value - ) - except Exception: - value = condition.value or "" + # Get column and value without type coercion + column = df[condition.column_id] + value = condition.value if condition.value is not None else "" + + # Ensure string values for string operations + if condition.operator in ["contains", "regex", "starts_with", "ends_with"]: + value = str(value) + # Ensure list values for isin operation + elif condition.operator == "in": + value = list(value) if value else [] + + # Handle numeric comparisons consistently without type coercion if condition.operator == "==": - df_filter = df[condition.column_id] == value + df_filter = column == value elif condition.operator == "!=": - df_filter = df[condition.column_id] != value + df_filter = column != value elif condition.operator == ">": - df_filter = df[condition.column_id] > value + df_filter = column > value elif condition.operator == "<": - df_filter = df[condition.column_id] < value + df_filter = column < value elif condition.operator == ">=": - df_filter = df[condition.column_id] >= value + df_filter = column >= value elif condition.operator == "<=": - df_filter = df[condition.column_id] <= value + df_filter = column <= value + # Handle boolean operations elif condition.operator == "is_true": - df_filter = df[condition.column_id].eq(True) + df_filter = column.eq(True) elif condition.operator == "is_false": - df_filter = df[condition.column_id].eq(False) + df_filter = column.eq(False) + # Handle null checks elif condition.operator == "is_nan": - df_filter = df[condition.column_id].isna() + df_filter = column.isna() elif condition.operator == "is_not_nan": - df_filter = df[condition.column_id].notna() + df_filter = column.notna() + # Handle equality operations consistently with numeric comparisons elif condition.operator == "equals": - df_filter = df[condition.column_id].eq(value) + df_filter = column == value elif condition.operator == "does_not_equal": - df_filter = df[condition.column_id].ne(value) + df_filter = column != value + # Handle string operations elif condition.operator == "contains": - df_filter = df[condition.column_id].str.contains( - value, regex=False, na=False - ) + df_filter = column.str.contains(value, regex=False, na=False) elif condition.operator == "regex": - df_filter = df[condition.column_id].str.contains( - value, regex=True, na=False - ) + df_filter = column.str.contains(value, regex=True, na=False) elif condition.operator == "starts_with": - df_filter = df[condition.column_id].str.startswith( - value, na=False - ) + df_filter = column.str.startswith(value, na=False) elif condition.operator == "ends_with": - df_filter = df[condition.column_id].str.endswith( - value, na=False - ) + df_filter = column.str.endswith(value, na=False) + # Handle list operations elif condition.operator == "in": - df_filter = df[condition.column_id].isin(value) + df_filter = column.isin(value) else: - assert_never(condition.operator) + raise ValueError(f"Unknown operator: {condition.operator}") clauses.append(df_filter) @@ -131,7 +136,7 @@ def handle_filter_rows( elif transform.operation == "remove_rows": df = df[~pd.concat(clauses, axis=1).all(axis=1)] else: - assert_never(transform.operation) + raise ValueError(f"Unknown operation: {transform.operation}") return df @@ -153,7 +158,7 @@ def handle_group_by( elif transform.aggregation == "max": return group.max() else: - assert_never(transform.aggregation) + raise ValueError(f"Unknown aggregation: {transform.aggregation}") @staticmethod def handle_aggregate( @@ -362,7 +367,7 @@ def handle_group_by( elif agg_func == "max": aggs.append(col(column_id).max().alias(f"{column_id}_max")) else: - assert_never(agg_func) + raise ValueError(f"Unknown aggregation function: {agg_func}") return df.group_by(transform.column_ids, maintain_order=True).agg(aggs) @@ -388,7 +393,7 @@ def handle_aggregate( elif agg_func == "max": agg_df = selected_df.max() else: - assert_never(agg_func) + raise ValueError(f"Unknown aggregation function: {agg_func}") # Rename all agg_df = agg_df.rename( @@ -539,7 +544,7 @@ def handle_filter_rows( elif condition.operator == "in": filter_conditions.append(column.isin(value)) else: - assert_never(condition.operator) + raise ValueError(f"Unknown operator: {condition.operator}") combined_condition = ibis.and_(*filter_conditions) @@ -651,7 +656,4 @@ def as_sql_code(transformed_df: "ibis.Table") -> str | None: return None -def _coerce_value(dtype: Any, value: Any) -> Any: - import numpy as np - - return np.array([value]).astype(dtype)[0] +# Removed _coerce_value function as we now use direct comparisons without type coercion diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py b/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py index a1e1e040eff..c43c17fc7cd 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py @@ -38,18 +38,14 @@ def generate_where_clause(df_name: str, where: Condition) -> str: where.value, ) - if operator == "==": - return ( - f"{df_name}[{_as_literal(column_id)}] == {_as_literal(value)}" - ) + # Handle numeric comparisons consistently without type coercion + if operator in ["==", "!=", ">", "<", ">=", "<="]: + # Use direct comparison operators for all numeric comparisons + return f"{df_name}[{_as_literal(column_id)}] {operator} {_as_literal(value)}" elif operator == "equals": - return ( - f"{df_name}[{_as_literal(column_id)}].eq({_as_literal(value)})" - ) + return f"{df_name}[{_as_literal(column_id)}] == {_as_literal(value)}" elif operator == "does_not_equal": - return ( - f"{df_name}[{_as_literal(column_id)}].ne({_as_literal(value)})" - ) + return f"{df_name}[{_as_literal(column_id)}] != {_as_literal(value)}" elif operator == "contains": return f"{df_name}[{_as_literal(column_id)}].str.contains({_as_literal(value)})" # noqa: E501 elif operator == "regex": @@ -60,12 +56,6 @@ def generate_where_clause(df_name: str, where: Condition) -> str: return f"{df_name}[{_as_literal(column_id)}].str.endswith({_as_literal(value)})" # noqa: E501 elif operator == "in": return f"{df_name}[{_as_literal(column_id)}].isin({_list_of_strings(value)})" # noqa: E501 - elif operator == "!=": - return ( - f"{df_name}[{_as_literal(column_id)}].ne({_as_literal(value)})" - ) - elif operator in [">", ">=", "<", "<="]: - return f"{df_name}[{_as_literal(column_id)}] {operator} {_as_literal(value)}" # noqa: E501 elif operator == "is_nan": return f"{df_name}[{_as_literal(column_id)}].isna()" elif operator == "is_not_nan":