diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 98493a913..50623967e 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -6,14 +6,13 @@ from types import ModuleType, UnionType from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints -from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop +from bluesky.run_engine import RunEngine from pydantic import create_model from pydantic.fields import FieldInfo, ModelField from blueapi.config import EnvironmentConfig, SourceKind from blueapi.utils import ( BlueapiPlanModelConfig, - connect_ophyd_async_devices, load_module_all, ) @@ -45,7 +44,6 @@ class BlueskyContext: plans: dict[str, Plan] = field(default_factory=dict) devices: dict[str, Device] = field(default_factory=dict) plan_functions: dict[str, PlanGenerator] = field(default_factory=dict) - sim: bool = field(default=False) _reference_cache: dict[type, type] = field(default_factory=dict) @@ -78,13 +76,6 @@ def with_config(self, config: EnvironmentConfig) -> None: elif source.kind is SourceKind.DODAL: self.with_dodal_module(mod) - call_in_bluesky_event_loop( - connect_ophyd_async_devices( - self.devices.values(), - self.sim, - ) - ) - def with_plan_module(self, module: ModuleType) -> None: """ Register all functions in the module supplied as plans. diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index afa9818c1..24cdd312a 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -154,9 +154,7 @@ def setup_handler( handler = Handler( config, - context=BlueskyContext( - sim=False, - ), + context=BlueskyContext(), ) handler.start() diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index b3c212a51..b871f842a 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,7 +1,6 @@ from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig from .invalid_config_error import InvalidConfigError from .modules import load_module_all -from .ophyd_async_connect import connect_ophyd_async_devices from .serialization import serialize from .thread_exception import handle_all_exceptions @@ -14,5 +13,4 @@ "BlueapiModelConfig", "BlueapiPlanModelConfig", "InvalidConfigError", - "connect_ophyd_async_devices", ] diff --git a/src/blueapi/utils/ophyd_async_connect.py b/src/blueapi/utils/ophyd_async_connect.py deleted file mode 100644 index 382b412bf..000000000 --- a/src/blueapi/utils/ophyd_async_connect.py +++ /dev/null @@ -1,54 +0,0 @@ -import asyncio -import logging -from collections.abc import Iterable -from contextlib import suppress -from typing import Any - -from ophyd_async.core import DEFAULT_TIMEOUT, NotConnected -from ophyd_async.core import Device as OphydAsyncDevice - - -async def connect_ophyd_async_devices( - devices: Iterable[Any], - sim: bool = False, - timeout: float = DEFAULT_TIMEOUT, -) -> None: - tasks: dict[asyncio.Task, str] = {} - for device in devices: - if isinstance(device, OphydAsyncDevice): - task = asyncio.create_task(device.connect(sim=sim)) - tasks[task] = device.name - if tasks: - await _wait_for_tasks(tasks, timeout=timeout) - - -async def _wait_for_tasks(tasks: dict[asyncio.Task, str], timeout: float): - done, pending = await asyncio.wait(tasks, timeout=timeout) - if pending: - msg = f"{len(pending)} Devices did not connect:" - for t in pending: - t.cancel() - with suppress(Exception): - await t - msg += _format_awaited_task_error_message(tasks, t) - logging.error(msg) - raised = [t for t in done if t.exception()] - if raised: - logging.error(f"{len(raised)} Devices raised an error:") - for t in raised: - logging.exception(f" {tasks[t]}:", exc_info=t.exception()) - if pending or raised: - raise NotConnected("Not all Devices connected") - - -def _format_awaited_task_error_message( - tasks: dict[asyncio.Task, str], t: asyncio.Task -) -> str: - e = t.exception() - part_one = f"\n {tasks[t]}: {type(e).__name__}" - lines = str(e).splitlines() - - part_two = ( - f": {e}" if len(lines) <= 1 else "".join(f"\n {line}" for line in lines) - ) - return part_one + part_two diff --git a/tests/utils/test_ophyd_async_connect.py b/tests/utils/test_ophyd_async_connect.py deleted file mode 100644 index f3dcba767..000000000 --- a/tests/utils/test_ophyd_async_connect.py +++ /dev/null @@ -1,77 +0,0 @@ -import asyncio -import unittest - -from blueapi.utils.ophyd_async_connect import _format_awaited_task_error_message -from blueapi.worker.task import Task - -_SIMPLE_TASK = Task(name="sleep", params={"time": 0.0}) -_LONG_TASK = Task(name="sleep", params={"time": 1.0}) - - -class TestFormatErrorMessage(unittest.TestCase): - def setUp(self): - # Setup the asyncio event loop for each test - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - # Close the loop at the end of each test - self.loop.close() - - async def _create_task_with_exception(self, exception): - """Helper coroutine to create a task that raises an exception.""" - - async def raise_exception(): - raise exception - - task = self.loop.create_task(raise_exception()) - await asyncio.sleep(0.1) # Allow time for the task to raise the exception - return task - - def test_format_error_message_single_line(self): - # Test formatting with an exception that has a single-line message - exception = ValueError("A single-line error") - task = self.loop.run_until_complete(self._create_task_with_exception(exception)) - tasks = {task: "Task1"} - expected_output = "\n Task1: ValueError: A single-line error" - self.assertEqual( - _format_awaited_task_error_message(tasks, task), expected_output - ) - - def test_format_error_message_multi_line(self): - # Test formatting with an exception that has a multi-line message - exception = ValueError("A multi-line\nerror message") - task = self.loop.run_until_complete(self._create_task_with_exception(exception)) - tasks = {task: "Task2"} - expected_output = "\n Task2: ValueError\n A multi-line\n error message" - self.assertEqual( - _format_awaited_task_error_message(tasks, task), expected_output - ) - - def test_format_error_message_simple_task_failure(self): - # Test formatting with the _SIMPLE_TASK key and a failing asyncio task - exception = RuntimeError("Simple task error") - failing_task = self.loop.run_until_complete( - self._create_task_with_exception(exception) - ) - tasks = {failing_task: _SIMPLE_TASK.name} - expected_output = "\n sleep: RuntimeError: Simple task error" - self.assertEqual( - _format_awaited_task_error_message(tasks, failing_task), expected_output - ) - - def test_format_error_message_long_task_failure(self): - # Test formatting with the _LONG_TASK key and a failing asyncio task - exception = RuntimeError("Long task error") - failing_task = self.loop.run_until_complete( - self._create_task_with_exception(exception) - ) - tasks = {failing_task: _LONG_TASK.name} - expected_output = "\n sleep: RuntimeError: Long task error" - self.assertEqual( - _format_awaited_task_error_message(tasks, failing_task), expected_output - ) - - -if __name__ == "__main__": - unittest.main()