diff --git a/.idea/ruff.xml b/.idea/ruff.xml index b87cfd6..e5d944a 100644 --- a/.idea/ruff.xml +++ b/.idea/ruff.xml @@ -1,6 +1,6 @@ - \ No newline at end of file diff --git a/src/ephys_link/common.py b/src/ephys_link/common.py index 162a8db..2208b64 100644 --- a/src/ephys_link/common.py +++ b/src/ephys_link/common.py @@ -124,6 +124,20 @@ def __init__(self, angles: list, error: str) -> None: super(AngularOutputData, self).__init__(angles=angles, error=error) +class ShankCountOutputData(dict): + """Output format for (num_shanks, error) + + :param shank_count: Number of shanks on the probe + :type shank_count: int + :param error: Error message + :type error: str + """ + + def __init__(self, shank_count: int, error: str) -> None: + """Constructor""" + super(ShankCountOutputData, self).__init__(shank_count=shank_count, error=error) + + class DriveToDepthOutputData(dict): """Output format for depth driving (depth, error) diff --git a/src/ephys_link/platform_handler.py b/src/ephys_link/platform_handler.py index 185f660..06ad24b 100644 --- a/src/ephys_link/platform_handler.py +++ b/src/ephys_link/platform_handler.py @@ -176,8 +176,8 @@ def get_angles(self, manipulator_id: str) -> com.AngularOutputData: :param manipulator_id: The ID of the manipulator to get the position of. :type manipulator_id: str - :return: Callback parameters (manipulator ID, angles in (yaw, pitch, roll) (or an - empty array on error) in degrees, error message) + :return: Callback parameters (angles in (yaw, pitch, roll) or an + empty array on error in degrees, error message) :rtype: :class:`ephys_link.common.AngularOutputData` """ try: @@ -197,6 +197,15 @@ def get_angles(self, manipulator_id: str) -> com.AngularOutputData: print(f"[ERROR]\t\t Manipulator not registered: {manipulator_id}") return com.AngularOutputData([], "Manipulator not registered") + def get_shank_count(self, manipulator_id: str) -> com.ShankCountOutputData: + """Get the number of shanks on the probe + + :param manipulator_id: The ID of the manipulator to get the number of shanks of. + :type manipulator_id: str + :return: Callback parameters (number of shanks or -1 on error, error message) + """ + return self._get_shank_count(manipulator_id) + async def goto_pos( self, manipulator_id: str, position: list[float], speed: int ) -> com.PositionalOutputData: @@ -453,6 +462,15 @@ def _get_angles(self, manipulator_id: str) -> com.AngularOutputData: :rtype: :class:`ephys_link.common.AngularOutputData` """ + @abstractmethod + def _get_shank_count(self, manipulator_id: str) -> com.ShankCountOutputData: + """Get the number of shanks on the probe + + :param manipulator_id: The ID of the manipulator to get the number of shanks of. + :type manipulator_id: str + :return: Callback parameters (number of shanks or -1 on error, error message) + """ + @abstractmethod async def _goto_pos( self, manipulator_id: str, position: list[float], speed: int diff --git a/src/ephys_link/platforms/new_scale_handler.py b/src/ephys_link/platforms/new_scale_handler.py index 6ec244b..1620b35 100644 --- a/src/ephys_link/platforms/new_scale_handler.py +++ b/src/ephys_link/platforms/new_scale_handler.py @@ -75,6 +75,9 @@ def _get_pos(self, manipulator_id: str) -> com.PositionalOutputData: def _get_angles(self, manipulator_id: str) -> com.AngularOutputData: raise NotImplementedError + def _get_shank_count(self, manipulator_id: str) -> com.ShankCountOutputData: + raise NotImplementedError + async def _goto_pos( self, manipulator_id: str, position: list[float], speed: int ) -> com.PositionalOutputData: diff --git a/src/ephys_link/platforms/new_scale_pathfinder_handler.py b/src/ephys_link/platforms/new_scale_pathfinder_handler.py index 94342d0..9af696a 100644 --- a/src/ephys_link/platforms/new_scale_pathfinder_handler.py +++ b/src/ephys_link/platforms/new_scale_pathfinder_handler.py @@ -191,6 +191,18 @@ def _get_angles(self, manipulator_id: str) -> com.AngularOutputData: "", ) + def _get_shank_count(self, manipulator_id: str) -> com.ShankCountOutputData: + """Get the number of shanks on the probe + + :param manipulator_id: manipulator ID + :return: Callback parameters (number of shanks (or -1 on error), error message) + """ + for probe in self.query_data()["ProbeArray"]: + if probe["Id"] == manipulator_id: + return com.ShankCountOutputData(probe["ShankCount"], "") + + return com.ShankCountOutputData(-1, "Unable to find manipulator") + async def _goto_pos( self, manipulator_id: str, position: list[float], speed: int ) -> com.PositionalOutputData: diff --git a/src/ephys_link/platforms/sensapex_handler.py b/src/ephys_link/platforms/sensapex_handler.py index 5f33732..17ca643 100644 --- a/src/ephys_link/platforms/sensapex_handler.py +++ b/src/ephys_link/platforms/sensapex_handler.py @@ -50,6 +50,9 @@ def _get_pos(self, manipulator_id: str) -> com.PositionalOutputData: def _get_angles(self, manipulator_id: str) -> com.AngularOutputData: raise NotImplementedError + def _get_shank_count(self, manipulator_id: str) -> com.ShankCountOutputData: + raise NotImplementedError + async def _goto_pos( self, manipulator_id: str, position: list[float], speed: int ) -> com.PositionalOutputData: diff --git a/src/ephys_link/platforms/ump3_handler.py b/src/ephys_link/platforms/ump3_handler.py index 15a8694..28b349b 100644 --- a/src/ephys_link/platforms/ump3_handler.py +++ b/src/ephys_link/platforms/ump3_handler.py @@ -43,6 +43,9 @@ def _get_pos(self, manipulator_id: str) -> com.PositionalOutputData: def _get_angles(self, manipulator_id: str) -> com.AngularOutputData: raise NotImplementedError + def _get_shank_count(self, manipulator_id: str) -> com.ShankCountOutputData: + raise NotImplementedError + async def _goto_pos( self, manipulator_id: str, position: list[float], speed: int ) -> com.PositionalOutputData: diff --git a/src/ephys_link/server.py b/src/ephys_link/server.py index 691f54a..76eb0be 100644 --- a/src/ephys_link/server.py +++ b/src/ephys_link/server.py @@ -171,6 +171,22 @@ async def get_angles(_, manipulator_id: str) -> com.AngularOutputData: return platform.get_angles(manipulator_id) +@sio.event +async def get_shank_count(_, manipulator_id: str) -> com.ShankCountOutputData: + """Number of shanks of manipulator request + + :param _: Socket session ID (unused) + :type _: str + :param manipulator_id: ID of manipulator to pull number of shanks from + :type manipulator_id: str + :return: Callback parameters (manipulator ID, number of shanks (or -1 on error), error + message) + :rtype: :class:`ephys_link.common.ShankCountOutputData` + """ + + return platform.get_shank_count(manipulator_id) + + @sio.event async def goto_pos( _, data: com.GotoPositionInputDataFormat