diff --git a/packages/api-server/api_server/app.py b/packages/api-server/api_server/app.py index 750564f56..f55d155c3 100644 --- a/packages/api-server/api_server/app.py +++ b/packages/api-server/api_server/app.py @@ -21,6 +21,7 @@ get_alert_events, get_beacon_events, get_fleet_events, + get_rio_events, get_rmf_events, get_task_events, ) @@ -77,6 +78,7 @@ async def lifespan(_app: FastIO): await stack.enter_async_context(get_fleet_events) await stack.enter_async_context(get_alert_events) await stack.enter_async_context(get_beacon_events) + await stack.enter_async_context(get_rio_events) await Tortoise.init( db_url=app_config.db_url, diff --git a/packages/api-server/api_server/rmf_io/__init__.py b/packages/api-server/api_server/rmf_io/__init__.py index 7c8ca563a..255db9e1c 100644 --- a/packages/api-server/api_server/rmf_io/__init__.py +++ b/packages/api-server/api_server/rmf_io/__init__.py @@ -2,11 +2,13 @@ AlertEvents, BeaconEvents, FleetEvents, + RioEvents, RmfEvents, TaskEvents, get_alert_events, get_beacon_events, get_fleet_events, + get_rio_events, get_rmf_events, get_task_events, ) diff --git a/packages/api-server/api_server/rmf_io/events.py b/packages/api-server/api_server/rmf_io/events.py index d862009bc..d89d92989 100644 --- a/packages/api-server/api_server/rmf_io/events.py +++ b/packages/api-server/api_server/rmf_io/events.py @@ -71,4 +71,6 @@ def __init__(self): self.rios = Subject[mdl.Rio]() -rio_events = RioEvents() +@singleton_dep +def get_rio_events(): + return RioEvents() diff --git a/packages/api-server/api_server/routes/rios.py b/packages/api-server/api_server/routes/rios.py index 267ae4fb9..cefeb62d3 100644 --- a/packages/api-server/api_server/routes/rios.py +++ b/packages/api-server/api_server/routes/rios.py @@ -1,11 +1,11 @@ from typing import Annotated -from fastapi import Query, Response +from fastapi import Depends, Query, Response from api_server.fast_io import FastIORouter, SubscriptionRequest from api_server.models import Rio from api_server.models.tortoise_models import Rio as DbRio -from api_server.rmf_io import rio_events +from api_server.rmf_io import RioEvents, get_rio_events router = FastIORouter(tags=["RIOs"]) @@ -31,11 +31,16 @@ async def query_rios( @router.sub("", response_model=Rio) async def sub_rio(_req: SubscriptionRequest): + rio_events = get_rio_events() return rio_events.rios @router.put("", response_model=None) -async def put_rio(rio: Rio, resp: Response): +async def put_rio( + rio: Rio, + resp: Response, + rio_events: Annotated[RioEvents, Depends(get_rio_events)], +): rio_dict = rio.model_dump() del rio_dict["id"] _, created = await DbRio.update_or_create(rio_dict, id=rio.id) diff --git a/packages/api-server/api_server/routes/test_rios.py b/packages/api-server/api_server/routes/test_rios.py index e21e4b4e3..2b3a34aea 100644 --- a/packages/api-server/api_server/routes/test_rios.py +++ b/packages/api-server/api_server/routes/test_rios.py @@ -2,7 +2,7 @@ from api_server.models import Rio from api_server.models.tortoise_models import Rio as DbRio -from api_server.rmf_io import rio_events +from api_server.rmf_io import get_rio_events from api_server.test import AppFixture @@ -35,7 +35,7 @@ def test_get_rios(self): def test_sub_rios(self): with self.subscribe_sio("/rios") as sub: - rio_events.rios.on_next( + get_rio_events().rios.on_next( Rio(id="test_rio", type="test_type", data={"battery": 1}) ) rio = Rio(**next(sub))