Skip to content

Commit

Permalink
Merge branch 'master' into griffin/obj_read-metadata_only-frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning authored Jan 22, 2025
2 parents c8ec850 + 7caea1b commit 9847e11
Show file tree
Hide file tree
Showing 19 changed files with 372 additions and 46 deletions.
55 changes: 54 additions & 1 deletion docs/docs/guides/core-types/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This guide will show you how to:
- Download the latest version
- Iterate over examples

## Sample code
## Quickstart

<Tabs groupId="programming-language" queryString>
<TabItem value="python" label="Python" default>
Expand Down Expand Up @@ -68,3 +68,56 @@ This guide will show you how to:

</TabItem>
</Tabs>

## Alternate constructors

<Tabs groupId="programming-language" queryString>
<TabItem value="python" label="Python" default>
Datasets can also be constructed from common Weave objects like `Call`s, and popular python objects like `pandas.DataFrame`s.
<Tabs groupId="use-case">
<TabItem value="from-calls" label="From Calls">
This can be useful if you want to create an example from specific examples.

```python
@weave.op
def model(task: str) -> str:
return f"Now working on {task}"

res1, call1 = model.call(task="fetch")
res2, call2 = model.call(task="parse")

dataset = Dataset.from_calls([call1, call2])
# Now you can use the dataset to evaluate the model, etc.
```
</TabItem>

<TabItem value="from-pandas" label="From Pandas">
You can also freely convert between `Dataset`s and `pandas.DataFrame`s.

```python
import pandas as pd

df = pd.DataFrame([
{'id': '0', 'sentence': "He no likes ice cream.", 'correction': "He doesn't like ice cream."},
{'id': '1', 'sentence': "She goed to the store.", 'correction': "She went to the store."},
{'id': '2', 'sentence': "They plays video games all day.", 'correction': "They play video games all day."}
])
dataset = Dataset.from_pandas(df)
df2 = dataset.to_pandas()

assert df.equals(df2)
```

</TabItem>

</Tabs>

</TabItem>
<TabItem value="typescript" label="TypeScript">

```typescript
This feature is not available in TypeScript yet. Stay tuned!
```

</TabItem>
</Tabs>
36 changes: 36 additions & 0 deletions tests/integrations/pandas-test/test_pandas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd

import weave
from weave import Dataset


def test_op_save_with_global_df(client):
Expand All @@ -20,3 +21,38 @@ def my_op(a: str) -> str:
call = list(my_op.calls())[0]
assert call.inputs == {"a": "d"}
assert call.output == "a"


def test_dataset(client):
rows = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ds = Dataset(rows=rows)
df = ds.to_pandas()
assert df["a"].tolist() == [1, 3, 5]
assert df["b"].tolist() == [2, 4, 6]

df2 = pd.DataFrame(rows)
ds2 = Dataset.from_pandas(df2)
assert ds2.rows == rows
assert df.equals(df2)
assert ds.rows == ds2.rows


def test_calls_to_dataframe(client):
@weave.op
def greet(name: str, age: int) -> str:
return f"Hello, {name}! You are {age} years old."

greet("Alice", 30)
greet("Bob", 25)

calls = greet.calls()
dataset = Dataset.from_calls(calls)
df = dataset.to_pandas()
assert df["inputs"].tolist() == [
{"name": "Alice", "age": 30},
{"name": "Bob", "age": 25},
]
assert df["output"].tolist() == [
"Hello, Alice! You are 30 years old.",
"Hello, Bob! You are 25 years old.",
]
25 changes: 25 additions & 0 deletions tests/trace/test_call_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import weave


def test_call_to_dict(client):
@weave.op
def greet(name: str, age: int) -> str:
return f"Hello {name}, you are {age}!"

_, call = greet.call("Alice", 30)
assert call.to_dict() == {
"op_name": call.op_name,
"display_name": call.display_name,
"inputs": call.inputs,
"output": call.output,
"exception": call.exception,
"summary": call.summary,
"attributes": call.attributes,
"started_at": call.started_at,
"ended_at": call.ended_at,
"deleted_at": call.deleted_at,
"id": call.id,
"parent_id": call.parent_id,
"trace_id": call.trace_id,
"project_id": call.project_id,
}
21 changes: 21 additions & 0 deletions tests/trace/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,24 @@ def test_pythonic_access(client):

with pytest.raises(IndexError):
ds[-1]


def test_dataset_from_calls(client):
@weave.op
def greet(name: str, age: int) -> str:
return f"Hello {name}, you are {age}!"

greet("Alice", 30)
greet("Bob", 25)

calls = client.get_calls()
dataset = weave.Dataset.from_calls(calls)
rows = list(dataset.rows)

