diff --git a/lib/sycamore/sycamore/executor.py b/lib/sycamore/sycamore/executor.py index e8e885185..ec03abea0 100644 --- a/lib/sycamore/sycamore/executor.py +++ b/lib/sycamore/sycamore/executor.py @@ -1,4 +1,5 @@ import logging +from threading import Lock from typing import Callable, Iterable, TYPE_CHECKING if TYPE_CHECKING: @@ -9,6 +10,7 @@ from sycamore.plan_nodes import Node +ray_init_lock = Lock() logger = logging.getLogger(__name__) @@ -42,22 +44,23 @@ def _ray_logging_setup(): def sycamore_ray_init(**ray_args) -> None: import ray - if ray.is_initialized(): - logging.warning("Ignoring explicit request to initialize ray when it is already initialized") - return + with ray_init_lock: + if ray.is_initialized(): + logging.warning("Ignoring explicit request to initialize ray when it is already initialized") + return - if "logging_level" not in ray_args: - ray_args.update({"logging_level": logging.INFO}) + if "logging_level" not in ray_args: + ray_args.update({"logging_level": logging.INFO}) - if "runtime_env" not in ray_args: - ray_args["runtime_env"] = {} + if "runtime_env" not in ray_args: + ray_args["runtime_env"] = {} - if "worker_process_setup_hook" not in ray_args["runtime_env"]: - # logging.error("Spurious log 0: If you do not see spurious log 1 & 2, - # log messages are being dropped") - ray_args["runtime_env"]["worker_process_setup_hook"] = _ray_logging_setup + if "worker_process_setup_hook" not in ray_args["runtime_env"]: + # logging.error("Spurious log 0: If you do not see spurious log 1 & 2, + # log messages are being dropped") + ray_args["runtime_env"]["worker_process_setup_hook"] = _ray_logging_setup - ray.init(**ray_args) + ray.init(**ray_args) def visit_parallelism(n: Node): diff --git a/lib/sycamore/sycamore/tests/integration/test_executor.py b/lib/sycamore/sycamore/tests/integration/test_executor.py index ab45e5e7a..6448d9417 100644 --- a/lib/sycamore/sycamore/tests/integration/test_executor.py +++ b/lib/sycamore/sycamore/tests/integration/test_executor.py @@ -1,4 +1,8 @@ +from concurrent.futures import ThreadPoolExecutor +import tempfile +import sycamore from sycamore.context import ExecMode +from sycamore.tests.config import TEST_DIR import sycamore.tests.unit.test_executor as unit @@ -6,3 +10,29 @@ class TestPrepare(unit.TestPrepare): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.exec_mode = ExecMode.RAY + +def test_multiple_ray_init(): + import ray + + with tempfile.TemporaryDirectory() as tempdir: + context = sycamore.init(exec_mode=ExecMode.RAY) + + def write(): + ( + context.read.materialize(path=TEST_DIR / "resources/data/materialize/json_writer") + .write.json(tempdir) + ) + + num_workers = 10 + executor = ThreadPoolExecutor(max_workers=num_workers) + futures = [executor.submit(write) for _ in range(num_workers)] + got = 0 + for future in futures: + e = future.exception() + if e is not None: + assert False, e + future.result() + got += 1 + + assert got == num_workers + executor.shutdown() \ No newline at end of file