From 4f1f44cb45e63ce5ed81c1d722a50317044210b9 Mon Sep 17 00:00:00 2001 From: Joe Shannon Date: Mon, 13 May 2024 14:15:46 +0100 Subject: [PATCH] Remove ophyd_async_connect (#462) The device connection is now handled by device_instantiation in dodal. This function also provides the option on whether to wait for connection, so it is not needed here too. Additionally it can lead to undefined (currently) behaviour if the device is initially created with fake_with_ophyd_sim = True but then later connected again by blueapi with fake_with_ophyd_sim = False. This also leaves the sim property on BlueskyContext redundant so that is removed too. For full customisation and flexibility of lazy connect we need #440. Fixes #461. --- src/blueapi/core/context.py | 11 +--- src/blueapi/service/handler.py | 4 +- src/blueapi/utils/__init__.py | 2 - src/blueapi/utils/ophyd_async_connect.py | 54 ----------------- tests/utils/test_ophyd_async_connect.py | 77 ------------------------ 5 files changed, 2 insertions(+), 146 deletions(-) delete mode 100644 src/blueapi/utils/ophyd_async_connect.py delete mode 100644 tests/utils/test_ophyd_async_connect.py diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 98493a913..50623967e 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -6,14 +6,13 @@ from types import ModuleType, UnionType from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints -from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop +from bluesky.run_engine import RunEngine from pydantic import create_model from pydantic.fields import FieldInfo, ModelField from blueapi.config import EnvironmentConfig, SourceKind from blueapi.utils import ( BlueapiPlanModelConfig, - connect_ophyd_async_devices, load_module_all, ) @@ -45,7 +44,6 @@ class BlueskyContext: plans: dict[str, Plan] = field(default_factory=dict) devices: dict[str, Device] = field(default_factory=dict) plan_functions: dict[str, PlanGenerator] = field(default_factory=dict) - sim: bool = field(default=False) _reference_cache: dict[type, type] = field(default_factory=dict) @@ -78,13 +76,6 @@ def with_config(self, config: EnvironmentConfig) -> None: elif source.kind is SourceKind.DODAL: self.with_dodal_module(mod) - call_in_bluesky_event_loop( - connect_ophyd_async_devices( - self.devices.values(), - self.sim, - ) - ) - def with_plan_module(self, module: ModuleType) -> None: """ Register all functions in the module supplied as plans. diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index afa9818c1..24cdd312a 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -154,9 +154,7 @@ def setup_handler( handler = Handler( config, - context=BlueskyContext( - sim=False, - ), + context=BlueskyContext(), ) handler.start() diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index b3c212a51..b871f842a 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,7 +1,6 @@ from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig from .invalid_config_error import InvalidConfigError from .modules import load_module_all -from .ophyd_async_connect import connect_ophyd_async_devices from .serialization import serialize from .thread_exception import handle_all_exceptions @@ -14,5 +13,4 @@ "BlueapiModelConfig", "BlueapiPlanModelConfig", "InvalidConfigError", - "connect_ophyd_async_devices", ] diff --git a/src/blueapi/utils/ophyd_async_connect.py b/src/blueapi/utils/ophyd_async_connect.py deleted file mode 100644 index 382b412bf..000000000 --- a/src/blueapi/utils/ophyd_async_connect.py +++ /dev/null @@ -1,54 +0,0 @@ -import asyncio -import logging -from collections.abc import Iterable -from contextlib import suppress -from typing import Any - -from ophyd_async.core import DEFAULT_TIMEOUT, NotConnected -from ophyd_async.core import Device as OphydAsyncDevice - - -async def connect_ophyd_async_devices( - devices: Iterable[Any], - sim: bool = False, - timeout: float = DEFAULT_TIMEOUT, -) -> None: - tasks: dict[asyncio.Task, str] = {} - for device in devices: - if isinstance(device, OphydAsyncDevice): - task = asyncio.create_task(device.connect(sim=sim)) - tasks[task] = device.name - if tasks: - await _wait_for_tasks(tasks, timeout=timeout) - - -async def _wait_for_tasks(tasks: dict[asyncio.Task, str], timeout: float): - done, pending = await asyncio.wait(tasks, timeout=timeout) - if pending: - msg = f"{len(pending)} Devices did not connect:" - for t in pending: - t.cancel() - with suppress(Exception): - await t - msg += _format_awaited_task_error_message(tasks, t) - logging.error(msg) - raised = [t for t in done if t.exception()] - if raised: - logging.error(f"{len(raised)} Devices raised an error:") - for t in raised: - logging.exception(f" {tasks[t]}:", exc_info=t.exception()) - if pending or raised: - raise NotConnected("Not all Devices connected") - - -def _format_awaited_task_error_message( - tasks: dict[asyncio.Task, str], t: asyncio.Task -) -> str: - e = t.exception() - part_one = f"\n {tasks[t]}: {type(e).__name__}" - lines = str(e).splitlines() - - part_two = ( - f": {e}" if len(lines) <= 1 else "".join(f"\n {line}" for line in lines) - ) - return part_one + part_two diff --git a/tests/utils/test_ophyd_async_connect.py b/tests/utils/test_ophyd_async_connect.py deleted file mode 100644 index f3dcba767..000000000 --- a/tests/utils/test_ophyd_async_connect.py +++ /dev/null @@ -1,77 +0,0 @@ -import asyncio -import unittest - -from blueapi.utils.ophyd_async_connect import _format_awaited_task_error_message -from blueapi.worker.task import Task - -_SIMPLE_TASK = Task(name="sleep", params={"time": 0.0}) -_LONG_TASK = Task(name="sleep", params={"time": 1.0}) - - -class TestFormatErrorMessage(unittest.TestCase): - def setUp(self): - # Setup the asyncio event loop for each test - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - # Close the loop at the end of each test - self.loop.close() - - async def _create_task_with_exception(self, exception): - """Helper coroutine to create a task that raises an exception.""" - - async def raise_exception(): - raise exception - - task = self.loop.create_task(raise_exception()) - await asyncio.sleep(0.1) # Allow time for the task to raise the exception - return task - - def test_format_error_message_single_line(self): - # Test formatting with an exception that has a single-line message - exception = ValueError("A single-line error") - task = self.loop.run_until_complete(self._create_task_with_exception(exception)) - tasks = {task: "Task1"} - expected_output = "\n Task1: ValueError: A single-line error" - self.assertEqual( - _format_awaited_task_error_message(tasks, task), expected_output - ) - - def test_format_error_message_multi_line(self): - # Test formatting with an exception that has a multi-line message - exception = ValueError("A multi-line\nerror message") - task = self.loop.run_until_complete(self._create_task_with_exception(exception)) - tasks = {task: "Task2"} - expected_output = "\n Task2: ValueError\n A multi-line\n error message" - self.assertEqual( - _format_awaited_task_error_message(tasks, task), expected_output - ) - - def test_format_error_message_simple_task_failure(self): - # Test formatting with the _SIMPLE_TASK key and a failing asyncio task - exception = RuntimeError("Simple task error") - failing_task = self.loop.run_until_complete( - self._create_task_with_exception(exception) - ) - tasks = {failing_task: _SIMPLE_TASK.name} - expected_output = "\n sleep: RuntimeError: Simple task error" - self.assertEqual( - _format_awaited_task_error_message(tasks, failing_task), expected_output - ) - - def test_format_error_message_long_task_failure(self): - # Test formatting with the _LONG_TASK key and a failing asyncio task - exception = RuntimeError("Long task error") - failing_task = self.loop.run_until_complete( - self._create_task_with_exception(exception) - ) - tasks = {failing_task: _LONG_TASK.name} - expected_output = "\n sleep: RuntimeError: Long task error" - self.assertEqual( - _format_awaited_task_error_message(tasks, failing_task), expected_output - ) - - -if __name__ == "__main__": - unittest.main()