diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 86aa08b59..964c15792 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -57,9 +57,12 @@ def _runner() -> WorkerDispatcher: return RUNNER -def setup_runner(config: ApplicationConfig | None = None, use_subprocess: bool = True): +def setup_runner( + config: ApplicationConfig | None = None, + runner: WorkerDispatcher | None = None, +): global RUNNER - runner = WorkerDispatcher(config, use_subprocess) + runner = runner or WorkerDispatcher(config) runner.start() RUNNER = runner diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 578dea13e..675d2d2ab 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -1,7 +1,7 @@ import uuid from collections.abc import Iterator from dataclasses import dataclass -from unittest.mock import MagicMock, patch +from unittest.mock import Mock, patch import pytest from fastapi import status @@ -12,32 +12,52 @@ from blueapi.core.bluesky_types import Plan from blueapi.service import main +from blueapi.service.interface import ( + cancel_active_task, + get_device, + get_plan, + pause_worker, + resume_worker, + submit_task, +) from blueapi.service.model import ( DeviceModel, + EnvironmentResponse, PlanModel, StateChangeRequest, WorkerTask, ) +from blueapi.service.runner import WorkerDispatcher from blueapi.worker.event import WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TrackableTask +class MockCountModel(BaseModel): ... + + +COUNT = Plan(name="count", model=MockCountModel) + + +@pytest.fixture +def mock_runner() -> Mock: + return Mock(spec=WorkerDispatcher) + + @pytest.fixture -def client() -> Iterator[TestClient]: +def client(mock_runner: Mock) -> Iterator[TestClient]: with patch("blueapi.service.interface.worker"): - main.setup_runner(use_subprocess=False) + main.setup_runner(runner=mock_runner) yield TestClient(main.get_app()) main.teardown_runner() -@patch("blueapi.service.interface.get_plans") -def test_get_plans(get_plans_mock: MagicMock, client: TestClient) -> None: +def test_get_plans(mock_runner: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - get_plans_mock.return_value = [PlanModel.from_plan(plan)] + mock_runner.run.return_value = [PlanModel.from_plan(plan)] response = client.get("/plans") @@ -58,17 +78,16 @@ class MyModel(BaseModel): } -@patch("blueapi.service.interface.get_plan") -def test_get_plan_by_name(get_plan_mock: MagicMock, client: TestClient) -> None: +def test_get_plan_by_name(mock_runner: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - get_plan_mock.return_value = PlanModel.from_plan(plan) + mock_runner.run.return_value = PlanModel.from_plan(plan) response = client.get("/plans/my-plan") - get_plan_mock.assert_called_once_with("my-plan") + mock_runner.run.assert_called_once_with(get_plan, "my-plan") assert response.status_code == status.HTTP_200_OK assert response.json() == { "description": None, @@ -82,25 +101,21 @@ class MyModel(BaseModel): } -@patch("blueapi.service.interface.get_plan") -def test_get_non_existent_plan_by_name( - get_plan_mock: MagicMock, client: TestClient -) -> None: - get_plan_mock.side_effect = KeyError("my-plan") +def test_get_non_existent_plan_by_name(mock_runner: Mock, client: TestClient) -> None: + mock_runner.run.side_effect = KeyError("my-plan") response = client.get("/plans/my-plan") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == {"detail": "Item not found"} -@patch("blueapi.service.interface.get_devices") -def test_get_devices(get_devices_mock: MagicMock, client: TestClient) -> None: +def test_get_devices(mock_runner: Mock, client: TestClient) -> None: @dataclass class MyDevice: name: str device = MyDevice("my-device") - get_devices_mock.return_value = [DeviceModel.from_device(device)] + mock_runner.run.return_value = [DeviceModel.from_device(device)] response = client.get("/devices") @@ -115,18 +130,17 @@ class MyDevice: } -@patch("blueapi.service.interface.get_device") -def test_get_device_by_name(get_device_mock: MagicMock, client: TestClient) -> None: +def test_get_device_by_name(mock_runner: Mock, client: TestClient) -> None: @dataclass class MyDevice: name: str device = MyDevice("my-device") - get_device_mock.return_value = DeviceModel.from_device(device) + mock_runner.run.return_value = DeviceModel.from_device(device) response = client.get("/devices/my-device") - get_device_mock.assert_called_once_with("my-device") + mock_runner.run.assert_called_once_with(get_device, "my-device") assert response.status_code == status.HTTP_200_OK assert response.json() == { "name": "my-device", @@ -134,51 +148,44 @@ class MyDevice: } -@patch("blueapi.service.interface.get_device") -def test_get_non_existent_device_by_name( - get_device_mock: MagicMock, client: TestClient -) -> None: - get_device_mock.side_effect = KeyError("my-device") +def test_get_non_existent_device_by_name(mock_runner: Mock, client: TestClient) -> None: + mock_runner.run.side_effect = KeyError("my-device") response = client.get("/devices/my-device") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == {"detail": "Item not found"} -@patch("blueapi.service.interface.submit_task") -@patch("blueapi.service.interface.get_plan") -def test_create_task( - get_plan_mock: MagicMock, submit_task_mock: MagicMock, client: TestClient -) -> None: +def test_create_task(mock_runner: Mock, client: TestClient) -> None: task = Task(name="count", params={"detectors": ["x"]}) task_id = str(uuid.uuid4()) - submit_task_mock.return_value = task_id + mock_runner.run.side_effect = [COUNT, task_id] response = client.post("/tasks", json=task.model_dump()) - submit_task_mock.assert_called_once_with(task) + mock_runner.run.assert_called_with(submit_task, task) assert response.json() == {"task_id": task_id} -@patch("blueapi.service.interface.submit_task") -@patch("blueapi.service.interface.get_plan") -def test_create_task_validation_error( - get_plan_mock: MagicMock, submit_task_mock: MagicMock, client: TestClient -) -> None: +def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - get_plan_mock.return_value = PlanModel.from_plan(plan) - submit_task_mock.side_effect = ValidationError.from_exception_data( - title="ValueError", - line_errors=[ - InitErrorDetails( - type="missing", loc=("id",), msg="value is required for Identifier" - ) # type: ignore - ], - ) + + mock_runner.run.side_effect = [ + PlanModel.from_plan(plan), + ValidationError.from_exception_data( + title="ValueError", + line_errors=[ + InitErrorDetails( + type="missing", loc=("id",), msg="value is required for Identifier" + ) # type: ignore + ], + ), + ] + response = client.post("/tasks", json={"name": "my-plan"}) assert response.status_code == 422 assert response.json() == { @@ -192,32 +199,21 @@ class MyModel(BaseModel): } -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") -def test_put_plan_begins_task( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient -) -> None: +def test_put_plan_begins_task(client: TestClient) -> None: task_id = "04cd9aa6-b902-414b-ae4b-49ea4200e957" - # Set to idle - get_active_task_mock.return_value = None - begin_task_mock.return_value = WorkerTask(task_id=task_id) - resp = client.put("/worker/task", json={"task_id": task_id}) assert resp.status_code == status.HTTP_200_OK assert resp.json() == {"task_id": task_id} -@patch("blueapi.service.interface.get_active_task") -def test_put_plan_fails_if_not_idle( - get_active_task_mock: MagicMock, client: TestClient -) -> None: +def test_put_plan_fails_if_not_idle(mock_runner: Mock, client: TestClient) -> None: task_id_current = "260f7de3-b608-4cdc-a66c-257e95809792" task_id_new = "07e98d68-21b5-4ad7-ac34-08b2cb992d42" # Set to non idle - get_active_task_mock.return_value = TrackableTask( + mock_runner.run.return_value = TrackableTask( task=None, task_id=task_id_current, is_complete=False ) @@ -227,8 +223,7 @@ def test_put_plan_fails_if_not_idle( assert resp.json() == {"detail": "Worker already active"} -@patch("blueapi.service.interface.get_tasks") -def test_get_tasks(get_tasks_mock: MagicMock, client: TestClient) -> None: +def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: tasks = [ TrackableTask(task_id="0", task=Task(name="sleep", params={"time": 0.0})), TrackableTask( @@ -239,7 +234,7 @@ def test_get_tasks(get_tasks_mock: MagicMock, client: TestClient) -> None: ), ] - get_tasks_mock.return_value = tasks + mock_runner.run.return_value = tasks response = client.get("/tasks") assert response.status_code == status.HTTP_200_OK @@ -266,10 +261,7 @@ def test_get_tasks(get_tasks_mock: MagicMock, client: TestClient) -> None: } -@patch("blueapi.service.interface.get_tasks_by_status") -def test_get_tasks_by_status( - get_tasks_by_status_mock: MagicMock, client: TestClient -) -> None: +def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: tasks = [ TrackableTask( task_id="3", @@ -279,7 +271,7 @@ def test_get_tasks_by_status( ), ] - get_tasks_by_status_mock.return_value = tasks + mock_runner.run.return_value = tasks response = client.get("/tasks", params={"task_status": "PENDING"}) assert response.json() == { @@ -301,19 +293,14 @@ def test_get_tasks_by_status_invalid(client: TestClient) -> None: assert response.status_code == status.HTTP_400_BAD_REQUEST -@patch("blueapi.service.interface.clear_task") -def test_delete_submitted_task(clear_task_mock: MagicMock, client: TestClient) -> None: +def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: task_id = str(uuid.uuid4()) - clear_task_mock.return_value = task_id + mock_runner.run.return_value = task_id response = client.delete(f"/tasks/{task_id}") assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") -def test_set_active_task( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient -) -> None: +def test_set_active_task(client: TestClient) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) @@ -323,15 +310,13 @@ def test_set_active_task( assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") def test_set_active_task_active_task_complete( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient + mock_runner: Mock, client: TestClient ) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - get_active_task_mock.return_value = TrackableTask( + mock_runner.run.return_value = TrackableTask( task_id="1", task=Task(name="a_completed_task"), is_complete=True, @@ -344,15 +329,13 @@ def test_set_active_task_active_task_complete( assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") def test_set_active_task_worker_already_running( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient + mock_runner: Mock, client: TestClient ) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - get_active_task_mock.return_value = TrackableTask( + mock_runner.run.return_value = TrackableTask( task_id="1", task=Task(name="a_running_task"), is_complete=False, @@ -365,15 +348,14 @@ def test_set_active_task_worker_already_running( assert response.json() == {"detail": "Worker already active"} -@patch("blueapi.service.interface.get_task_by_id") -def test_get_task(get_task_by_id: MagicMock, client: TestClient): +def test_get_task(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) task = TrackableTask( task_id=task_id, task=Task(name="third_task"), ) - get_task_by_id.return_value = task + mock_runner.run.return_value = task response = client.get(f"/tasks/{task_id}") assert response.json() == { @@ -386,8 +368,7 @@ def test_get_task(get_task_by_id: MagicMock, client: TestClient): } -@patch("blueapi.service.interface.get_tasks") -def test_get_all_tasks(get_all_tasks: MagicMock, client: TestClient): +def test_get_all_tasks(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) tasks = [ TrackableTask( @@ -396,7 +377,7 @@ def test_get_all_tasks(get_all_tasks: MagicMock, client: TestClient): ) ] - get_all_tasks.return_value = tasks + mock_runner.run.return_value = tasks response = client.get("/tasks") assert response.status_code == status.HTTP_200_OK assert response.json() == { @@ -413,138 +394,108 @@ def test_get_all_tasks(get_all_tasks: MagicMock, client: TestClient): } -@patch("blueapi.service.interface.get_task_by_id") -def test_get_task_error(get_task_by_id_mock: MagicMock, client: TestClient): +def test_get_task_error(mock_runner: Mock, client: TestClient): task_id = 567 - get_task_by_id_mock.return_value = None + mock_runner.run.return_value = None response = client.get(f"/tasks/{task_id}") assert response.json() == {"detail": "Item not found"} -@patch("blueapi.service.interface.get_active_task") -def test_get_active_task(get_active_task_mock: MagicMock, client: TestClient): +def test_get_active_task(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) task = TrackableTask( task_id=task_id, task=Task(name="third_task"), ) - get_active_task_mock.return_value = task + mock_runner.run.return_value = task response = client.get("/worker/task") assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.get_active_task") -def test_get_active_task_none(get_active_task_mock: MagicMock, client: TestClient): - get_active_task_mock.return_value = None +def test_get_active_task_none(mock_runner: Mock, client: TestClient): + mock_runner.run.return_value = None response = client.get("/worker/task") assert response.json() == {"task_id": None} -@patch("blueapi.service.interface.get_worker_state") -def test_get_state(get_worker_state_mock: MagicMock, client: TestClient): +def test_get_state(mock_runner: Mock, client: TestClient): state = WorkerState.SUSPENDING - get_worker_state_mock.return_value = state + mock_runner.run.return_value = state response = client.get("/worker/state") assert response.json() == state -@patch("blueapi.service.interface.pause_worker") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_running_to_paused( - get_worker_state_mock: MagicMock, pause_worker_mock: MagicMock, client: TestClient -): +def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.PAUSED - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - pause_worker_mock.assert_called_once_with(False) + mock_runner.run.assert_any_call(pause_worker, False) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.resume_worker") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_paused_to_running( - get_worker_state_mock: MagicMock, resume_worker_mock: MagicMock, client: TestClient -): +def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): current_state = WorkerState.PAUSED final_state = WorkerState.RUNNING - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - resume_worker_mock.assert_called_once() + mock_runner.run.assert_any_call(resume_worker) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.cancel_active_task") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_running_to_aborting( - get_worker_state_mock: MagicMock, - cancel_active_task_mock: MagicMock, - client: TestClient, -): +def test_set_state_running_to_aborting(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.ABORTING - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - cancel_active_task_mock.assert_called_once_with(True, None) + mock_runner.run.assert_any_call(cancel_active_task, True, None) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.cancel_active_task") -@patch("blueapi.service.interface.get_worker_state") def test_set_state_running_to_stopping_including_reason( - get_worker_state_mock: MagicMock, - cancel_active_task_mock: MagicMock, - client: TestClient, + mock_runner: Mock, client: TestClient ): current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING reason = "blueapi is being stopped" - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state, reason=reason).model_dump(), ) - cancel_active_task_mock.assert_called_once_with(False, reason) + mock_runner.run.assert_any_call(cancel_active_task, False, reason) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.cancel_active_task") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_transition_error( - get_worker_state_mock: MagicMock, - cancel_active_task_mock: MagicMock, - client: TestClient, -): +def test_set_state_transition_error(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING - get_worker_state_mock.side_effect = [current_state, final_state] - - cancel_active_task_mock.side_effect = TransitionError() + mock_runner.run.side_effect = [current_state, TransitionError(), final_state] response = client.put( "/worker/state", @@ -555,15 +506,12 @@ def test_set_state_transition_error( assert response.json() == final_state -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_invalid_transition( - get_worker_state_mock: MagicMock, client: TestClient -): +def test_set_state_invalid_transition(mock_runner: Mock, client: TestClient): current_state = WorkerState.STOPPING requested_state = WorkerState.PAUSED final_state = WorkerState.STOPPING - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, final_state] response = client.put( "/worker/state", @@ -574,21 +522,18 @@ def test_set_state_invalid_transition( assert response.json() == final_state -def test_get_environment_idle(client: TestClient) -> None: +def test_get_environment_idle(mock_runner: Mock, client: TestClient) -> None: + mock_runner.state = EnvironmentResponse( + initialized=True, + error_message=None, + ) + assert client.get("/environment").json() == { "initialized": True, "error_message": None, } -def test_delete_environment(client: TestClient) -> None: +def test_delete_environment(mock_runner: Mock, client: TestClient) -> None: response = client.delete("/environment") assert response.status_code is status.HTTP_200_OK - - -@patch("blueapi.service.runner.Pool") -def test_subprocess_enabled_by_default(mp_pool_mock: MagicMock): - """Ensure that in the default rest app a subprocess runner is used""" - main.setup_runner() - mp_pool_mock.assert_called_once() - main.teardown_runner()