Skip to content

Commit

Permalink
feat: transform multiple columns of Table at once (#982)
Browse files Browse the repository at this point in the history
### Summary of Changes

* Rename `transform_column` to `transform_columns`
* Rename parameter `name` to `selector`
* `selector` can now be a list of column names
* `transformer` can now optionally have a second parameter to receive
the entire row
  • Loading branch information
lars-reimann authored Jan 14, 2025
1 parent 38dc89c commit 2db9069
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 99 deletions.
6 changes: 3 additions & 3 deletions benchmarks/table/row_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def _run_split_rows() -> None:
table_2._lazy_frame.collect()


def _run_transform_column() -> None:
table.transform_column("column_0", lambda value: value * 2)._lazy_frame.collect()
def _run_transform_columns() -> None:
table.transform_columns("column_0", lambda value: value * 2)._lazy_frame.collect()


if __name__ == "__main__":
Expand Down Expand Up @@ -101,7 +101,7 @@ def _run_transform_column() -> None:
number=REPETITIONS,
),
"transform_column": timeit(
_run_transform_column,
_run_transform_columns,
number=REPETITIONS,
),
}
Expand Down
4 changes: 1 addition & 3 deletions docs/tutorials/data_processing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -886,9 +886,7 @@
"output_type": "execute_result"
}
],
"source": [
"titanic_slice.transform_column(\"parents_children\", lambda cell: cell > 0)"
]
"source": "titanic_slice.transform_columns(\"parents_children\", lambda cell: cell > 0)"
}
],
"metadata": {
Expand Down
2 changes: 2 additions & 0 deletions src/safeds/_validation/_check_column_has_no_missing_values.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""The module name must differ from the function name, so it can be re-exported properly with apipkg."""

from __future__ import annotations

from typing import TYPE_CHECKING
Expand Down
16 changes: 8 additions & 8 deletions src/safeds/_validation/_check_columns_exist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
from safeds.data.tabular.typing import Schema


def _check_columns_exist(table_or_schema: Table | Schema, requested_names: str | list[str]) -> None:
def _check_columns_exist(table_or_schema: Table | Schema, selector: str | list[str]) -> None:
"""
Check whether the specified column names exist, and raise an error if they do not.
Check whether the specified columns exist, and raise an error if they do not.
Parameters
----------
table_or_schema:
The table or schema to check.
requested_names:
The column names to check.
selector:
The columns to check.
Raises
------
Expand All @@ -33,16 +33,16 @@ def _check_columns_exist(table_or_schema: Table | Schema, requested_names: str |

if isinstance(table_or_schema, Table):
table_or_schema = table_or_schema.schema
if isinstance(requested_names, str):
requested_names = [requested_names]
if isinstance(selector, str):
selector = [selector]

if len(requested_names) > 1:
if len(selector) > 1:
# Create a set for faster containment checks
known_names: Container = set(table_or_schema.column_names)
else:
known_names = table_or_schema.column_names

unknown_names = [name for name in requested_names if name not in known_names]
unknown_names = [name for name in selector if name not in known_names]
if unknown_names:
message = _build_error_message(table_or_schema, unknown_names)
raise ColumnNotFoundError(message) from None
Expand Down
2 changes: 2 additions & 0 deletions src/safeds/_validation/_check_indices_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""The module name must differ from the function name, so it can be re-exported properly with apipkg."""

from __future__ import annotations

from typing import TYPE_CHECKING
Expand Down
75 changes: 59 additions & 16 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,8 @@ def add_computed_column(
- [add_columns][safeds.data.tabular.containers._table.Table.add_columns]:
Add column objects to the table.
- [add_index_column][safeds.data.tabular.containers._table.Table.add_index_column]
- [transform_column][safeds.data.tabular.containers._table.Table.transform_column]:
Transform an existing column with a custom function.
- [transform_columns][safeds.data.tabular.containers._table.Table.transform_columns]:
Transform existing columns with a custom function.
"""
_check_columns_dont_exist(self, name)

Expand Down Expand Up @@ -1179,22 +1179,23 @@ def select_columns(
self._lazy_frame.select(selector),
)

def transform_column(
def transform_columns(
self,
name: str,
transformer: Callable[[Cell], Cell],
selector: str | list[str],
transformer: Callable[[Cell], Cell] | Callable[[Cell, Row], Cell],
) -> Table:
"""
Transform a column with a custom function and return the result as a new table.
Transform columns with a custom function and return the result as a new table.
**Note:** The original table is not modified.
Parameters
----------
name:
The name of the column to transform.
selector:
The names of the columns to transform.
transformer:
The function that computes the new values of the column.
The function that computes the new values. It may take either a single cell or a cell and the entire row as
arguments (see examples).
Returns
-------
Expand All @@ -1210,7 +1211,7 @@ def transform_column(
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> table.transform_column("a", lambda cell: cell + 1)
>>> table.transform_columns("a", lambda cell: cell + 1)
+-----+-----+
| a | b |
| --- | --- |
Expand All @@ -1221,21 +1222,63 @@ def transform_column(
| 4 | 6 |
+-----+-----+
>>> table.transform_columns(["a", "b"], lambda cell: cell + 1)
+-----+-----+
| a | b |
| --- | --- |
| i64 | i64 |
+===========+
| 2 | 5 |
| 3 | 6 |
| 4 | 7 |
+-----+-----+
>>> table.transform_columns("a", lambda cell, row: cell + row["b"])
+-----+-----+
| a | b |
| --- | --- |
| i64 | i64 |
+===========+
| 5 | 4 |
| 7 | 5 |
| 9 | 6 |
+-----+-----+
Related
-------
- [add_computed_column][safeds.data.tabular.containers._table.Table.add_computed_column]:
Add a new column that is computed from other columns.
- [transform_table][safeds.data.tabular.containers._table.Table.transform_table]:
Transform the entire table with a fitted transformer.
"""
_check_columns_exist(self, name)

import polars as pl

expression = transformer(_LazyCell(pl.col(name)))
_check_columns_exist(self, selector)

if isinstance(selector, str):
selector = [selector]

parameter_count = transformer.__code__.co_argcount
if parameter_count == 1:
# Transformer only takes a cell
expressions = [
transformer( # type: ignore[call-arg]
_LazyCell(pl.col(name)),
)._polars_expression.alias(name)
for name in selector
]
else:
# Transformer takes a cell and the entire row
expressions = [
transformer( # type: ignore[call-arg]
_LazyCell(pl.col(name)),
_LazyVectorizedRow(self),
)._polars_expression.alias(name)
for name in selector
]

return Table._from_polars_lazy_frame(
self._lazy_frame.with_columns(expression._polars_expression.alias(name)),
self._lazy_frame.with_columns(*expressions),
)

# ------------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -2384,8 +2427,8 @@ def transform_table(self, fitted_transformer: TableTransformer) -> Table:
-------
- [inverse_transform_table][safeds.data.tabular.containers._table.Table.inverse_transform_table]:
Inverse-transform the table with a fitted, invertible transformer.
- [transform_column][safeds.data.tabular.containers._table.Table.transform_column]:
Transform a single column with a custom function.
- [transform_columns][safeds.data.tabular.containers._table.Table.transform_columns]:
Transform columns with a custom function.
"""
return fitted_transformer.transform(self)

Expand Down

This file was deleted.

111 changes: 111 additions & 0 deletions tests/safeds/data/tabular/containers/_table/test_transform_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from collections.abc import Callable

import pytest

from safeds.data.tabular.containers import Cell, Row, Table
from safeds.exceptions import ColumnNotFoundError


@pytest.mark.parametrize(
("table_factory", "selector", "transformer", "expected"),
[
# no rows (constant value)
(
lambda: Table({"col1": []}),
"col1",
lambda _: Cell.from_literal(None),
Table({"col1": []}),
),
# no rows (computed value)
(
lambda: Table({"col1": []}),
"col1",
lambda cell: 2 * cell,
Table({"col1": []}),
),
# non-empty (constant value)
(
lambda: Table({"col1": [1, 2]}),
"col1",
lambda _: Cell.from_literal(None),
Table({"col1": [None, None]}),
),
# non-empty (computed value)
(
lambda: Table({"col1": [1, 2]}),
"col1",
lambda cell: 2 * cell,
Table({"col1": [2, 4]}),
),
# multiple columns transformed (constant value)
(
lambda: Table({"col1": [1, 2], "col2": [3, 4]}),
["col1", "col2"],
lambda _: Cell.from_literal(None),
Table({"col1": [None, None], "col2": [None, None]}),
),
# multiple columns transformed (computed value)
(
lambda: Table({"col1": [1, 2], "col2": [3, 4]}),
["col1", "col2"],
lambda cell: 2 * cell,
Table({"col1": [2, 4], "col2": [6, 8]}),
),
# lambda takes row parameter
(
lambda: Table({"col1": [1, 2], "col2": [3, 4]}),
"col1",
lambda cell, row: 2 * cell + row["col2"],
Table({"col1": [5, 8], "col2": [3, 4]}),
),
],
ids=[
"no rows (constant value)",
"no rows (computed value)",
"non-empty (constant value)",
"non-empty (computed value)",
"multiple columns transformed (constant value)",
"multiple columns transformed (computed value)",
"lambda takes row parameter",
],
)
class TestHappyPath:
def test_should_transform_columns(
self,
table_factory: Callable[[], Table],
selector: str,
transformer: Callable[[Cell], Cell] | Callable[[Cell, Row], Cell],
expected: Table,
) -> None:
actual = table_factory().transform_columns(selector, transformer)
assert actual == expected

def test_should_not_mutate_receiver(
self,
table_factory: Callable[[], Table],
selector: str,
transformer: Callable[[Cell], Cell] | Callable[[Cell, Row], Cell],
expected: Table, # noqa: ARG002
) -> None:
original = table_factory()
original.transform_columns(selector, transformer)
assert original == table_factory()


@pytest.mark.parametrize(
("table", "selector"),
[
(Table({"col1": [1, 2]}), "col2"),
(Table({"col1": [1, 2]}), ["col1", "col2"]),
],
ids=[
"one column name",
"multiple column names",
],
)
def test_should_raise_if_column_not_found(
table: Table,
selector: str,
) -> None:
with pytest.raises(ColumnNotFoundError):
table.transform_columns(selector, lambda cell: cell * 2)

0 comments on commit 2db9069

Please sign in to comment.