Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for collecting token rewards for liquidity positions #22

Draft
wants to merge 2 commits into
base: rohan/collect-and-read
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cdp-agentkit-core/cdp_agentkit_core/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
TransferInput,
transfer,
)
from cdp_agentkit_core.actions.uniswap_v3.collect import (
UNISWAP_V3_COLLECT_PROMPT,
UniswapV3CollectInput,
uniswap_v3_collect,
)
from cdp_agentkit_core.actions.uniswap_v3.create_pool import (
UNISWAP_V3_CREATE_POOL_PROMPT,
UniswapV3CreatePoolInput,
Expand Down Expand Up @@ -85,6 +90,9 @@
"UNISWAP_V3_GET_POOL_SLOT0_PROMPT",
"UniswapV3GetPoolSlot0Input",
"uniswap_v3_get_pool_slot0",
"UNISWAP_V3_COLLECT_PROMPT",
"UniswapV3CollectInput",
"uniswap_v3_collect",
"DEPLOY_NFT_PROMPT",
"DeployNftInput",
"deploy_nft",
Expand Down
75 changes: 75 additions & 0 deletions cdp-agentkit-core/cdp_agentkit_core/actions/uniswap_v3/collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from cdp import Wallet
from pydantic import BaseModel, Field

from cdp_agentkit_core.actions.uniswap_v3.constants import UNISWAP_V3_POOL_ABI

UNISWAP_V3_COLLECT_PROMPT = """
This tool will collect tokens from a Uniswap v3 pool. This tool takes the pool address, recipient address, tickLower, tickUpper, amount0 requested for token0, and amount1 requested for token1 as inputs. tickLower is the lower tick of the position for which to collect fees, while tickUppwer is the upper tick of the position for which to collect fees."""


class UniswapV3CollectInput(BaseModel):
"""Input argument schema for collect action."""

pool_address: str = Field(
...,
description="The address of the pool to collect from.",
)
recipient_address: str = Field(
...,
description="The address of the recipient of the collected tokens.",
)
tick_lower: str = Field(
...,
description="The lower tick of the position for which to collect fees.",
)
tick_upper: str = Field(
...,
description="The upper tick of the position for which to collect fees.",
)
amount0_requested: str = Field(
...,
description="The amount of token0 requested for collection.",
)
amount1_requested: str = Field(
...,
description="The amount of token1 requested for collection.",
)


def uniswap_v3_collect(
wallet: Wallet,
pool_address: str,
recipient_address: str,
tick_lower: str,
tick_upper: str,
amount0_requested: str,
amount1_requested: str,
) -> str:
"""Collect tokens from a Uniswap v3 pool.

Args:
wallet (Wallet): The wallet to collect tokens from.
pool_address (str): The address of the pool to collect from.
recipient_address (str): The address of the recipient of the collected tokens.
tick_lower (str): The lower tick of the position for which to collect fees.
tick_upper (str): The upper tick of the position for which to collect fees.
amount0_requested (str): The amount of token0 requested for collection.
amount1_requested (str): The amount of token1 requested for collection.

Returns:
str: A message containing the details of the collected tokens.

"""
pool = wallet.invoke_contract(
contract_address=pool_address,
method="collect",
abi=UNISWAP_V3_POOL_ABI,
args={
"recipient": recipient_address,
"tickLower": tick_lower,
"tickUpper": tick_upper,
"amount0Requested": amount0_requested,
"amount1Requested": amount1_requested,
},
).wait()
return f"Requested collection of {amount0_requested} of token0 and {amount1_requested} of token1 from pool {pool_address}.\nTransaction hash for the collection: {pool.transaction.transaction_hash}\nTransaction link for the collection: {pool.transaction.transaction_link}"
114 changes: 114 additions & 0 deletions cdp-agentkit-core/tests/actions/uniswap_v3/test_collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from unittest.mock import patch

import pytest

from cdp_agentkit_core.actions.uniswap_v3.collect import (
UniswapV3CollectInput,
uniswap_v3_collect,
)
from cdp_agentkit_core.actions.uniswap_v3.constants import UNISWAP_V3_POOL_ABI

MOCK_POOL_ADDRESS = "0x4200000000000000000000000000000000000006"
MOCK_RECIPIENT_ADDRESS = "0x1234567890123456789012345678901234567890"
MOCK_TICK_LOWER = "100"
MOCK_TICK_UPPER = "200"
MOCK_AMOUNT0_REQUESTED = "1000"
MOCK_AMOUNT1_REQUESTED = "2000"


def test_collect_input_model_valid():
"""Test that CollectInput accepts valid parameters."""
input_model = UniswapV3CollectInput(
pool_address=MOCK_POOL_ADDRESS,
recipient_address=MOCK_RECIPIENT_ADDRESS,
tick_lower=MOCK_TICK_LOWER,
tick_upper=MOCK_TICK_UPPER,
amount0_requested=MOCK_AMOUNT0_REQUESTED,
amount1_requested=MOCK_AMOUNT1_REQUESTED,
)

assert input_model.pool_address == MOCK_POOL_ADDRESS
assert input_model.recipient_address == MOCK_RECIPIENT_ADDRESS
assert input_model.tick_lower == MOCK_TICK_LOWER
assert input_model.tick_upper == MOCK_TICK_UPPER
assert input_model.amount0_requested == MOCK_AMOUNT0_REQUESTED
assert input_model.amount1_requested == MOCK_AMOUNT1_REQUESTED


def test_collect_input_model_missing_params():
"""Test that CollectInput raises error when params are missing."""
with pytest.raises(ValueError):
UniswapV3CollectInput()


def test_collect_success(wallet_factory, contract_invocation_factory):
"""Test successful token collection with valid parameters."""
mock_wallet = wallet_factory()
mock_contract_instance = contract_invocation_factory()

with (
patch.object(
mock_wallet, "invoke_contract", return_value=mock_contract_instance
) as mock_invoke,
patch.object(
mock_contract_instance, "wait", return_value=mock_contract_instance
) as mock_contract_wait,
):
action_response = uniswap_v3_collect(
mock_wallet,
MOCK_POOL_ADDRESS,
MOCK_RECIPIENT_ADDRESS,
MOCK_TICK_LOWER,
MOCK_TICK_UPPER,
MOCK_AMOUNT0_REQUESTED,
MOCK_AMOUNT1_REQUESTED,
)

expected_response = f"Requested collection of {MOCK_AMOUNT0_REQUESTED} of token0 and {MOCK_AMOUNT1_REQUESTED} of token1 from pool {MOCK_POOL_ADDRESS}.\nTransaction hash for the collection: {mock_contract_instance.transaction.transaction_hash}\nTransaction link for the collection: {mock_contract_instance.transaction.transaction_link}"
assert action_response == expected_response

mock_invoke.assert_called_once_with(
contract_address=MOCK_POOL_ADDRESS,
method="collect",
abi=UNISWAP_V3_POOL_ABI,
args={
"recipient": MOCK_RECIPIENT_ADDRESS,
"tickLower": MOCK_TICK_LOWER,
"tickUpper": MOCK_TICK_UPPER,
"amount0Requested": MOCK_AMOUNT0_REQUESTED,
"amount1Requested": MOCK_AMOUNT1_REQUESTED,
},
)
mock_contract_wait.assert_called_once_with()


def test_collect_api_error(wallet_factory):
"""Test collect when API error occurs."""
mock_wallet = wallet_factory()

with patch.object(
mock_wallet, "invoke_contract", side_effect=Exception("API error")
) as mock_invoke:
with pytest.raises(Exception, match="API error"):
uniswap_v3_collect(
mock_wallet,
MOCK_POOL_ADDRESS,
MOCK_RECIPIENT_ADDRESS,
MOCK_TICK_LOWER,
MOCK_TICK_UPPER,
MOCK_AMOUNT0_REQUESTED,
MOCK_AMOUNT1_REQUESTED,
)

mock_invoke.assert_called_once_with(
contract_address=MOCK_POOL_ADDRESS,
method="collect",
abi=UNISWAP_V3_POOL_ABI,
args={
"recipient": MOCK_RECIPIENT_ADDRESS,
"tickLower": MOCK_TICK_LOWER,
"tickUpper": MOCK_TICK_UPPER,
"amount0Requested": MOCK_AMOUNT0_REQUESTED,
"amount1Requested": MOCK_AMOUNT1_REQUESTED,
},
)
9 changes: 9 additions & 0 deletions cdp-langchain/cdp_langchain/agent_toolkits/cdp_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
REQUEST_FAUCET_FUNDS_PROMPT,
TRADE_PROMPT,
TRANSFER_PROMPT,
UNISWAP_V3_COLLECT_PROMPT,
UNISWAP_V3_CREATE_POOL_PROMPT,
UNISWAP_V3_GET_POOL_LIQUIDITY_PROMPT,
UNISWAP_V3_GET_POOL_OBSERVE_PROMPT,
Expand All @@ -27,6 +28,7 @@
RequestFaucetFundsInput,
TradeInput,
TransferInput,
UniswapV3CollectInput,
UniswapV3CreatePoolInput,
UniswapV3GetPoolInput,
UniswapV3GetPoolLiquidityInput,
Expand Down Expand Up @@ -94,6 +96,7 @@ class CdpToolkit(BaseToolkit):
uniswap_v3_get_pool_observe
uniswap_v3_get_pool_slot0
uniswap_v3_get_pool_liquidity
uniswap_v3_collect
Use within an agent:
.. code-block:: python

Expand Down Expand Up @@ -185,6 +188,12 @@ def from_cdp_agentkit_wrapper(cls, cdp_agentkit_wrapper: CdpAgentkitWrapper) ->
"description": UNISWAP_V3_GET_POOL_LIQUIDITY_PROMPT,
"args_schema": UniswapV3GetPoolLiquidityInput,
},
{
"mode": "uniswap_v3_collect",
"name": "uniswap_v3_collect",
"description": UNISWAP_V3_COLLECT_PROMPT,
"args_schema": UniswapV3CollectInput,
},
{
"mode": "get_wallet_details",
"name": "get_wallet_details",
Expand Down
23 changes: 23 additions & 0 deletions cdp-langchain/cdp_langchain/utils/cdp_agentkit_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
request_faucet_funds,
trade,
transfer,
uniswap_v3_collect,
uniswap_v3_create_pool,
uniswap_v3_get_pool,
uniswap_v3_get_pool_liquidity,
Expand Down Expand Up @@ -92,6 +93,26 @@ def uniswap_v3_create_pool_wrapper(self, token_a: str, token_b: str, fee: str) -
"""
return uniswap_v3_create_pool(wallet=self.wallet, token_a=token_a, token_b=token_b, fee=fee)

def uniswap_v3_collect_wrapper(
self,
pool_address: str,
recipient_address: str,
tick_lower: str,
tick_upper: str,
amount0_requested: str,
amount1_requested: str,
) -> str:
"""Collect tokens from a Uniswap v3 pool by wrapping call to CDP Agentkit Core."""
return uniswap_v3_collect(
wallet=self.wallet,
pool_address=pool_address,
recipient_address=recipient_address,
tick_lower=tick_lower,
tick_upper=tick_upper,
amount0_requested=amount0_requested,
amount1_requested=amount1_requested,
)

def uniswap_v3_get_pool_wrapper(
self, network_id: str, token_a: str, token_b: str, fee: str
) -> str:
Expand Down Expand Up @@ -313,6 +334,8 @@ def run(self, mode: str, **kwargs) -> str:
return self.uniswap_v3_get_pool_observe_wrapper(**kwargs)
elif mode == "uniswap_v3_get_pool_slot0":
return self.uniswap_v3_get_pool_slot0_wrapper(**kwargs)
elif mode == "uniswap_v3_collect":
return self.uniswap_v3_collect_wrapper(**kwargs)
elif mode == "deploy_token":
return self.deploy_token_wrapper(**kwargs)
elif mode == "mint_nft":
Expand Down
Loading