Skip to content

Commit

Permalink
feat: consistent selector parameters (#983)
Browse files Browse the repository at this point in the history
### Summary of Changes

- Name all parameters `selector` that select a subset of the columns of
a `Table` by name or (later) by a `ColumnSelector`.
- Parameters of `Table.join` and `Table.to_tabular_dataset` are
deliberately unchanged, since I want users to always specify column
names here explicitly.
- Remove the ability to pass a lambda predicate to `select_columns`. 
  - This was inconsistent with the other `selector` parameters.
- It was also quite slow, since we sequentially looped over the columns.
The upcoming `ColumnSelector`s will cover common cases and be more
performant.
- For all other cases, simply use `Table.to_columns`, a list
comprehension, and `Table.from_columns`,
  • Loading branch information
lars-reimann authored Jan 14, 2025
1 parent 2db9069 commit dc4640b
Show file tree
Hide file tree
Showing 33 changed files with 220 additions and 253 deletions.
4 changes: 2 additions & 2 deletions docs/tutorials/classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@
"source": [
"from safeds.data.tabular.transformation import SimpleImputer\n",
"\n",
"simple_imputer = SimpleImputer(column_names=[\"age\", \"fare\"], strategy=SimpleImputer.Strategy.mean())\n",
"simple_imputer = SimpleImputer(selector=[\"age\", \"fare\"], strategy=SimpleImputer.Strategy.mean())\n",
"fitted_simple_imputer_train, transformed_train_data = simple_imputer.fit_and_transform(train_table)\n",
"transformed_test_data = fitted_simple_imputer_train.transform(test_table)"
]
Expand Down Expand Up @@ -241,7 +241,7 @@
"from safeds.data.tabular.transformation import OneHotEncoder\n",
"\n",
"fitted_one_hot_encoder_train, transformed_train_data = OneHotEncoder(\n",
" column_names=[\"sex\", \"port_embarked\"],\n",
" selector=[\"sex\", \"port_embarked\"],\n",
").fit_and_transform(transformed_train_data)\n",
"transformed_test_data = fitted_one_hot_encoder_train.transform(transformed_test_data)"
]
Expand Down
10 changes: 5 additions & 5 deletions docs/tutorials/data_processing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@
"source": [
"from safeds.data.tabular.transformation import SimpleImputer\n",
"\n",
"imputer = SimpleImputer(SimpleImputer.Strategy.constant(0), column_names=[\"age\", \"fare\", \"cabin\", \"port_embarked\"]).fit(\n",
"imputer = SimpleImputer(SimpleImputer.Strategy.constant(0), selector=[\"age\", \"fare\", \"cabin\", \"port_embarked\"]).fit(\n",
" titanic,\n",
")\n",
"imputer.transform(titanic_slice)"
Expand Down Expand Up @@ -583,7 +583,7 @@
"source": [
"from safeds.data.tabular.transformation import LabelEncoder\n",
"\n",
"encoder = LabelEncoder(column_names=[\"sex\", \"port_embarked\"]).fit(titanic)\n",
"encoder = LabelEncoder(selector=[\"sex\", \"port_embarked\"]).fit(titanic)\n",
"encoder.transform(titanic_slice)"
]
},
Expand Down Expand Up @@ -674,7 +674,7 @@
"source": [
"from safeds.data.tabular.transformation import OneHotEncoder\n",
"\n",
"encoder = OneHotEncoder(column_names=[\"sex\", \"port_embarked\"]).fit(titanic)\n",
"encoder = OneHotEncoder(selector=[\"sex\", \"port_embarked\"]).fit(titanic)\n",
"encoder.transform(titanic_slice)"
]
},
Expand Down Expand Up @@ -745,7 +745,7 @@
"source": [
"from safeds.data.tabular.transformation import RangeScaler\n",
"\n",
"scaler = RangeScaler(column_names=\"age\", min_=0.0, max_=1.0).fit(titanic)\n",
"scaler = RangeScaler(selector=\"age\", min_=0.0, max_=1.0).fit(titanic)\n",
"scaler.transform(titanic_slice)"
]
},
Expand Down Expand Up @@ -816,7 +816,7 @@
"source": [
"from safeds.data.tabular.transformation import StandardScaler\n",
"\n",
"scaler = StandardScaler(column_names=[\"age\", \"travel_class\"]).fit(titanic)\n",
"scaler = StandardScaler(selector=[\"age\", \"travel_class\"]).fit(titanic)\n",
"scaler.transform(titanic_slice)"
]
},
Expand Down
4 changes: 2 additions & 2 deletions src/safeds/_validation/_check_bounds_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _check_bounds(
if actual is None:
return # Skip the check if the actual value is None (i.e., not provided).

if lower_bound is None:
if lower_bound is None: # pragma: no cover
lower_bound = _OpenBound(float("-inf"))
if upper_bound is None:
upper_bound = _OpenBound(float("inf"))
Expand Down Expand Up @@ -148,7 +148,7 @@ def _to_string_as_upper_bound(self) -> str:


def _float_to_string(value: float) -> str:
if value == float("-inf"):
if value == float("-inf"): # pragma: no cover
return "-\u221e"
elif value == float("inf"):
return "\u221e"
Expand Down
18 changes: 8 additions & 10 deletions src/safeds/_validation/_check_column_is_numeric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,19 @@ def _check_column_is_numeric(

def _check_columns_are_numeric(
table_or_schema: Table | Schema,
column_names: str | list[str],
selector: str | list[str],
*,
operation: str = "do a numeric operation",
) -> None:
"""
Check if the columns with the specified names are numeric and raise an error if they are not.
Missing columns are ignored. Use `_check_columns_exist` to check for missing columns.
Check if the specified columns are numeric and raise an error if they are not. Missing columns are ignored.
Parameters
----------
table_or_schema:
The table or schema to check.
column_names:
The column names to check.
selector:
The columns to check.
operation:
The operation that is performed on the columns. This is used in the error message.
Expand All @@ -76,17 +74,17 @@ def _check_columns_are_numeric(

if isinstance(table_or_schema, Table):
table_or_schema = table_or_schema.schema
if isinstance(column_names, str):
column_names = [column_names]
if isinstance(selector, str): # pragma: no cover
selector = [selector]

if len(column_names) > 1:
if len(selector) > 1:
# Create a set for faster containment checks
known_names: Container = set(table_or_schema.column_names)
else:
known_names = table_or_schema.column_names

non_numeric_names = [
name for name in column_names if name in known_names and not table_or_schema.get_column_type(name).is_numeric
name for name in selector if name in known_names and not table_or_schema.get_column_type(name).is_numeric
]
if non_numeric_names:
message = _build_error_message(non_numeric_names, operation)
Expand Down
2 changes: 1 addition & 1 deletion src/safeds/_validation/_check_schema_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _check_schema(


def _check_types(expected_schema: Schema, actual_schema: Schema, *, check_types: _TypeCheckingMode) -> None:
if check_types == "off":
if check_types == "off": # pragma: no cover
return

mismatched_types: list[tuple[str, pl.DataType, pl.DataType]] = []
Expand Down
2 changes: 1 addition & 1 deletion src/safeds/data/labeled/containers/_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def __init__(self, column: Column) -> None:
)
# TODO: should not one-hot-encode the target. label encoding without order is sufficient. should also not
# be done automatically?
self._one_hot_encoder = OneHotEncoder(column_names=self._column_name).fit(column_as_table)
self._one_hot_encoder = OneHotEncoder(selector=self._column_name).fit(column_as_table)
self._tensor = torch.Tensor(
self._one_hot_encoder.transform(column_as_table)._data_frame.to_torch(dtype=pl.Float32),
).to(_get_device())
Expand Down
78 changes: 31 additions & 47 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def has_column(self, name: str) -> bool:

def remove_columns(
self,
names: str | list[str],
selector: str | list[str],
*,
ignore_unknown_names: bool = False,
) -> Table:
Expand All @@ -786,8 +786,8 @@ def remove_columns(
Parameters
----------
names:
The names of the columns to remove.
selector:
The columns to remove.
ignore_unknown_names:
If set to True, columns that are not present in the table will be ignored.
If set to False, an error will be raised if any of the specified columns do not exist.
Expand Down Expand Up @@ -831,18 +831,18 @@ def remove_columns(
Related
-------
- [select_columns][safeds.data.tabular.containers._table.Table.select_columns]:
Keep only a subset of the columns. This method accepts either column names, or a predicate.
Keep only a subset of the columns.
- [remove_columns_with_missing_values][safeds.data.tabular.containers._table.Table.remove_columns_with_missing_values]
- [remove_non_numeric_columns][safeds.data.tabular.containers._table.Table.remove_non_numeric_columns]
"""
if isinstance(names, str):
names = [names]
if isinstance(selector, str):
selector = [selector]

if not ignore_unknown_names:
_check_columns_exist(self, names)
_check_columns_exist(self, selector)

return Table._from_polars_lazy_frame(
self._lazy_frame.drop(names, strict=not ignore_unknown_names),
self._lazy_frame.drop(selector, strict=not ignore_unknown_names),
)

def remove_columns_with_missing_values(
Expand Down Expand Up @@ -900,7 +900,7 @@ def remove_columns_with_missing_values(
- [KNearestNeighborsImputer][safeds.data.tabular.transformation._k_nearest_neighbors_imputer.KNearestNeighborsImputer]:
Replace missing values with a value computed from the nearest neighbors.
- [select_columns][safeds.data.tabular.containers._table.Table.select_columns]:
Keep only a subset of the columns. This method accepts either column names, or a predicate.
Keep only a subset of the columns.
- [remove_columns][safeds.data.tabular.containers._table.Table.remove_columns]:
Remove columns from the table by name.
- [remove_non_numeric_columns][safeds.data.tabular.containers._table.Table.remove_non_numeric_columns]
Expand Down Expand Up @@ -955,7 +955,7 @@ def remove_non_numeric_columns(self) -> Table:
Related
-------
- [select_columns][safeds.data.tabular.containers._table.Table.select_columns]:
Keep only a subset of the columns. This method accepts either column names, or a predicate.
Keep only a subset of the columns.
- [remove_columns][safeds.data.tabular.containers._table.Table.remove_columns]:
Remove columns from the table by name.
- [remove_columns_with_missing_values][safeds.data.tabular.containers._table.Table.remove_columns_with_missing_values]
Expand Down Expand Up @@ -1113,21 +1113,17 @@ def replace_column(

def select_columns(
self,
selector: str | list[str] | Callable[[Column], bool],
selector: str | list[str],
) -> Table:
"""
Select a subset of the columns and return the result as a new table.
**Notes:**
- The original table is not modified.
- If the `selector` is a custom function, this operation must fully load the data into memory, which can be
expensive.
**Note:** The original table is not modified.
Parameters
----------
selector:
The names of the columns to keep, or a predicate that decides whether to keep a column.
The columns to keep.
Returns
-------
Expand Down Expand Up @@ -1161,23 +1157,11 @@ def select_columns(
- [remove_columns_with_missing_values][safeds.data.tabular.containers._table.Table.remove_columns_with_missing_values]
- [remove_non_numeric_columns][safeds.data.tabular.containers._table.Table.remove_non_numeric_columns]
"""
import polars as pl

# Select by predicate
if callable(selector):
return Table._from_polars_lazy_frame(
pl.LazyFrame(
[column._series for column in self.to_columns() if selector(column)],
),
)

# Select by column names
else:
_check_columns_exist(self, selector)
_check_columns_exist(self, selector)

return Table._from_polars_lazy_frame(
self._lazy_frame.select(selector),
)
return Table._from_polars_lazy_frame(
self._lazy_frame.select(selector),
)

def transform_columns(
self,
Expand Down Expand Up @@ -1611,7 +1595,7 @@ def remove_rows_by_column(
def remove_rows_with_missing_values(
self,
*,
column_names: str | list[str] | None = None,
selector: str | list[str] | None = None,
) -> Table:
"""
Remove rows that contain missing values in the specified columns and return the result as a new table.
Expand All @@ -1624,8 +1608,8 @@ def remove_rows_with_missing_values(
Parameters
----------
column_names:
The names of the columns to check. If None, all columns are checked.
selector:
The columns to check. If None, all columns are checked.
Returns
-------
Expand All @@ -1645,7 +1629,7 @@ def remove_rows_with_missing_values(
| 1 | 4 |
+-----+-----+
>>> table.remove_rows_with_missing_values(column_names=["b"])
>>> table.remove_rows_with_missing_values(selector=["b"])
+------+-----+
| a | b |
| --- | --- |
Expand All @@ -1669,18 +1653,18 @@ def remove_rows_with_missing_values(
- [remove_duplicate_rows][safeds.data.tabular.containers._table.Table.remove_duplicate_rows]
- [remove_rows_with_outliers][safeds.data.tabular.containers._table.Table.remove_rows_with_outliers]
"""
if isinstance(column_names, list) and not column_names:
if isinstance(selector, list) and not selector:
# polars panics in this case
return self

return Table._from_polars_lazy_frame(
self._lazy_frame.drop_nulls(subset=column_names),
self._lazy_frame.drop_nulls(subset=selector),
)

def remove_rows_with_outliers(
self,
*,
column_names: str | list[str] | None = None,
selector: str | list[str] | None = None,
z_score_threshold: float = 3,
) -> Table:
"""
Expand All @@ -1701,8 +1685,8 @@ def remove_rows_with_outliers(
Parameters
----------
column_names:
Names of the columns to consider. If None, all numeric columns are considered.
selector:
The columns to check. If None, all columns are checked.
z_score_threshold:
The z-score threshold for detecting outliers. Must be greater than or equal to 0.
Expand Down Expand Up @@ -1755,14 +1739,14 @@ def remove_rows_with_outliers(
lower_bound=_ClosedBound(0),
)

if column_names is None:
column_names = self.column_names
if selector is None:
selector = self.column_names

import polars as pl
import polars.selectors as cs

# polar's `all_horizontal` raises a `ComputeError` if there are no columns
selected = self._lazy_frame.select(cs.numeric() & cs.by_name(column_names))
selected = self._lazy_frame.select(cs.numeric() & cs.by_name(selector))
if not selected.collect_schema().names():
return self

Expand Down Expand Up @@ -2268,9 +2252,9 @@ def join(
right_table:
The table to join with the left table.
left_names:
Name or list of names of columns to join on in the left table.
Names of columns to join on in the left table.
right_names:
Name or list of names of columns to join on in the right table.
Names of columns to join on in the right table.
mode:
Specify which type of join you want to use.
Expand Down
Loading

0 comments on commit dc4640b

Please sign in to comment.