Skip to content

Commit

Permalink
Merge branch 'master' into tim/huggingface_datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Jan 22, 2025
2 parents 3a1cdf1 + 7b2b7e0 commit 64b12c5
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 37 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ docs:
build:
uv build

prerelease-dry-run:
uv run ./weave/scripts/prerelease_dry_run.py

prepare-release: docs build

synchronize-base-object-schemas:
Expand Down
4 changes: 3 additions & 1 deletion dev_docs/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ This document outlines how to publish a new Weave release to our public [PyPI pa

1. Verify the head of master is ready for release and announce merge freeze to the Weave team while the release is being published (Either ask an admin on the Weave repo to place a freeze on https://www.mergefreeze.com/ or use the mergefreeze Slack app if it is set up or just post in Slack)

2. You should also run through this [sample notebook](https://colab.research.google.com/drive/1DmkLzhFCFC0OoN-ggBDoG1nejGw2jQZy#scrollTo=29hJrcJQA7jZ) remember to install from master. You can also just run the [quickstart](http://wandb.me/weave_colab).
2. Manual Verifications:
- Run `make prerelease-dry-run` to verify that the dry run script works.
- You should also run through this [sample notebook](https://colab.research.google.com/drive/1DmkLzhFCFC0OoN-ggBDoG1nejGw2jQZy#scrollTo=29hJrcJQA7jZ) remember to install from master. You can also just run the [quickstart](http://wandb.me/weave_colab).

3. To prepare a PATCH release, go to GitHub Actions and run the [bump-python-sdk-version](https://github.com/wandb/weave/actions/workflows/bump_version.yaml) workflow on master. This will:

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ module = "weave_query.*"
ignore_errors = true

[tool.bumpversion]
current_version = "0.51.30-dev0"
current_version = "0.51.31-dev0"
parse = """(?x)
(?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\.
Expand Down
22 changes: 22 additions & 0 deletions tests/trace/image_patch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from tempfile import NamedTemporaryFile

from weave.initialization import pil_image_thread_safety


def test_patching_import_order():
# This test verifies the correct behavior if patching occurs after the construction
# of an image
assert pil_image_thread_safety._patched
pil_image_thread_safety.undo_threadsafe_patch_to_pil_image()
assert not pil_image_thread_safety._patched
import PIL

image = PIL.Image.new("RGB", (10, 10))
with NamedTemporaryFile(suffix=".png") as f:
image.save(f.name)
image = PIL.Image.open(f.name)

pil_image_thread_safety.apply_threadsafe_patch_to_pil_image()
assert pil_image_thread_safety._patched

image.crop((0, 0, 10, 10))
34 changes: 24 additions & 10 deletions tests/trace/type_handlers/Image/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,19 @@ def accept_image_jpg_pillow(val):
file_path.unlink()


def make_random_image(image_size: tuple[int, int] = (1024, 1024)):
random_colour = (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
)
return Image.new("RGB", image_size, random_colour)


@pytest.fixture
def dataset_ref(client):
# This fixture represents a saved dataset containing images
IMAGE_SIZE = (1024, 1024)
N_ROWS = 50

def make_random_image():
random_colour = (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
)
return Image.new("RGB", IMAGE_SIZE, random_colour)

rows = [{"img": make_random_image()} for _ in range(N_ROWS)]
dataset = weave.Dataset(rows=rows)
ref = weave.publish(dataset)
Expand Down Expand Up @@ -202,3 +201,18 @@ async def test_many_images_will_consistently_log():

# But if there's an issue, the stderr will contain `Task failed:`
assert "Task failed" not in res.stderr


def test_images_in_load_of_dataset(client):
N_ROWS = 50
rows = [{"img": make_random_image()} for _ in range(N_ROWS)]
dataset = weave.Dataset(rows=rows)
ref = weave.publish(dataset)

dataset = ref.get()
for gotten_row, local_row in zip(dataset, rows):
assert isinstance(gotten_row["img"], Image.Image)
assert gotten_row["img"].size == local_row["img"].size
assert gotten_row["img"].tobytes() == local_row["img"].tobytes()

return ref
88 changes: 65 additions & 23 deletions weave/initialization/pil_image_thread_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,39 @@
- `undo_threadsafe_patch_to_pil_image`
There is a discussion here: https://github.com/python-pillow/Pillow/issues/4848#issuecomment-671339193 in which
the author claims that the Pillow library is thread-safe. However, my reasoning leads me to a different conclusion.
the author claims that the Pillow library is thread-safe. However, empirical evidence suggests otherwise.
Specifically, the `ImageFile.load` method is not thread-safe. This is because `load` will both close and delete
an open file handler as well as modify properties of the ImageFile object (namely the `im` property which contains
the underlying image data). Inside of Weave we use threads to parallelize work which may involve Images. This bug
has presented itself not only in our own persistence layer, but also in user code where they are consuming loaded
images across threads.
Specifically, the `ImageFile.load` method is not thread-safe because it:
1. Closes and deletes an open file handler
2. Modifies properties of the ImageFile object (namely the `im` property which contains the underlying image data)
We call `apply_threadsafe_patch_to_pil_image` in the `__init__.py` file to ensure that the ImageFile class is thread-safe.
Inside Weave, we use threads to parallelize work which may involve Images. This thread-safety issue has manifested
not only in our persistence layer but also in user code where loaded images are accessed across threads.
We call `apply_threadsafe_patch_to_pil_image` in the `__init__.py` file to ensure thread-safety for the ImageFile class.
"""

import threading
from functools import wraps
from typing import Any, Callable, Optional

# Global state
# `_patched` is a boolean that indicates whether the thread-safety patch has been applied
# `_original_methods` is a dictionary that stores the original methods of the ImageFile class
# `_new_lock_lock` is a lock that is used to create a new lock for each ImageFile instance
# `_fallback_load_lock` is a global lock that is used to ensure thread-safe image loading when per-instance locking fails
_patched = False
_original_methods: dict[str, Optional[Callable]] = {"init": None, "load": None}
_original_methods: dict[str, Optional[Callable]] = {"load": None}
_new_lock_lock = threading.Lock()
_fallback_load_lock = threading.Lock()


def apply_threadsafe_patch_to_pil_image() -> None:
"""Apply thread-safety patch to PIL ImageFile class."""
"""Apply thread-safety patch to PIL ImageFile class.
This function is idempotent - calling it multiple times has no additional effect.
If PIL is not installed or if patching fails, the function will handle the error gracefully.
"""
global _patched

if _patched:
Expand All @@ -41,35 +53,61 @@ def apply_threadsafe_patch_to_pil_image() -> None:


def _apply_threadsafe_patch() -> None:
"""Internal function that performs the actual thread-safety patching of PIL ImageFile.
Raises:
ImportError: If PIL is not installed
Exception: For any other unexpected errors during patching
"""
from PIL.ImageFile import ImageFile

global _original_methods

# Store original methods
_original_methods["init"] = ImageFile.__init__
_original_methods["load"] = ImageFile.load

old_load = ImageFile.load
old_init = ImageFile.__init__

@wraps(old_init)
def new_init(self: ImageFile, *args: Any, **kwargs: Any) -> None:
self._weave_load_lock = threading.Lock() # type: ignore
return old_init(self, *args, **kwargs)

@wraps(old_load)
def new_load(self: ImageFile, *args: Any, **kwargs: Any) -> Any:
with self._weave_load_lock: # type: ignore
# This function wraps PIL's ImageFile.load method to make it thread-safe
# by ensuring only one thread can load an image at a time per ImageFile instance.

# We use a per-instance lock to allow concurrent loading of different images
# while preventing concurrent access to the same image.
try:
# Create a new lock for this ImageFile instance if it doesn't exist.
# The lock creation itself needs to be thread-safe, hence _new_lock_lock.
# Note: this `_new_lock_lock` is global as opposed to per-instance, else
# it would be possible for the same ImageFile to be loaded by multiple threads
# thereby creating a race where different threads would be each minting their
# own lock for the same ImageFile!
if not hasattr(self, "_weave_load_lock"):
with _new_lock_lock:
# Double-check pattern: verify the attribute still doesn't exist
# after acquiring the lock to prevent race conditions
if not hasattr(self, "_weave_load_lock"):
setattr(self, "_weave_load_lock", threading.Lock())
lock = getattr(self, "_weave_load_lock")

except Exception:
# If anything goes wrong with the locking mechanism,
# fall back to the global lock for safety
lock = _fallback_load_lock
# Acquire the instance-specific lock before loading the image.
# This ensures thread-safety by preventing concurrent:
# - Modification of the 'im' property
# - Access to the file handler
with lock:
return old_load(self, *args, **kwargs)

ImageFile.__init__ = new_init # type: ignore
# Replace the load method with our thread-safe version
ImageFile.load = new_load # type: ignore


def undo_threadsafe_patch_to_pil_image() -> None:
"""Revert the thread-safety patch applied to PIL ImageFile class.
If the patch hasn't been applied, this function does nothing.
This function is idempotent - if the patch hasn't been applied, this function does nothing.
If the patch has been applied but can't be reverted, an error message is printed.
"""
global _patched
Expand All @@ -90,13 +128,17 @@ def undo_threadsafe_patch_to_pil_image() -> None:


def _undo_threadsafe_patch() -> None:
"""Internal function that performs the actual removal of thread-safety patches.
Raises:
ImportError: If PIL is not installed
Exception: For any other unexpected errors during unpatching
"""
from PIL.ImageFile import ImageFile

global _original_methods

if _original_methods["init"] is not None:
ImageFile.__init__ = _original_methods["init"] # type: ignore
if _original_methods["load"] is not None:
ImageFile.load = _original_methods["load"] # type: ignore

_original_methods = {"init": None, "load": None}
_original_methods = {"load": None}
37 changes: 37 additions & 0 deletions weave/scripts/prerelease_dry_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "weave @ git+https://github.com/wandb/weave.git@master",
# ]
# ///

# Run this script with `uv run prerelease_dry_run.py`

import datetime

import weave

# This uniq id ensures that the op is not cached
uniq_id = datetime.datetime.now().timestamp()


@weave.op
def func(a: int) -> float:
return a + uniq_id


def main() -> None:
client = weave.init("test-project")
res = func(42)

client._flush()
calls = func.calls()

assert len(calls) == 1
assert calls[0].output == res
assert calls[0].inputs == {"a": 42}


if __name__ == "__main__":
main()
print("Dry run passed")
2 changes: 1 addition & 1 deletion weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
FILE_CHUNK_SIZE = 100000

MAX_DELETE_CALLS_COUNT = 100
INITIAL_CALLS_STREAM_BATCH_SIZE = 100
INITIAL_CALLS_STREAM_BATCH_SIZE = 50
MAX_CALLS_STREAM_BATCH_SIZE = 500


Expand Down
2 changes: 1 addition & 1 deletion weave/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@
"""

VERSION = "0.51.30-dev0"
VERSION = "0.51.31-dev0"
64 changes: 64 additions & 0 deletions weave_query/tests/test_propagate_gql_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
from weave_query import (
weave_types as types,
graph,
op_def,
op_args
)
from weave_query.language_features.tagging import (
tagged_value_type,
)
from weave_query.propagate_gql_keys import _propagate_gql_keys_for_node
from weave_query.ops_domain import wb_domain_types as wdt

def test_mapped_tag_propagation():
test_op = op_def.OpDef(
name="run-base_op",
input_type=op_args.OpNamedArgs({"run": wdt.RunType}),
output_type=types.List(types.Number()),
resolve_fn=lambda: None
)

mapped_opdef = op_def.OpDef(
name="mapped_run-base_op",
input_type=op_args.OpNamedArgs({"run": types.List(wdt.RunType)}),
output_type=types.List(types.List(types.Number())),
resolve_fn=lambda: None
)

mapped_opdef.derived_from = test_op
test_op.derived_ops = {"mapped": mapped_opdef}

test_node = graph.OutputNode(
types.List(types.Number()),
"mapped_run-base_op",
{
"run": graph.OutputNode(
tagged_value_type.TaggedValueType(types.TypedDict({"project": wdt.ProjectType}), types.List(wdt.RunType)),
"limit",
{
"arr": graph.OutputNode(
tagged_value_type.TaggedValueType(
types.TypedDict({"project": wdt.ProjectType}),
types.List(wdt.RunType)
),
"project-filteredRuns",
{}
)
}
)
}
)

def mock_key_fn(ip, input_type):
return types.List(types.Number())

result = _propagate_gql_keys_for_node(mapped_opdef, test_node, mock_key_fn, None)

assert isinstance(result, tagged_value_type.TaggedValueType)
# existing project tag from inputs flowed to output
assert result.tag.property_types["project"]
# run input propagated as tag on output
assert result.value.object_type.tag.property_types["run"]
assert isinstance(result.value.object_type.value, types.List)
assert isinstance(result.value.object_type.value.object_type, types.Number)
5 changes: 5 additions & 0 deletions weave_query/weave_query/propagate_gql_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def _propagate_gql_keys_for_node(
raise ValueError('GQL key function returned "Invalid" type')

if is_mapped:
# Handle tag propagation for mapped run ops
if opdef_util.should_tag_op_def_outputs(opdef.derived_from):
new_output_type = tagged_value_type.TaggedValueType(
types.TypedDict({first_arg_name: unwrapped_input_type}), new_output_type
)
new_output_type = types.List(new_output_type)

# now we rewrap the types to propagate the tags
Expand Down

0 comments on commit 64b12c5

Please sign in to comment.