diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b4d99d..2723d63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Unreleased +### Added + +- Contract invocation support. + ## [0.0.3] - 2024-09-25 ### Added diff --git a/cdp/api_clients.py b/cdp/api_clients.py index c20f0eb..a3615ad 100644 --- a/cdp/api_clients.py +++ b/cdp/api_clients.py @@ -1,6 +1,7 @@ from cdp.cdp_api_client import CdpApiClient from cdp.client.api.addresses_api import AddressesApi from cdp.client.api.assets_api import AssetsApi +from cdp.client.api.contract_invocations_api import ContractInvocationsApi from cdp.client.api.external_addresses_api import ExternalAddressesApi from cdp.client.api.networks_api import NetworksApi from cdp.client.api.trades_api import TradesApi @@ -23,6 +24,7 @@ class ApiClients: _networks (Optional[NetworksApi]): The NetworksApi client instance. _assets (Optional[AssetsApi]): The AssetsApi client instance. _trades (Optional[TradesApi]): The TradesApi client instance. + _contract_invocations (Optional[ContractInvocationsApi]): The ContractInvocationsApi client instance. """ @@ -41,6 +43,7 @@ def __init__(self, cdp_client: CdpApiClient) -> None: self._networks: NetworksApi | None = None self._assets: AssetsApi | None = None self._trades: TradesApi | None = None + self._contract_invocations: ContractInvocationsApi | None = None @property def wallets(self) -> WalletsApi: @@ -146,3 +149,18 @@ def trades(self) -> TradesApi: if self._trades is None: self._trades = TradesApi(api_client=self._cdp_client) return self._trades + + @property + def contract_invocations(self) -> ContractInvocationsApi: + """Get the ContractInvocationsApi client instance. + + Returns: + ContractInvocationsApi: The ContractInvocationsApi client instance. + + Note: + This property lazily initializes the ContractInvocationsApi client on first access. + + """ + if self._contract_invocations is None: + self._contract_invocations = ContractInvocationsApi(api_client=self._cdp_client) + return self._contract_invocations diff --git a/cdp/contract_invocation.py b/cdp/contract_invocation.py new file mode 100644 index 0000000..c65430b --- /dev/null +++ b/cdp/contract_invocation.py @@ -0,0 +1,350 @@ +import json +import time +from collections.abc import Iterator +from decimal import Decimal +from typing import Any + +from eth_account.signers.local import LocalAccount + +from cdp.asset import Asset +from cdp.cdp import Cdp +from cdp.client.models.broadcast_contract_invocation_request import ( + BroadcastContractInvocationRequest, +) +from cdp.client.models.contract_invocation import ContractInvocation as ContractInvocationModel +from cdp.client.models.create_contract_invocation_request import CreateContractInvocationRequest +from cdp.errors import TransactionNotSignedError +from cdp.transaction import Transaction + + +class ContractInvocation: + """A class representing a contract invocation.""" + + def __init__(self, model: ContractInvocationModel) -> None: + """Initialize the ContractInvocation class. + + Args: + model (ContractInvocationModel): The model representing the contract invocation. + + """ + self._model = model + self._transaction = None + + @property + def contract_invocation_id(self) -> str: + """Get the contract invocation ID. + + Returns: + str: The contract invocation ID. + + """ + return self._model.contract_invocation_id + + @property + def wallet_id(self) -> str: + """Get the wallet ID of the contract invocation. + + Returns: + str: The wallet ID. + + """ + return self._model.wallet_id + + @property + def address_id(self) -> str: + """Get the address ID of the contract invocation. + + Returns: + str: The address ID. + + """ + return self._model.address_id + + @property + def network_id(self) -> str: + """Get the network ID of the contract invocation. + + Returns: + str: The network ID. + + """ + return self._model.network_id + + @property + def contract_address(self) -> str: + """Get the contract address. + + Returns: + str: The contract address. + + """ + return self._model.contract_address + + @property + def abi(self) -> dict[str, Any]: + """Get the ABI of the contract invocation. + + Returns: + Dict: The ABI JSON. + + """ + return dict(json.loads(self._model.abi).items()) + + @property + def method(self) -> str: + """Get the method being invoked in the contract. + + Returns: + str: The method being invoked. + + """ + return self._model.method + + @property + def args(self) -> dict[str, Any]: + """Get the arguments passed to the contract method. + + Returns: + Dict: The arguments passed to the contract method. + + """ + return dict(json.loads(self._model.args).items()) + + @property + def amount(self) -> Decimal: + """Get the amount sent to the contract in atomic units. + + Returns: + Decimal: The amount in atomic units. + + """ + return Decimal(self._model.amount) + + @property + def transaction(self) -> Transaction | None: + """Get the transaction associated with the contract invocation. + + Returns: + Transaction: The transaction. + + """ + if self._transaction is None and self._model.transaction is not None: + self._update_transaction(self._model) + return self._transaction + + @property + def transaction_link(self) -> str: + """Get the link to the transaction on the blockchain explorer. + + Returns: + str: The transaction link. + + """ + return self.send_tx_delegate.transaction_link + + @property + def transaction_hash(self) -> str: + """Get the transaction hash of the contract invocation. + + Returns: + str: The transaction hash. + + """ + return self.send_tx_delegate.transaction_hash + + @property + def status(self) -> str: + """Get the status of the contract invocation. + + Returns: + str: The status. + + """ + return self.transaction.status if self.transaction else None + + def sign(self, key: LocalAccount) -> "ContractInvocation": + """Sign the contract invocation transaction with the given key. + + Args: + key (LocalAccount): The key to sign the contract invocation with. + + Returns: + ContractInvocation: The signed ContractInvocation object. + + Raises: + ValueError: If the key is not a LocalAccount. + + """ + if not isinstance(key, LocalAccount): + raise ValueError("key must be a LocalAccount") + + self.transaction.sign(key) + return self + + def broadcast(self) -> "ContractInvocation": + """Broadcast the contract invocation to the network. + + Returns: + ContractInvocation: The broadcasted ContractInvocation object. + + Raises: + TransactionNotSignedError: If the transaction is not signed. + + """ + if not self.transaction.signed: + raise TransactionNotSignedError("Contract invocation is not signed") + + broadcast_contract_invocation_request = BroadcastContractInvocationRequest( + signed_payload=self.transaction.signature + ) + + model = Cdp.api_clients.contract_invocations.broadcast_contract_invocation( + wallet_id=self.wallet_id, + address_id=self.address_id, + contract_invocation_id=self.contract_invocation_id, + broadcast_contract_invocation_request=broadcast_contract_invocation_request, + ) + self._model = model + return self + + def reload(self) -> "ContractInvocation": + """Reload the Contract Invocation model with the latest version from the server. + + Returns: + ContractInvocation: The updated ContractInvocation object. + + """ + model = Cdp.api_clients.contract_invocations.get_contract_invocation( + wallet_id=self.wallet_id, + address_id=self.address_id, + contract_invocation_id=self.contract_invocation_id, + ) + + self._model = model + self._update_transaction(model) + + return self + + def wait( + self, interval_seconds: float = 0.2, timeout_seconds: float = 20 + ) -> "ContractInvocation": + """Wait until the contract invocation is signed or fails by polling the server. + + Args: + interval_seconds: The interval at which to poll the server. + timeout_seconds: The maximum time to wait before timing out. + + Returns: + ContractInvocation: The completed contract invocation. + + Raises: + TimeoutError: If the invocation takes longer than the given timeout. + + """ + start_time = time.time() + while not self.transaction.terminal_state: + self.reload() + + if time.time() - start_time > timeout_seconds: + raise TimeoutError("Contract Invocation timed out") + + time.sleep(interval_seconds) + + return self + + @classmethod + def create( + cls, + address_id: str, + wallet_id: str, + network_id: str, + contract_address: str, + method: str, + abi: list[dict] | None = None, + args: dict | None = None, + amount: Decimal | None = None, + asset_id: str | None = None, + ) -> "ContractInvocation": + """Create a new ContractInvocation object. + + Args: + address_id (str): The address ID of the signing address. + wallet_id (str): The wallet ID associated with the signing address. + network_id (str): The Network ID. + contract_address (str): The contract address. + method (str): The contract method. + abi (Optional[list[dict]]): The contract ABI, if provided. + args (Optional[dict]): The arguments to pass to the contract method. + amount (Optional[Decimal]): The amount of native asset to send to a payable contract method. + asset_id (Optional[str]): The asset ID to send to the contract. + + Returns: + ContractInvocation: The new ContractInvocation object. + + """ + atomic_amount = None + abi_json = None + + if asset_id and amount: + asset = Asset.fetch(network_id, asset_id) + atomic_amount = str(int(asset.to_atomic_amount(Decimal(amount)))) + + if abi: + abi_json = json.dumps(abi, separators=(",", ":")) + + create_contract_invocation_request = CreateContractInvocationRequest( + contract_address=contract_address, + abi=abi_json, + method=method, + args=json.dumps(args or {}, separators=(",", ":")), + amount=atomic_amount, + ) + + model = Cdp.api_clients.contract_invocations.create_contract_invocation( + wallet_id=wallet_id, + address_id=address_id, + create_contract_invocation_request=create_contract_invocation_request, + ) + + return cls(model) + + @classmethod + def list(cls, wallet_id: str, address_id: str) -> Iterator["ContractInvocation"]: + """List Contract Invocations. + + Args: + wallet_id (str): The wallet ID. + address_id (str): The address ID. + + Returns: + Iterator[ContractInvocation]: An iterator of ContractInvocation objects. + + """ + page = None + while True: + response = Cdp.api_clients.contract_invocations.list_contract_invocations( + wallet_id=wallet_id, address_id=address_id, limit=100, page=page + ) + for contract_invocation in response.data: + yield cls(contract_invocation) + + if not response.has_more: + break + + page = response.next_page + + def _update_transaction(self, model: ContractInvocationModel) -> None: + """Update the transaction with the new model.""" + if model.transaction is not None: + self._transaction = Transaction(model.transaction) + + def __str__(self) -> str: + """Return a string representation of the contract invocation.""" + return ( + f"ContractInvocation: (id: {self.contract_invocation_id}, wallet_id: {self.wallet_id}, address_id: {self.address_id}, " + f"network_id: {self.network_id}, method: {self.method}, status: {self.status})" + ) + + def __repr__(self) -> str: + """Return a string representation of the contract invocation.""" + return str(self) diff --git a/cdp/wallet.py b/cdp/wallet.py index 41c2414..f67fd70 100644 --- a/cdp/wallet.py +++ b/cdp/wallet.py @@ -1,3 +1,4 @@ +import builtins import hashlib import json import os @@ -25,6 +26,7 @@ ) from cdp.client.models.wallet import Wallet as WalletModel from cdp.client.models.wallet_list import WalletList +from cdp.contract_invocation import ContractInvocation from cdp.faucet_transaction import FaucetTransaction from cdp.trade import Trade from cdp.wallet_address import WalletAddress @@ -384,6 +386,41 @@ def trade(self, amount: Number | Decimal | str, from_asset_id: str, to_asset_id: return self.default_address.trade(amount, from_asset_id, to_asset_id) + def invoke_contract( + self, + contract_address: str, + method: str, + abi: builtins.list[dict] | None = None, + args: dict | None = None, + amount: Number | Decimal | str | None = None, + asset_id: str | None = None, + ) -> ContractInvocation: + """Invoke a method on the specified contract address, with the given ABI and arguments. + + Args: + contract_address (str): The address of the contract to invoke. + method (str): The name of the method to call on the contract. + abi (Optional[list[dict]]): The ABI of the contract, if provided. + args (Optional[dict]): The arguments to pass to the method. + amount (Optional[Union[Number, Decimal, str]]): The amount to send with the invocation, if applicable. + asset_id (Optional[str]): The asset ID associated with the amount, if applicable. + + Returns: + ContractInvocation: The contract invocation object. + + Raises: + ValueError: If the default address does not exist. + + """ + if self.default_address is None: + raise ValueError("Default address does not exist") + + invocation = self.default_address.invoke_contract( + contract_address, method, abi, args, amount, asset_id + ) + + return invocation + @property def default_address(self) -> WalletAddress | None: """Get the default address of the wallet. diff --git a/cdp/wallet_address.py b/cdp/wallet_address.py index 7dcb607..b261474 100644 --- a/cdp/wallet_address.py +++ b/cdp/wallet_address.py @@ -8,6 +8,7 @@ from cdp.address import Address from cdp.cdp import Cdp from cdp.client.models.address import Address as AddressModel +from cdp.contract_invocation import ContractInvocation from cdp.errors import InsufficientFundsError from cdp.trade import Trade from cdp.transfer import Transfer @@ -146,6 +147,55 @@ def trade(self, amount: Number | Decimal | str, from_asset_id: str, to_asset_id: return trade + def invoke_contract( + self, + contract_address: str, + method: str, + abi: list[dict] | None = None, + args: dict | None = None, + amount: Number | Decimal | str | None = None, + asset_id: str | None = None, + ) -> ContractInvocation: + """Invoke a method on the specified contract address, with the given ABI and arguments. + + Args: + contract_address (str): The address of the contract to invoke. + method (str): The name of the method to call on the contract. + abi (Optional[list[dict]]): The ABI of the contract, if provided. + args (Optional[dict]): The arguments to pass to the method. + amount (Optional[Union[Number, Decimal, str]]): The amount to send with the invocation, if applicable. + asset_id (Optional[str]): The asset ID associated with the amount, if applicable. + + Returns: + ContractInvocation: The contract invocation object. + + """ + normalied_amount = Decimal(amount) if amount else Decimal("0") + + if amount and asset_id: + self._ensure_sufficient_balance(amount, asset_id) + + invocation = ContractInvocation.create( + address_id=self.address_id, + wallet_id=self.wallet_id, + network_id=self.network_id, + contract_address=contract_address, + method=method, + abi=abi, + args=args, + amount=normalied_amount, + asset_id=asset_id, + ) + + if Cdp.use_server_signer: + return invocation + + invocation.sign(self.key) + + invocation.broadcast() + + return invocation + def transfers(self) -> Iterator[Transfer]: """List transfers for this wallet address. diff --git a/tests/test_contract_invocation.py b/tests/test_contract_invocation.py new file mode 100644 index 0000000..a472cfe --- /dev/null +++ b/tests/test_contract_invocation.py @@ -0,0 +1,262 @@ +from decimal import Decimal +from unittest.mock import ANY, Mock, call, patch + +import pytest + +from cdp.asset import Asset +from cdp.client.models.asset import Asset as AssetModel +from cdp.client.models.contract_invocation import ContractInvocation as ContractInvocationModel +from cdp.client.models.transaction import Transaction as TransactionModel +from cdp.contract_invocation import ContractInvocation +from cdp.errors import TransactionNotSignedError + + +@pytest.fixture +def asset_model_factory(): + """Create and return a factory for creating AssetModel fixtures.""" + + def _create_asset_model(network_id="base-sepolia", asset_id="usdc", decimals=6): + return AssetModel(network_id=network_id, asset_id=asset_id, decimals=decimals) + + return _create_asset_model + + +@pytest.fixture +def asset_factory(asset_model_factory): + """Create and return a factory for creating Asset fixtures.""" + + def _create_asset(network_id="base-sepolia", asset_id="usdc", decimals=6): + asset_model = asset_model_factory(network_id, asset_id, decimals) + return Asset.from_model(asset_model) + + return _create_asset + + +@pytest.fixture +def transaction_model_factory(): + """Create and return a factory for creating TransactionModel fixtures.""" + + def _create_transaction_model(status="complete"): + return TransactionModel( + network_id="base-sepolia", + transaction_hash="0xtransactionhash", + from_address_id="0xaddressid", + to_address_id="0xdestination", + unsigned_payload="0xunsignedpayload", + signed_payload="0xsignedpayload" + if status in ["signed", "broadcast", "complete"] + else None, + status=status, + transaction_link="https://sepolia.basescan.org/tx/0xtransactionlink" + if status == "complete" + else None, + ) + + return _create_transaction_model + + +@pytest.fixture +def contract_invocation_model_factory(transaction_model_factory): + """Create and return a factory for creating ContractInvocationModel fixtures.""" + + def _create_contract_invocation_model(status="complete"): + return ContractInvocationModel( + network_id="base-sepolia", + wallet_id="test-wallet-id", + address_id="0xaddressid", + contract_invocation_id="test-invocation-id", + contract_address="0xcontractaddress", + method="testMethod", + args='{"arg1": "value1"}', + abi='{"abi": "data"}', + amount="1", + transaction=transaction_model_factory(status), + ) + + return _create_contract_invocation_model + + +@pytest.fixture +def contract_invocation_factory(contract_invocation_model_factory): + """Create and return a factory for creating ContractInvocation fixtures.""" + + def _create_contract_invocation(status="complete"): + contract_invocation_model = contract_invocation_model_factory(status) + return ContractInvocation(contract_invocation_model) + + return _create_contract_invocation + + +def test_contract_invocation_initialization(contract_invocation_factory): + """Test the initialization of a ContractInvocation object.""" + contract_invocation = contract_invocation_factory() + assert isinstance(contract_invocation, ContractInvocation) + + +def test_contract_invocation_properties(contract_invocation_factory): + """Test the properties of a ContractInvocation object.""" + contract_invocation = contract_invocation_factory() + assert contract_invocation.contract_invocation_id == "test-invocation-id" + assert contract_invocation.wallet_id == "test-wallet-id" + assert contract_invocation.address_id == "0xaddressid" + assert contract_invocation.contract_address == "0xcontractaddress" + assert contract_invocation.method == "testMethod" + assert contract_invocation.args == {"arg1": "value1"} + assert contract_invocation.abi == {"abi": "data"} + assert contract_invocation.amount == Decimal("1") # 1 in atomic units + assert contract_invocation.status.value == "complete" + assert contract_invocation.transaction_link == "https://sepolia.basescan.org/tx/0xtransactionlink" + assert contract_invocation.transaction_hash == "0xtransactionhash" + + +@patch("cdp.Cdp.api_clients") +@patch("cdp.contract_invocation.Asset") +def test_create_contract_invocation( + mock_asset, mock_api_clients, contract_invocation_factory, asset_factory +): + """Test the creation of a ContractInvocation object.""" + mock_fetch = Mock() + mock_fetch.return_value = asset_factory() + mock_asset.fetch = mock_fetch + mock_asset.to_atomic_amount = Mock(return_value=Decimal("1")) + + mock_create_invocation = Mock() + mock_create_invocation.return_value = contract_invocation_factory()._model + mock_api_clients.contract_invocations.create_contract_invocation = mock_create_invocation + + contract_invocation = ContractInvocation.create( + address_id="0xaddressid", + wallet_id="test-wallet-id", + network_id="base-sepolia", + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + + assert isinstance(contract_invocation, ContractInvocation) + mock_create_invocation.assert_called_once_with( + wallet_id="test-wallet-id", + address_id="0xaddressid", + create_contract_invocation_request=ANY, + ) + + +@patch("cdp.Cdp.api_clients") +def test_broadcast_contract_invocation(mock_api_clients, contract_invocation_factory): + """Test the broadcasting of a ContractInvocation object.""" + contract_invocation = contract_invocation_factory(status="signed") + mock_broadcast = Mock(return_value=contract_invocation_factory(status="broadcast")._model) + mock_api_clients.contract_invocations.broadcast_contract_invocation = mock_broadcast + + response = contract_invocation.broadcast() + + assert isinstance(response, ContractInvocation) + mock_broadcast.assert_called_once_with( + wallet_id=contract_invocation.wallet_id, + address_id=contract_invocation.address_id, + contract_invocation_id=contract_invocation.contract_invocation_id, + broadcast_contract_invocation_request=ANY, + ) + + +def test_broadcast_unsigned_contract_invocation(contract_invocation_factory): + """Test the broadcasting of an unsigned ContractInvocation object.""" + contract_invocation = contract_invocation_factory(status="pending") + with pytest.raises(TransactionNotSignedError, match="Contract invocation is not signed"): + contract_invocation.broadcast() + + +@patch("cdp.Cdp.api_clients") +def test_reload_contract_invocation(mock_api_clients, contract_invocation_factory): + """Test the reloading of a ContractInvocation object.""" + contract_invocation = contract_invocation_factory(status="pending") + complete_invocation = contract_invocation_factory(status="complete") + mock_get_invocation = Mock() + mock_api_clients.contract_invocations.get_contract_invocation = mock_get_invocation + mock_get_invocation.return_value = complete_invocation._model + + contract_invocation.reload() + + mock_get_invocation.assert_called_once_with( + wallet_id=contract_invocation.wallet_id, + address_id=contract_invocation.address_id, + contract_invocation_id=contract_invocation.contract_invocation_id, + ) + assert contract_invocation.status.value == "complete" + + +@patch("cdp.Cdp.api_clients") +@patch("cdp.contract_invocation.time.sleep") +@patch("cdp.contract_invocation.time.time") +def test_wait_for_contract_invocation( + mock_time, mock_sleep, mock_api_clients, contract_invocation_factory +): + """Test the waiting for a ContractInvocation object to complete.""" + pending_invocation = contract_invocation_factory(status="pending") + complete_invocation = contract_invocation_factory(status="complete") + mock_get_invocation = Mock() + mock_api_clients.contract_invocations.get_contract_invocation = mock_get_invocation + mock_get_invocation.side_effect = [pending_invocation._model, complete_invocation._model] + + mock_time.side_effect = [0, 0.2, 0.4] + + result = pending_invocation.wait(interval_seconds=0.2, timeout_seconds=1) + + assert result.status.value == "complete" + mock_get_invocation.assert_called_with( + wallet_id=pending_invocation.wallet_id, + address_id=pending_invocation.address_id, + contract_invocation_id=pending_invocation.contract_invocation_id, + ) + assert mock_get_invocation.call_count == 2 + mock_sleep.assert_has_calls([call(0.2)] * 2) + assert mock_time.call_count == 3 + + +@patch("cdp.Cdp.api_clients") +@patch("cdp.contract_invocation.time.sleep") +@patch("cdp.contract_invocation.time.time") +def test_wait_for_contract_invocation_timeout( + mock_time, mock_sleep, mock_api_clients, contract_invocation_factory +): + """Test the waiting for a ContractInvocation object to complete with a timeout.""" + pending_invocation = contract_invocation_factory(status="pending") + mock_get_invocation = Mock(return_value=pending_invocation._model) + mock_api_clients.contract_invocations.get_contract_invocation = mock_get_invocation + + mock_time.side_effect = [0, 0.5, 1.0, 1.5, 2.0, 2.5] + + with pytest.raises(TimeoutError, match="Contract Invocation timed out"): + pending_invocation.wait(interval_seconds=0.5, timeout_seconds=2) + + assert mock_get_invocation.call_count == 5 + mock_sleep.assert_has_calls([call(0.5)] * 4) + assert mock_time.call_count == 6 + + +def test_sign_contract_invocation_invalid_key(contract_invocation_factory): + """Test the signing of a ContractInvocation object with an invalid key.""" + contract_invocation = contract_invocation_factory() + with pytest.raises(ValueError, match="key must be a LocalAccount"): + contract_invocation.sign("invalid_key") + + +def test_contract_invocation_str_representation(contract_invocation_factory): + """Test the string representation of a ContractInvocation object.""" + contract_invocation = contract_invocation_factory() + expected_str = ( + f"ContractInvocation: (id: {contract_invocation.contract_invocation_id}, " + f"wallet_id: {contract_invocation.wallet_id}, address_id: {contract_invocation.address_id}, " + f"network_id: {contract_invocation.network_id}, method: {contract_invocation.method}, " + f"status: {contract_invocation.status})" + ) + assert str(contract_invocation) == expected_str + + +def test_contract_invocation_repr(contract_invocation_factory): + """Test the representation of a ContractInvocation object.""" + contract_invocation = contract_invocation_factory() + assert repr(contract_invocation) == str(contract_invocation) diff --git a/tests/test_wallet.py b/tests/test_wallet.py index afac1f3..575395a 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -10,6 +10,7 @@ from cdp.client.models.create_wallet_request import CreateWalletRequest, CreateWalletRequestWallet from cdp.client.models.feature_set import FeatureSet from cdp.client.models.wallet import Wallet as WalletModel +from cdp.contract_invocation import ContractInvocation from cdp.trade import Trade from cdp.transfer import Transfer from cdp.wallet import Wallet @@ -420,6 +421,60 @@ def test_wallet_transfer_no_default_address(wallet_factory): ) +@patch("cdp.Cdp.use_server_signer", True) +def test_wallet_invoke_contract_with_server_signer(wallet_factory): + """Test the invoke_contract method of a Wallet with server-signer.""" + wallet = wallet_factory() + mock_default_address = Mock(spec=WalletAddress) + mock_contract_invocation_instance = Mock(spec=ContractInvocation) + mock_default_address.invoke_contract.return_value = mock_contract_invocation_instance + + with patch.object( + Wallet, "default_address", new_callable=PropertyMock + ) as mock_default_address_prop: + mock_default_address_prop.return_value = mock_default_address + + contract_invocation = wallet.invoke_contract( + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + + assert isinstance(contract_invocation, ContractInvocation) + mock_default_address.invoke_contract.assert_called_once_with( + "0xcontractaddress", + "testMethod", + [{"abi": "data"}], + {"arg1": "value1"}, + Decimal("1"), + "wei", + ) + + +@patch("cdp.Cdp.use_server_signer", True) +def test_wallet_contract_invocation_no_default_address(wallet_factory): + """Test the invoke_contract method of a Wallet with no default address.""" + wallet = wallet_factory() + + with patch.object( + Wallet, "default_address", new_callable=PropertyMock + ) as mock_default_address_prop: + mock_default_address_prop.return_value = None + + with pytest.raises(ValueError, match="Default address does not exist"): + wallet.invoke_contract( + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + + @patch("cdp.Cdp.api_clients") def test_wallet_reload(mock_api_clients, wallet_factory): """Test the reload method of a Wallet.""" diff --git a/tests/test_wallet_address.py b/tests/test_wallet_address.py index 96362e9..633ee6b 100644 --- a/tests/test_wallet_address.py +++ b/tests/test_wallet_address.py @@ -7,6 +7,7 @@ from cdp.client.models.address import Address as AddressModel from cdp.client.models.asset import Asset as AssetModel from cdp.client.models.balance import Balance as BalanceModel +from cdp.contract_invocation import ContractInvocation from cdp.errors import InsufficientFundsError from cdp.trade import Trade from cdp.transfer import Transfer @@ -334,6 +335,178 @@ def test_trades_api_error(mock_trade, wallet_address): wallet_address.trades() +@patch("cdp.wallet_address.ContractInvocation") +@patch("cdp.Cdp.api_clients") +@patch("cdp.Cdp.use_server_signer", True) +def test_invoke_contract_with_server_signer( + mock_api_clients, mock_contract_invocation, wallet_address, balance_model +): + """Test the invoke_contract method with a server signer.""" + mock_contract_invocation_instance = Mock(spec=ContractInvocation) + mock_contract_invocation.create.return_value = mock_contract_invocation_instance + + mock_get_balance = Mock() + mock_get_balance.return_value = balance_model + mock_api_clients.external_addresses.get_external_address_balance = mock_get_balance + + contract_invocation = wallet_address.invoke_contract( + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + + assert isinstance(contract_invocation, ContractInvocation) + mock_get_balance.assert_called_once_with( + network_id=wallet_address.network_id, address_id=wallet_address.address_id, asset_id="eth" + ) + mock_contract_invocation.create.assert_called_once_with( + address_id=wallet_address.address_id, + wallet_id=wallet_address.wallet_id, + network_id=wallet_address.network_id, + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + mock_contract_invocation_instance.sign.assert_not_called() + mock_contract_invocation_instance.broadcast.assert_not_called() + + +@patch("cdp.wallet_address.ContractInvocation") +@patch("cdp.Cdp.api_clients") +@patch("cdp.Cdp.use_server_signer", False) +def test_invoke_contract( + mock_api_clients, mock_contract_invocation, wallet_address_with_key, balance_model +): + """Test the invoke_contract method.""" + mock_contract_invocation_instance = Mock(spec=ContractInvocation) + mock_contract_invocation.create.return_value = mock_contract_invocation_instance + + mock_get_balance = Mock() + mock_get_balance.return_value = balance_model + mock_api_clients.external_addresses.get_external_address_balance = mock_get_balance + + contract_invocation = wallet_address_with_key.invoke_contract( + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + + assert isinstance(contract_invocation, ContractInvocation) + mock_get_balance.assert_called_once_with( + network_id=wallet_address_with_key.network_id, + address_id=wallet_address_with_key.address_id, + asset_id="eth", + ) + mock_contract_invocation.create.assert_called_once_with( + address_id=wallet_address_with_key.address_id, + wallet_id=wallet_address_with_key.wallet_id, + network_id=wallet_address_with_key.network_id, + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + mock_contract_invocation_instance.sign.assert_called_once_with(wallet_address_with_key.key) + mock_contract_invocation_instance.broadcast.assert_called_once() + + +@patch("cdp.wallet_address.ContractInvocation") +@patch("cdp.Cdp.api_clients") +@patch("cdp.Cdp.use_server_signer", False) +def test_invoke_contract_api_error( + mock_api_clients, mock_contract_invocation, wallet_address_with_key, balance_model +): + """Test the invoke_contract method raises an error when the create API call fails.""" + mock_contract_invocation.create.side_effect = Exception("API Error") + + mock_get_balance = Mock() + mock_get_balance.return_value = balance_model + mock_api_clients.external_addresses.get_external_address_balance = mock_get_balance + + with pytest.raises(Exception, match="API Error"): + wallet_address_with_key.invoke_contract( + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + + mock_get_balance.assert_called_once_with( + network_id=wallet_address_with_key.network_id, + address_id=wallet_address_with_key.address_id, + asset_id="eth", + ) + mock_contract_invocation.create.assert_called_once_with( + address_id=wallet_address_with_key.address_id, + wallet_id=wallet_address_with_key.wallet_id, + network_id=wallet_address_with_key.network_id, + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + + +@patch("cdp.wallet_address.ContractInvocation") +@patch("cdp.Cdp.api_clients") +@patch("cdp.Cdp.use_server_signer", False) +def test_invoke_contract_broadcast_api_error( + mock_api_clients, mock_contract_invocation, wallet_address_with_key, balance_model +): + """Test the invoke_contract method raises an error when the broadcast API call fails.""" + mock_contract_invocation_instance = Mock(spec=ContractInvocation) + mock_contract_invocation.create.return_value = mock_contract_invocation_instance + mock_contract_invocation_instance.broadcast.side_effect = Exception("API Error") + + mock_get_balance = Mock() + mock_get_balance.return_value = balance_model + mock_api_clients.external_addresses.get_external_address_balance = mock_get_balance + + with pytest.raises(Exception, match="API Error"): + wallet_address_with_key.invoke_contract( + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + + mock_get_balance.assert_called_once_with( + network_id=wallet_address_with_key.network_id, + address_id=wallet_address_with_key.address_id, + asset_id="eth", + ) + mock_contract_invocation.create.assert_called_once_with( + address_id=wallet_address_with_key.address_id, + wallet_id=wallet_address_with_key.wallet_id, + network_id=wallet_address_with_key.network_id, + contract_address="0xcontractaddress", + method="testMethod", + abi=[{"abi": "data"}], + args={"arg1": "value1"}, + amount=Decimal("1"), + asset_id="wei", + ) + mock_contract_invocation_instance.sign.assert_called_once_with(wallet_address_with_key.key) + mock_contract_invocation_instance.broadcast.assert_called_once() + + @patch("cdp.Cdp.api_clients") def test_ensure_sufficient_balance_sufficient(mock_api_clients, wallet_address, balance_model): """Test the ensure_sufficient_balance method with sufficient balance."""