diff --git a/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py b/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py index 347687039..a77174ce1 100644 --- a/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py +++ b/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py @@ -8,6 +8,7 @@ from cdp_agentkit_core.actions.request_faucet_funds import RequestFaucetFundsAction from cdp_agentkit_core.actions.trade import TradeAction from cdp_agentkit_core.actions.transfer import TransferAction +from cdp_agentkit_core.actions.webhook import CreateWebhookAction from cdp_agentkit_core.actions.wow.buy_token import WowBuyTokenAction from cdp_agentkit_core.actions.wow.create_token import WowCreateTokenAction from cdp_agentkit_core.actions.wow.sell_token import WowSellTokenAction @@ -28,6 +29,7 @@ def get_all_cdp_actions() -> list[type[CdpAction]]: __all__ = [ "CDP_ACTIONS", "CdpAction", + "CreateWebhookAction", "DeployNftAction", "DeployTokenAction", "GetBalanceAction", diff --git a/cdp-agentkit-core/cdp_agentkit_core/actions/webhook.py b/cdp-agentkit-core/cdp_agentkit_core/actions/webhook.py new file mode 100644 index 000000000..65c82846e --- /dev/null +++ b/cdp-agentkit-core/cdp_agentkit_core/actions/webhook.py @@ -0,0 +1,173 @@ +from enum import Enum +from typing import Any + +from cdp import Webhook +from cdp.client.models.webhook import WebhookEventTypeFilter +from cdp.client.models.webhook_smart_contract_event_filter import WebhookSmartContractEventFilter +from cdp.client.models.webhook_wallet_activity_filter import WebhookWalletActivityFilter +from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator + +from cdp_agentkit_core.actions import CdpAction + +CREATE_WEBHOOK_PROMPT = """ +Create a new webhook to receive real-time updates for on-chain events. +Supports monitoring wallet activity or smart contract events by specifying: +- Callback URL for receiving events +- Event type (wallet_activity, smart_contract_event_activity, erc20_transfer or erc721_transfer) +- wallet or contract addresses to listen +Also supports monitoring erc20_transfer or erc721_transfer, when those are defined at least one of these filters needs to be provided (only one of them is required): +- Contract address to listen for token transfers +- Sender address for erc20_transfer and erc721_transfer (listen on transfers originating from this address) +- Recipient address for erc20_transfer and erc721_transfer (listen on transfers being made to this address) +Ensure event_type_filter is only sent when eventy_type is wallet_activity or smart_contract_event_activity and event_filters is only sent when event_type is erc20_transfer or erc721_transfer +""" + +class WebhookEventType(str, Enum): + """Valid webhook event types.""" + + WALLET_ACTIVITY = "wallet_activity" + SMART_CONTRACT_EVENT_ACTIVITY = "smart_contract_event_activity" + ERC20_TRANSFER = "erc20_transfer" + ERC721_TRANSFER = "erc721_transfer" + +class WebhookNetworks(str, Enum): + """Networks available for creating webhooks.""" + + BASE_MAINNET = "base-mainnet" + BASE_SEPOLIA = "base-sepolia" + +class EventFilter(BaseModel): + """Schema for event filters.""" + + from_address: str | None = Field(None, description="Sender address for token transfers") + to_address: str | None = Field(None, description="Recipient address for token transfers") + contract_address: str | None = Field(None, description="Contract address for token transfers") + + @model_validator(mode='after') + def validate_at_least_one_filter(self) -> 'EventFilter': + """Ensure at least one filter is provided.""" + if not any([self.from_address, self.to_address, self.contract_address]): + raise ValueError("At least one filter must be provided") + return self + +class EventTypeFilter(BaseModel): + """Schema for event type filter.""" + + addresses: list[str] | None = Field(None, description="List of wallet or contract addresses to monitor") + + @field_validator('addresses') + @classmethod + def validate_addresses_not_empty(cls, v: list[str] | None) -> list[str] | None: + """Ensure addresses list is not empty when provided.""" + if v is not None and len(v) == 0: + raise ValueError("addresses must contain at least one value when provided") + return v + +class WebhookInput(BaseModel): + """Input schema for create webhook action.""" + + notification_uri: HttpUrl = Field(..., description="The callback URL where webhook events will be sent") + event_type: WebhookEventType + event_type_filter: EventTypeFilter | None = None + event_filters: list[EventFilter] | None = None + network_id: WebhookNetworks + + @model_validator(mode='after') + def validate_filters(self) -> 'WebhookInput': + """Validate that the correct filter is provided based on event type.""" + if self.event_type in [WebhookEventType.WALLET_ACTIVITY, WebhookEventType.SMART_CONTRACT_EVENT_ACTIVITY]: + if self.event_filters is not None: + raise ValueError( + f"event_filters should not be provided when event_type is {self.event_type}. " + "Use event_type_filter instead." + ) + if self.event_type_filter is None: + raise ValueError( + f"event_type_filter must be provided when event_type is {self.event_type}" + ) + + if self.event_type in [WebhookEventType.ERC20_TRANSFER, WebhookEventType.ERC721_TRANSFER]: + if self.event_type_filter is not None: + raise ValueError( + f"event_type_filter should not be provided when event_type is {self.event_type}. " + "Use event_filters instead." + ) + if not self.event_filters: + raise ValueError( + f"event_filters must be provided when event_type is {self.event_type}" + ) + + return self + +def create_webhook( + notification_uri: str | HttpUrl, + event_type: str, + network_id: str, + event_type_filter: dict[str, Any] | None = None, + event_filters: list[dict[str, Any]] | None = None, +) -> str: + """Create a new webhook for monitoring on-chain events. + + Args: + notification_uri: The callback URL where webhook events will be sent + event_type: Type of events to monitor + network_id: Network to monitor + event_type_filter: Filter for event types, this will only be used when eventy_type is wallet_activity or smart_contract_event_activity + event_filters: Filters for events, this filter will only be used when event_type is erc20_transfer or erc721_transfer + + Returns: + str: Details of the created webhook + + """ + print(f"notification_uri: {notification_uri}") + print(f"event_type_filter: {event_type_filter}") + print(f"event_filters: {event_filters}") + try: + webhook_options = { + "notification_uri": str(notification_uri), + "event_type": event_type, + "network_id": network_id, + } + + # Handle different event types with appropriate filtering + if event_type == WebhookEventType.WALLET_ACTIVITY: + wallet_activity_filter = WebhookWalletActivityFilter( + addresses=event_type_filter.get("addresses", []) if event_type_filter else [], + wallet_id="" + ) + webhook_options["event_type_filter"] = WebhookEventTypeFilter(actual_instance=wallet_activity_filter) + + elif event_type == WebhookEventType.SMART_CONTRACT_EVENT_ACTIVITY: + contract_activity_filter = WebhookSmartContractEventFilter( + contract_addresses=event_type_filter.get("addresses", []) if event_type_filter else [], + ) + webhook_options["event_type_filter"] = WebhookEventTypeFilter(actual_instance=contract_activity_filter) + + elif event_type in [WebhookEventType.ERC20_TRANSFER, WebhookEventType.ERC721_TRANSFER]: + if event_filters and event_filters[0]: + filter_dict = {} + if event_filters[0].get("contract_address"): + filter_dict["contract_address"] = event_filters[0]["contract_address"] + if event_filters[0].get("from_address"): + filter_dict["from_address"] = event_filters[0]["from_address"] + if event_filters[0].get("to_address"): + filter_dict["to_address"] = event_filters[0]["to_address"] + webhook_options["event_filters"] = [filter_dict] + else: + raise ValueError(f"Unsupported event type: {event_type}") + + # Create webhook using Webhook.create() + print(f"webhook_options: {webhook_options}") + webhook = Webhook.create(**webhook_options) + return f"The webhook was successfully created: {webhook}\n\n" + + except Exception as error: + return f"Error: {error!s}" + +class CreateWebhookAction(CdpAction): + """Create webhook action.""" + + name: str = "create_webhook" + description: str = CREATE_WEBHOOK_PROMPT + args_schema: type[BaseModel] = WebhookInput + func = create_webhook diff --git a/cdp-agentkit-core/tests/actions/test_webhook.py b/cdp-agentkit-core/tests/actions/test_webhook.py new file mode 100644 index 000000000..4337e2290 --- /dev/null +++ b/cdp-agentkit-core/tests/actions/test_webhook.py @@ -0,0 +1,148 @@ +from unittest.mock import Mock, patch + +import pytest + +from cdp_agentkit_core.actions.webhook import ( + WebhookInput, + create_webhook, +) + +# Test constants +MOCK_NETWORK = "base-sepolia" +MOCK_URL = "https://example.com/" +MOCK_ADDRESS = "0x321" +MOCK_EVENT_TYPE = "wallet_activity" +SUCCESS_MESSAGE = "The webhook was successfully created:" + +@pytest.fixture +def mock_webhook(): + """Provide a mocked Webhook instance for testing.""" + with patch('cdp_agentkit_core.actions.webhook.Webhook') as mock: + mock_instance = Mock() + mock.create.return_value = mock_instance + yield mock + +def test_webhook_input_valid_parsing(): + """Test successful parsing of valid webhook inputs.""" + # Test wallet activity webhook input + valid_input = { + "notification_uri": MOCK_URL, + "event_type": MOCK_EVENT_TYPE, + "event_type_filter": { + "addresses": [MOCK_ADDRESS] + }, + "network_id": MOCK_NETWORK + } + + result = WebhookInput.model_validate(valid_input) + assert str(result.notification_uri) == MOCK_URL + assert result.event_type == MOCK_EVENT_TYPE + assert result.event_type_filter.addresses == [MOCK_ADDRESS] + assert result.network_id == MOCK_NETWORK + + # Test ERC721 transfer webhook input + another_valid_input = { + "notification_uri": MOCK_URL, + "event_type": "erc721_transfer", + "event_filters": [{ + "from_address": MOCK_ADDRESS + }], + "network_id": MOCK_NETWORK + } + + result = WebhookInput.model_validate(another_valid_input) + assert str(result.notification_uri) == MOCK_URL + assert result.event_type == "erc721_transfer" + assert result.event_filters[0].from_address == MOCK_ADDRESS + +def test_webhook_input_invalid_parsing(): + """Test parsing failure for invalid webhook input.""" + empty_input = {} + with pytest.raises(ValueError): + WebhookInput.model_validate(empty_input) + +def test_create_wallet_activity_webhook(mock_webhook): + """Test creating wallet activity webhook.""" + args = { + "notification_uri": MOCK_URL, + "event_type": MOCK_EVENT_TYPE, + "event_type_filter": { + "addresses": [MOCK_ADDRESS] + }, + "network_id": MOCK_NETWORK + } + + response = create_webhook(**args) + + assert mock_webhook.create.call_count == 1 + assert SUCCESS_MESSAGE in response + +def test_create_smart_contract_activity_webhook(mock_webhook): + """Test creating smart contract activity webhook.""" + args = { + "notification_uri": MOCK_URL, + "event_type": "smart_contract_event_activity", + "event_type_filter": { + "addresses": [MOCK_ADDRESS] + }, + "network_id": MOCK_NETWORK + } + + response = create_webhook(**args) + + assert mock_webhook.create.call_count == 1 + assert SUCCESS_MESSAGE in response + +def test_create_erc20_transfer_webhook(mock_webhook): + """Test creating ERC20 transfer webhook.""" + args = { + "notification_uri": MOCK_URL, + "event_type": "erc20_transfer", + "event_type_filter": { + "addresses": [MOCK_ADDRESS] + }, + "event_filters": [{ + "from_address": MOCK_ADDRESS + }], + "network_id": MOCK_NETWORK + } + + response = create_webhook(**args) + + assert mock_webhook.create.call_count == 1 + assert SUCCESS_MESSAGE in response + +def test_create_erc721_transfer_webhook(mock_webhook): + """Test creating ERC721 transfer webhook.""" + args = { + "notification_uri": MOCK_URL, + "event_type": "erc721_transfer", + "event_filters": [{ + "from_address": MOCK_ADDRESS + }], + "network_id": MOCK_NETWORK + } + + response = create_webhook(**args) + + assert mock_webhook.create.call_count == 1 + assert SUCCESS_MESSAGE in response + +def test_create_webhook_error_handling(mock_webhook): + """Test error handling when creating webhook fails.""" + error_msg = "Failed to create webhook" + mock_webhook.create.side_effect = Exception(error_msg) + + args = { + "notification_uri": MOCK_URL, + "event_type": MOCK_EVENT_TYPE, + "event_type_filter": { + "addresses": ["test"] + }, + "network_id": MOCK_NETWORK + } + + response = create_webhook(**args) + + assert mock_webhook.create.call_count == 1 + assert f"Error: {error_msg}" in response