Skip to content

Commit

Permalink
Collect support
Browse files Browse the repository at this point in the history
  • Loading branch information
rohan-agarwal-coinbase committed Nov 7, 2024
1 parent 7acf69c commit 670b804
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 9 deletions.
10 changes: 5 additions & 5 deletions cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
TransferInput,
transfer,
)
from cdp_agentkit_core.actions.uniswap_v3.collect import (
UNISWAP_V3_COLLECT_PROMPT,
UniswapV3CollectInput,
uniswap_v3_collect,
)
from cdp_agentkit_core.actions.uniswap_v3.create_pool import (
UNISWAP_V3_CREATE_POOL_PROMPT,
UniswapV3CreatePoolInput,
Expand All @@ -68,11 +73,6 @@
UniswapV3GetPoolSlot0Input,
uniswap_v3_get_pool_slot0,
)
from cdp_agentkit_core.actions.uniswap_v3.collect import (
UNISWAP_V3_COLLECT_PROMPT,
UniswapV3CollectInput,
uniswap_v3_collect,
)

__all__ = [
"UNISWAP_V3_CREATE_POOL_PROMPT",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from cdp_agentkit_core.actions.uniswap_v3.constants import UNISWAP_V3_POOL_ABI


UNISWAP_V3_COLLECT_PROMPT = """
This tool will collect tokens from a Uniswap v3 pool. This tool takes the pool address, recipient address, tickLower, tickUpper, amount0 requested for token0, and amount1 requested for token1 as inputs. tickLower is the lower tick of the position for which to collect fees, while tickUppwer is the upper tick of the position for which to collect fees."""

Expand Down Expand Up @@ -37,7 +36,15 @@ class UniswapV3CollectInput(BaseModel):
)


def uniswap_v3_collect(wallet: Wallet, pool_address: str, recipient_address: str, tick_lower: str, tick_upper: str, amount0_requested: str, amount1_requested: str) -> str:
def uniswap_v3_collect(
wallet: Wallet,
pool_address: str,
recipient_address: str,
tick_lower: str,
tick_upper: str,
amount0_requested: str,
amount1_requested: str,
) -> str:
"""Collect tokens from a Uniswap v3 pool.
Args:
Expand All @@ -53,7 +60,6 @@ def uniswap_v3_collect(wallet: Wallet, pool_address: str, recipient_address: str
str: A message containing the details of the collected tokens.
"""

