Skip to content

Commit

Permalink
feat: update to support callback in createTask (#38)
Browse files Browse the repository at this point in the history
* feat: update to support callback in createTask

* feat: bumpversion
  • Loading branch information
eruizgar91 authored Nov 12, 2024
1 parent 103aede commit b5b2f85
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 21 deletions.
33 changes: 29 additions & 4 deletions payments_py/ai_query_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ async def subscribe(self, callback: Any, join_account_room: bool = True, join_ag
subscribe_event_types (Optional[List[str]]): The event types to subscribe to.
get_pending_events_on_subscribe (bool): If True, it will get the pending events on subscribe.
"""
self.set_subscriber(callback=callback, join_account_room=join_account_room, join_agent_rooms=join_agent_rooms, subscribe_event_types=subscribe_event_types, get_pending_events_on_subscribe=get_pending_events_on_subscribe)
await self.connect_socket()
await self.connect_socket_subscriber(callback=callback, join_account_room=join_account_room, join_agent_rooms=join_agent_rooms, subscribe_event_types=subscribe_event_types, get_pending_events_on_subscribe=get_pending_events_on_subscribe)
await asyncio.Event().wait()

async def log_task(self, task_log: TaskLog):
Expand All @@ -68,7 +67,7 @@ async def log_task(self, task_log: TaskLog):
await self.socket_client.emit('_task-log', json.dumps(data))


def create_task(self, did: str, task: Task):
async def create_task(self, did: str, task: Task, _callback: Optional[Any]=None):
"""
Subscribers can create an AI Task for an Agent. The task must contain the input query that will be used by the AI Agent.
This method is used by subscribers of a Payment Plan required to access a specific AI Agent or Service. Users who are not subscribers won't be able to create AI Tasks for that Agent.
Expand All @@ -78,6 +77,8 @@ def create_task(self, did: str, task: Task):
Args:
did (str): The DID of the service.
task (Task): The task to create.
_callback (Any): The callback to execute when a new task log event is received (optional)
Example:
task = {
Expand All @@ -91,7 +92,10 @@ def create_task(self, did: str, task: Task):
"""
endpoint = self.parse_url_to_proxy(TASK_ENDPOINT).replace('{did}', did)
token = self.get_service_token(did)
return self.post(endpoint, task, headers={'Authorization': f'Bearer {token.accessToken}'})
result = self.post(endpoint, task, headers={'Authorization': f'Bearer {token.accessToken}'})
if(result.status_code == 201 and _callback):
await self.subscribe_task_logs(_callback, [result["task"]["task_id"]])
return result

def create_steps(self, did: str, task_id: str, steps: List[Step]):
"""
Expand Down Expand Up @@ -212,4 +216,25 @@ def get_tasks_from_agents(self):
"""
return self.get(self.parse_url(GET_AGENTS_ENDPOINT))

async def subscribe_task_logs(self, callback: Any, tasks: List[str]):
try:
if not tasks:
raise Exception('No task rooms to join in configuration')
await self.connect_socket()

await self.socket_client.on('_connected', self._on_connected(callback, tasks))
except Exception as error:
raise Exception(f"Unable to initialize websocket client: {self.web_socket_host} - {str(error)}")


def _on_connected(self, callback: Any, tasks: List[str]):
async def handle_connected_event(*args):
print(f"connectTasksSocket:: Joining tasks: {tasks}")
await self.socket_client.emit('_join-tasks', {'tasks': tasks})

async def handle_task_log_event(data: Any):
callback(data)

await self.socket_client.on('task-log', handle_task_log_event)

return handle_connected_event
27 changes: 13 additions & 14 deletions payments_py/nvm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ def __init__(self, opts: BackendApiOptions):
raise ValueError(f"Invalid URL: {self.opts.backend_host} - {str(error)}")


def set_subscriber(self, callback, join_account_room, join_agent_rooms, subscribe_event_types, get_pending_events_on_subscribe):
async def connect_socket_subscriber(self, callback, join_account_room, join_agent_rooms, subscribe_event_types, get_pending_events_on_subscribe):
self.callback = callback
self.join_account_room = join_account_room
self.join_agent_rooms = join_agent_rooms
self.subscribe_event_types = subscribe_event_types
self.get_pending_events_on_subscribe = get_pending_events_on_subscribe
self.socket_client.on('_connected', self.connect_handler)
self.socket_client.on('_connected', self._subscribe)
await self.connect_socket()

async def connect_socket(self):
if not self.has_key:
Expand All @@ -96,19 +97,9 @@ async def connect_socket(self):

async def disconnect_socket(self):
if self.socket_client and self.socket_client.connected:
await self.socket_client.disconnect()
await self.socket_client.disconnect()

async def connect_handler(self, data):
await self._subscribe()
if self.get_pending_events_on_subscribe:
try:
print('Emiting pending events')
if(self.get_pending_events_on_subscribe and self.join_agent_rooms):
await self._emit_step_events(AgentExecutionStatus.Pending, self.join_agent_rooms)
except Exception as e:
print('query-api:: Unable to get pending events', e)

async def _subscribe(self):
async def _subscribe(self, data):
if not self.join_account_room and not self.join_agent_rooms:
raise ValueError('No rooms to join in configuration')
if not self.socket_client.connected:
Expand All @@ -126,6 +117,14 @@ async def event_handler(data):
self.socket_client.on(event, event_handler)
else:
self.socket_client.on('step-updated', event_handler)
if self.get_pending_events_on_subscribe:
try:
print('Emiting pending events')
if(self.get_pending_events_on_subscribe and self.join_agent_rooms):
await self._emit_step_events(AgentExecutionStatus.Pending, self.join_agent_rooms)
except Exception as e:
print('query-api:: Unable to get pending events', e)


async def _emit_step_events(self, status: AgentExecutionStatus = AgentExecutionStatus.Pending, dids: List[str] = []):
message = { "status": status.value, "dids": dids }
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "payments-py"
version = "0.5.1"
version = "0.5.2"
description = ""
authors = ["enrique <enrique@nevermined.io>"]
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions tests/protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def test_AIQueryApi_create_task_in_plan_purchased(ai_query_api_build_fixtu
assert builder.user_room_id, "User room ID is not set"


task = subscriber.ai_protocol.create_task(agent.did, {'query': 'sample_query', 'name': 'sample_task'})
task = await subscriber.ai_protocol.create_task(agent.did, {'query': 'sample_query', 'name': 'sample_task'})
print('Task created:', task.json())

await asyncio.wait_for(response_event.wait(), timeout=120)
Expand All @@ -144,7 +144,7 @@ async def test_AIQueryApi_create_task_in_plan_purchased(ai_query_api_build_fixtu
assert int(balance2.balance) == int(balance_before_task.balance) - 2

with pytest.raises(Exception) as excinfo:
task = subscriber.ai_protocol.create_task(did=agent.did, task={})
task = await subscriber.ai_protocol.create_task(did=agent.did, task={})
exception_args = excinfo.value.args[0]
assert exception_args['status'] == 400

Expand Down

0 comments on commit b5b2f85

Please sign in to comment.