Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: unify integer/float filter approach to avoid shape mismatch #3485

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 108 additions & 46 deletions marimo/_plugins/ui/_impl/dataframes/transforms/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,56 +73,118 @@ def handle_filter_rows(

clauses: List[pd.Series[Any]] = []
for condition in transform.where:
try:
value = _coerce_value(
df[condition.column_id].dtype, condition.value
# Get column and its type
column = df[condition.column_id]
dtype = column.dtype

# Handle string operations
if condition.operator in [
"contains",
"regex",
"starts_with",
"ends_with",
]:
df_filter = (
column.str.contains(
str(condition.value or ""),
regex=condition.operator == "regex",
na=False,
)
if condition.operator in ["contains", "regex"]
else (
column.str.startswith(
str(condition.value or ""), na=False
)
if condition.operator == "starts_with"
else column.str.endswith(
str(condition.value or ""), na=False
)
)
)
except Exception:
value = condition.value or ""
if condition.operator == "==":
df_filter = df[condition.column_id] == value
elif condition.operator == "!=":
df_filter = df[condition.column_id] != value
elif condition.operator == ">":
df_filter = df[condition.column_id] > value
elif condition.operator == "<":
df_filter = df[condition.column_id] < value
elif condition.operator == ">=":
df_filter = df[condition.column_id] >= value
elif condition.operator == "<=":
df_filter = df[condition.column_id] <= value
elif condition.operator == "is_true":
df_filter = df[condition.column_id].eq(True)
elif condition.operator == "is_false":
df_filter = df[condition.column_id].eq(False)
elif condition.operator == "is_nan":
df_filter = df[condition.column_id].isna()
elif condition.operator == "is_not_nan":
df_filter = df[condition.column_id].notna()
elif condition.operator == "equals":
df_filter = df[condition.column_id].eq(value)
elif condition.operator == "does_not_equal":
df_filter = df[condition.column_id].ne(value)
elif condition.operator == "contains":
df_filter = df[condition.column_id].str.contains(
value, regex=False, na=False
)
elif condition.operator == "regex":
df_filter = df[condition.column_id].str.contains(
value, regex=True, na=False
)
elif condition.operator == "starts_with":
df_filter = df[condition.column_id].str.startswith(
value, na=False
# Handle numeric comparisons
elif condition.operator in ["==", "!=", ">", "<", ">=", "<="]:
# Skip coercion for integer columns with float values
if dtype.kind == "i" and isinstance(condition.value, float):
value = condition.value
else:
try:
value = _coerce_value(dtype, condition.value)
except Exception:
value = condition.value

if condition.operator == "==":
df_filter = column == value
elif condition.operator == "!=":
df_filter = column != value
else:
df_filter = eval(f"column {condition.operator} value")
# Handle list operations
elif condition.operator == "in":
value = (
condition.value
if isinstance(condition.value, (list, tuple))
else []
)
elif condition.operator == "ends_with":
df_filter = df[condition.column_id].str.endswith(
value, na=False
df_filter = column.isin(value)
# Handle boolean operations
elif condition.operator in ["is_true", "is_false"]:
df_filter = column.eq(
True if condition.operator == "is_true" else False
)
elif condition.operator == "in":
df_filter = df[condition.column_id].isin(value)
# Handle null checks
elif condition.operator == "is_nan":
df_filter = column.isna()
elif condition.operator == "is_not_nan":
df_filter = column.notna()
# Handle equality operations with proper type handling
elif condition.operator in ["equals", "does_not_equal"]:
# For numeric types, handle direct comparison
if dtype.kind in ["i", "f"] and isinstance(
condition.value, (int, float)
):
# For numeric comparisons, we can use the value directly
scalar_value = condition.value
else:
# For other types, try coercion with fallback
try:
scalar_value = _coerce_value(dtype, condition.value)
except Exception:
# Use the original value as a last resort
scalar_value = (
condition.value
if condition.value is not None
else None
)

# Apply the comparison with proper scalar value
if scalar_value is not None:
df_filter = (
column.eq(scalar_value)
if condition.operator == "equals"
else column.ne(scalar_value)
)
else:
# Handle None values specially
df_filter = (
column.isna()
if condition.operator == "equals"
else column.notna()
)
else:
assert_never(condition.operator)
# All valid operators should be handled above
# This is just to satisfy the type checker
from typing import get_args

from marimo._plugins.ui._impl.dataframes.transforms.types import (
Operator,
)

# Get all possible values of the Operator type
operator_values = get_args(Operator)
if condition.operator not in operator_values:
raise ValueError(f"Invalid operator: {condition.operator}")
# This line should never be reached
raise AssertionError("Unhandled operator case")

clauses.append(df_filter)

Expand Down
22 changes: 12 additions & 10 deletions marimo/_plugins/ui/_impl/dataframes/transforms/print_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,25 @@ def generate_where_clause(df_name: str, where: Condition) -> str:
where.value,
)

if operator == "==":
# For numeric comparisons, check if we're comparing an integer column with a float
if operator in ["==", "!=", ">", "<", ">=", "<="]:
# Add dtype check in the generated code
return (
f"{df_name}[{_as_literal(column_id)}] == {_as_literal(value)}"
f"(lambda col: col {operator} {_as_literal(value)} "
f"if col.dtype.kind == 'i' and isinstance({_as_literal(value)}, float) "
f"else col {operator} {_as_literal(value)})({df_name}[{_as_literal(column_id)}])"
)
elif operator == "equals":
return (
f"{df_name}[{_as_literal(column_id)}].eq({_as_literal(value)})"
f"(lambda col: col.eq({_as_literal(value)}) "
f"if col.dtype.kind == 'i' and isinstance({_as_literal(value)}, float) "
f"else col.eq({_as_literal(value)}))({df_name}[{_as_literal(column_id)}])"
)
elif operator == "does_not_equal":
return (
f"{df_name}[{_as_literal(column_id)}].ne({_as_literal(value)})"
f"(lambda col: col.ne({_as_literal(value)}) "
f"if col.dtype.kind == 'i' and isinstance({_as_literal(value)}, float) "
f"else col.ne({_as_literal(value)}))({df_name}[{_as_literal(column_id)}])"
)
elif operator == "contains":
return f"{df_name}[{_as_literal(column_id)}].str.contains({_as_literal(value)})" # noqa: E501
Expand All @@ -60,12 +68,6 @@ def generate_where_clause(df_name: str, where: Condition) -> str:
return f"{df_name}[{_as_literal(column_id)}].str.endswith({_as_literal(value)})" # noqa: E501
elif operator == "in":
return f"{df_name}[{_as_literal(column_id)}].isin({_list_of_strings(value)})" # noqa: E501
elif operator == "!=":
return (
f"{df_name}[{_as_literal(column_id)}].ne({_as_literal(value)})"
)
elif operator in [">", ">=", "<", "<="]:
return f"{df_name}[{_as_literal(column_id)}] {operator} {_as_literal(value)}" # noqa: E501
elif operator == "is_nan":
return f"{df_name}[{_as_literal(column_id)}].isna()"
elif operator == "is_not_nan":
Expand Down
Loading