diff --git a/packages/api-server/api_server/gateway.py b/packages/api-server/api_server/gateway.py index 5d1605f71..e37b1123b 100644 --- a/packages/api-server/api_server/gateway.py +++ b/packages/api-server/api_server/gateway.py @@ -35,9 +35,12 @@ from rmf_task_msgs.srv import SubmitTask as RmfSubmitTask from rosidl_runtime_py.convert import message_to_ordereddict from std_msgs.msg import Bool as BoolMsg +from tortoise.exceptions import IntegrityError +from api_server.exceptions import AlreadyExistsError, InvalidInputError, NotFoundError from api_server.fast_io.singleton_dep import singleton_dep from api_server.models.user import User +from api_server.repositories.alerts import AlertRepository from api_server.repositories.cached_files import get_cached_file_repo from api_server.repositories.rmf import RmfRepository from api_server.rmf_io.events import ( @@ -51,7 +54,6 @@ from .models import ( AlertParameter, AlertRequest, - AlertResponse, BeaconState, BuildingMap, DeliveryAlert, @@ -70,6 +72,7 @@ def __init__( cached_files: CachedFilesRepository, ros_node: rclpy.node.Node, alert_events: AlertEvents, + alert_repo: AlertRepository, rmf_events: RmfEvents, rmf_repo: RmfRepository, loop: asyncio.AbstractEventLoop, @@ -79,6 +82,7 @@ def __init__( self._cached_files = cached_files self._ros_node = ros_node self._alert_events = alert_events + self._alert_repo = alert_repo self._rmf_events = rmf_events self._rmf_repo = rmf_repo self._loop = loop @@ -354,8 +358,24 @@ def convert_alert(msg): ) def handle_alert(alert: AlertRequest): + async def create_alert(alert: AlertRequest): + try: + created_alert = await self._alert_repo.create_new_alert(alert) + except IntegrityError as e: + logging.error("%s, %s", str(e), alert) + return + except AlreadyExistsError as e: + logging.error("%s, %s", str(e), alert) + return + if not created_alert: + logging.error("Failed to create alert: %s", alert) + return + + self._alert_events.alert_requests.on_next(created_alert) + logging.debug("%s", alert) + logging.info(f"Received alert: {alert}") - self._alert_events.alert_requests.on_next(alert) + self._loop.create_task(create_alert(alert)) alert_sub = self._ros_node.create_subscription( RmfAlert, @@ -370,22 +390,56 @@ def handle_alert(alert: AlertRequest): ) self._subscriptions.append(alert_sub) - def convert_alert_response(msg): - alert_response = cast(RmfAlertResponse, msg) - return AlertResponse( - id=alert_response.id, - unix_millis_response_time=round(datetime.now().timestamp() * 1000), - response=alert_response.response, - ) - - def handle_alert_response(alert_response: AlertResponse): - logging.info(f"Received alert response: {alert_response}") - self._alert_events.alert_responses.on_next(alert_response) + # FIXME(ac): Due to also subscribing to alert responses, this callback + # gets triggered as well even if the response is called through REST, + # which publishes a ROS 2 message and gets picked up by this subscriber. + # This causes alert_repo.create_response to be called twice in total, + # resulting in a conflict of responses for the same alert ID. This does + # not cause any issues, just that an error log is produced. + def handle_alert_response(msg): + msg = cast(RmfAlertResponse, msg) + + async def create_response(alert_id: str, response: str): + try: + created_response = await self._alert_repo.create_response( + msg.id, msg.response + ) + except IntegrityError as e: + logging.error( + "%s, id: %s, response: %s", str(e), alert_id, response + ) + return + except AlreadyExistsError as e: + logging.error( + "%s, id: %s, response: %s", str(e), alert_id, response + ) + return + except NotFoundError as e: + logging.error( + "%s, id: %s, response: %s", str(e), alert_id, response + ) + return + except InvalidInputError as e: + logging.error( + "%s, id: %s, response: %s", str(e), alert_id, response + ) + return + if not created_response: + logging.error( + f"Failed to create alert response [{msg.response}] for alert id [{msg.id}]" + ) + return + + self._alert_events.alert_responses.on_next(created_response) + logging.debug("%s", created_response) + + logging.info(f"Received response [{msg.response}] for alert id [{msg.id}]") + self._loop.create_task(create_response(msg.id, msg.response)) alert_response_sub = self._ros_node.create_subscription( RmfAlertResponse, "alert_response", - lambda msg: handle_alert_response(convert_alert_response(msg)), + handle_alert_response, rclpy.qos.QoSProfile( history=rclpy.qos.HistoryPolicy.KEEP_LAST, depth=10, @@ -505,6 +559,7 @@ def get_rmf_gateway(): get_cached_file_repo(), get_ros_node(), get_alert_events(), + AlertRepository(), get_rmf_events(), RmfRepository(User.get_system_user()), asyncio.get_event_loop(), diff --git a/packages/api-server/api_server/models/alerts.py b/packages/api-server/api_server/models/alerts.py index c64846350..4dff213f0 100644 --- a/packages/api-server/api_server/models/alerts.py +++ b/packages/api-server/api_server/models/alerts.py @@ -1,4 +1,3 @@ -from datetime import datetime from enum import Enum from pydantic import BaseModel @@ -20,18 +19,6 @@ class AlertResponse(BaseModel): def from_tortoise(tortoise: ttm.AlertResponse) -> "AlertResponse": return AlertResponse(**dict(tortoise.data)) - async def save(self) -> None: - await ttm.AlertResponse.update_or_create( - { - "response_time": datetime.fromtimestamp( - self.unix_millis_response_time / 1000 - ), - "response": self.response, - "data": self.json(), - }, - id=self.id, - ) - class AlertRequest(BaseModel): class Tier(str, Enum): @@ -53,16 +40,3 @@ class Tier(str, Enum): @staticmethod def from_tortoise(tortoise: ttm.AlertRequest) -> "AlertRequest": return AlertRequest(**dict(tortoise.data)) - - async def save(self) -> None: - await ttm.AlertRequest.update_or_create( - { - "request_time": datetime.fromtimestamp( - self.unix_millis_alert_time / 1000 - ), - "response_expected": (len(self.responses_available) > 0), - "task_id": self.task_id, - "data": self.json(), - }, - id=self.id, - ) diff --git a/packages/api-server/api_server/repositories/alerts.py b/packages/api-server/api_server/repositories/alerts.py index 427e2582e..136eee6b4 100644 --- a/packages/api-server/api_server/repositories/alerts.py +++ b/packages/api-server/api_server/repositories/alerts.py @@ -17,7 +17,7 @@ async def create_new_alert(self, alert: AlertRequest) -> AlertRequest: request_time=datetime.fromtimestamp(alert.unix_millis_alert_time / 1000), response_expected=(len(alert.responses_available) > 0), task_id=alert.task_id, - data=alert.json(), + data=alert.model_dump(), ) return alert @@ -30,6 +30,13 @@ async def get_alert(self, alert_id: str) -> AlertRequest: return alert_model async def create_response(self, alert_id: str, response: str) -> AlertResponse: + existing_response = await ttm.AlertResponse.get_or_none(id=alert_id) + if existing_response is not None: + existing_response_model = AlertResponse.from_tortoise(existing_response) + raise AlreadyExistsError( + f"Alert with ID {alert_id} already has a response of {existing_response_model.response}" + ) + alert = await ttm.AlertRequest.get_or_none(id=alert_id) if alert is None: raise NotFoundError(f"Alert with ID {alert_id} does not exists") @@ -51,7 +58,7 @@ async def create_response(self, alert_id: str, response: str) -> AlertResponse: alert_response_model.unix_millis_response_time / 1000 ), response=response, - data=alert_response_model.json(), + data=alert_response_model.model_dump(), alert_request=alert, ) return alert_response_model diff --git a/packages/api-server/api_server/routes/alerts.py b/packages/api-server/api_server/routes/alerts.py index 4d0b66a06..5c1f2de3a 100644 --- a/packages/api-server/api_server/routes/alerts.py +++ b/packages/api-server/api_server/routes/alerts.py @@ -79,6 +79,8 @@ async def respond_to_alert( alert_response_model = await repo.create_response(alert_id, response) except IntegrityError as e: raise HTTPException(400, e) from e + except AlreadyExistsError as e: + raise HTTPException(409, str(e)) from e except NotFoundError as e: raise HTTPException(404, str(e)) from e except InvalidInputError as e: