diff --git a/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py b/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py index 321730f73..1961776b9 100644 --- a/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py +++ b/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py @@ -43,6 +43,11 @@ TransferInput, transfer, ) +from cdp_agentkit_core.actions.uniswap_v3.collect import ( + UNISWAP_V3_COLLECT_PROMPT, + UniswapV3CollectInput, + uniswap_v3_collect, +) from cdp_agentkit_core.actions.uniswap_v3.create_pool import ( UNISWAP_V3_CREATE_POOL_PROMPT, UniswapV3CreatePoolInput, @@ -85,6 +90,9 @@ "UNISWAP_V3_GET_POOL_SLOT0_PROMPT", "UniswapV3GetPoolSlot0Input", "uniswap_v3_get_pool_slot0", + "UNISWAP_V3_COLLECT_PROMPT", + "UniswapV3CollectInput", + "uniswap_v3_collect", "DEPLOY_NFT_PROMPT", "DeployNftInput", "deploy_nft", diff --git a/cdp-agentkit-core/cdp_agentkit_core/actions/uniswap_v3/collect.py b/cdp-agentkit-core/cdp_agentkit_core/actions/uniswap_v3/collect.py new file mode 100644 index 000000000..dff8b63b1 --- /dev/null +++ b/cdp-agentkit-core/cdp_agentkit_core/actions/uniswap_v3/collect.py @@ -0,0 +1,75 @@ +from cdp import Wallet +from pydantic import BaseModel, Field + +from cdp_agentkit_core.actions.uniswap_v3.constants import UNISWAP_V3_POOL_ABI + +UNISWAP_V3_COLLECT_PROMPT = """ +This tool will collect tokens from a Uniswap v3 pool. This tool takes the pool address, recipient address, tickLower, tickUpper, amount0 requested for token0, and amount1 requested for token1 as inputs. tickLower is the lower tick of the position for which to collect fees, while tickUppwer is the upper tick of the position for which to collect fees.""" + + +class UniswapV3CollectInput(BaseModel): + """Input argument schema for collect action.""" + + pool_address: str = Field( + ..., + description="The address of the pool to collect from.", + ) + recipient_address: str = Field( + ..., + description="The address of the recipient of the collected tokens.", + ) + tick_lower: str = Field( + ..., + description="The lower tick of the position for which to collect fees.", + ) + tick_upper: str = Field( + ..., + description="The upper tick of the position for which to collect fees.", + ) + amount0_requested: str = Field( + ..., + description="The amount of token0 requested for collection.", + ) + amount1_requested: str = Field( + ..., + description="The amount of token1 requested for collection.", + ) + + +def uniswap_v3_collect( + wallet: Wallet, + pool_address: str, + recipient_address: str, + tick_lower: str, + tick_upper: str, + amount0_requested: str, + amount1_requested: str, +) -> str: + """Collect tokens from a Uniswap v3 pool. + + Args: + wallet (Wallet): The wallet to collect tokens from. + pool_address (str): The address of the pool to collect from. + recipient_address (str): The address of the recipient of the collected tokens. + tick_lower (str): The lower tick of the position for which to collect fees. + tick_upper (str): The upper tick of the position for which to collect fees. + amount0_requested (str): The amount of token0 requested for collection. + amount1_requested (str): The amount of token1 requested for collection. + + Returns: + str: A message containing the details of the collected tokens. + + """ + pool = wallet.invoke_contract( + contract_address=pool_address, + method="collect", + abi=UNISWAP_V3_POOL_ABI, + args={ + "recipient": recipient_address, + "tickLower": tick_lower, + "tickUpper": tick_upper, + "amount0Requested": amount0_requested, + "amount1Requested": amount1_requested, + }, + ).wait() + return f"Requested collection of {amount0_requested} of token0 and {amount1_requested} of token1 from pool {pool_address}.\nTransaction hash for the collection: {pool.transaction.transaction_hash}\nTransaction link for the collection: {pool.transaction.transaction_link}" diff --git a/cdp-agentkit-core/tests/actions/uniswap_v3/test_collect.py b/cdp-agentkit-core/tests/actions/uniswap_v3/test_collect.py new file mode 100644 index 000000000..c677d215c --- /dev/null +++ b/cdp-agentkit-core/tests/actions/uniswap_v3/test_collect.py @@ -0,0 +1,114 @@ +from unittest.mock import patch + +import pytest + +from cdp_agentkit_core.actions.uniswap_v3.collect import ( + UniswapV3CollectInput, + uniswap_v3_collect, +) +from cdp_agentkit_core.actions.uniswap_v3.constants import UNISWAP_V3_POOL_ABI + +MOCK_POOL_ADDRESS = "0x4200000000000000000000000000000000000006" +MOCK_RECIPIENT_ADDRESS = "0x1234567890123456789012345678901234567890" +MOCK_TICK_LOWER = "100" +MOCK_TICK_UPPER = "200" +MOCK_AMOUNT0_REQUESTED = "1000" +MOCK_AMOUNT1_REQUESTED = "2000" + + +def test_collect_input_model_valid(): + """Test that CollectInput accepts valid parameters.""" + input_model = UniswapV3CollectInput( + pool_address=MOCK_POOL_ADDRESS, + recipient_address=MOCK_RECIPIENT_ADDRESS, + tick_lower=MOCK_TICK_LOWER, + tick_upper=MOCK_TICK_UPPER, + amount0_requested=MOCK_AMOUNT0_REQUESTED, + amount1_requested=MOCK_AMOUNT1_REQUESTED, + ) + + assert input_model.pool_address == MOCK_POOL_ADDRESS + assert input_model.recipient_address == MOCK_RECIPIENT_ADDRESS + assert input_model.tick_lower == MOCK_TICK_LOWER + assert input_model.tick_upper == MOCK_TICK_UPPER + assert input_model.amount0_requested == MOCK_AMOUNT0_REQUESTED + assert input_model.amount1_requested == MOCK_AMOUNT1_REQUESTED + + +def test_collect_input_model_missing_params(): + """Test that CollectInput raises error when params are missing.""" + with pytest.raises(ValueError): + UniswapV3CollectInput() + + +def test_collect_success(wallet_factory, contract_invocation_factory): + """Test successful token collection with valid parameters.""" + mock_wallet = wallet_factory() + mock_contract_instance = contract_invocation_factory() + + with ( + patch.object( + mock_wallet, "invoke_contract", return_value=mock_contract_instance + ) as mock_invoke, + patch.object( + mock_contract_instance, "wait", return_value=mock_contract_instance + ) as mock_contract_wait, + ): + action_response = uniswap_v3_collect( + mock_wallet, + MOCK_POOL_ADDRESS, + MOCK_RECIPIENT_ADDRESS, + MOCK_TICK_LOWER, + MOCK_TICK_UPPER, + MOCK_AMOUNT0_REQUESTED, + MOCK_AMOUNT1_REQUESTED, + ) + + expected_response = f"Requested collection of {MOCK_AMOUNT0_REQUESTED} of token0 and {MOCK_AMOUNT1_REQUESTED} of token1 from pool {MOCK_POOL_ADDRESS}.\nTransaction hash for the collection: {mock_contract_instance.transaction.transaction_hash}\nTransaction link for the collection: {mock_contract_instance.transaction.transaction_link}" + assert action_response == expected_response + + mock_invoke.assert_called_once_with( + contract_address=MOCK_POOL_ADDRESS, + method="collect", + abi=UNISWAP_V3_POOL_ABI, + args={ + "recipient": MOCK_RECIPIENT_ADDRESS, + "tickLower": MOCK_TICK_LOWER, + "tickUpper": MOCK_TICK_UPPER, + "amount0Requested": MOCK_AMOUNT0_REQUESTED, + "amount1Requested": MOCK_AMOUNT1_REQUESTED, + }, + ) + mock_contract_wait.assert_called_once_with() + + +def test_collect_api_error(wallet_factory): + """Test collect when API error occurs.""" + mock_wallet = wallet_factory() + + with patch.object( + mock_wallet, "invoke_contract", side_effect=Exception("API error") + ) as mock_invoke: + with pytest.raises(Exception, match="API error"): + uniswap_v3_collect( + mock_wallet, + MOCK_POOL_ADDRESS, + MOCK_RECIPIENT_ADDRESS, + MOCK_TICK_LOWER, + MOCK_TICK_UPPER, + MOCK_AMOUNT0_REQUESTED, + MOCK_AMOUNT1_REQUESTED, + ) + + mock_invoke.assert_called_once_with( + contract_address=MOCK_POOL_ADDRESS, + method="collect", + abi=UNISWAP_V3_POOL_ABI, + args={ + "recipient": MOCK_RECIPIENT_ADDRESS, + "tickLower": MOCK_TICK_LOWER, + "tickUpper": MOCK_TICK_UPPER, + "amount0Requested": MOCK_AMOUNT0_REQUESTED, + "amount1Requested": MOCK_AMOUNT1_REQUESTED, + }, + ) diff --git a/cdp-langchain/cdp_langchain/agent_toolkits/cdp_toolkit.py b/cdp-langchain/cdp_langchain/agent_toolkits/cdp_toolkit.py index 78b4dcadb..b290a27ce 100644 --- a/cdp-langchain/cdp_langchain/agent_toolkits/cdp_toolkit.py +++ b/cdp-langchain/cdp_langchain/agent_toolkits/cdp_toolkit.py @@ -13,6 +13,7 @@ REQUEST_FAUCET_FUNDS_PROMPT, TRADE_PROMPT, TRANSFER_PROMPT, + UNISWAP_V3_COLLECT_PROMPT, UNISWAP_V3_CREATE_POOL_PROMPT, UNISWAP_V3_GET_POOL_LIQUIDITY_PROMPT, UNISWAP_V3_GET_POOL_OBSERVE_PROMPT, @@ -27,6 +28,7 @@ RequestFaucetFundsInput, TradeInput, TransferInput, + UniswapV3CollectInput, UniswapV3CreatePoolInput, UniswapV3GetPoolInput, UniswapV3GetPoolLiquidityInput, @@ -94,6 +96,7 @@ class CdpToolkit(BaseToolkit): uniswap_v3_get_pool_observe uniswap_v3_get_pool_slot0 uniswap_v3_get_pool_liquidity + uniswap_v3_collect Use within an agent: .. code-block:: python @@ -185,6 +188,12 @@ def from_cdp_agentkit_wrapper(cls, cdp_agentkit_wrapper: CdpAgentkitWrapper) -> "description": UNISWAP_V3_GET_POOL_LIQUIDITY_PROMPT, "args_schema": UniswapV3GetPoolLiquidityInput, }, + { + "mode": "uniswap_v3_collect", + "name": "uniswap_v3_collect", + "description": UNISWAP_V3_COLLECT_PROMPT, + "args_schema": UniswapV3CollectInput, + }, { "mode": "get_wallet_details", "name": "get_wallet_details", diff --git a/cdp-langchain/cdp_langchain/utils/cdp_agentkit_wrapper.py b/cdp-langchain/cdp_langchain/utils/cdp_agentkit_wrapper.py index eed940d89..ac6e35b0b 100644 --- a/cdp-langchain/cdp_langchain/utils/cdp_agentkit_wrapper.py +++ b/cdp-langchain/cdp_langchain/utils/cdp_agentkit_wrapper.py @@ -16,6 +16,7 @@ request_faucet_funds, trade, transfer, + uniswap_v3_collect, uniswap_v3_create_pool, uniswap_v3_get_pool, uniswap_v3_get_pool_liquidity, @@ -92,6 +93,26 @@ def uniswap_v3_create_pool_wrapper(self, token_a: str, token_b: str, fee: str) - """ return uniswap_v3_create_pool(wallet=self.wallet, token_a=token_a, token_b=token_b, fee=fee) + def uniswap_v3_collect_wrapper( + self, + pool_address: str, + recipient_address: str, + tick_lower: str, + tick_upper: str, + amount0_requested: str, + amount1_requested: str, + ) -> str: + """Collect tokens from a Uniswap v3 pool by wrapping call to CDP Agentkit Core.""" + return uniswap_v3_collect( + wallet=self.wallet, + pool_address=pool_address, + recipient_address=recipient_address, + tick_lower=tick_lower, + tick_upper=tick_upper, + amount0_requested=amount0_requested, + amount1_requested=amount1_requested, + ) + def uniswap_v3_get_pool_wrapper( self, network_id: str, token_a: str, token_b: str, fee: str ) -> str: @@ -313,6 +334,8 @@ def run(self, mode: str, **kwargs) -> str: return self.uniswap_v3_get_pool_observe_wrapper(**kwargs) elif mode == "uniswap_v3_get_pool_slot0": return self.uniswap_v3_get_pool_slot0_wrapper(**kwargs) + elif mode == "uniswap_v3_collect": + return self.uniswap_v3_collect_wrapper(**kwargs) elif mode == "deploy_token": return self.deploy_token_wrapper(**kwargs) elif mode == "mint_nft":