Skip to content

Commit

Permalink
fix: Support extracting generic bound from TypeVar (#88)
Browse files Browse the repository at this point in the history
Adds support for Pydantic models that are generic with a currency type.

Fixes #87.
  • Loading branch information
antonagestam authored Aug 18, 2024
1 parent 9e0ff83 commit 041fcf9
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 8 deletions.
14 changes: 8 additions & 6 deletions goose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ environments:
- prettier

hooks:
- id: check-manifest
environment: check-manifest
command: check-manifest
parameterize: false
read_only: true
args: [--no-build-isolation]
# Commented out until there's a fix for pinning setuptools.
# https://github.com/antonagestam/goose/issues/30
# - id: check-manifest
# environment: check-manifest
# command: check-manifest
# parameterize: false
# read_only: true
# args: [--no-build-isolation]

- id: prettier
environment: node
Expand Down
11 changes: 9 additions & 2 deletions src/immoney/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,16 @@ class OverdraftDict(TypedDict):

def extract_currency_type_arg(source_type: type) -> type[Currency]:
match get_args(source_type):
case (type() as currency_type,):
assert issubclass(currency_type, Currency)
case (type() as currency_type,) if issubclass(currency_type, Currency):
return currency_type
# TypeVar with Currency bound.
case (TypeVar(__bound__=type() as currency_type),) if issubclass(
currency_type, Currency
):
return currency_type
# TypeVar without bound.
case (TypeVar(__bound__=None),):
return Currency
case invalid: # pragma: no cover
raise TypeError(f"Invalid type args: {invalid!r}.")

Expand Down
196 changes: 196 additions & 0 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
from fractions import Fraction
from typing import Generic
from typing import TypeVar

import pytest
from pydantic import BaseModel
Expand Down Expand Up @@ -263,6 +265,200 @@ def test_can_generate_schema(self) -> None:
}


C_bounded = TypeVar("C_bounded", bound=Currency)


class BoundedGenericMoneyModel(BaseModel, Generic[C_bounded]):
money: Money[C_bounded]


class TestBoundedGenericMoneyModel:
@pytest.mark.parametrize(
("subunits", "currency_code", "expected"),
(
(4990, "USD", USD("49.90")),
(4990, "EUR", EUR("49.90")),
(0, "NOK", NOK(0)),
),
)
def test_can_roundtrip_valid_data(
self,
subunits: int,
currency_code: str,
expected: Money[C_bounded],
) -> None:
data = {
"money": {
"subunits": subunits,
"currency": currency_code,
}
}

instance = BoundedGenericMoneyModel[C_bounded].model_validate(data)
assert instance.money == expected
assert json.loads(instance.model_dump_json()) == data

def test_parsing_raises_validation_error_for_negative_value(self) -> None:
with pytest.raises(
ValidationError,
match=r"Input should be greater than or equal to 0",
):
BoundedGenericMoneyModel.model_validate(
{
"money": {
"currency": "EUR",
"subunits": -1,
},
}
)

def test_parsing_raises_validation_error_for_invalid_currency(self) -> None:
with pytest.raises(
ValidationError,
match=r"Input should be.*\[type=literal_error",
):
BoundedGenericMoneyModel.model_validate(
{
"money": {
"currency": "JCN",
"subunits": 4990,
},
}
)

def test_can_instantiate_valid_value(self) -> None:
instance = BoundedGenericMoneyModel(money=USD("49.90"))
assert instance.money == USD("49.90")

def test_instantiation_raises_validation_error_for_invalid_currency(self) -> None:
with pytest.raises(ValidationError, match=r"Currency is not registered"):
BoundedGenericMoneyModel(money=JCN(1))

def test_can_generate_schema(self) -> None:
assert BoundedGenericMoneyModel.model_json_schema() == {
"properties": {
"money": {
"properties": {
"currency": {
"enum": sorted_items_equal(default_registry.keys()),
"title": "Currency",
"type": "string",
},
"subunits": {
"minimum": 0,
"title": "Subunits",
"type": "integer",
},
},
"required": sorted_items_equal(["subunits", "currency"]),
"title": "Money",
"type": "object",
},
},
"required": ["money"],
"title": BoundedGenericMoneyModel.__name__,
"type": "object",
}


C_unbound = TypeVar("C_unbound")


class UnboundGenericMoneyModel(BaseModel, Generic[C_unbound]):
# mypy rightfully errors here, demanding that the type var is bounded to
# Currency, but we still want to test this case.
money: Money[C_unbound] # type: ignore[type-var]


class TestUnboundGenericMoneyModel:
@pytest.mark.parametrize(
("subunits", "currency_code", "expected"),
(
(4990, "USD", USD("49.90")),
(4990, "EUR", EUR("49.90")),
(0, "NOK", NOK(0)),
),
)
def test_can_roundtrip_valid_data(
self,
subunits: int,
currency_code: str,
expected: Money[C_unbound], # type: ignore[type-var]
) -> None:
data = {
"money": {
"subunits": subunits,
"currency": currency_code,
}
}

instance = UnboundGenericMoneyModel[C_unbound].model_validate(data)
assert instance.money == expected
assert json.loads(instance.model_dump_json()) == data

def test_parsing_raises_validation_error_for_negative_value(self) -> None:
with pytest.raises(
ValidationError,
match=r"Input should be greater than or equal to 0",
):
UnboundGenericMoneyModel.model_validate(
{
"money": {
"currency": "EUR",
"subunits": -1,
},
}
)

def test_parsing_raises_validation_error_for_invalid_currency(self) -> None:
with pytest.raises(
ValidationError,
match=r"Input should be.*\[type=literal_error",
):
UnboundGenericMoneyModel.model_validate(
{
"money": {
"currency": "JCN",
"subunits": 4990,
},
}
)

def test_can_instantiate_valid_value(self) -> None:
instance = UnboundGenericMoneyModel(money=USD("49.90"))
assert instance.money == USD("49.90")

def test_instantiation_raises_validation_error_for_invalid_currency(self) -> None:
with pytest.raises(ValidationError, match=r"Currency is not registered"):
UnboundGenericMoneyModel(money=JCN(1))

def test_can_generate_schema(self) -> None:
assert UnboundGenericMoneyModel.model_json_schema() == {
"properties": {
"money": {
"properties": {
"currency": {
"enum": sorted_items_equal(default_registry.keys()),
"title": "Currency",
"type": "string",
},
"subunits": {
"minimum": 0,
"title": "Subunits",
"type": "integer",
},
},
"required": sorted_items_equal(["subunits", "currency"]),
"title": "Money",
"type": "object",
},
},
"required": ["money"],
"title": UnboundGenericMoneyModel.__name__,
"type": "object",
}


class SpecializedMoneyModel(BaseModel):
money: Money[USDType]

Expand Down

0 comments on commit 041fcf9

Please sign in to comment.