assert len(rows) == 2
assert rows[0]["inputs"]["name"] == "Alice"
assert rows[0]["inputs"]["age"] == 30
assert rows[0]["output"] == "Hello Alice, you are 30!"
assert rows[1]["inputs"]["name"] == "Bob"
assert rows[1]["inputs"]["age"] == 25
assert rows[1]["output"] == "Hello Bob, you are 25!"
67 changes: 67 additions & 0 deletions tests/trace/test_uri_get.py → tests/trace/test_objectify.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from concurrent.futures import Future
from dataclasses import replace
from typing import TypeVar

import pytest

import weave
from weave.flow.obj import Object
from weave.flow.prompt.prompt import EasyPrompt
from weave.trace.objectify import register_object
from weave.trace.refs import RefWithExtra
from weave.trace_server.trace_server_interface import ObjectVersionFilter, ObjQueryReq

T = TypeVar("T")


@pytest.fixture(
params=[
Expand Down Expand Up @@ -94,3 +103,61 @@ def model(a: int) -> int:
for obj in res.objs:
assert obj.version_index == 0
assert obj.is_latest == 1


def resolve_ref_futures(ref: RefWithExtra) -> RefWithExtra:
"""This is a bit of a hack to resolve futures in an initally unsaved object's extra fields.
Currently, the extras are still a Future and not yet replaced with the actual value.
This function resolves the futures and replaces them with the actual values.
"""
extras = ref._extra
new_extras = []
for name, val in zip(extras[::2], extras[1::2]):
if isinstance(val, Future):
val = val.result()
new_extras.append(name)
new_extras.append(val)
ref = replace(ref, _extra=tuple(new_extras))
return ref


def test_drill_down_dataset_refs_same_after_publishing(client):
ds = weave.Dataset(
name="test",
rows=[{"a": {"b": 1}}, {"a": {"b": 2}}, {"a": {"b": 3}}],
)
ref = weave.publish(ds)
ds2 = ref.get()
ref2 = weave.publish(ds2)
ds3 = ref2.get()

assert resolve_ref_futures(ds.rows.ref) == ds2.rows.ref
for row, row2 in zip(ds.rows, ds2.rows):
assert resolve_ref_futures(row.ref) == row2.ref
assert resolve_ref_futures(row["a"].ref) == row2["a"].ref
assert resolve_ref_futures(row["a"]["b"].ref) == row2["a"]["b"].ref

assert ds2.ref == ds3.ref
for row2, row3 in zip(ds2.rows, ds3.rows):
assert row2.ref == row3.ref
assert row2["a"].ref == row3["a"].ref
assert row2["a"]["b"].ref == row3["a"]["b"].ref

assert ds3.rows == [{"a": {"b": 1}}, {"a": {"b": 2}}, {"a": {"b": 3}}]
for i, row in enumerate(ds3.rows, 1):
assert row == {"a": {"b": i}}
assert row["a"] == {"b": i}
assert row["a"]["b"] == i


def test_registration():
# This is a second class named Dataset. The first has already been registered
# in weave.flow.obj. This should raise an error.

with pytest.raises(ValueError, match="Class Dataset already registered as"):

@register_object
class Dataset(Object):
anything: str
doesnt_matter: int
13 changes: 13 additions & 0 deletions tests/trace/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,6 +1608,19 @@ def test_object_version_read(client):
assert obj_res.obj.val == {"a": i}
assert obj_res.obj.version_index == i

# read each object one at a time, check the version, metadata only
for i in range(10):
obj_res = client.server.obj_read(
tsi.ObjReadReq(
project_id=client._project_id(),
object_id=refs[i].name,
digest=refs[i].digest,
metadata_only=True,
)
)
assert obj_res.obj.val == {}
assert obj_res.obj.version_index == i

# now grab the latest version of the object
obj_res = client.server.obj_read(
tsi.ObjReadReq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import React, {useMemo} from 'react';
import {Icon} from '../../../../Icon';
import {LoadingDots} from '../../../../LoadingDots';
import {Tailwind} from '../../../../Tailwind';
import {Timestamp} from '../../../../Timestamp';
import {WeaveCHTableSourceRefContext} from '../pages/CallPage/DataTableView';
import {ObjectViewerSection} from '../pages/CallPage/ObjectViewerSection';
import {objectVersionText} from '../pages/common/Links';
Expand All @@ -30,6 +31,7 @@ export const DatasetVersionPage: React.FC<{
const projectName = objectVersion.project;
const objectName = objectVersion.objectId;
const objectVersionIndex = objectVersion.versionIndex;
const {createdAtMs} = objectVersion;

const objectVersions = useRootObjectVersions(
entityName,
Expand Down Expand Up @@ -76,7 +78,7 @@ export const DatasetVersionPage: React.FC<{
}
headerContent={
<Tailwind>
<div className="grid w-full grid-flow-col grid-cols-[auto_auto_1fr] gap-[16px] text-[14px]">
<div className="grid w-full grid-flow-col grid-cols-[auto_auto_auto_1fr] gap-[16px] text-[14px]">
<div className="block">
<p className="text-moon-500">Name</p>
<ObjectVersionsLink
Expand Down Expand Up @@ -109,6 +111,12 @@ export const DatasetVersionPage: React.FC<{
<p className="text-moon-500">Version</p>
<p>{objectVersionIndex}</p>
</div>
<div className="block">
<p className="text-moon-500">Created</p>
<p>
<Timestamp value={createdAtMs / 1000} format="relative" />
</p>
</div>
{objectVersion.userId && (
<div className="block">
<p className="text-moon-500">Created by</p>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import {Button} from '@wandb/weave/components/Button';
import {maybePluralizeWord} from '@wandb/weave/core/util/string';
import React, {useState} from 'react';
import React, {useContext, useState} from 'react';
import {useHistory} from 'react-router-dom';

import {useClosePeek} from '../../context';
import {
useClosePeek,
useWeaveflowCurrentRouteContext,
WeaveflowPeekContext,
} from '../../context';
import {DeleteModal} from '../common/DeleteModal';
import {useWFHooks} from '../wfReactInterface/context';
import {ObjectVersionSchema} from '../wfReactInterface/wfDataModelHooksInterface';
Expand All @@ -13,13 +18,32 @@ export const DeleteObjectButtonWithModal: React.FC<{
}> = ({objVersionSchema, overrideDisplayStr}) => {
const {useObjectDeleteFunc} = useWFHooks();
const closePeek = useClosePeek();
const {isPeeking} = useContext(WeaveflowPeekContext);
const routerContext = useWeaveflowCurrentRouteContext();
const history = useHistory();
const {objectVersionsDelete} = useObjectDeleteFunc();
const [deleteModalOpen, setDeleteModalOpen] = useState(false);

const deleteStr =
overrideDisplayStr ??
`${objVersionSchema.objectId}:v${objVersionSchema.versionIndex}`;

const onSuccess = () => {
if (isPeeking) {
closePeek();
} else {
history.push(
routerContext.objectVersionsUIUrl(
objVersionSchema.entity,
objVersionSchema.project,
{
objectName: objVersionSchema.objectId,
}
)
);
}
};

return (
<>
<Button
Expand All @@ -39,7 +63,7 @@ export const DeleteObjectButtonWithModal: React.FC<{
[objVersionSchema.versionHash]
)
}
onSuccess={closePeek}
onSuccess={onSuccess}
/>
</>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {maybePluralizeWord} from '../../../../../../core/util/string';
import {Icon, IconName} from '../../../../../Icon';
import {LoadingDots} from '../../../../../LoadingDots';
import {Tailwind} from '../../../../../Tailwind';
import {Timestamp} from '../../../../../Timestamp';
import {Tooltip} from '../../../../../Tooltip';
import {DatasetVersionPage} from '../../datasets/DatasetVersionPage';
import {NotFoundPanel} from '../../NotFoundPanel';
Expand Down Expand Up @@ -121,7 +122,7 @@ const ObjectVersionPageInner: React.FC<{
const projectName = objectVersion.project;
const objectName = objectVersion.objectId;
const objectVersionIndex = objectVersion.versionIndex;
const refExtra = objectVersion.refExtra;
const {refExtra, createdAtMs} = objectVersion;
const objectVersions = useRootObjectVersions(
entityName,
projectName,
Expand Down Expand Up @@ -231,7 +232,7 @@ const ObjectVersionPageInner: React.FC<{
}
headerContent={
<Tailwind>
<div className="grid w-full grid-flow-col grid-cols-[auto_auto_1fr] gap-[16px] text-[14px]">
<div className="grid w-full grid-flow-col grid-cols-[auto_auto_auto_1fr] gap-[16px] text-[14px]">
<div className="block">
<p className="text-moon-500">Name</p>
<div className="flex items-center">
Expand Down Expand Up @@ -266,6 +267,12 @@ const ObjectVersionPageInner: React.FC<{
<p className="text-moon-500">Version</p>
<p>{objectVersionIndex}</p>
</div>
<div className="block">
<p className="text-moon-500">Created</p>
<p>
<Timestamp value={createdAtMs / 1000} format="relative" />
</p>
</div>
{objectVersion.userId && (
<div className="block">
<p className="text-moon-500">Created by</p>
Expand Down
Loading

0 comments on commit 9847e11

Please sign in to comment.