pool = wallet.invoke_contract(
contract_address=pool_address,
method="collect",
Expand Down
114 changes: 114 additions & 0 deletions cdp-agentkit-core/tests/actions/uniswap_v3/test_collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from unittest.mock import patch

import pytest

from cdp_agentkit_core.actions.uniswap_v3.collect import (
UniswapV3CollectInput,
uniswap_v3_collect,
)
from cdp_agentkit_core.actions.uniswap_v3.constants import UNISWAP_V3_POOL_ABI

MOCK_POOL_ADDRESS = "0x4200000000000000000000000000000000000006"
MOCK_RECIPIENT_ADDRESS = "0x1234567890123456789012345678901234567890"
MOCK_TICK_LOWER = "100"
MOCK_TICK_UPPER = "200"
MOCK_AMOUNT0_REQUESTED = "1000"
MOCK_AMOUNT1_REQUESTED = "2000"


def test_collect_input_model_valid():
"""Test that CollectInput accepts valid parameters."""
input_model = UniswapV3CollectInput(
pool_address=MOCK_POOL_ADDRESS,
recipient_address=MOCK_RECIPIENT_ADDRESS,
tick_lower=MOCK_TICK_LOWER,
tick_upper=MOCK_TICK_UPPER,
amount0_requested=MOCK_AMOUNT0_REQUESTED,
amount1_requested=MOCK_AMOUNT1_REQUESTED,
)

assert input_model.pool_address == MOCK_POOL_ADDRESS
assert input_model.recipient_address == MOCK_RECIPIENT_ADDRESS
assert input_model.tick_lower == MOCK_TICK_LOWER
assert input_model.tick_upper == MOCK_TICK_UPPER
assert input_model.amount0_requested == MOCK_AMOUNT0_REQUESTED
assert input_model.amount1_requested == MOCK_AMOUNT1_REQUESTED


def test_collect_input_model_missing_params():
"""Test that CollectInput raises error when params are missing."""
with pytest.raises(ValueError):
UniswapV3CollectInput()


def test_collect_success(wallet_factory, contract_invocation_factory):
"""Test successful token collection with valid parameters."""
mock_wallet = wallet_factory()
mock_contract_instance = contract_invocation_factory()

with (
patch.object(
mock_wallet, "invoke_contract", return_value=mock_contract_instance
) as mock_invoke,
patch.object(
mock_contract_instance, "wait", return_value=mock_contract_instance
) as mock_contract_wait,
):
action_response = uniswap_v3_collect(
mock_wallet,
MOCK_POOL_ADDRESS,
MOCK_RECIPIENT_ADDRESS,
MOCK_TICK_LOWER,
MOCK_TICK_UPPER,
MOCK_AMOUNT0_REQUESTED,
MOCK_AMOUNT1_REQUESTED,
)

expected_response = f"Requested collection of {MOCK_AMOUNT0_REQUESTED} of token0 and {MOCK_AMOUNT1_REQUESTED} of token1 from pool {MOCK_POOL_ADDRESS}.\nTransaction hash for the collection: {mock_contract_instance.transaction.transaction_hash}\nTransaction link for the collection: {mock_contract_instance.transaction.transaction_link}"
assert action_response == expected_response

mock_invoke.assert_called_once_with(
contract_address=MOCK_POOL_ADDRESS,
method="collect",
abi=UNISWAP_V3_POOL_ABI,
args={
"recipient": MOCK_RECIPIENT_ADDRESS,
"tickLower": MOCK_TICK_LOWER,
"tickUpper": MOCK_TICK_UPPER,
"amount0Requested": MOCK_AMOUNT0_REQUESTED,
"amount1Requested": MOCK_AMOUNT1_REQUESTED,
},
)
mock_contract_wait.assert_called_once_with()


def test_collect_api_error(wallet_factory):
"""Test collect when API error occurs."""
mock_wallet = wallet_factory()

with patch.object(
mock_wallet, "invoke_contract", side_effect=Exception("API error")
) as mock_invoke:
with pytest.raises(Exception, match="API error"):
uniswap_v3_collect(
mock_wallet,
MOCK_POOL_ADDRESS,
MOCK_RECIPIENT_ADDRESS,
MOCK_TICK_LOWER,
MOCK_TICK_UPPER,
MOCK_AMOUNT0_REQUESTED,
MOCK_AMOUNT1_REQUESTED,
)

mock_invoke.assert_called_once_with(
contract_address=MOCK_POOL_ADDRESS,
method="collect",
abi=UNISWAP_V3_POOL_ABI,
args={
"recipient": MOCK_RECIPIENT_ADDRESS,
"tickLower": MOCK_TICK_LOWER,
"tickUpper": MOCK_TICK_UPPER,
"amount0Requested": MOCK_AMOUNT0_REQUESTED,
"amount1Requested": MOCK_AMOUNT1_REQUESTED,
},
)
9 changes: 8 additions & 1 deletion cdp-langchain/cdp_langchain/agent_toolkits/cdp_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
RequestFaucetFundsInput,
TradeInput,
TransferInput,
UniswapV3CreatePoolInput,
UniswapV3CollectInput,
UniswapV3CreatePoolInput,
UniswapV3GetPoolInput,
UniswapV3GetPoolLiquidityInput,
UniswapV3GetPoolObserveInput,
Expand Down Expand Up @@ -96,6 +96,7 @@ class CdpToolkit(BaseToolkit):
uniswap_v3_get_pool_observe
uniswap_v3_get_pool_slot0
uniswap_v3_get_pool_liquidity
uniswap_v3_collect
Use within an agent:
.. code-block:: python
Expand Down Expand Up @@ -187,6 +188,12 @@ def from_cdp_agentkit_wrapper(cls, cdp_agentkit_wrapper: CdpAgentkitWrapper) ->
"description": UNISWAP_V3_GET_POOL_LIQUIDITY_PROMPT,
"args_schema": UniswapV3GetPoolLiquidityInput,
},
{
"mode": "uniswap_v3_collect",
"name": "uniswap_v3_collect",
"description": UNISWAP_V3_COLLECT_PROMPT,
"args_schema": UniswapV3CollectInput,
},
{
"mode": "get_wallet_details",
"name": "get_wallet_details",
Expand Down
23 changes: 23 additions & 0 deletions cdp-langchain/cdp_langchain/utils/cdp_agentkit_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
request_faucet_funds,
trade,
transfer,
uniswap_v3_collect,
uniswap_v3_create_pool,
uniswap_v3_get_pool,
uniswap_v3_get_pool_liquidity,
Expand Down Expand Up @@ -92,6 +93,26 @@ def uniswap_v3_create_pool_wrapper(self, token_a: str, token_b: str, fee: str) -
"""
return uniswap_v3_create_pool(wallet=self.wallet, token_a=token_a, token_b=token_b, fee=fee)

def uniswap_v3_collect_wrapper(
self,
pool_address: str,
recipient_address: str,
tick_lower: str,
tick_upper: str,
amount0_requested: str,
amount1_requested: str,
) -> str:
"""Collect tokens from a Uniswap v3 pool by wrapping call to CDP Agentkit Core."""
return uniswap_v3_collect(
wallet=self.wallet,
pool_address=pool_address,
recipient_address=recipient_address,
tick_lower=tick_lower,
tick_upper=tick_upper,
amount0_requested=amount0_requested,
amount1_requested=amount1_requested,
)

def uniswap_v3_get_pool_wrapper(
self, network_id: str, token_a: str, token_b: str, fee: str
) -> str:
Expand Down Expand Up @@ -313,6 +334,8 @@ def run(self, mode: str, **kwargs) -> str:
return self.uniswap_v3_get_pool_observe_wrapper(**kwargs)
elif mode == "uniswap_v3_get_pool_slot0":
return self.uniswap_v3_get_pool_slot0_wrapper(**kwargs)
elif mode == "uniswap_v3_collect":
return self.uniswap_v3_collect_wrapper(**kwargs)
elif mode == "deploy_token":
return self.deploy_token_wrapper(**kwargs)
elif mode == "mint_nft":
Expand Down

0 comments on commit 670b804

Please sign in to comment.