From 8b3cbc2cbf434b81371ff688f7bc263782258b7b Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Wed, 22 Jan 2025 16:17:21 -0500 Subject: [PATCH] feat(weave): Provide helper methods to bridge pandas and weave datasets (#3384) --- docs/docs/guides/core-types/datasets.md | 54 ++++++++++++++----- tests/integrations/pandas-test/test_pandas.py | 36 +++++++++++++ weave/flow/dataset.py | 18 ++++++- 3 files changed, 93 insertions(+), 15 deletions(-) diff --git a/docs/docs/guides/core-types/datasets.md b/docs/docs/guides/core-types/datasets.md index 3388fd9c425b..eb8741dcd4fa 100644 --- a/docs/docs/guides/core-types/datasets.md +++ b/docs/docs/guides/core-types/datasets.md @@ -73,25 +73,51 @@ This guide will show you how to: - Datasets can also be constructed from common Weave objects like `list[Call]`, which is useful if you want to run an evaluation on a handful of examples. + Datasets can also be constructed from common Weave objects like `Call`s, and popular python objects like `pandas.DataFrame`s. + + + This can be useful if you want to create an example from specific examples. -```python -@weave.op -def model(task: str) -> str: - return f"Now working on {task}" + ```python + @weave.op + def model(task: str) -> str: + return f"Now working on {task}" -res1, call1 = model.call(task="fetch") -res2, call2 = model.call(task="parse") + res1, call1 = model.call(task="fetch") + res2, call2 = model.call(task="parse") -dataset = Dataset.from_calls([call1, call2]) -# Now you can use the dataset to evaluate the model, etc. -``` + dataset = Dataset.from_calls([call1, call2]) + # Now you can use the dataset to evaluate the model, etc. + ``` + + + + You can also freely convert between `Dataset`s and `pandas.DataFrame`s. + + ```python + import pandas as pd + + df = pd.DataFrame([ + {'id': '0', 'sentence': "He no likes ice cream.", 'correction': "He doesn't like ice cream."}, + {'id': '1', 'sentence': "She goed to the store.", 'correction': "She went to the store."}, + {'id': '2', 'sentence': "They plays video games all day.", 'correction': "They play video games all day."} + ]) + dataset = Dataset.from_pandas(df) + df2 = dataset.to_pandas() + + assert df.equals(df2) + ``` + + + + - - ```typescript - This feature is not available in TypeScript yet. Stay tuned! - ``` + +```typescript +This feature is not available in TypeScript yet. Stay tuned! +``` + diff --git a/tests/integrations/pandas-test/test_pandas.py b/tests/integrations/pandas-test/test_pandas.py index 9f07fb6a51bd..cf4160547744 100644 --- a/tests/integrations/pandas-test/test_pandas.py +++ b/tests/integrations/pandas-test/test_pandas.py @@ -1,6 +1,7 @@ import pandas as pd import weave +from weave import Dataset def test_op_save_with_global_df(client): @@ -20,3 +21,38 @@ def my_op(a: str) -> str: call = list(my_op.calls())[0] assert call.inputs == {"a": "d"} assert call.output == "a" + + +def test_dataset(client): + rows = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}] + ds = Dataset(rows=rows) + df = ds.to_pandas() + assert df["a"].tolist() == [1, 3, 5] + assert df["b"].tolist() == [2, 4, 6] + + df2 = pd.DataFrame(rows) + ds2 = Dataset.from_pandas(df2) + assert ds2.rows == rows + assert df.equals(df2) + assert ds.rows == ds2.rows + + +def test_calls_to_dataframe(client): + @weave.op + def greet(name: str, age: int) -> str: + return f"Hello, {name}! You are {age} years old." + + greet("Alice", 30) + greet("Bob", 25) + + calls = greet.calls() + dataset = Dataset.from_calls(calls) + df = dataset.to_pandas() + assert df["inputs"].tolist() == [ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + ] + assert df["output"].tolist() == [ + "Hello, Alice! You are 30 years old.", + "Hello, Bob! You are 25 years old.", + ] diff --git a/weave/flow/dataset.py b/weave/flow/dataset.py index 65d691da5ad6..340047abcb8f 100644 --- a/weave/flow/dataset.py +++ b/weave/flow/dataset.py @@ -1,5 +1,5 @@ from collections.abc import Iterable, Iterator -from typing import Any +from typing import TYPE_CHECKING, Any from pydantic import field_validator from typing_extensions import Self @@ -10,6 +10,9 @@ from weave.trace.vals import WeaveObject, WeaveTable from weave.trace.weave_client import Call +if TYPE_CHECKING: + import pandas as pd + def short_str(obj: Any, limit: int = 25) -> str: str_val = str(obj) @@ -60,6 +63,19 @@ def from_calls(cls, calls: Iterable[Call]) -> Self: rows = [call.to_dict() for call in calls] return cls(rows=rows) + @classmethod + def from_pandas(cls, df: "pd.DataFrame") -> Self: + rows = df.to_dict(orient="records") + return cls(rows=rows) + + def to_pandas(self) -> "pd.DataFrame": + try: + import pandas as pd + except ImportError: + raise ImportError("pandas is required to use this method") + + return pd.DataFrame(self.rows) + @field_validator("rows", mode="before") def convert_to_table(cls, rows: Any) -> weave.Table: if not isinstance(rows, weave.Table):