From c0583072eda6845c74a901ab638b42aeeb562ea1 Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Fri, 17 Jan 2025 15:13:56 -0500 Subject: [PATCH] tests: fix flakey dataframe FilterTransform test (#3488) --- .../_impl/dataframes/transforms/handlers.py | 74 ++++++++++++------- 1 file changed, 47 insertions(+), 27 deletions(-) diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py b/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py index 821557f0c77..ebd38b79292 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py @@ -73,52 +73,57 @@ def handle_filter_rows( clauses: List[pd.Series[Any]] = [] for condition in transform.where: + column: pd.Series[Any] = df[condition.column_id] + try: value = _coerce_value( df[condition.column_id].dtype, condition.value ) except Exception: value = condition.value or "" + + # Handle numeric comparisons if condition.operator == "==": - df_filter = df[condition.column_id] == value + df_filter = column == value elif condition.operator == "!=": - df_filter = df[condition.column_id] != value + df_filter = column != value elif condition.operator == ">": - df_filter = df[condition.column_id] > value + df_filter = column > value elif condition.operator == "<": - df_filter = df[condition.column_id] < value + df_filter = column < value elif condition.operator == ">=": - df_filter = df[condition.column_id] >= value + df_filter = column >= value elif condition.operator == "<=": - df_filter = df[condition.column_id] <= value + df_filter = column <= value + # Handle boolean operations elif condition.operator == "is_true": - df_filter = df[condition.column_id].eq(True) + df_filter = column.eq(True) elif condition.operator == "is_false": - df_filter = df[condition.column_id].eq(False) + df_filter = column.eq(False) + # Handle null checks elif condition.operator == "is_nan": - df_filter = df[condition.column_id].isna() + df_filter = column.isna() elif condition.operator == "is_not_nan": - df_filter = df[condition.column_id].notna() + df_filter = column.notna() + # Handle equality operations elif condition.operator == "equals": - df_filter = df[condition.column_id].eq(value) + df_filter = column == value elif condition.operator == "does_not_equal": - df_filter = df[condition.column_id].ne(value) + df_filter = column != value + # Handle string operations elif condition.operator == "contains": - df_filter = df[condition.column_id].str.contains( - value, regex=False, na=False + df_filter = column.str.contains( + str(value), regex=False, na=False ) elif condition.operator == "regex": - df_filter = df[condition.column_id].str.contains( - value, regex=True, na=False + df_filter = column.str.contains( + str(value), regex=True, na=False ) elif condition.operator == "starts_with": - df_filter = df[condition.column_id].str.startswith( - value, na=False - ) + df_filter = column.str.startswith(str(value), na=False) elif condition.operator == "ends_with": - df_filter = df[condition.column_id].str.endswith( - value, na=False - ) + df_filter = column.str.endswith(str(value), na=False) + # Handle list operations with proper Unicode handling elif condition.operator == "in": df_filter = df[condition.column_id].isin(value) else: @@ -488,9 +493,11 @@ def handle_sort_column( ) -> "ibis.Table": return df.order_by( [ - df[transform.column_id].asc() - if transform.ascending - else df[transform.column_id].desc() + ( + df[transform.column_id].asc() + if transform.ascending + else df[transform.column_id].desc() + ) ] ) @@ -548,7 +555,7 @@ def handle_filter_rows( elif transform.operation == "remove_rows": return df.filter(~combined_condition) else: - raise ValueError(f"Unsupported operation: {transform.operation}") + assert_never(transform.operation) @staticmethod def handle_group_by( @@ -652,6 +659,19 @@ def as_sql_code(transformed_df: "ibis.Table") -> str | None: def _coerce_value(dtype: Any, value: Any) -> Any: + """Coerce value to match column dtype while preserving numeric precision.""" import numpy as np - return np.array([value]).astype(dtype)[0] + # Handle None/empty values + if value is None: + return None + + # If its a int or float, return as is + if isinstance(value, (int, float)): + return value + + # Default coercion for other cases + try: + return np.array([value]).astype(dtype)[0] + except Exception: + return value