diff --git a/tests/trace/image_patch_test.py b/tests/trace/image_patch_test.py new file mode 100644 index 000000000000..6eca3482b5b9 --- /dev/null +++ b/tests/trace/image_patch_test.py @@ -0,0 +1,22 @@ +from tempfile import NamedTemporaryFile + +from weave.initialization import pil_image_thread_safety + + +def test_patching_import_order(): + # This test verifies the correct behavior if patching occurs after the construction + # of an image + assert pil_image_thread_safety._patched + pil_image_thread_safety.undo_threadsafe_patch_to_pil_image() + assert not pil_image_thread_safety._patched + import PIL + + image = PIL.Image.new("RGB", (10, 10)) + with NamedTemporaryFile(suffix=".png") as f: + image.save(f.name) + image = PIL.Image.open(f.name) + + pil_image_thread_safety.apply_threadsafe_patch_to_pil_image() + assert pil_image_thread_safety._patched + + image.crop((0, 0, 10, 10)) diff --git a/tests/trace/type_handlers/Image/image_test.py b/tests/trace/type_handlers/Image/image_test.py index 316bbef917ab..a5dcbc37dc3a 100644 --- a/tests/trace/type_handlers/Image/image_test.py +++ b/tests/trace/type_handlers/Image/image_test.py @@ -149,20 +149,19 @@ def accept_image_jpg_pillow(val): file_path.unlink() +def make_random_image(image_size: tuple[int, int] = (1024, 1024)): + random_colour = ( + random.randint(0, 255), + random.randint(0, 255), + random.randint(0, 255), + ) + return Image.new("RGB", image_size, random_colour) + + @pytest.fixture def dataset_ref(client): # This fixture represents a saved dataset containing images - IMAGE_SIZE = (1024, 1024) N_ROWS = 50 - - def make_random_image(): - random_colour = ( - random.randint(0, 255), - random.randint(0, 255), - random.randint(0, 255), - ) - return Image.new("RGB", IMAGE_SIZE, random_colour) - rows = [{"img": make_random_image()} for _ in range(N_ROWS)] dataset = weave.Dataset(rows=rows) ref = weave.publish(dataset) @@ -202,3 +201,18 @@ async def test_many_images_will_consistently_log(): # But if there's an issue, the stderr will contain `Task failed:` assert "Task failed" not in res.stderr + + +def test_images_in_load_of_dataset(client): + N_ROWS = 50 + rows = [{"img": make_random_image()} for _ in range(N_ROWS)] + dataset = weave.Dataset(rows=rows) + ref = weave.publish(dataset) + + dataset = ref.get() + for gotten_row, local_row in zip(dataset, rows): + assert isinstance(gotten_row["img"], Image.Image) + assert gotten_row["img"].size == local_row["img"].size + assert gotten_row["img"].tobytes() == local_row["img"].tobytes() + + return ref diff --git a/weave/initialization/pil_image_thread_safety.py b/weave/initialization/pil_image_thread_safety.py index bac688a5d8ce..00093b50b799 100644 --- a/weave/initialization/pil_image_thread_safety.py +++ b/weave/initialization/pil_image_thread_safety.py @@ -4,27 +4,39 @@ - `undo_threadsafe_patch_to_pil_image` There is a discussion here: https://github.com/python-pillow/Pillow/issues/4848#issuecomment-671339193 in which -the author claims that the Pillow library is thread-safe. However, my reasoning leads me to a different conclusion. +the author claims that the Pillow library is thread-safe. However, empirical evidence suggests otherwise. -Specifically, the `ImageFile.load` method is not thread-safe. This is because `load` will both close and delete -an open file handler as well as modify properties of the ImageFile object (namely the `im` property which contains -the underlying image data). Inside of Weave we use threads to parallelize work which may involve Images. This bug -has presented itself not only in our own persistence layer, but also in user code where they are consuming loaded -images across threads. +Specifically, the `ImageFile.load` method is not thread-safe because it: +1. Closes and deletes an open file handler +2. Modifies properties of the ImageFile object (namely the `im` property which contains the underlying image data) -We call `apply_threadsafe_patch_to_pil_image` in the `__init__.py` file to ensure that the ImageFile class is thread-safe. +Inside Weave, we use threads to parallelize work which may involve Images. This thread-safety issue has manifested +not only in our persistence layer but also in user code where loaded images are accessed across threads. + +We call `apply_threadsafe_patch_to_pil_image` in the `__init__.py` file to ensure thread-safety for the ImageFile class. """ import threading from functools import wraps from typing import Any, Callable, Optional +# Global state +# `_patched` is a boolean that indicates whether the thread-safety patch has been applied +# `_original_methods` is a dictionary that stores the original methods of the ImageFile class +# `_new_lock_lock` is a lock that is used to create a new lock for each ImageFile instance +# `_fallback_load_lock` is a global lock that is used to ensure thread-safe image loading when per-instance locking fails _patched = False -_original_methods: dict[str, Optional[Callable]] = {"init": None, "load": None} +_original_methods: dict[str, Optional[Callable]] = {"load": None} +_new_lock_lock = threading.Lock() +_fallback_load_lock = threading.Lock() def apply_threadsafe_patch_to_pil_image() -> None: - """Apply thread-safety patch to PIL ImageFile class.""" + """Apply thread-safety patch to PIL ImageFile class. + + This function is idempotent - calling it multiple times has no additional effect. + If PIL is not installed or if patching fails, the function will handle the error gracefully. + """ global _patched if _patched: @@ -41,35 +53,61 @@ def apply_threadsafe_patch_to_pil_image() -> None: def _apply_threadsafe_patch() -> None: + """Internal function that performs the actual thread-safety patching of PIL ImageFile. + + Raises: + ImportError: If PIL is not installed + Exception: For any other unexpected errors during patching + """ from PIL.ImageFile import ImageFile global _original_methods # Store original methods - _original_methods["init"] = ImageFile.__init__ _original_methods["load"] = ImageFile.load - old_load = ImageFile.load - old_init = ImageFile.__init__ - - @wraps(old_init) - def new_init(self: ImageFile, *args: Any, **kwargs: Any) -> None: - self._weave_load_lock = threading.Lock() # type: ignore - return old_init(self, *args, **kwargs) @wraps(old_load) def new_load(self: ImageFile, *args: Any, **kwargs: Any) -> Any: - with self._weave_load_lock: # type: ignore + # This function wraps PIL's ImageFile.load method to make it thread-safe + # by ensuring only one thread can load an image at a time per ImageFile instance. + + # We use a per-instance lock to allow concurrent loading of different images + # while preventing concurrent access to the same image. + try: + # Create a new lock for this ImageFile instance if it doesn't exist. + # The lock creation itself needs to be thread-safe, hence _new_lock_lock. + # Note: this `_new_lock_lock` is global as opposed to per-instance, else + # it would be possible for the same ImageFile to be loaded by multiple threads + # thereby creating a race where different threads would be each minting their + # own lock for the same ImageFile! + if not hasattr(self, "_weave_load_lock"): + with _new_lock_lock: + # Double-check pattern: verify the attribute still doesn't exist + # after acquiring the lock to prevent race conditions + if not hasattr(self, "_weave_load_lock"): + setattr(self, "_weave_load_lock", threading.Lock()) + lock = getattr(self, "_weave_load_lock") + + except Exception: + # If anything goes wrong with the locking mechanism, + # fall back to the global lock for safety + lock = _fallback_load_lock + # Acquire the instance-specific lock before loading the image. + # This ensures thread-safety by preventing concurrent: + # - Modification of the 'im' property + # - Access to the file handler + with lock: return old_load(self, *args, **kwargs) - ImageFile.__init__ = new_init # type: ignore + # Replace the load method with our thread-safe version ImageFile.load = new_load # type: ignore def undo_threadsafe_patch_to_pil_image() -> None: """Revert the thread-safety patch applied to PIL ImageFile class. - If the patch hasn't been applied, this function does nothing. + This function is idempotent - if the patch hasn't been applied, this function does nothing. If the patch has been applied but can't be reverted, an error message is printed. """ global _patched @@ -90,13 +128,17 @@ def undo_threadsafe_patch_to_pil_image() -> None: def _undo_threadsafe_patch() -> None: + """Internal function that performs the actual removal of thread-safety patches. + + Raises: + ImportError: If PIL is not installed + Exception: For any other unexpected errors during unpatching + """ from PIL.ImageFile import ImageFile global _original_methods - if _original_methods["init"] is not None: - ImageFile.__init__ = _original_methods["init"] # type: ignore if _original_methods["load"] is not None: ImageFile.load = _original_methods["load"] # type: ignore - _original_methods = {"init": None, "load": None} + _original_methods = {"load": None}