Skip to content

Commit

Permalink
tests: fix flakey dataframe FilterTransform test
Browse files Browse the repository at this point in the history
  • Loading branch information
mscolnick committed Jan 17, 2025
1 parent 75cc093 commit 876fe06
Showing 1 changed file with 47 additions and 27 deletions.
74 changes: 47 additions & 27 deletions marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,52 +73,57 @@ def handle_filter_rows(

clauses: List[pd.Series[Any]] = []
for condition in transform.where:
column: pd.Series[Any] = df[condition.column_id]

try:
value = _coerce_value(
df[condition.column_id].dtype, condition.value
)
except Exception:
value = condition.value or ""

# Handle numeric comparisons
if condition.operator == "==":
df_filter = df[condition.column_id] == value
df_filter = column == value
elif condition.operator == "!=":
df_filter = df[condition.column_id] != value
df_filter = column != value
elif condition.operator == ">":
df_filter = df[condition.column_id] > value
df_filter = column > value
elif condition.operator == "<":
df_filter = df[condition.column_id] < value
df_filter = column < value
elif condition.operator == ">=":
df_filter = df[condition.column_id] >= value
df_filter = column >= value
elif condition.operator == "<=":
df_filter = df[condition.column_id] <= value
df_filter = column <= value
# Handle boolean operations
elif condition.operator == "is_true":
df_filter = df[condition.column_id].eq(True)
df_filter = column.eq(True)
elif condition.operator == "is_false":
df_filter = df[condition.column_id].eq(False)
df_filter = column.eq(False)
# Handle null checks
elif condition.operator == "is_nan":
df_filter = df[condition.column_id].isna()
df_filter = column.isna()
elif condition.operator == "is_not_nan":
df_filter = df[condition.column_id].notna()
df_filter = column.notna()
# Handle equality operations
elif condition.operator == "equals":
df_filter = df[condition.column_id].eq(value)
df_filter = column == value
elif condition.operator == "does_not_equal":
df_filter = df[condition.column_id].ne(value)
df_filter = column != value
# Handle string operations
elif condition.operator == "contains":
df_filter = df[condition.column_id].str.contains(
value, regex=False, na=False
df_filter = column.str.contains(
str(value), regex=False, na=False
)
elif condition.operator == "regex":
df_filter = df[condition.column_id].str.contains(
value, regex=True, na=False
df_filter = column.str.contains(
str(value), regex=True, na=False
)
elif condition.operator == "starts_with":
df_filter = df[condition.column_id].str.startswith(
value, na=False
)
df_filter = column.str.startswith(str(value), na=False)
elif condition.operator == "ends_with":
df_filter = df[condition.column_id].str.endswith(
value, na=False
)
df_filter = column.str.endswith(str(value), na=False)
# Handle list operations with proper Unicode handling
elif condition.operator == "in":
df_filter = df[condition.column_id].isin(value)
else:
Expand Down Expand Up @@ -488,9 +493,11 @@ def handle_sort_column(
) -> "ibis.Table":
return df.order_by(
[
df[transform.column_id].asc()
if transform.ascending
else df[transform.column_id].desc()
(
df[transform.column_id].asc()
if transform.ascending
else df[transform.column_id].desc()
)
]
)

Expand Down Expand Up @@ -548,7 +555,7 @@ def handle_filter_rows(
elif transform.operation == "remove_rows":
return df.filter(~combined_condition)
else:
raise ValueError(f"Unsupported operation: {transform.operation}")
assert_never(transform.operation)

@staticmethod
def handle_group_by(
Expand Down Expand Up @@ -652,6 +659,19 @@ def as_sql_code(transformed_df: "ibis.Table") -> str | None:


def _coerce_value(dtype: Any, value: Any) -> Any:
"""Coerce value to match column dtype while preserving numeric precision."""
import numpy as np

return np.array([value]).astype(dtype)[0]
# Handle None/empty values
if value is None:
return None

# If its a int or float, return as is
if isinstance(value, (int, float)):
return value

# Default coercion for other cases
try:
return np.array([value]).astype(dtype)[0]
except Exception:
return value

0 comments on commit 876fe06

Please sign in to comment.