diff --git a/payments_py/ai_query_api.py b/payments_py/ai_query_api.py index c142b58..03ec542 100644 --- a/payments_py/ai_query_api.py +++ b/payments_py/ai_query_api.py @@ -50,17 +50,10 @@ async def subscribe(self, callback: Any, join_account_room: bool = True, join_ag subscribe_event_types (Optional[List[str]]): The event types to subscribe to. get_pending_events_on_subscribe (bool): If True, it will get the pending events on subscribe. """ - await self._subscribe(callback, join_account_room, join_agent_rooms, subscribe_event_types) - print('query-api:: Connected to the server') - if get_pending_events_on_subscribe: - try: - if(get_pending_events_on_subscribe and join_agent_rooms): - await self._emit_step_events(AgentExecutionStatus.Pending, join_agent_rooms) - except Exception as e: - print('query-api:: Unable to get pending events', e) + self.set_subscriber(callback=callback, join_account_room=join_account_room, join_agent_rooms=join_agent_rooms, subscribe_event_types=subscribe_event_types, get_pending_events_on_subscribe=get_pending_events_on_subscribe) + await self.connect_socket() await asyncio.Event().wait() - def create_task(self, did: str, task: Any): """ Subscribers can create an AI Task for an Agent. The task must contain the input query that will be used by the AI Agent. diff --git a/payments_py/nvm_backend.py b/payments_py/nvm_backend.py index cbfa380..0827bce 100644 --- a/payments_py/nvm_backend.py +++ b/payments_py/nvm_backend.py @@ -39,7 +39,12 @@ def __init__(self, opts: BackendApiOptions): self.socket_client = sio self.user_room_id = None self.has_key = False - + self.callback = None + self.join_account_room = None + self.join_agent_rooms = None + self.subscribe_event_types = None + self.get_pending_events_on_subscribe = None + default_headers = { 'Accept': 'application/json', **(opts.headers or {}), @@ -76,7 +81,16 @@ def __init__(self, opts: BackendApiOptions): self.opts.backend_host = backend_url except Exception as error: raise ValueError(f"Invalid URL: {self.opts.backend_host} - {str(error)}") - + + + def set_subscriber(self, callback, join_account_room, join_agent_rooms, subscribe_event_types, get_pending_events_on_subscribe): + self.callback = callback + self.join_account_room = join_account_room + self.join_agent_rooms = join_agent_rooms + self.subscribe_event_types = subscribe_event_types + self.get_pending_events_on_subscribe = get_pending_events_on_subscribe + self.socket_client.on('_connected', self.connect_handler) + async def connect_socket(self): if not self.has_key: raise ValueError('Unable to subscribe to the server because a key was not provided') @@ -99,28 +113,36 @@ async def disconnect_socket(self): if self.socket_client and self.socket_client.connected: self.socket_client.disconnect() - async def _subscribe(self, callback, join_account_room: bool = True, join_agent_rooms: Optional[Union[str, List[str]]] = None, subscribe_event_types: Optional[List[str]] = None): - if not join_account_room and not join_agent_rooms: + async def connect_handler(self, data): + await self._subscribe() + if self.get_pending_events_on_subscribe: + try: + print('Emiting pending events') + if(self.get_pending_events_on_subscribe and self.join_agent_rooms): + await self._emit_step_events(AgentExecutionStatus.Pending, self.join_agent_rooms) + except Exception as e: + print('query-api:: Unable to get pending events', e) + + async def _subscribe(self): + if not self.join_account_room and not self.join_agent_rooms: raise ValueError('No rooms to join in configuration') - await self.connect_socket() if not self.socket_client.connected: raise ConnectionError('Failed to connect to the WebSocket server.') async def event_handler(data): parsed_data = json.loads(data) - await callback(parsed_data) + await self.callback(parsed_data) - await self.join_room(join_account_room, join_agent_rooms) + await self.join_room(self.join_account_room, self.join_agent_rooms) - if subscribe_event_types: - for event in subscribe_event_types: + if self.subscribe_event_types: + for event in self.subscribe_event_types: print(f"nvm-backend:: Subscribing to event: {event}") self.socket_client.on(event, event_handler) else: self.socket_client.on('step-updated', event_handler) - + async def _emit_step_events(self, status: AgentExecutionStatus = AgentExecutionStatus.Pending, dids: List[str] = []): - await self.connect_socket() message = { "status": status.value, "dids": dids } print(f"nvm-backend:: Emitting step: {json.dumps(message)}") await self.socket_client.emit(event='_emit-steps', data=json.dumps(message)) diff --git a/payments_py/payments.py b/payments_py/payments.py index d8cbb96..0f93332 100644 --- a/payments_py/payments.py +++ b/payments_py/payments.py @@ -508,7 +508,6 @@ def create_agent(self, plan_did: str, name: str, description: str, service_charg query_protocol_version, service_host) - def order_plan(self, plan_did: str, agreementId: Optional[str] = None) -> OrderPlanResultDto: """ Orders a Payment Plan. The user needs to have enough balance in the token selected by the owner of the Payment Plan. diff --git a/pyproject.toml b/pyproject.toml index e5ebcaa..34ed282 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "payments-py" -version = "0.4.8" +version = "0.4.9" description = "" authors = ["enrique "] readme = "README.md" diff --git a/tests/protocol_test.py b/tests/protocol_test.py index 90694ab..8a6845e 100644 --- a/tests/protocol_test.py +++ b/tests/protocol_test.py @@ -7,7 +7,6 @@ from payments_py.data_models import AgentExecutionStatus, CreateAssetResultDto, OrderPlanResultDto response_event = asyncio.Event() -room_joined_event = asyncio.Event() global response_data response_data = None @@ -73,9 +72,6 @@ async def eventsReceived(data): }) print(result.json()) -async def on_join_rooms(data): - print("Joined room:", data) - room_joined_event.set() @pytest.mark.asyncio(loop_scope="session") async def test_AIQueryApi_create_task_in_plan_purchased(ai_query_api_build_fixture, ai_query_api_subscriber_fixture): @@ -103,18 +99,7 @@ async def test_AIQueryApi_create_task_in_plan_purchased(ai_query_api_build_fixtu auth_type="none", use_ai_hub=True, ) - # agent = builder.create_service( - # plan_did=plan.did, - # service_type='agent', - # name="Agent service", - # description="test", - # amount_of_credits=1, - # service_charge_type="fixed", - # auth_type="none", - # is_nevermined_hosted=True, - # implements_query_protocol=True, - # query_protocol_version='v1' - # ) + assert isinstance(agent, CreateAssetResultDto) assert agent.did.startswith("did:") print('Agent service created:', agent.did) @@ -135,8 +120,6 @@ async def test_AIQueryApi_create_task_in_plan_purchased(ai_query_api_build_fixtu assert builder.ai_protocol.socket_client.connected, "WebSocket connection failed" assert builder.user_room_id, "User room ID is not set" - builder.ai_protocol.socket_client.on("_join-rooms_", on_join_rooms) - await asyncio.wait_for(room_joined_event.wait(), timeout=10) task = subscriber.ai_protocol.create_task(agent.did, {'query': 'sample_query', 'name': 'sample_task'}) print('Task created:', task.json()) @@ -176,14 +159,14 @@ async def test_AIQueryApi_create_task_in_plan_purchased(ai_query_api_build_fixtu pass # @pytest.mark.asyncio(loop_scope="session") -# async def test_AI_send_task(ai_query_api_build_fixture): -# builder = ai_query_api_build_fixture +# async def test_AI_send_task(ai_query_api_subscriber_fixture): +# builder = ai_query_api_subscriber_fixture # task = builder.ai_protocol.create_task('did:nv:7d86045034ea8a14c133c487374a175c56a9c6144f6395581435bc7f1dc9e0cc', -# {'query': 'https://www.youtube.com/watch?v=SB7eoaVw4Sk', 'name': 'Summarize video'}) +# {'query': 'https://www.youtube.com/watch?v=0q_BrgesfF4', 'name': 'Summarize video'}) # print('Task created:', task.json()) # @pytest.mark.asyncio(loop_scope="session") -# async def test_AI_send_task2(ai_query_api_build_fixture): -# builder = ai_query_api_build_fixture -# task = builder.ai_protocol.get_task_with_steps(did='did:nv:a8983b06c0f25fb4064fc61d6527c84ca1813e552bfad5fa1c974caa3c5ccf49', task_id='task-cd5a90e6-688f-45a3-a299-1845d10db625') +# async def test_AI_send_task2(ai_query_api_subscriber_fixture): +# builder = ai_query_api_subscriber_fixture +# task = builder.ai_protocol.get_task_with_steps(did='did:nv:7d86045034ea8a14c133c487374a175c56a9c6144f6395581435bc7f1dc9e0cc', task_id='task-6b16b12e-3aa2-43c3-a756-a150b07665e2') # print('Task result:', task.json()) \ No newline at end of file