From 6b3ba381f26e62cb692034c9f5f671914b3f7d43 Mon Sep 17 00:00:00 2001 From: John Peterson Date: Thu, 19 Sep 2024 11:41:28 -0400 Subject: [PATCH] unit tests for asset, balance, errors, & transaction + more docs --- cdp/errors.py | 68 ++++++++++++++++- cdp/sponsored_send.py | 6 ++ cdp/trade.py | 9 +++ cdp/transaction.py | 6 ++ cdp/transfer.py | 1 + cdp/wallet.py | 2 +- tests/__init__.py | 0 tests/test_asset.py | 120 ++++++++++++++++++++++++++++++ tests/test_balance.py | 99 +++++++++++++++++++++++++ tests/test_errors.py | 140 +++++++++++++++++++++++++++++++++++ tests/test_transaction.py | 150 ++++++++++++++++++++++++++++++++++++++ 11 files changed, 596 insertions(+), 5 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_asset.py create mode 100644 tests/test_balance.py create mode 100644 tests/test_errors.py create mode 100644 tests/test_transaction.py diff --git a/cdp/errors.py b/cdp/errors.py index e47dfe7..31b9a11 100644 --- a/cdp/errors.py +++ b/cdp/errors.py @@ -5,6 +5,8 @@ class APIError(Exception): + """A wrapper for API exceptions to provide more context.""" + def __init__( self, err: ApiException, @@ -20,6 +22,18 @@ def __init__( @classmethod def from_error(cls, err: ApiException) -> "APIError": + """Create an APIError from an ApiException. + + Args: + err (ApiException): The ApiException to create an APIError from. + + Returns: + APIError: The APIError. + + Raises: + ValueError: If the argument is not an ApiException. + + """ if not isinstance(err, ApiException): raise ValueError("argument must be an ApiException") @@ -41,21 +55,51 @@ def from_error(cls, err: ApiException) -> "APIError": @property def http_code(self) -> int: + """Get the HTTP status code. + + Returns: + int: The HTTP status code. + + """ return self._http_code @property def api_code(self) -> str | None: + """Get the API error code. + + Returns: + str | None: The API error code. + + """ return self._api_code @property def api_message(self) -> str | None: + """Get the API error message. + + Returns: + str | None: The API error message. + + """ return self._api_message @property def handled(self) -> bool: + """Get whether the error is handled. + + Returns: + bool: True if the error is handled, False otherwise. + + """ return self._handled def __str__(self) -> str: + """Get a string representation of the APIError. + + Returns: + str: The string representation of the APIError. + + """ if self.handled: return f"APIError(http_code={self.http_code}, api_code={self.api_code}, api_message={self.api_message})" else: @@ -66,6 +110,12 @@ class InvalidConfigurationError(Exception): """Exception raised for errors in the configuration of the Coinbase SDK.""" def __init__(self, message: str = "Invalid configuration provided") -> None: + """Initialize the InvalidConfigurationError. + + Args: + message (str): The error message. + + """ self.message = message super().__init__(self.message) @@ -74,6 +124,12 @@ class InvalidAPIKeyFormatError(Exception): """Exception raised for errors in the format of the API key.""" def __init__(self, message: str = "Invalid API key format") -> None: + """Initialize the InvalidAPIKeyFormatError. + + Args: + message (str): The error message. + + """ self.message = message super().__init__(self.message) @@ -90,7 +146,8 @@ def __init__(self, expected: Decimal, exact: Decimal, msg: str = "Insufficient f msg (str): The error message prefix. """ - super().__init__(f"{msg}: have {exact}, need {expected}.") + self.message = f"{msg}: have {exact}, need {expected}." + super().__init__(self.message) class AlreadySignedError(Exception): @@ -103,7 +160,8 @@ def __init__(self, msg: str = "Resource already signed") -> None: msg (str): The error message. """ - super().__init__(msg) + self.message = msg + super().__init__(self.message) class TransactionNotSignedError(Exception): @@ -116,7 +174,8 @@ def __init__(self, msg: str = "Transaction must be signed") -> None: msg (str): The error message. """ - super().__init__(msg) + self.message = msg + super().__init__(self.message) class AddressCannotSignError(Exception): @@ -131,7 +190,8 @@ def __init__( msg (str): The error message. """ - super().__init__(msg) + self.message = msg + super().__init__(self.message) class UnimplementedError(APIError): diff --git a/cdp/sponsored_send.py b/cdp/sponsored_send.py index 784f9ea..cf3e3d6 100644 --- a/cdp/sponsored_send.py +++ b/cdp/sponsored_send.py @@ -35,6 +35,12 @@ class Status(Enum): @classmethod def terminal_states(cls): + """Get the terminal states. + + Returns: + List[str]: The terminal states. + + """ return [cls.COMPLETE, cls.FAILED] def __str__(self) -> str: diff --git a/cdp/trade.py b/cdp/trade.py index fa9e7a9..8d3929d 100644 --- a/cdp/trade.py +++ b/cdp/trade.py @@ -80,6 +80,15 @@ def list(wallet_id: str, address_id: str) -> list["Trade"]: return [Trade(model) for model in models.data] def broadcast(self) -> "Trade": + """Broadcast the trade. + + Returns: + Trade: The broadcasted trade. + + Raises: + TransactionNotSignedError: If the trade is not signed. + + """ if not all(self.transaction.signed, self.approve_transaction.signed): raise TransactionNotSignedError("Trade is not signed") diff --git a/cdp/transaction.py b/cdp/transaction.py index b679272..7d23612 100644 --- a/cdp/transaction.py +++ b/cdp/transaction.py @@ -23,6 +23,12 @@ class Status(Enum): @classmethod def terminal_states(cls): + """Get the terminal states. + + Returns: + List[str]: The terminal states. + + """ return [cls.COMPLETE, cls.FAILED] def __str__(self) -> str: diff --git a/cdp/transfer.py b/cdp/transfer.py index 10becd8..0cec1cd 100644 --- a/cdp/transfer.py +++ b/cdp/transfer.py @@ -312,4 +312,5 @@ def __str__(self) -> str: ) def __repr__(self) -> str: + """Get a string representation of the Transfer.""" return str(self) diff --git a/cdp/wallet.py b/cdp/wallet.py index d299251..a139732 100644 --- a/cdp/wallet.py +++ b/cdp/wallet.py @@ -376,7 +376,7 @@ def save_seed(self, file_path: str, encrypt: bool | None = False) -> None: ValueError: If the wallet does not have a seed loaded. """ - if self._master is None or self._seed == None: + if self._master is None or self._seed is None: raise ValueError("Wallet does not have seed loaded") key = self._encryption_key() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_asset.py b/tests/test_asset.py new file mode 100644 index 0000000..3b3a322 --- /dev/null +++ b/tests/test_asset.py @@ -0,0 +1,120 @@ +from decimal import Decimal +from unittest.mock import Mock, patch + +import pytest + +from cdp.asset import Asset +from cdp.client.models.asset import Asset as AssetModel + + +@pytest.fixture +def asset_model(): + """Fixture for asset model.""" + return AssetModel(network_id="ethereum-goerli", asset_id="eth", decimals=18) + + +@pytest.fixture +def asset(asset_model): + """Fixture for asset.""" + return Asset.from_model(asset_model) + + +def test_asset_initialization(asset): + """Test asset initialization.""" + assert asset.network_id == "ethereum-goerli" + assert asset.asset_id == "eth" + assert asset.decimals == 18 + + +def test_asset_from_model(asset_model): + """Test asset from model.""" + asset = Asset.from_model(asset_model) + assert isinstance(asset, Asset) + assert asset.network_id == asset_model.network_id + assert asset.asset_id == asset_model.asset_id + assert asset.decimals == asset_model.decimals + + +def test_asset_from_model_with_gwei(asset_model): + """Test asset from model with gwei.""" + asset = Asset.from_model(asset_model, asset_id="gwei") + assert asset.decimals == 9 + + +def test_asset_from_model_with_wei(asset_model): + """Test asset from model with wei.""" + asset = Asset.from_model(asset_model, asset_id="wei") + assert asset.decimals == 0 + + +def test_asset_from_model_with_invalid_asset_id(asset_model): + """Test asset from model with invalid asset ID.""" + with pytest.raises(ValueError, match="Unsupported asset ID: invalid"): + Asset.from_model(asset_model, asset_id="invalid") + + +@patch("cdp.Cdp.api_clients") +def test_asset_fetch(mock_api_clients, asset_model): + """Test asset fetch.""" + mock_get_asset = Mock() + mock_get_asset.return_value = asset_model + mock_api_clients.assets.get_asset = mock_get_asset + + asset = Asset.fetch("ethereum-goerli", "eth") + assert isinstance(asset, Asset) + assert asset.network_id == "ethereum-goerli" + assert asset.asset_id == "eth" + mock_get_asset.assert_called_once_with(network_id="ethereum-goerli", asset_id="eth") + + +@patch("cdp.Cdp.api_clients") +def test_asset_fetch_api_error(mock_api_clients): + """Test asset fetch API error.""" + mock_get_asset = Mock() + mock_get_asset.side_effect = Exception("API error") + mock_api_clients.assets.get_asset = mock_get_asset + + with pytest.raises(Exception, match="API error"): + Asset.fetch("ethereum-goerli", "eth") + + +@pytest.mark.parametrize( + "input_asset_id, expected_output", + [ + ("eth", "eth"), + ("wei", "eth"), + ("gwei", "eth"), + ("usdc", "usdc"), + ], +) +def test_primary_denomination(input_asset_id, expected_output): + """Test primary denomination.""" + assert Asset.primary_denomination(input_asset_id) == expected_output + + +def test_from_atomic_amount(asset): + """Test from atomic amount.""" + assert asset.from_atomic_amount(Decimal("1000000000000000000")) == Decimal("1") + assert asset.from_atomic_amount(Decimal("500000000000000000")) == Decimal("0.5") + + +def test_to_atomic_amount(asset): + """Test to atomic amount.""" + assert asset.to_atomic_amount(Decimal("1")) == Decimal("1000000000000000000") + assert asset.to_atomic_amount(Decimal("0.5")) == Decimal("500000000000000000") + + +def test_asset_str_representation(asset): + """Test asset string representation.""" + expected_str = ( + "Asset: (asset_id: eth, network_id: ethereum-goerli, contract_address: None, decimals: 18)" + ) + assert str(asset) == expected_str + + +def test_asset_repr(asset): + """Test asset repr.""" + expected_repr = ( + "Asset: (asset_id: eth, network_id: ethereum-goerli, contract_address: None, decimals: 18)" + ) + assert repr(asset) == expected_repr diff --git a/tests/test_balance.py b/tests/test_balance.py new file mode 100644 index 0000000..602f7bb --- /dev/null +++ b/tests/test_balance.py @@ -0,0 +1,99 @@ +from decimal import Decimal + +import pytest + +from cdp.asset import Asset +from cdp.balance import Balance +from cdp.client.models.asset import Asset as AssetModel +from cdp.client.models.balance import Balance as BalanceModel + + +@pytest.fixture +def asset_model(): + """Fixture for asset model.""" + return AssetModel(network_id="ethereum-goerli", asset_id="eth", decimals=18) + + +@pytest.fixture +def asset(asset_model): + """Fixture for asset.""" + return Asset.from_model(asset_model) + + +@pytest.fixture +def balance_model(asset_model): + """Fixture for balance model.""" + return BalanceModel(amount="1000000000000000000", asset=asset_model) + + +@pytest.fixture +def balance(asset): + """Fixture for balance.""" + return Balance(Decimal("1.5"), asset) + + +def test_balance_initialization(asset): + """Test balance initialization.""" + balance = Balance(Decimal("1"), asset) + assert balance.amount == Decimal("1") + assert balance.asset == asset + assert balance.asset_id == "eth" + + +def test_balance_initialization_with_asset_id(asset): + """Test balance initialization with asset ID.""" + balance = Balance(Decimal("10"), asset, asset_id="gwei") + assert balance.amount == Decimal("10") + assert balance.asset == asset + assert balance.asset_id == "gwei" + + +def test_balance_from_model(balance_model): + """Test balance from model.""" + balance = Balance.from_model(balance_model) + assert balance.amount == Decimal("1") + assert isinstance(balance.asset, Asset) + assert balance.asset.asset_id == "eth" + assert balance.asset_id == "eth" + + +def test_balance_from_model_with_asset_id(balance_model): + """Test balance from model with asset ID.""" + balance = Balance.from_model(balance_model, asset_id="gwei") + assert balance.amount == Decimal("1000000000") + assert isinstance(balance.asset, Asset) + assert balance.asset.asset_id == "eth" + assert balance.asset_id == "gwei" + + +def test_balance_amount(balance): + """Test balance amount.""" + assert balance.amount == Decimal("1.5") + + +def test_balance_asset(balance, asset): + """Test balance asset.""" + assert balance.asset == asset + + +def test_balance_asset_id(balance, asset): + """Test balance asset ID.""" + assert balance.asset_id == asset.asset_id + + +def test_balance_str_representation(asset): + """Test balance string representation.""" + balance = Balance(Decimal("1.5"), asset) + assert ( + str(balance) + == "Balance: (amount: 1.5, asset: Asset: (asset_id: eth, network_id: ethereum-goerli, contract_address: None, decimals: 18))" + ) + + +def test_balance_repr(asset): + """Test balance repr.""" + balance = Balance(Decimal("1.5"), asset) + assert ( + repr(balance) + == "Balance: (amount: 1.5, asset: Asset: (asset_id: eth, network_id: ethereum-goerli, contract_address: None, decimals: 18))" + ) diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..a6411e9 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,140 @@ +import json +from decimal import Decimal + +import pytest + +from cdp.client.exceptions import ApiException +from cdp.errors import ( + ERROR_CODE_TO_ERROR_CLASS, + AddressCannotSignError, + AlreadySignedError, + APIError, + InsufficientFundsError, + InvalidAPIKeyFormatError, + InvalidConfigurationError, + TransactionNotSignedError, +) + + +def test_api_error_init(): + """Test API error initialization.""" + err = ApiException(400, "Bad Request") + api_error = APIError(err, code="test_code", message="Test message") + + assert api_error.http_code == 400 + assert api_error.api_code == "test_code" + assert api_error.api_message == "Test message" + assert api_error.handled is True + + +def test_api_error_from_error_with_valid_json(): + """Test API error from error with valid JSON.""" + err = ApiException(400, "Bad Request") + err.body = json.dumps({"code": "invalid_wallet_id", "message": "Invalid wallet ID"}) + api_error = APIError.from_error(err) + + assert isinstance(api_error, ERROR_CODE_TO_ERROR_CLASS["invalid_wallet_id"]) + assert api_error.api_code == "invalid_wallet_id" + assert api_error.api_message == "Invalid wallet ID" + + +def test_api_error_from_error_with_invalid_json(): + """Test API error from error with invalid JSON.""" + err = ApiException(400, "Bad Request") + err.body = "Invalid JSON" + api_error = APIError.from_error(err) + + assert isinstance(api_error, APIError) + assert api_error.api_code is None + assert api_error.api_message is None + + +def test_api_error_from_error_with_unknown_code(): + """Test API error from error with unknown code.""" + err = ApiException(400, "Bad Request") + err.body = json.dumps({"code": "unknown_code", "message": "Unknown error"}) + api_error = APIError.from_error(err) + + assert isinstance(api_error, APIError) + assert api_error.api_code == "unknown_code" + assert api_error.api_message == "Unknown error" + assert api_error.handled is False + + +def test_api_error_str_representation(): + """Test API error string representation.""" + err = ApiException(400, "Bad Request") + api_error = APIError(err, code="test_code", message="Test message") + + assert str(api_error) == "APIError(http_code=400, api_code=test_code, api_message=Test message)" + + +def test_invalid_configuration_error(): + """Test invalid configuration error.""" + with pytest.raises(InvalidConfigurationError, match="Custom configuration error"): + raise InvalidConfigurationError("Custom configuration error") + + +def test_invalid_api_key_format_error(): + """Test invalid API key format error.""" + with pytest.raises(InvalidAPIKeyFormatError, match="Invalid API key format"): + raise InvalidAPIKeyFormatError() + + +def test_insufficient_funds_error(): + """Test insufficient funds error.""" + with pytest.raises(InsufficientFundsError, match="Insufficient funds: have 50, need 100"): + raise InsufficientFundsError(Decimal(100), Decimal(50)) + + +def test_already_signed_error(): + """Test already signed error.""" + with pytest.raises(AlreadySignedError, match="Resource already signed"): + raise AlreadySignedError() + + +def test_transaction_not_signed_error(): + """Test transaction not signed error.""" + with pytest.raises(TransactionNotSignedError, match="Transaction must be signed"): + raise TransactionNotSignedError() + + +def test_address_cannot_sign_error(): + """Test address cannot sign error.""" + with pytest.raises( + AddressCannotSignError, match="Address cannot sign transaction without private key loaded" + ): + raise AddressCannotSignError() + + +@pytest.mark.parametrize( + "error_code, expected_class", + [ + ("unimplemented", "UnimplementedError"), + ("unauthorized", "UnauthorizedError"), + ("internal", "InternalError"), + ("not_found", "NotFoundError"), + ("invalid_wallet_id", "InvalidWalletIDError"), + ("invalid_address_id", "InvalidAddressIDError"), + ("invalid_wallet", "InvalidWalletError"), + ("invalid_address", "InvalidAddressError"), + ("invalid_amount", "InvalidAmountError"), + ("invalid_transfer_id", "InvalidTransferIDError"), + ("invalid_page_token", "InvalidPageError"), + ("invalid_page_limit", "InvalidLimitError"), + ("already_exists", "AlreadyExistsError"), + ("malformed_request", "MalformedRequestError"), + ("unsupported_asset", "UnsupportedAssetError"), + ("invalid_asset_id", "InvalidAssetIDError"), + ("invalid_destination", "InvalidDestinationError"), + ("invalid_network_id", "InvalidNetworkIDError"), + ("resource_exhausted", "ResourceExhaustedError"), + ("faucet_limit_reached", "FaucetLimitReachedError"), + ("invalid_signed_payload", "InvalidSignedPayloadError"), + ("invalid_transfer_status", "InvalidTransferStatusError"), + ("network_feature_unsupported", "NetworkFeatureUnsupportedError"), + ], +) +def test_error_code_mapping(error_code, expected_class): + """Test error code mapping.""" + assert ERROR_CODE_TO_ERROR_CLASS[error_code].__name__ == expected_class diff --git a/tests/test_transaction.py b/tests/test_transaction.py new file mode 100644 index 0000000..268ff57 --- /dev/null +++ b/tests/test_transaction.py @@ -0,0 +1,150 @@ +import pytest + +from cdp.client.models.transaction import Transaction as TransactionModel +from cdp.transaction import Transaction + + +@pytest.fixture +def transaction_model(): + """Fixture for a transaction model.""" + return TransactionModel( + network_id="ethereum-goerli", + transaction_hash="0xtransactionhash", + from_address_id="0xfromaddressid", + to_address_id="0xtoaddressid", + unsigned_payload="0xunsignedpayload", + signed_payload="0xsignedpayload", + status="complete", + block_hash="0xblockhash", + block_height="123456", + transaction_link="https://basescan.org/tx/0xtransactionlink", + ) + + +@pytest.fixture +def unsigned_transaction_model(): + """Fixture for an unsigned transaction model.""" + return TransactionModel( + network_id="ethereum-goerli", + from_address_id="0xfromaddressid", + to_address_id="0xtoaddressid", + unsigned_payload="0xunsignedpayload", + status="pending", + ) + + +@pytest.fixture +def transaction(transaction_model): + """Fixture for a transaction.""" + return Transaction(transaction_model) + + +@pytest.fixture +def unsigned_transaction(unsigned_transaction_model): + """Fixture for an unsigned transaction.""" + return Transaction(unsigned_transaction_model) + + +def test_transaction_initialization(transaction): + """Test transaction initialization.""" + assert isinstance(transaction._model, TransactionModel) + assert transaction._raw is None + assert transaction._signature == "0xsignedpayload" + + +def test_transaction_initialization_invalid_model(): + """Test transaction initialization with an invalid model.""" + with pytest.raises(TypeError, match="model must be of type TransactionModel"): + Transaction("invalid_model") + + +def test_unsigned_payload(transaction): + """Test unsigned payload.""" + assert transaction.unsigned_payload == "0xunsignedpayload" + + +def test_signed_payload(transaction): + """Test signed payload.""" + assert transaction.signed_payload == "0xsignedpayload" + + +def test_transaction_hash(transaction): + """Test transaction hash.""" + assert transaction.transaction_hash == "0xtransactionhash" + + +@pytest.mark.parametrize( + "status, expected_status", + [ + ("pending", Transaction.Status.PENDING), + ("signed", Transaction.Status.SIGNED), + ("broadcast", Transaction.Status.BROADCAST), + ("complete", Transaction.Status.COMPLETE), + ("failed", Transaction.Status.FAILED), + ("unspecified", Transaction.Status.UNSPECIFIED), + ], +) +def test_status(transaction, status, expected_status): + """Test transaction status.""" + transaction._model.status = status + assert transaction.status == expected_status + + +def test_from_address_id(transaction): + """Test from address ID.""" + assert transaction.from_address_id == "0xfromaddressid" + + +def test_to_address_id(transaction): + """Test to address ID.""" + assert transaction.to_address_id == "0xtoaddressid" + + +def test_terminal_state(unsigned_transaction, transaction): + """Test terminal state.""" + assert not unsigned_transaction.terminal_state + assert transaction.terminal_state + + +def test_block_hash(transaction): + """Test block hash.""" + assert transaction.block_hash == "0xblockhash" + + +def test_block_height(transaction): + """Test block height.""" + assert transaction.block_height == "123456" + + +def test_transaction_link(transaction): + """Test transaction link.""" + assert transaction.transaction_link == "https://basescan.org/tx/0xtransactionlink" + + +def test_signed(unsigned_transaction, transaction): + """Test signed.""" + assert not unsigned_transaction.signed + assert transaction.signed + + +def test_signature(transaction): + """Test signature.""" + assert transaction.signature == "0xsignedpayload" + + +def test_signature_not_signed(unsigned_transaction): + """Test signature not signed.""" + with pytest.raises(ValueError, match="Transaction is not signed"): + signature = unsigned_transaction.signature + + +def test_str_representation(transaction): + """Test string representation.""" + expected_str = "Transaction{transaction_hash: '0xtransactionhash', status: 'complete'}" + assert str(transaction) == expected_str + + +def test_repr(transaction): + """Test repr.""" + expected_repr = "Transaction{transaction_hash: '0xtransactionhash', status: 'complete'}" + assert repr(transaction) == expected_repr