diff --git a/tests/apiv2/test_countries.py b/tests/apiv2/test_countries.py new file mode 100644 index 00000000..65bee73f --- /dev/null +++ b/tests/apiv2/test_countries.py @@ -0,0 +1,80 @@ +from tests.util import APITestCase +from tracker import models +from tracker.api.serializers import CountryRegionSerializer, CountrySerializer + + +class TestCountry(APITestCase): + serializer_class = CountrySerializer + model_name = 'country' + lookup_key = 'numeric_or_alpha' + id_field = 'alpha2' + + def test_fetch(self): + with self.saveSnapshot(): + country = models.Country.objects.first() + data = self.get_list() + self.assertEqual( + data['count'], + models.Country.objects.count(), + msg='Country count did not match', + ) + self.assertV2ModelPresent(country, data['results']) + + with self.subTest('via numeric code'): + data = self.get_detail( + country, kwargs={'numeric_or_alpha': country.numeric} + ) + self.assertV2ModelPresent(country, data) + + with self.subTest('via alpha2'): + data = self.get_detail( + country, kwargs={'numeric_or_alpha': country.alpha2} + ) + self.assertV2ModelPresent(country, data) + + with self.subTest('via alpha3'): + data = self.get_detail( + country, kwargs={'numeric_or_alpha': country.alpha3} + ) + self.assertV2ModelPresent(country, data) + + with self.subTest('error cases'): + self.get_detail(None, kwargs={'numeric_or_alpha': '00'}, status_code=404) + self.get_detail(None, kwargs={'numeric_or_alpha': '000'}, status_code=404) + self.get_detail(None, kwargs={'numeric_or_alpha': 'XX'}, status_code=404) + self.get_detail(None, kwargs={'numeric_or_alpha': 'XXX'}, status_code=404) + self.get_detail( + None, kwargs={'numeric_or_alpha': 'foobar'}, status_code=404 + ) + + +class TestCountryRegions(APITestCase): + serializer_class = CountryRegionSerializer + model_name = 'countryregion' + + def test_fetch(self): + region = models.CountryRegion.objects.create( + name='Test Region', country=models.Country.objects.first() + ) + + with self.saveSnapshot(): + data = self.get_list(model_name='region') + self.assertEqual( + data['count'], + models.CountryRegion.objects.count(), + msg='Region count did not match', + ) + self.assertV2ModelPresent(region, data['results']) + + data = self.get_detail(region, model_name='region') + self.assertV2ModelPresent(region, data) + + with self.subTest('via country'): + data = self.get_noun( + 'regions', + region.country, + lookup_key='numeric_or_alpha', + model_name='country', + kwargs={'numeric_or_alpha': region.country.alpha3}, + ) + self.assertV2ModelPresent(region, data['results']) diff --git a/tests/util.py b/tests/util.py index 0b5090f1..7c3a193e 100644 --- a/tests/util.py +++ b/tests/util.py @@ -228,7 +228,9 @@ class APITestCase(TransactionTestCase, AssertionHelpers): view_user_permissions = [] # trickles to add_user and locked_user add_user_permissions = [] # trickles to locked_user locked_user_permissions = [] + lookup_key = 'pk' encoder = DjangoJSONEncoder() + id_field = 'id' def parseJSON(self, response, status_code=200): self.assertEqual( @@ -278,11 +280,15 @@ def get_detail( self.client.force_authenticate(user=other_kwargs['user']) model_name = model_name or self.model_name assert model_name is not None - pk = obj if isinstance(obj, int) else obj.pk + lookup_kwargs = {**kwargs} + if self.lookup_key == 'pk': + pk = obj if isinstance(obj, int) else obj.pk + lookup_kwargs['pk'] = pk url = reverse( self._get_viewname(model_name, 'detail', **kwargs), - kwargs={'pk': pk, **kwargs}, + kwargs=lookup_kwargs, ) + with self._snapshot('GET', url, data) as snapshot: response = self.client.get( url, @@ -336,10 +342,13 @@ def get_noun( status_code=200, data=None, kwargs=None, + lookup_key=None, **other_kwargs, ): kwargs = kwargs or {} - if obj is not None: + if lookup_key is None: + lookup_key = self.lookup_key + if obj is not None and lookup_key == 'pk': kwargs['pk'] = obj.pk if 'user' in other_kwargs: self.client.force_authenticate(user=other_kwargs['user']) @@ -720,14 +729,14 @@ def assertV2ModelPresent( m for m in data if expected_model['type'] == m.get('type', None) - and expected_model['id'] == m.get('id', None) + and expected_model[self.id_field] == m.get(self.id_field, None) ), None, ) ) is None: self.fail( 'Could not find model "%s:%s" in data' - % (expected_model['type'], expected_model['id']) + % (expected_model['type'], expected_model[self.id_field]) ) problems = self._compare_model( expected_model, found_model, partial, missing_ok=missing_ok @@ -738,7 +747,7 @@ def assertV2ModelPresent( % ( f'{msg}\n' if msg else '', expected_model['type'], - expected_model['id'], + expected_model[self.id_field], '\n'.join(problems), ) ) @@ -759,8 +768,8 @@ def assertV2ModelNotPresent(self, unexpected_model, data): model for model in data if ( - model.get('id', None) == unexpected_model['id'] - and model.get('type', None) == unexpected_model['type'] + unexpected_model['type'] == model['type'] + and unexpected_model[self.id_field] == model[self.id_field] ) ), None, @@ -769,7 +778,7 @@ def assertV2ModelNotPresent(self, unexpected_model, data): ): self.fail( 'Found model "%s:%s" in data' - % (unexpected_model['type'], unexpected_model['id']) + % (unexpected_model['type'], unexpected_model[self.id_field]) ) def assertExactV2Models( @@ -902,6 +911,7 @@ def _snapshot(self, method, url, data): self._snapshot_num = 1 # obscure ids from url since they can drift depending on test order/results, remove leading tracker since it's redundant, and slugify everything else + # FIXME: this doesn't quite work for Country since we don't use PK lookups in the urls pieces += [ f'S{self._snapshot_num}', re.sub( diff --git a/tracker/api/serializers.py b/tracker/api/serializers.py index feecbaa1..7101a9c3 100644 --- a/tracker/api/serializers.py +++ b/tracker/api/serializers.py @@ -17,10 +17,11 @@ from rest_framework.validators import UniqueTogetherValidator from tracker.api import messages -from tracker.models import Ad, Interstitial, Interview from tracker.models.bid import Bid, DonationBid +from tracker.models.country import Country, CountryRegion from tracker.models.donation import Donation, Donor, Milestone from tracker.models.event import Event, SpeedRun, Tag, Talent, VideoLink, VideoLinkType +from tracker.models.interstitial import Ad, Interstitial, Interview log = logging.getLogger(__name__) @@ -298,6 +299,46 @@ def to_representation(self, obj): return obj.__class__.__name__.lower() +class CountrySerializer(PrimaryOrNaturalKeyLookup, TrackerModelSerializer): + type = ClassNameField() + + class Meta: + model = Country + fields = ( + 'type', + 'name', + 'alpha2', + 'alpha3', + 'numeric', + ) + + def to_representation(self, instance): + if self.root == self or getattr(self.root, 'child', None) == self: + return super().to_representation(instance) + else: + return instance.alpha3 + + +class CountryRegionSerializer(PrimaryOrNaturalKeyLookup, TrackerModelSerializer): + type = ClassNameField() + country = CountrySerializer() + + class Meta: + model = CountryRegion + fields = ( + 'type', + 'id', + 'name', + 'country', + ) + + def to_representation(self, instance): + if self.root == self or getattr(self.root, 'child', None) == self: + return super().to_representation(instance) + else: + return [instance.name, instance.country.alpha3] + + class EventNestedSerializerMixin: event_move = False @@ -582,6 +623,9 @@ def get_donor_name(self, donation: Donation): class EventSerializer(PrimaryOrNaturalKeyLookup, TrackerModelSerializer): type = ClassNameField() + # include these later + # allowed_prize_countries = CountrySerializer(many=True) + # disallowed_prize_regions = CountryRegionSerializer(many=True) timezone = serializers.SerializerMethodField() amount = serializers.SerializerMethodField() donation_count = serializers.SerializerMethodField() @@ -603,6 +647,8 @@ class Meta: 'datetime', 'timezone', 'use_one_step_screening', + # 'allowed_prize_countries', + # 'disallowed_prize_regions', ) def get_fields(self): diff --git a/tracker/api/urls.py b/tracker/api/urls.py index bf3cc26e..dddf9dbf 100644 --- a/tracker/api/urls.py +++ b/tracker/api/urls.py @@ -4,7 +4,17 @@ from rest_framework import routers from tracker.api import views -from tracker.api.views import ad, bids, donations, interview, me, milestone, run, talent +from tracker.api.views import ( + ad, + bids, + country, + donations, + interview, + me, + milestone, + run, + talent, +) router = routers.DefaultRouter() @@ -34,6 +44,8 @@ def event_nested_route(path, viewset, *, basename=None, feed=False): event_nested_route(r'milestones', milestone.MilestoneViewSet) router.register(r'donations', donations.DonationViewSet, basename='donations') router.register(r'me', me.MeViewSet, basename='me') +router.register(r'countries', country.CountryViewSet) +router.register(r'regions', country.CountryRegionViewSet, basename='region') # use the router-generated URLs, and also link to the browsable API urlpatterns = [ diff --git a/tracker/api/views/country.py b/tracker/api/views/country.py new file mode 100644 index 00000000..c613bc9c --- /dev/null +++ b/tracker/api/views/country.py @@ -0,0 +1,51 @@ +import re + +from django.db.models import Q +from rest_framework.decorators import action +from rest_framework.exceptions import NotFound +from rest_framework.generics import get_object_or_404 + +from tracker.api.pagination import TrackerPagination +from tracker.api.serializers import CountryRegionSerializer, CountrySerializer +from tracker.api.views import TrackerReadViewSet +from tracker.models import Country, CountryRegion + + +class CountryViewSet(TrackerReadViewSet): + serializer_class = CountrySerializer + pagination_class = TrackerPagination + queryset = Country.objects.all() + lookup_field = 'numeric_or_alpha' + + def get_object(self): + queryset = self.get_queryset() + pk = self.kwargs['numeric_or_alpha'] + if re.match('[0-9]{3}', pk): + return get_object_or_404(queryset, numeric=pk) + elif re.match('[A-Z]{2,3}', pk): + return get_object_or_404(queryset, Q(alpha2=pk) | Q(alpha3=pk)) + raise NotFound( + detail='Provide either an ISO 3166-1 numeric, alpha2, or alpha3 code', + code='invalid_lookup', + ) + + @action(detail=True) + def regions(self, request, *args, **kwargs): + viewset = CountryRegionViewSet(request=request, country=self.get_object()) + viewset.initial(request, *args, **kwargs) + return viewset.list(request, *args, **kwargs) + + +class CountryRegionViewSet(TrackerReadViewSet): + serializer_class = CountryRegionSerializer + pagination_class = TrackerPagination + queryset = CountryRegion.objects.select_related('country') + + def __init__(self, country=None, *args, **kwargs): + self.country = country + super().__init__(*args, **kwargs) + + def filter_queryset(self, queryset): + if self.country: + queryset = queryset.filter(country=self.country) + return queryset