diff --git a/django_prices/models.py b/django_prices/models.py index 48a2de5..131bc9c 100644 --- a/django_prices/models.py +++ b/django_prices/models.py @@ -15,6 +15,11 @@ class MoneyField(models.DecimalField): def __init__(self, verbose_name=None, currency=None, **kwargs): self.currency = currency + if (isinstance(kwargs.get('default'), Money) and + kwargs['default'].currency != self.currency): + raise ValueError( + 'Invalid currency for default value: %r (expected %r)' % ( + kwargs['default'].currency, self.currency)) super(MoneyField, self).__init__(verbose_name, **kwargs) def from_db_value(self, value, expression, connection, context): @@ -67,6 +72,8 @@ def validators(self): def deconstruct(self): name, path, args, kwargs = super(MoneyField, self).deconstruct() kwargs['currency'] = self.currency + if isinstance(kwargs.get('default'), Money): + kwargs['default'] = kwargs['default'].amount return name, path, args, kwargs diff --git a/tests/models.py b/tests/models.py index a86fed5..f946dc7 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,5 +1,6 @@ from django.db import models from django_prices.models import MoneyField, TaxedMoneyField +from prices import Money class Model(models.Model): @@ -10,3 +11,9 @@ class Model(models.Model): 'gross', currency='BTC', default='5', max_digits=9, decimal_places=2) price = TaxedMoneyField(net_field='price_net', gross_field='price_gross') + + +class ModelWithDefault(models.Model): + price = MoneyField( + currency='USD', default=Money('0', 'USD'), max_digits=9, + decimal_places=2) diff --git a/tests/test_prices.py b/tests/test_prices.py index 15bb5a3..7d619f7 100644 --- a/tests/test_prices.py +++ b/tests/test_prices.py @@ -6,7 +6,7 @@ import pytest from django.core.exceptions import ValidationError -from django.db import connection +from django.db import connection, models from django.utils import translation from prices import Money, TaxedMoney, percentage_discount @@ -19,7 +19,7 @@ from .forms import ( ModelForm, OptionalPriceForm, RequiredPriceForm, ValidatedPriceForm) -from .models import Model +from .models import Model, ModelWithDefault @pytest.fixture(scope='module') @@ -197,6 +197,27 @@ def test_combined_field_validation(): instance.full_clean() +def test_money_field_default_money(): + instance = ModelWithDefault() + assert instance.price == Money(0, 'USD') + + +def test_money_field_default_money_deconstruct(): + instance = ModelWithDefault() + _, _, _, kwargs = instance._meta.get_field('price').deconstruct() + assert kwargs == { + 'currency': 'USD', 'default': Decimal(0), 'max_digits': 9, + 'decimal_places': 2} + + +def test_money_field_default_money_invalid_currency(): + with pytest.raises(ValueError): + class InvalidModel(models.Model): + price = MoneyField( + currency='USD', default=Money('0', 'BTC'), max_digits=9, + decimal_places=2) + + def test_field_passes_all_validations(): form = RequiredPriceForm(data={'price_net': '20'}) form.full_clean()