Skip to content

Commit

Permalink
add Country/Region to V2
Browse files Browse the repository at this point in the history
[#184870680]
  • Loading branch information
uraniumanchor committed Nov 12, 2024
1 parent fda6c87 commit eb3c605
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 12 deletions.
80 changes: 80 additions & 0 deletions tests/apiv2/test_countries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from tests.util import APITestCase
from tracker import models
from tracker.api.serializers import CountryRegionSerializer, CountrySerializer


class TestCountry(APITestCase):
serializer_class = CountrySerializer
model_name = 'country'
lookup_key = 'numeric_or_alpha'
id_field = 'alpha2'

def test_fetch(self):
with self.saveSnapshot():
country = models.Country.objects.first()
data = self.get_list()
self.assertEqual(
data['count'],
models.Country.objects.count(),
msg='Country count did not match',
)
self.assertV2ModelPresent(country, data['results'])

with self.subTest('via numeric code'):
data = self.get_detail(
country, kwargs={'numeric_or_alpha': country.numeric}
)
self.assertV2ModelPresent(country, data)

with self.subTest('via alpha2'):
data = self.get_detail(
country, kwargs={'numeric_or_alpha': country.alpha2}
)
self.assertV2ModelPresent(country, data)

with self.subTest('via alpha3'):
data = self.get_detail(
country, kwargs={'numeric_or_alpha': country.alpha3}
)
self.assertV2ModelPresent(country, data)

with self.subTest('error cases'):
self.get_detail(None, kwargs={'numeric_or_alpha': '00'}, status_code=404)
self.get_detail(None, kwargs={'numeric_or_alpha': '000'}, status_code=404)
self.get_detail(None, kwargs={'numeric_or_alpha': 'XX'}, status_code=404)
self.get_detail(None, kwargs={'numeric_or_alpha': 'XXX'}, status_code=404)
self.get_detail(
None, kwargs={'numeric_or_alpha': 'foobar'}, status_code=404
)


class TestCountryRegions(APITestCase):
serializer_class = CountryRegionSerializer
model_name = 'countryregion'

def test_fetch(self):
region = models.CountryRegion.objects.create(
name='Test Region', country=models.Country.objects.first()
)

with self.saveSnapshot():
data = self.get_list(model_name='region')
self.assertEqual(
data['count'],
models.CountryRegion.objects.count(),
msg='Region count did not match',
)
self.assertV2ModelPresent(region, data['results'])

data = self.get_detail(region, model_name='region')
self.assertV2ModelPresent(region, data)

with self.subTest('via country'):
data = self.get_noun(
'regions',
region.country,
lookup_key='numeric_or_alpha',
model_name='country',
kwargs={'numeric_or_alpha': region.country.alpha3},
)
self.assertV2ModelPresent(region, data['results'])
30 changes: 20 additions & 10 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ class APITestCase(TransactionTestCase, AssertionHelpers):
view_user_permissions = [] # trickles to add_user and locked_user
add_user_permissions = [] # trickles to locked_user
locked_user_permissions = []
lookup_key = 'pk'
encoder = DjangoJSONEncoder()
id_field = 'id'

def parseJSON(self, response, status_code=200):
self.assertEqual(
Expand Down Expand Up @@ -277,11 +279,15 @@ def get_detail(
self.client.force_authenticate(user=other_kwargs['user'])
model_name = model_name or self.model_name
assert model_name is not None
pk = obj if isinstance(obj, int) else obj.pk
lookup_kwargs = {**kwargs}
if self.lookup_key == 'pk':
pk = obj if isinstance(obj, int) else obj.pk
lookup_kwargs['pk'] = pk
url = reverse(
self._get_viewname(model_name, 'detail', **kwargs),
kwargs={'pk': pk, **kwargs},
kwargs=lookup_kwargs,
)

with self._snapshot('GET', url, data) as snapshot:
response = self.client.get(
url,
Expand Down Expand Up @@ -335,10 +341,13 @@ def get_noun(
status_code=200,
data=None,
kwargs=None,
lookup_key=None,
**other_kwargs,
):
kwargs = kwargs or {}
if obj is not None:
if lookup_key is None:
lookup_key = self.lookup_key
if obj is not None and lookup_key == 'pk':
kwargs['pk'] = obj.pk
if 'user' in other_kwargs:
self.client.force_authenticate(user=other_kwargs['user'])
Expand Down Expand Up @@ -717,15 +726,15 @@ def assertV2ModelPresent(
(
m
for m in data
if expected_model['type'] == m['type']
and expected_model['id'] == m['id']
if expected_model['type'] == m.get('type', None)
and expected_model[self.id_field] == m.get(self.id_field, None)
),
None,
)
) is None:
self.fail(
'Could not find model "%s:%s" in data'
% (expected_model['type'], expected_model['id'])
% (expected_model['type'], expected_model[self.id_field])
)
problems = self._compare_model(
expected_model, found_model, partial, missing_ok=missing_ok
Expand All @@ -736,7 +745,7 @@ def assertV2ModelPresent(
% (
f'{msg}\n' if msg else '',
expected_model['type'],
expected_model['id'],
expected_model[self.id_field],
'\n'.join(problems),
)
)
Expand All @@ -755,8 +764,8 @@ def assertV2ModelNotPresent(self, unexpected_model, data):
model
for model in data
if (
model['id'] == unexpected_model['id']
and model['type'] == unexpected_model['type']
unexpected_model['type'] == model['type']
and unexpected_model[self.id_field] == model[self.id_field]
)
),
None,
Expand All @@ -765,7 +774,7 @@ def assertV2ModelNotPresent(self, unexpected_model, data):
):
self.fail(
'Found model "%s:%s" in data'
% (unexpected_model['type'], unexpected_model['id'])
% (unexpected_model['type'], unexpected_model[self.id_field])
)

def assertLogEntry(self, model_name: str, pk: int, change_type, message: str):
Expand Down Expand Up @@ -858,6 +867,7 @@ def _snapshot(self, method, url, data):
self._snapshot_num = 1

# obscure ids from url since they can drift depending on test order/results, remove leading tracker since it's redundant, and slugify everything else
# FIXME: this doesn't quite work for Country since we don't use PK lookups in the urls
pieces += [
f'S{self._snapshot_num}',
re.sub(
Expand Down
48 changes: 47 additions & 1 deletion tracker/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from rest_framework.validators import UniqueTogetherValidator

from tracker.api import messages
from tracker.models import Ad, Interstitial, Interview
from tracker.models.bid import Bid, DonationBid
from tracker.models.country import Country, CountryRegion
from tracker.models.donation import Donation, Donor, Milestone
from tracker.models.event import Event, SpeedRun, Tag, Talent, VideoLink, VideoLinkType
from tracker.models.interstitial import Ad, Interstitial, Interview

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -298,6 +299,46 @@ def to_representation(self, obj):
return obj.__class__.__name__.lower()


class CountrySerializer(PrimaryOrNaturalKeyLookup, TrackerModelSerializer):
type = ClassNameField()

class Meta:
model = Country
fields = (
'type',
'name',
'alpha2',
'alpha3',
'numeric',
)

def to_representation(self, instance):
if self.root == self or getattr(self.root, 'child', None) == self:
return super().to_representation(instance)
else:
return instance.alpha3


class CountryRegionSerializer(PrimaryOrNaturalKeyLookup, TrackerModelSerializer):
type = ClassNameField()
country = CountrySerializer()

class Meta:
model = CountryRegion
fields = (
'type',
'id',
'name',
'country',
)

def to_representation(self, instance):
if self.root == self or getattr(self.root, 'child', None) == self:
return super().to_representation(instance)
else:
return [instance.name, instance.country.alpha3]


class EventNestedSerializerMixin:
event_move = False

Expand Down Expand Up @@ -562,6 +603,9 @@ def get_donor_name(self, donation: Donation):

class EventSerializer(PrimaryOrNaturalKeyLookup, TrackerModelSerializer):
type = ClassNameField()
# include these later
# allowed_prize_countries = CountrySerializer(many=True)
# disallowed_prize_regions = CountryRegionSerializer(many=True)
timezone = serializers.SerializerMethodField()
amount = serializers.SerializerMethodField()
donation_count = serializers.SerializerMethodField()
Expand All @@ -583,6 +627,8 @@ class Meta:
'datetime',
'timezone',
'use_one_step_screening',
# 'allowed_prize_countries',
# 'disallowed_prize_regions',
)

def get_fields(self):
Expand Down
14 changes: 13 additions & 1 deletion tracker/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
from rest_framework import routers

from tracker.api import views
from tracker.api.views import ad, bids, donations, interview, me, milestone, run, talent
from tracker.api.views import (
ad,
bids,
country,
donations,
interview,
me,
milestone,
run,
talent,
)

router = routers.DefaultRouter()

Expand Down Expand Up @@ -34,6 +44,8 @@ def event_nested_route(path, viewset, *, basename=None, feed=False):
event_nested_route(r'milestones', milestone.MilestoneViewSet)
router.register(r'donations', donations.DonationViewSet, basename='donations')
router.register(r'me', me.MeViewSet, basename='me')
router.register(r'countries', country.CountryViewSet)
router.register(r'regions', country.CountryRegionViewSet, basename='region')

# use the router-generated URLs, and also link to the browsable API
urlpatterns = [
Expand Down
51 changes: 51 additions & 0 deletions tracker/api/views/country.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import re

from django.db.models import Q
from rest_framework.decorators import action
from rest_framework.exceptions import NotFound
from rest_framework.generics import get_object_or_404

from tracker.api.pagination import TrackerPagination
from tracker.api.serializers import CountryRegionSerializer, CountrySerializer
from tracker.api.views import TrackerReadViewSet
from tracker.models import Country, CountryRegion


class CountryViewSet(TrackerReadViewSet):
serializer_class = CountrySerializer
pagination_class = TrackerPagination
queryset = Country.objects.all()
lookup_field = 'numeric_or_alpha'

def get_object(self):
queryset = self.get_queryset()
pk = self.kwargs['numeric_or_alpha']
if re.match('[0-9]{3}', pk):
return get_object_or_404(queryset, numeric=pk)
elif re.match('[A-Z]{2,3}', pk):
return get_object_or_404(queryset, Q(alpha2=pk) | Q(alpha3=pk))
raise NotFound(
detail='Provide either an ISO 3166-1 numeric, alpha2, or alpha3 code',
code='invalid_lookup',
)

@action(detail=True)
def regions(self, request, *args, **kwargs):
viewset = CountryRegionViewSet(request=request, country=self.get_object())
viewset.initial(request, *args, **kwargs)
return viewset.list(request, *args, **kwargs)


class CountryRegionViewSet(TrackerReadViewSet):
serializer_class = CountryRegionSerializer
pagination_class = TrackerPagination
queryset = CountryRegion.objects.select_related('country')

def __init__(self, country=None, *args, **kwargs):
self.country = country
super().__init__(*args, **kwargs)

def filter_queryset(self, queryset):
if self.country:
queryset = queryset.filter(country=self.country)
return queryset

0 comments on commit eb3c605

Please sign in to comment.