Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(weave_query): propgate tags on mapped run ops during gql key propagation #3386

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions weave_query/tests/test_propagate_gql_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
from weave_query import (
weave_types as types,
graph,
op_def,
op_args
)
from weave_query.language_features.tagging import (
tagged_value_type,
)
from weave_query.propagate_gql_keys import _propagate_gql_keys_for_node
from weave_query.ops_domain import wb_domain_types as wdt

def test_mapped_tag_propagation():
test_op = op_def.OpDef(
name="run-base_op",
input_type=op_args.OpNamedArgs({"run": wdt.RunType}),
output_type=types.List(types.Number()),
resolve_fn=lambda: None
)

mapped_opdef = op_def.OpDef(
name="mapped_run-base_op",
input_type=op_args.OpNamedArgs({"run": types.List(wdt.RunType)}),
output_type=types.List(types.List(types.Number())),
resolve_fn=lambda: None
)

mapped_opdef.derived_from = test_op
test_op.derived_ops = {"mapped": mapped_opdef}

test_node = graph.OutputNode(
types.List(types.Number()),
"mapped_run-base_op",
{
"run": graph.OutputNode(
tagged_value_type.TaggedValueType(types.TypedDict({"project": wdt.ProjectType}), types.List(wdt.RunType)),
"limit",
{
"arr": graph.OutputNode(
tagged_value_type.TaggedValueType(
types.TypedDict({"project": wdt.ProjectType}),
types.List(wdt.RunType)
),
"project-filteredRuns",
{}
)
}
)
}
)

def mock_key_fn(ip, input_type):
return types.List(types.Number())

result = _propagate_gql_keys_for_node(mapped_opdef, test_node, mock_key_fn, None)

assert isinstance(result, tagged_value_type.TaggedValueType)
# existing project tag from inputs flowed to output
assert result.tag.property_types["project"]
# run input propagated as tag on output
assert result.value.object_type.tag.property_types["run"]
assert isinstance(result.value.object_type.value, types.List)
assert isinstance(result.value.object_type.value.object_type, types.Number)
5 changes: 5 additions & 0 deletions weave_query/weave_query/propagate_gql_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def _propagate_gql_keys_for_node(
raise ValueError('GQL key function returned "Invalid" type')

if is_mapped:
# Handle tag propagation for mapped run ops
if opdef_util.should_tag_op_def_outputs(opdef.derived_from):
new_output_type = tagged_value_type.TaggedValueType(
types.TypedDict({first_arg_name: unwrapped_input_type}), new_output_type
)
new_output_type = types.List(new_output_type)

# now we rewrap the types to propagate the tags
Expand Down
Loading