Skip to content

Commit

Permalink
Fixed the patch
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Jan 22, 2025
1 parent 212bf14 commit 71d2093
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 23 deletions.
21 changes: 21 additions & 0 deletions tests/trace/image_patch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from tempfile import NamedTemporaryFile

from weave.initialization import pil_image_thread_safety


def test_patching_import_order():
# This test verifies the correct behavior if patching occurs after the construction
# of an image
assert pil_image_thread_safety._patched
pil_image_thread_safety.undo_threadsafe_patch_to_pil_image()
assert not pil_image_thread_safety._patched
import PIL

image = PIL.Image.new("RGB", (10, 10))
with NamedTemporaryFile(suffix=".png") as f:
image.save(f.name)
image = PIL.Image.open(f.name)

pil_image_thread_safety.apply_threadsafe_patch_to_pil_image()
assert pil_image_thread_safety._patched
image.crop((0, 0, 10, 10))
34 changes: 24 additions & 10 deletions tests/trace/type_handlers/Image/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,19 @@ def accept_image_jpg_pillow(val):
file_path.unlink()


def make_random_image(image_size: tuple[int, int] = (1024, 1024)):
random_colour = (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
)
return Image.new("RGB", image_size, random_colour)


@pytest.fixture
def dataset_ref(client):
# This fixture represents a saved dataset containing images
IMAGE_SIZE = (1024, 1024)
N_ROWS = 50

def make_random_image():
random_colour = (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
)
return Image.new("RGB", IMAGE_SIZE, random_colour)

rows = [{"img": make_random_image()} for _ in range(N_ROWS)]
dataset = weave.Dataset(rows=rows)
ref = weave.publish(dataset)
Expand Down Expand Up @@ -202,3 +201,18 @@ async def test_many_images_will_consistently_log():

# But if there's an issue, the stderr will contain `Task failed:`
assert "Task failed" not in res.stderr


def test_images_in_load_of_dataset(client):
N_ROWS = 50
rows = [{"img": make_random_image()} for _ in range(N_ROWS)]
dataset = weave.Dataset(rows=rows)
ref = weave.publish(dataset)

dataset = ref.get()
for gotten_row, local_row in zip(dataset, rows):
assert isinstance(gotten_row["img"], Image.Image)
assert gotten_row["img"].size == local_row["img"].size
assert gotten_row["img"].tobytes() == local_row["img"].tobytes()

return ref
45 changes: 32 additions & 13 deletions weave/initialization/pil_image_thread_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,28 @@
import threading
from functools import wraps
from typing import Any, Callable, Optional
from weakref import WeakKeyDictionary

_patched = False
_original_methods: dict[str, Optional[Callable]] = {"init": None, "load": None}
_original_methods: dict[str, Optional[Callable]] = {"load": None}


class TheadSafeLockLookup:
global_lock: threading.Lock
lock_map: WeakKeyDictionary[Any, threading.Lock]

def __init__(self) -> None:
self.global_lock = threading.Lock()
self.lock_map = WeakKeyDictionary()

def get_lock(self, obj: Any) -> threading.Lock:
with self.global_lock:
if obj not in self.lock_map:
self.lock_map[obj] = threading.Lock()
return self.lock_map[obj]


_global_thread_safe_lock_lookup = TheadSafeLockLookup()


def apply_threadsafe_patch_to_pil_image() -> None:
Expand All @@ -46,23 +65,25 @@ def _apply_threadsafe_patch() -> None:
global _original_methods

# Store original methods
_original_methods["init"] = ImageFile.__init__
_original_methods["load"] = ImageFile.load

old_load = ImageFile.load
old_init = ImageFile.__init__

@wraps(old_init)
def new_init(self: ImageFile, *args: Any, **kwargs: Any) -> None:
self._weave_load_lock = threading.Lock() # type: ignore
return old_init(self, *args, **kwargs)

@wraps(old_load)
def new_load(self: ImageFile, *args: Any, **kwargs: Any) -> Any:
# There is an edge case at play here: If the ImageFile is constructed
# before patching (ie import weave), then the ImageFile will not have
# the _weave_load_lock attribute and this will raise an AttributeError.
# To avoid this, we check for the existence of the _weave_load_lock
# attribute before acquiring the lock.
#
# Unfortunately, this means in these cases, the ImageFile will not be
# thread-safe.
if not hasattr(self, "_weave_load_lock"):
self._weave_load_lock = _global_thread_safe_lock_lookup.get_lock(self) # type: ignore
with self._weave_load_lock: # type: ignore
return old_load(self, *args, **kwargs)

ImageFile.__init__ = new_init # type: ignore
# ImageFile.__init__ = new_init # type: ignore
ImageFile.load = new_load # type: ignore


Expand Down Expand Up @@ -94,9 +115,7 @@ def _undo_threadsafe_patch() -> None:

global _original_methods

if _original_methods["init"] is not None:
ImageFile.__init__ = _original_methods["init"] # type: ignore
if _original_methods["load"] is not None:
ImageFile.load = _original_methods["load"] # type: ignore

_original_methods = {"init": None, "load": None}
_original_methods = {"load": None}

0 comments on commit 71d2093

Please sign in to comment.