diff --git a/payments_py/ai_query_api.py b/payments_py/ai_query_api.py index ab49ef6..af18e1f 100644 --- a/payments_py/ai_query_api.py +++ b/payments_py/ai_query_api.py @@ -51,8 +51,7 @@ async def subscribe(self, callback: Any, join_account_room: bool = True, join_ag subscribe_event_types (Optional[List[str]]): The event types to subscribe to. get_pending_events_on_subscribe (bool): If True, it will get the pending events on subscribe. """ - self.set_subscriber(callback=callback, join_account_room=join_account_room, join_agent_rooms=join_agent_rooms, subscribe_event_types=subscribe_event_types, get_pending_events_on_subscribe=get_pending_events_on_subscribe) - await self.connect_socket() + await self.connect_socket_subscriber(callback=callback, join_account_room=join_account_room, join_agent_rooms=join_agent_rooms, subscribe_event_types=subscribe_event_types, get_pending_events_on_subscribe=get_pending_events_on_subscribe) await asyncio.Event().wait() async def log_task(self, task_log: TaskLog): @@ -68,7 +67,7 @@ async def log_task(self, task_log: TaskLog): await self.socket_client.emit('_task-log', json.dumps(data)) - def create_task(self, did: str, task: Task): + async def create_task(self, did: str, task: Task, _callback: Optional[Any]=None): """ Subscribers can create an AI Task for an Agent. The task must contain the input query that will be used by the AI Agent. This method is used by subscribers of a Payment Plan required to access a specific AI Agent or Service. Users who are not subscribers won't be able to create AI Tasks for that Agent. @@ -78,6 +77,8 @@ def create_task(self, did: str, task: Task): Args: did (str): The DID of the service. task (Task): The task to create. + _callback (Any): The callback to execute when a new task log event is received (optional) + Example: task = { @@ -91,7 +92,10 @@ def create_task(self, did: str, task: Task): """ endpoint = self.parse_url_to_proxy(TASK_ENDPOINT).replace('{did}', did) token = self.get_service_token(did) - return self.post(endpoint, task, headers={'Authorization': f'Bearer {token.accessToken}'}) + result = self.post(endpoint, task, headers={'Authorization': f'Bearer {token.accessToken}'}) + if(result.status_code == 201 and _callback): + await self.subscribe_task_logs(_callback, [result["task"]["task_id"]]) + return result def create_steps(self, did: str, task_id: str, steps: List[Step]): """ @@ -212,4 +216,25 @@ def get_tasks_from_agents(self): """ return self.get(self.parse_url(GET_AGENTS_ENDPOINT)) + async def subscribe_task_logs(self, callback: Any, tasks: List[str]): + try: + if not tasks: + raise Exception('No task rooms to join in configuration') + await self.connect_socket() + + await self.socket_client.on('_connected', self._on_connected(callback, tasks)) + except Exception as error: + raise Exception(f"Unable to initialize websocket client: {self.web_socket_host} - {str(error)}") + + + def _on_connected(self, callback: Any, tasks: List[str]): + async def handle_connected_event(*args): + print(f"connectTasksSocket:: Joining tasks: {tasks}") + await self.socket_client.emit('_join-tasks', {'tasks': tasks}) + + async def handle_task_log_event(data: Any): + callback(data) + + await self.socket_client.on('task-log', handle_task_log_event) + return handle_connected_event diff --git a/payments_py/nvm_backend.py b/payments_py/nvm_backend.py index b01581c..854f197 100644 --- a/payments_py/nvm_backend.py +++ b/payments_py/nvm_backend.py @@ -67,13 +67,14 @@ def __init__(self, opts: BackendApiOptions): raise ValueError(f"Invalid URL: {self.opts.backend_host} - {str(error)}") - def set_subscriber(self, callback, join_account_room, join_agent_rooms, subscribe_event_types, get_pending_events_on_subscribe): + async def connect_socket_subscriber(self, callback, join_account_room, join_agent_rooms, subscribe_event_types, get_pending_events_on_subscribe): self.callback = callback self.join_account_room = join_account_room self.join_agent_rooms = join_agent_rooms self.subscribe_event_types = subscribe_event_types self.get_pending_events_on_subscribe = get_pending_events_on_subscribe - self.socket_client.on('_connected', self.connect_handler) + self.socket_client.on('_connected', self._subscribe) + await self.connect_socket() async def connect_socket(self): if not self.has_key: @@ -96,19 +97,9 @@ async def connect_socket(self): async def disconnect_socket(self): if self.socket_client and self.socket_client.connected: - await self.socket_client.disconnect() + await self.socket_client.disconnect() - async def connect_handler(self, data): - await self._subscribe() - if self.get_pending_events_on_subscribe: - try: - print('Emiting pending events') - if(self.get_pending_events_on_subscribe and self.join_agent_rooms): - await self._emit_step_events(AgentExecutionStatus.Pending, self.join_agent_rooms) - except Exception as e: - print('query-api:: Unable to get pending events', e) - - async def _subscribe(self): + async def _subscribe(self, data): if not self.join_account_room and not self.join_agent_rooms: raise ValueError('No rooms to join in configuration') if not self.socket_client.connected: @@ -126,6 +117,14 @@ async def event_handler(data): self.socket_client.on(event, event_handler) else: self.socket_client.on('step-updated', event_handler) + if self.get_pending_events_on_subscribe: + try: + print('Emiting pending events') + if(self.get_pending_events_on_subscribe and self.join_agent_rooms): + await self._emit_step_events(AgentExecutionStatus.Pending, self.join_agent_rooms) + except Exception as e: + print('query-api:: Unable to get pending events', e) + async def _emit_step_events(self, status: AgentExecutionStatus = AgentExecutionStatus.Pending, dids: List[str] = []): message = { "status": status.value, "dids": dids } diff --git a/pyproject.toml b/pyproject.toml index b4f6180..af96e86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "payments-py" -version = "0.5.1" +version = "0.5.2" description = "" authors = ["enrique "] readme = "README.md" diff --git a/tests/protocol_test.py b/tests/protocol_test.py index b76111b..aa12c93 100644 --- a/tests/protocol_test.py +++ b/tests/protocol_test.py @@ -121,7 +121,7 @@ async def test_AIQueryApi_create_task_in_plan_purchased(ai_query_api_build_fixtu assert builder.user_room_id, "User room ID is not set" - task = subscriber.ai_protocol.create_task(agent.did, {'query': 'sample_query', 'name': 'sample_task'}) + task = await subscriber.ai_protocol.create_task(agent.did, {'query': 'sample_query', 'name': 'sample_task'}) print('Task created:', task.json()) await asyncio.wait_for(response_event.wait(), timeout=120) @@ -144,7 +144,7 @@ async def test_AIQueryApi_create_task_in_plan_purchased(ai_query_api_build_fixtu assert int(balance2.balance) == int(balance_before_task.balance) - 2 with pytest.raises(Exception) as excinfo: - task = subscriber.ai_protocol.create_task(did=agent.did, task={}) + task = await subscriber.ai_protocol.create_task(did=agent.did, task={}) exception_args = excinfo.value.args[0] assert exception_args['status'] == 400