From 670b804f036714e559d2a52dc8806d3ed2c2991a Mon Sep 17 00:00:00 2001 From: Rohan Agarwal Date: Wed, 6 Nov 2024 22:32:31 -0500 Subject: [PATCH] Collect support --- .../cdp_agentkit_core/actions/__init__.py | 10 +- .../actions/uniswap_v3/collect.py | 12 +- .../tests/actions/uniswap_v3/test_collect.py | 114 ++++++++++++++++++ .../agent_toolkits/cdp_toolkit.py | 9 +- .../utils/cdp_agentkit_wrapper.py | 23 ++++ 5 files changed, 159 insertions(+), 9 deletions(-) create mode 100644 cdp-agentkit-core/tests/actions/uniswap_v3/test_collect.py diff --git a/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py b/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py index b0cf79af1..1961776b9 100644 --- a/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py +++ b/cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py @@ -43,6 +43,11 @@ TransferInput, transfer, ) +from cdp_agentkit_core.actions.uniswap_v3.collect import ( + UNISWAP_V3_COLLECT_PROMPT, + UniswapV3CollectInput, + uniswap_v3_collect, +) from cdp_agentkit_core.actions.uniswap_v3.create_pool import ( UNISWAP_V3_CREATE_POOL_PROMPT, UniswapV3CreatePoolInput, @@ -68,11 +73,6 @@ UniswapV3GetPoolSlot0Input, uniswap_v3_get_pool_slot0, ) -from cdp_agentkit_core.actions.uniswap_v3.collect import ( - UNISWAP_V3_COLLECT_PROMPT, - UniswapV3CollectInput, - uniswap_v3_collect, -) __all__ = [ "UNISWAP_V3_CREATE_POOL_PROMPT", diff --git a/cdp-agentkit-core/cdp_agentkit_core/actions/uniswap_v3/collect.py b/cdp-agentkit-core/cdp_agentkit_core/actions/uniswap_v3/collect.py index ea837141b..dff8b63b1 100644 --- a/cdp-agentkit-core/cdp_agentkit_core/actions/uniswap_v3/collect.py +++ b/cdp-agentkit-core/cdp_agentkit_core/actions/uniswap_v3/collect.py @@ -3,7 +3,6 @@ from cdp_agentkit_core.actions.uniswap_v3.constants import UNISWAP_V3_POOL_ABI - UNISWAP_V3_COLLECT_PROMPT = """ This tool will collect tokens from a Uniswap v3 pool. This tool takes the pool address, recipient address, tickLower, tickUpper, amount0 requested for token0, and amount1 requested for token1 as inputs. tickLower is the lower tick of the position for which to collect fees, while tickUppwer is the upper tick of the position for which to collect fees.""" @@ -37,7 +36,15 @@ class UniswapV3CollectInput(BaseModel): ) -def uniswap_v3_collect(wallet: Wallet, pool_address: str, recipient_address: str, tick_lower: str, tick_upper: str, amount0_requested: str, amount1_requested: str) -> str: +def uniswap_v3_collect( + wallet: Wallet, + pool_address: str, + recipient_address: str, + tick_lower: str, + tick_upper: str, + amount0_requested: str, + amount1_requested: str, +) -> str: """Collect tokens from a Uniswap v3 pool. Args: @@ -53,7 +60,6 @@ def uniswap_v3_collect(wallet: Wallet, pool_address: str, recipient_address: str str: A message containing the details of the collected tokens. """ - pool = wallet.invoke_contract( contract_address=pool_address, method="collect", diff --git a/cdp-agentkit-core/tests/actions/uniswap_v3/test_collect.py b/cdp-agentkit-core/tests/actions/uniswap_v3/test_collect.py new file mode 100644 index 000000000..c677d215c --- /dev/null +++ b/cdp-agentkit-core/tests/actions/uniswap_v3/test_collect.py @@ -0,0 +1,114 @@ +from unittest.mock import patch + +import pytest + +from cdp_agentkit_core.actions.uniswap_v3.collect import ( + UniswapV3CollectInput, + uniswap_v3_collect, +) +from cdp_agentkit_core.actions.uniswap_v3.constants import UNISWAP_V3_POOL_ABI + +MOCK_POOL_ADDRESS = "0x4200000000000000000000000000000000000006" +MOCK_RECIPIENT_ADDRESS = "0x1234567890123456789012345678901234567890" +MOCK_TICK_LOWER = "100" +MOCK_TICK_UPPER = "200" +MOCK_AMOUNT0_REQUESTED = "1000" +MOCK_AMOUNT1_REQUESTED = "2000" + + +def test_collect_input_model_valid(): + """Test that CollectInput accepts valid parameters.""" + input_model = UniswapV3CollectInput( + pool_address=MOCK_POOL_ADDRESS, + recipient_address=MOCK_RECIPIENT_ADDRESS, + tick_lower=MOCK_TICK_LOWER, + tick_upper=MOCK_TICK_UPPER, + amount0_requested=MOCK_AMOUNT0_REQUESTED, + amount1_requested=MOCK_AMOUNT1_REQUESTED, + ) + + assert input_model.pool_address == MOCK_POOL_ADDRESS + assert input_model.recipient_address == MOCK_RECIPIENT_ADDRESS + assert input_model.tick_lower == MOCK_TICK_LOWER + assert input_model.tick_upper == MOCK_TICK_UPPER + assert input_model.amount0_requested == MOCK_AMOUNT0_REQUESTED + assert input_model.amount1_requested == MOCK_AMOUNT1_REQUESTED + + +def test_collect_input_model_missing_params(): + """Test that CollectInput raises error when params are missing.""" + with pytest.raises(ValueError): + UniswapV3CollectInput() + + +def test_collect_success(wallet_factory, contract_invocation_factory): + """Test successful token collection with valid parameters.""" + mock_wallet = wallet_factory() + mock_contract_instance = contract_invocation_factory() + + with ( + patch.object( + mock_wallet, "invoke_contract", return_value=mock_contract_instance + ) as mock_invoke, + patch.object( + mock_contract_instance, "wait", return_value=mock_contract_instance + ) as mock_contract_wait, + ): + action_response = uniswap_v3_collect( + mock_wallet, + MOCK_POOL_ADDRESS, + MOCK_RECIPIENT_ADDRESS, + MOCK_TICK_LOWER, + MOCK_TICK_UPPER, + MOCK_AMOUNT0_REQUESTED, + MOCK_AMOUNT1_REQUESTED, + ) + + expected_response = f"Requested collection of {MOCK_AMOUNT0_REQUESTED} of token0 and {MOCK_AMOUNT1_REQUESTED} of token1 from pool {MOCK_POOL_ADDRESS}.\nTransaction hash for the collection: {mock_contract_instance.transaction.transaction_hash}\nTransaction link for the collection: {mock_contract_instance.transaction.transaction_link}" + assert action_response == expected_response + + mock_invoke.assert_called_once_with( + contract_address=MOCK_POOL_ADDRESS, + method="collect", + abi=UNISWAP_V3_POOL_ABI, + args={ + "recipient": MOCK_RECIPIENT_ADDRESS, + "tickLower": MOCK_TICK_LOWER, + "tickUpper": MOCK_TICK_UPPER, + "amount0Requested": MOCK_AMOUNT0_REQUESTED, + "amount1Requested": MOCK_AMOUNT1_REQUESTED, + }, + ) + mock_contract_wait.assert_called_once_with() + + +def test_collect_api_error(wallet_factory): + """Test collect when API error occurs.""" + mock_wallet = wallet_factory() + + with patch.object( + mock_wallet, "invoke_contract", side_effect=Exception("API error") + ) as mock_invoke: + with pytest.raises(Exception, match="API error"): + uniswap_v3_collect( + mock_wallet, + MOCK_POOL_ADDRESS, + MOCK_RECIPIENT_ADDRESS, + MOCK_TICK_LOWER, + MOCK_TICK_UPPER, + MOCK_AMOUNT0_REQUESTED, + MOCK_AMOUNT1_REQUESTED, + ) + + mock_invoke.assert_called_once_with( + contract_address=MOCK_POOL_ADDRESS, + method="collect", + abi=UNISWAP_V3_POOL_ABI, + args={ + "recipient": MOCK_RECIPIENT_ADDRESS, + "tickLower": MOCK_TICK_LOWER, + "tickUpper": MOCK_TICK_UPPER, + "amount0Requested": MOCK_AMOUNT0_REQUESTED, + "amount1Requested": MOCK_AMOUNT1_REQUESTED, + }, + ) diff --git a/cdp-langchain/cdp_langchain/agent_toolkits/cdp_toolkit.py b/cdp-langchain/cdp_langchain/agent_toolkits/cdp_toolkit.py index 03127f658..b290a27ce 100644 --- a/cdp-langchain/cdp_langchain/agent_toolkits/cdp_toolkit.py +++ b/cdp-langchain/cdp_langchain/agent_toolkits/cdp_toolkit.py @@ -28,8 +28,8 @@ RequestFaucetFundsInput, TradeInput, TransferInput, - UniswapV3CreatePoolInput, UniswapV3CollectInput, + UniswapV3CreatePoolInput, UniswapV3GetPoolInput, UniswapV3GetPoolLiquidityInput, UniswapV3GetPoolObserveInput, @@ -96,6 +96,7 @@ class CdpToolkit(BaseToolkit): uniswap_v3_get_pool_observe uniswap_v3_get_pool_slot0 uniswap_v3_get_pool_liquidity + uniswap_v3_collect Use within an agent: .. code-block:: python @@ -187,6 +188,12 @@ def from_cdp_agentkit_wrapper(cls, cdp_agentkit_wrapper: CdpAgentkitWrapper) -> "description": UNISWAP_V3_GET_POOL_LIQUIDITY_PROMPT, "args_schema": UniswapV3GetPoolLiquidityInput, }, + { + "mode": "uniswap_v3_collect", + "name": "uniswap_v3_collect", + "description": UNISWAP_V3_COLLECT_PROMPT, + "args_schema": UniswapV3CollectInput, + }, { "mode": "get_wallet_details", "name": "get_wallet_details", diff --git a/cdp-langchain/cdp_langchain/utils/cdp_agentkit_wrapper.py b/cdp-langchain/cdp_langchain/utils/cdp_agentkit_wrapper.py index eed940d89..ac6e35b0b 100644 --- a/cdp-langchain/cdp_langchain/utils/cdp_agentkit_wrapper.py +++ b/cdp-langchain/cdp_langchain/utils/cdp_agentkit_wrapper.py @@ -16,6 +16,7 @@ request_faucet_funds, trade, transfer, + uniswap_v3_collect, uniswap_v3_create_pool, uniswap_v3_get_pool, uniswap_v3_get_pool_liquidity, @@ -92,6 +93,26 @@ def uniswap_v3_create_pool_wrapper(self, token_a: str, token_b: str, fee: str) - """ return uniswap_v3_create_pool(wallet=self.wallet, token_a=token_a, token_b=token_b, fee=fee) + def uniswap_v3_collect_wrapper( + self, + pool_address: str, + recipient_address: str, + tick_lower: str, + tick_upper: str, + amount0_requested: str, + amount1_requested: str, + ) -> str: + """Collect tokens from a Uniswap v3 pool by wrapping call to CDP Agentkit Core.""" + return uniswap_v3_collect( + wallet=self.wallet, + pool_address=pool_address, + recipient_address=recipient_address, + tick_lower=tick_lower, + tick_upper=tick_upper, + amount0_requested=amount0_requested, + amount1_requested=amount1_requested, + ) + def uniswap_v3_get_pool_wrapper( self, network_id: str, token_a: str, token_b: str, fee: str ) -> str: @@ -313,6 +334,8 @@ def run(self, mode: str, **kwargs) -> str: return self.uniswap_v3_get_pool_observe_wrapper(**kwargs) elif mode == "uniswap_v3_get_pool_slot0": return self.uniswap_v3_get_pool_slot0_wrapper(**kwargs) + elif mode == "uniswap_v3_collect": + return self.uniswap_v3_collect_wrapper(**kwargs) elif mode == "deploy_token": return self.deploy_token_wrapper(**kwargs) elif mode == "mint_nft":