Skip to content

Commit

Permalink
chore(weave): Improve Weave patch to support pre-constructed images (#…
Browse files Browse the repository at this point in the history
…3457)

* Fixed the patch

* Revision

* Revision

* Revision

* Comment

* Comment
  • Loading branch information
tssweeney authored Jan 22, 2025
1 parent dbd1f65 commit 7738592
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 33 deletions.
22 changes: 22 additions & 0 deletions tests/trace/image_patch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from tempfile import NamedTemporaryFile

from weave.initialization import pil_image_thread_safety


def test_patching_import_order():
# This test verifies the correct behavior if patching occurs after the construction
# of an image
assert pil_image_thread_safety._patched
pil_image_thread_safety.undo_threadsafe_patch_to_pil_image()
assert not pil_image_thread_safety._patched
import PIL

image = PIL.Image.new("RGB", (10, 10))
with NamedTemporaryFile(suffix=".png") as f:
image.save(f.name)
image = PIL.Image.open(f.name)

pil_image_thread_safety.apply_threadsafe_patch_to_pil_image()
assert pil_image_thread_safety._patched

image.crop((0, 0, 10, 10))
34 changes: 24 additions & 10 deletions tests/trace/type_handlers/Image/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,19 @@ def accept_image_jpg_pillow(val):
file_path.unlink()


def make_random_image(image_size: tuple[int, int] = (1024, 1024)):
random_colour = (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
)
return Image.new("RGB", image_size, random_colour)


@pytest.fixture
def dataset_ref(client):
# This fixture represents a saved dataset containing images
IMAGE_SIZE = (1024, 1024)
N_ROWS = 50

def make_random_image():
random_colour = (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
)
return Image.new("RGB", IMAGE_SIZE, random_colour)

rows = [{"img": make_random_image()} for _ in range(N_ROWS)]
dataset = weave.Dataset(rows=rows)
ref = weave.publish(dataset)
Expand Down Expand Up @@ -202,3 +201,18 @@ async def test_many_images_will_consistently_log():

# But if there's an issue, the stderr will contain `Task failed:`
assert "Task failed" not in res.stderr


def test_images_in_load_of_dataset(client):
N_ROWS = 50
rows = [{"img": make_random_image()} for _ in range(N_ROWS)]
dataset = weave.Dataset(rows=rows)
ref = weave.publish(dataset)

dataset = ref.get()
for gotten_row, local_row in zip(dataset, rows):
assert isinstance(gotten_row["img"], Image.Image)
assert gotten_row["img"].size == local_row["img"].size
assert gotten_row["img"].tobytes() == local_row["img"].tobytes()

