Skip to content

Commit

Permalink
unit tests for asset, balance, errors, & transaction + more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
John-peterson-coinbase committed Sep 19, 2024
1 parent 2a99c99 commit 6b3ba38
Show file tree
Hide file tree
Showing 11 changed files with 596 additions and 5 deletions.
68 changes: 64 additions & 4 deletions cdp/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@


class APIError(Exception):
"""A wrapper for API exceptions to provide more context."""

def __init__(
self,
err: ApiException,
Expand All @@ -20,6 +22,18 @@ def __init__(

@classmethod
def from_error(cls, err: ApiException) -> "APIError":
"""Create an APIError from an ApiException.
Args:
err (ApiException): The ApiException to create an APIError from.
Returns:
APIError: The APIError.
Raises:
ValueError: If the argument is not an ApiException.
"""
if not isinstance(err, ApiException):
raise ValueError("argument must be an ApiException")

Expand All @@ -41,21 +55,51 @@ def from_error(cls, err: ApiException) -> "APIError":

@property
def http_code(self) -> int:
"""Get the HTTP status code.
Returns:
int: The HTTP status code.
"""
return self._http_code

@property
def api_code(self) -> str | None:
"""Get the API error code.
Returns:
str | None: The API error code.
"""
return self._api_code

@property
def api_message(self) -> str | None:
"""Get the API error message.
Returns:
str | None: The API error message.
"""
return self._api_message

@property
def handled(self) -> bool:
"""Get whether the error is handled.
Returns:
bool: True if the error is handled, False otherwise.
"""
return self._handled

def __str__(self) -> str:
"""Get a string representation of the APIError.
Returns:
str: The string representation of the APIError.
"""
if self.handled:
return f"APIError(http_code={self.http_code}, api_code={self.api_code}, api_message={self.api_message})"
else:
Expand All @@ -66,6 +110,12 @@ class InvalidConfigurationError(Exception):
"""Exception raised for errors in the configuration of the Coinbase SDK."""

def __init__(self, message: str = "Invalid configuration provided") -> None:
"""Initialize the InvalidConfigurationError.
Args:
message (str): The error message.
"""
self.message = message
super().__init__(self.message)

Expand All @@ -74,6 +124,12 @@ class InvalidAPIKeyFormatError(Exception):
"""Exception raised for errors in the format of the API key."""

def __init__(self, message: str = "Invalid API key format") -> None:
"""Initialize the InvalidAPIKeyFormatError.
Args:
message (str): The error message.
"""
self.message = message
super().__init__(self.message)

Expand All @@ -90,7 +146,8 @@ def __init__(self, expected: Decimal, exact: Decimal, msg: str = "Insufficient f
msg (str): The error message prefix.
"""
super().__init__(f"{msg}: have {exact}, need {expected}.")
self.message = f"{msg}: have {exact}, need {expected}."
super().__init__(self.message)


class AlreadySignedError(Exception):
Expand All @@ -103,7 +160,8 @@ def __init__(self, msg: str = "Resource already signed") -> None:
msg (str): The error message.
"""
super().__init__(msg)
self.message = msg
super().__init__(self.message)


class TransactionNotSignedError(Exception):
Expand All @@ -116,7 +174,8 @@ def __init__(self, msg: str = "Transaction must be signed") -> None:
msg (str): The error message.
"""
super().__init__(msg)
self.message = msg
super().__init__(self.message)


class AddressCannotSignError(Exception):
Expand All @@ -131,7 +190,8 @@ def __init__(
msg (str): The error message.
"""
super().__init__(msg)
self.message = msg
super().__init__(self.message)


class UnimplementedError(APIError):
Expand Down
6 changes: 6 additions & 0 deletions cdp/sponsored_send.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ class Status(Enum):

@classmethod
def terminal_states(cls):
"""Get the terminal states.
Returns:
List[str]: The terminal states.
"""
return [cls.COMPLETE, cls.FAILED]

def __str__(self) -> str:
Expand Down
9 changes: 9 additions & 0 deletions cdp/trade.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ def list(wallet_id: str, address_id: str) -> list["Trade"]:
return [Trade(model) for model in models.data]

def broadcast(self) -> "Trade":
"""Broadcast the trade.
Returns:
Trade: The broadcasted trade.
Raises:
TransactionNotSignedError: If the trade is not signed.
"""
if not all(self.transaction.signed, self.approve_transaction.signed):
raise TransactionNotSignedError("Trade is not signed")

Expand Down
6 changes: 6 additions & 0 deletions cdp/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class Status(Enum):

@classmethod
def terminal_states(cls):
"""Get the terminal states.
Returns:
List[str]: The terminal states.
"""
return [cls.COMPLETE, cls.FAILED]

def __str__(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions cdp/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,4 +312,5 @@ def __str__(self) -> str:
)

def __repr__(self) -> str:
"""Get a string representation of the Transfer."""
return str(self)
2 changes: 1 addition & 1 deletion cdp/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def save_seed(self, file_path: str, encrypt: bool | None = False) -> None:
ValueError: If the wallet does not have a seed loaded.
"""
if self._master is None or self._seed == None:
if self._master is None or self._seed is None:
raise ValueError("Wallet does not have seed loaded")

key = self._encryption_key()
Expand Down
Empty file added tests/__init__.py
Empty file.
120 changes: 120 additions & 0 deletions tests/test_asset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from decimal import Decimal
from unittest.mock import Mock, patch

import pytest

from cdp.asset import Asset
from cdp.client.models.asset import Asset as AssetModel


@pytest.fixture
def asset_model():
"""Fixture for asset model."""
return AssetModel(network_id="ethereum-goerli", asset_id="eth", decimals=18)


@pytest.fixture
def asset(asset_model):
"""Fixture for asset."""
return Asset.from_model(asset_model)


def test_asset_initialization(asset):
"""Test asset initialization."""
assert asset.network_id == "ethereum-goerli"
assert asset.asset_id == "eth"
assert asset.decimals == 18


def test_asset_from_model(asset_model):
"""Test asset from model."""
asset = Asset.from_model(asset_model)
assert isinstance(asset, Asset)
assert asset.network_id == asset_model.network_id
assert asset.asset_id == asset_model.asset_id
assert asset.decimals == asset_model.decimals


def test_asset_from_model_with_gwei(asset_model):
"""Test asset from model with gwei."""
asset = Asset.from_model(asset_model, asset_id="gwei")
assert asset.decimals == 9


def test_asset_from_model_with_wei(asset_model):
"""Test asset from model with wei."""
asset = Asset.from_model(asset_model, asset_id="wei")
assert asset.decimals == 0


def test_asset_from_model_with_invalid_asset_id(asset_model):
"""Test asset from model with invalid asset ID."""
with pytest.raises(ValueError, match="Unsupported asset ID: invalid"):
Asset.from_model(asset_model, asset_id="invalid")


@patch("cdp.Cdp.api_clients")
def test_asset_fetch(mock_api_clients, asset_model):
"""Test asset fetch."""
mock_get_asset = Mock()
mock_get_asset.return_value = asset_model
mock_api_clients.assets.get_asset = mock_get_asset

asset = Asset.fetch("ethereum-goerli", "eth")
assert isinstance(asset, Asset)
assert asset.network_id == "ethereum-goerli"
assert asset.asset_id == "eth"
mock_get_asset.assert_called_once_with(network_id="ethereum-goerli", asset_id="eth")


@patch("cdp.Cdp.api_clients")
def test_asset_fetch_api_error(mock_api_clients):
"""Test asset fetch API error."""
mock_get_asset = Mock()
mock_get_asset.side_effect = Exception("API error")
mock_api_clients.assets.get_asset = mock_get_asset

with pytest.raises(Exception, match="API error"):
Asset.fetch("ethereum-goerli", "eth")


@pytest.mark.parametrize(
"input_asset_id, expected_output",
[
("eth", "eth"),
("wei", "eth"),
("gwei", "eth"),
("usdc", "usdc"),
],
)
def test_primary_denomination(input_asset_id, expected_output):
"""Test primary denomination."""
assert Asset.primary_denomination(input_asset_id) == expected_output


def test_from_atomic_amount(asset):
"""Test from atomic amount."""
assert asset.from_atomic_amount(Decimal("1000000000000000000")) == Decimal("1")
assert asset.from_atomic_amount(Decimal("500000000000000000")) == Decimal("0.5")


def test_to_atomic_amount(asset):
"""Test to atomic amount."""
assert asset.to_atomic_amount(Decimal("1")) == Decimal("1000000000000000000")
assert asset.to_atomic_amount(Decimal("0.5")) == Decimal("500000000000000000")


def test_asset_str_representation(asset):
"""Test asset string representation."""
expected_str = (
"Asset: (asset_id: eth, network_id: ethereum-goerli, contract_address: None, decimals: 18)"
)
assert str(asset) == expected_str


def test_asset_repr(asset):
"""Test asset repr."""
expected_repr = (
"Asset: (asset_id: eth, network_id: ethereum-goerli, contract_address: None, decimals: 18)"
)
assert repr(asset) == expected_repr
Loading

0 comments on commit 6b3ba38

Please sign in to comment.