-
Notifications
You must be signed in to change notification settings - Fork 171
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add SearchAfterMixin for ES search_after capability
- Loading branch information
1 parent
01a8aaa
commit e2052c4
Showing
7 changed files
with
305 additions
and
21 deletions.
There are no files selected for viewing
94 changes: 94 additions & 0 deletions
94
course_discovery/apps/api/v2/tests/test_views/test_catalog_queries.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import urllib | ||
|
||
from rest_framework.reverse import reverse | ||
|
||
from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase | ||
from course_discovery.apps.core.tests.factories import UserFactory | ||
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin | ||
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory | ||
|
||
|
||
class CatalogQueryViewSetTests(ElasticsearchTestMixin, APITestCase): | ||
""" | ||
Unit tests for CatalogQueryViewSet. | ||
""" | ||
def setUp(self): | ||
super().setUp() | ||
self.user = UserFactory(is_staff=True, is_superuser=True) | ||
self.client.force_authenticate(self.user) | ||
self.course = CourseFactory(partner=self.partner, key='simple_key') | ||
self.course_run = CourseRunFactory(course=self.course, key='simple/key/run') | ||
self.url_base = reverse('api:v2:catalog-query_contains') | ||
self.error_message = 'CatalogQueryContains endpoint requires query and identifiers list(s)' | ||
self.refresh_index() | ||
|
||
def test_contains_single_course_run(self): | ||
""" Verify that a single course_run is contained in a query. """ | ||
qs = urllib.parse.urlencode({ | ||
'query': 'id:' + self.course_run.key, | ||
'course_run_ids': self.course_run.key, | ||
'course_uuids': self.course.uuid, | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
response = self.client.get(url) | ||
assert response.status_code == 200 | ||
assert response.data == {self.course_run.key: True, str(self.course.uuid): False} | ||
|
||
def test_contains_single_course(self): | ||
""" Verify that a single course is contained in a query. """ | ||
qs = urllib.parse.urlencode({ | ||
'query': 'key:' + self.course.key, | ||
'course_run_ids': self.course_run.key, | ||
'course_uuids': self.course.uuid, | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
response = self.client.get(url) | ||
assert response.status_code == 200 | ||
assert response.data == {self.course_run.key: False, str(self.course.uuid): True} | ||
|
||
def test_contains_course_and_run(self): | ||
""" Verify that both the course and the run are contained in the broadest query. """ | ||
self.course.course_runs.add(self.course_run) | ||
self.course.save() | ||
qs = urllib.parse.urlencode({ | ||
'query': 'org:*', | ||
'course_run_ids': self.course_run.key, | ||
'course_uuids': self.course.uuid, | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
response = self.client.get(url) | ||
assert response.status_code == 200 | ||
assert response.data == {self.course_run.key: True, str(self.course.uuid): True} | ||
|
||
def test_no_identifiers(self): | ||
""" Verify that a 400 status is returned if request does not contain any identifier lists. """ | ||
qs = urllib.parse.urlencode({ | ||
'query': 'id:*' | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
response = self.client.get(url) | ||
assert response.status_code == 400 | ||
assert response.data == self.error_message | ||
|
||
def test_no_query(self): | ||
""" Verify that a 400 status is returned if request does not contain a querystring. """ | ||
qs = urllib.parse.urlencode({ | ||
'course_run_ids': self.course_run.key, | ||
'course_uuids': self.course.uuid, | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
response = self.client.get(url) | ||
assert response.status_code == 400 | ||
assert response.data == self.error_message | ||
|
||
def test_incorrect_queries(self): | ||
""" Verify that a 400 status is returned if request contains incorrect query string. """ | ||
qs = urllib.parse.urlencode({ | ||
'query': 'title:', | ||
'course_run_ids': self.course_run.key, | ||
'course_uuids': self.course.uuid, | ||
}) | ||
url = f'{self.url_base}/?{qs}' | ||
|
||
response = self.client.get(url) | ||
assert response.status_code == 400 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,17 @@ | ||
"""API v2 URLs.""" | ||
|
||
from django.urls import re_path | ||
from rest_framework import routers | ||
|
||
from course_discovery.apps.api.v2.views import search as search_views | ||
from course_discovery.apps.api.v2.views.catalog_queries import CatalogQueryContainsViewSet | ||
|
||
app_name = 'v2' | ||
|
||
urlpatterns = [ | ||
re_path(r'^catalog/query_contains/?', CatalogQueryContainsViewSet.as_view(), name='catalog-query_contains'), | ||
] | ||
|
||
router = routers.SimpleRouter() | ||
router.register(r'search/all', search_views.AggregateSearchViewSet, basename='search-all') | ||
urlpatterns = router.urls | ||
urlpatterns += router.urls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import logging | ||
from uuid import UUID | ||
|
||
from elasticsearch_dsl.query import Q as ESDSLQ | ||
from rest_framework import status | ||
from rest_framework.generics import GenericAPIView | ||
from rest_framework.permissions import DjangoModelPermissions, IsAuthenticated | ||
from rest_framework.response import Response | ||
|
||
from course_discovery.apps.api.mixins import ValidElasticSearchQueryRequiredMixin | ||
from course_discovery.apps.course_metadata.models import Course, CourseRun, SearchAfterMixin | ||
from course_discovery.apps.course_metadata.search_indexes.documents import CourseDocument, CourseRunDocument | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class CatalogQueryContainsViewSet(ValidElasticSearchQueryRequiredMixin, GenericAPIView, SearchAfterMixin): | ||
permission_classes = (IsAuthenticated, DjangoModelPermissions) | ||
queryset = Course.objects.all() | ||
|
||
def get(self, request): | ||
""" | ||
Determine if a set of courses and/or course runs is found in the query results. | ||
Returns | ||
dict: mapping of course and run identifiers included in the request to boolean values | ||
indicating whether the associated course or run is contained in the queryset | ||
described by the query found in the request. | ||
""" | ||
query = request.GET.get('query') | ||
course_run_ids = request.GET.get('course_run_ids', None) | ||
course_uuids = request.GET.get('course_uuids', None) | ||
partner = self.request.site.partner | ||
|
||
if query and (course_run_ids or course_uuids): | ||
log.info( | ||
f"Attempting search against query {query} with course UUIDs {course_uuids} " | ||
f"and course run IDs {course_run_ids}" | ||
) | ||
identified_course_ids = set() | ||
specified_course_ids = [] | ||
if course_run_ids: | ||
course_run_ids = course_run_ids.split(',') | ||
specified_course_ids = course_run_ids | ||
identified_course_ids.update( | ||
i.key | ||
for i in self.search( | ||
query, | ||
queryset=CourseRun.objects.all(), | ||
partner=ESDSLQ('term', partner=partner.short_code), | ||
identifiers=ESDSLQ('terms', **{'key.raw': course_run_ids}), | ||
document=CourseRunDocument | ||
) | ||
) | ||
if course_uuids: | ||
course_uuids = [UUID(course_uuid) for course_uuid in course_uuids.split(',')] | ||
specified_course_ids += course_uuids | ||
|
||
log.info(f"Specified course ids: {specified_course_ids}") | ||
identified_course_ids.update( | ||
self.search( | ||
query, | ||
queryset=Course.objects.all(), | ||
partner=ESDSLQ('term', partner=partner.short_code), | ||
identifiers=ESDSLQ('terms', **{'uuid': course_uuids}), | ||
document=CourseDocument | ||
).values_list('uuid', flat=True) | ||
) | ||
log.info(f"Identified {len(identified_course_ids)} course ids: {identified_course_ids}") | ||
|
||
contains = {str(identifier): identifier in identified_course_ids for identifier in specified_course_ids} | ||
return Response(contains) | ||
return Response( | ||
'CatalogQueryContains endpoint requires query and identifiers list(s)', status=status.HTTP_400_BAD_REQUEST | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.