return ref
88 changes: 65 additions & 23 deletions weave/initialization/pil_image_thread_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,39 @@
- `undo_threadsafe_patch_to_pil_image`
There is a discussion here: https://github.com/python-pillow/Pillow/issues/4848#issuecomment-671339193 in which
the author claims that the Pillow library is thread-safe. However, my reasoning leads me to a different conclusion.
the author claims that the Pillow library is thread-safe. However, empirical evidence suggests otherwise.
Specifically, the `ImageFile.load` method is not thread-safe. This is because `load` will both close and delete
an open file handler as well as modify properties of the ImageFile object (namely the `im` property which contains
the underlying image data). Inside of Weave we use threads to parallelize work which may involve Images. This bug
has presented itself not only in our own persistence layer, but also in user code where they are consuming loaded
images across threads.
Specifically, the `ImageFile.load` method is not thread-safe because it:
1. Closes and deletes an open file handler
2. Modifies properties of the ImageFile object (namely the `im` property which contains the underlying image data)
We call `apply_threadsafe_patch_to_pil_image` in the `__init__.py` file to ensure that the ImageFile class is thread-safe.
Inside Weave, we use threads to parallelize work which may involve Images. This thread-safety issue has manifested
not only in our persistence layer but also in user code where loaded images are accessed across threads.
We call `apply_threadsafe_patch_to_pil_image` in the `__init__.py` file to ensure thread-safety for the ImageFile class.
"""

import threading
from functools import wraps
from typing import Any, Callable, Optional

# Global state
# `_patched` is a boolean that indicates whether the thread-safety patch has been applied
# `_original_methods` is a dictionary that stores the original methods of the ImageFile class
# `_new_lock_lock` is a lock that is used to create a new lock for each ImageFile instance
# `_fallback_load_lock` is a global lock that is used to ensure thread-safe image loading when per-instance locking fails
_patched = False
_original_methods: dict[str, Optional[Callable]] = {"init": None, "load": None}
_original_methods: dict[str, Optional[Callable]] = {"load": None}
_new_lock_lock = threading.Lock()
_fallback_load_lock = threading.Lock()


def apply_threadsafe_patch_to_pil_image() -> None:
"""Apply thread-safety patch to PIL ImageFile class."""
"""Apply thread-safety patch to PIL ImageFile class.
This function is idempotent - calling it multiple times has no additional effect.
If PIL is not installed or if patching fails, the function will handle the error gracefully.
"""
global _patched

if _patched:
Expand All @@ -41,35 +53,61 @@ def apply_threadsafe_patch_to_pil_image() -> None:


def _apply_threadsafe_patch() -> None:
"""Internal function that performs the actual thread-safety patching of PIL ImageFile.
Raises:
ImportError: If PIL is not installed
Exception: For any other unexpected errors during patching
"""
from PIL.ImageFile import ImageFile

global _original_methods

# Store original methods
_original_methods["init"] = ImageFile.__init__
_original_methods["load"] = ImageFile.load

old_load = ImageFile.load
old_init = ImageFile.__init__

@wraps(old_init)
def new_init(self: ImageFile, *args: Any, **kwargs: Any) -> None:
self._weave_load_lock = threading.Lock() # type: ignore
return old_init(self, *args, **kwargs)

@wraps(old_load)
def new_load(self: ImageFile, *args: Any, **kwargs: Any) -> Any:
with self._weave_load_lock: # type: ignore
# This function wraps PIL's ImageFile.load method to make it thread-safe
# by ensuring only one thread can load an image at a time per ImageFile instance.

# We use a per-instance lock to allow concurrent loading of different images
# while preventing concurrent access to the same image.
try:
# Create a new lock for this ImageFile instance if it doesn't exist.
# The lock creation itself needs to be thread-safe, hence _new_lock_lock.
# Note: this `_new_lock_lock` is global as opposed to per-instance, else
# it would be possible for the same ImageFile to be loaded by multiple threads
# thereby creating a race where different threads would be each minting their
# own lock for the same ImageFile!
if not hasattr(self, "_weave_load_lock"):
with _new_lock_lock:
# Double-check pattern: verify the attribute still doesn't exist
# after acquiring the lock to prevent race conditions
if not hasattr(self, "_weave_load_lock"):
setattr(self, "_weave_load_lock", threading.Lock())
lock = getattr(self, "_weave_load_lock")

except Exception:
# If anything goes wrong with the locking mechanism,
# fall back to the global lock for safety
lock = _fallback_load_lock
# Acquire the instance-specific lock before loading the image.
# This ensures thread-safety by preventing concurrent:
# - Modification of the 'im' property
# - Access to the file handler
with lock:
return old_load(self, *args, **kwargs)

ImageFile.__init__ = new_init # type: ignore
# Replace the load method with our thread-safe version
ImageFile.load = new_load # type: ignore


def undo_threadsafe_patch_to_pil_image() -> None:
"""Revert the thread-safety patch applied to PIL ImageFile class.
If the patch hasn't been applied, this function does nothing.
This function is idempotent - if the patch hasn't been applied, this function does nothing.
If the patch has been applied but can't be reverted, an error message is printed.
"""
global _patched
Expand All @@ -90,13 +128,17 @@ def undo_threadsafe_patch_to_pil_image() -> None:


def _undo_threadsafe_patch() -> None:
"""Internal function that performs the actual removal of thread-safety patches.
Raises:
ImportError: If PIL is not installed
Exception: For any other unexpected errors during unpatching
"""
from PIL.ImageFile import ImageFile

global _original_methods

if _original_methods["init"] is not None:
ImageFile.__init__ = _original_methods["init"] # type: ignore
if _original_methods["load"] is not None:
ImageFile.load = _original_methods["load"] # type: ignore

_original_methods = {"init": None, "load": None}
_original_methods = {"load": None}

0 comments on commit 7738592

Please sign in to comment.