Skip to content

Commit

Permalink
feat: add SearchAfterMixin for ES search_after capability
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali-D-Akbar committed Jan 17, 2025
1 parent 01a8aaa commit b566d5e
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import urllib

from rest_framework.reverse import reverse

from course_discovery.apps.api.v1.tests.test_views.mixins import APITestCase
from course_discovery.apps.core.tests.factories import UserFactory
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory


class CatalogQueryViewSetTests(ElasticsearchTestMixin, APITestCase):
"""
Unit tests for CatalogQueryViewSet.
"""
def setUp(self):
super().setUp()
self.user = UserFactory(is_staff=True, is_superuser=True)
self.client.force_authenticate(self.user)
self.course = CourseFactory(partner=self.partner, key='simple_key')
self.course_run = CourseRunFactory(course=self.course, key='simple/key/run')
self.url_base = reverse('api:v2:catalog-query_contains')
self.error_message = 'CatalogQueryContains endpoint requires query and identifiers list(s)'
self.refresh_index()

def test_contains_single_course_run(self):
""" Verify that a single course_run is contained in a query. """
qs = urllib.parse.urlencode({
'query': 'id:' + self.course_run.key,
'course_run_ids': self.course_run.key,
'course_uuids': self.course.uuid,
})
url = f'{self.url_base}/?{qs}'
response = self.client.get(url)
assert response.status_code == 200
assert response.data == {self.course_run.key: True, str(self.course.uuid): False}

def test_contains_single_course(self):
""" Verify that a single course is contained in a query. """
qs = urllib.parse.urlencode({
'query': 'key:' + self.course.key,
'course_run_ids': self.course_run.key,
'course_uuids': self.course.uuid,
})
url = f'{self.url_base}/?{qs}'
response = self.client.get(url)
assert response.status_code == 200
assert response.data == {self.course_run.key: False, str(self.course.uuid): True}

def test_contains_course_and_run(self):
""" Verify that both the course and the run are contained in the broadest query. """
self.course.course_runs.add(self.course_run)
self.course.save()
qs = urllib.parse.urlencode({
'query': 'org:*',
'course_run_ids': self.course_run.key,
'course_uuids': self.course.uuid,
})
url = f'{self.url_base}/?{qs}'
response = self.client.get(url)
assert response.status_code == 200
assert response.data == {self.course_run.key: True, str(self.course.uuid): True}

def test_no_identifiers(self):
""" Verify that a 400 status is returned if request does not contain any identifier lists. """
qs = urllib.parse.urlencode({
'query': 'id:*'
})
url = f'{self.url_base}/?{qs}'
response = self.client.get(url)
assert response.status_code == 400
assert response.data == self.error_message

def test_no_query(self):
""" Verify that a 400 status is returned if request does not contain a querystring. """
qs = urllib.parse.urlencode({
'course_run_ids': self.course_run.key,
'course_uuids': self.course.uuid,
})
url = f'{self.url_base}/?{qs}'
response = self.client.get(url)
assert response.status_code == 400
assert response.data == self.error_message

def test_incorrect_queries(self):
""" Verify that a 400 status is returned if request contains incorrect query string. """
qs = urllib.parse.urlencode({
'query': 'title:',
'course_run_ids': self.course_run.key,
'course_uuids': self.course.uuid,
})
url = f'{self.url_base}/?{qs}'

response = self.client.get(url)
assert response.status_code == 400
8 changes: 7 additions & 1 deletion course_discovery/apps/api/v2/urls.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""API v2 URLs."""

from django.urls import re_path
from rest_framework import routers

from course_discovery.apps.api.v2.views import search as search_views
from course_discovery.apps.api.v2.views.catalog_queries import CatalogQueryContainsViewSet

app_name = 'v2'

urlpatterns = [
re_path(r'^catalog/query_contains/?', CatalogQueryContainsViewSet.as_view(), name='catalog-query_contains'),
]

router = routers.SimpleRouter()
router.register(r'search/all', search_views.AggregateSearchViewSet, basename='search-all')
urlpatterns = router.urls
urlpatterns += router.urls
75 changes: 75 additions & 0 deletions course_discovery/apps/api/v2/views/catalog_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
from uuid import UUID

from elasticsearch_dsl.query import Q as ESDSLQ
from rest_framework import status
from rest_framework.generics import GenericAPIView
from rest_framework.permissions import DjangoModelPermissions, IsAuthenticated
from rest_framework.response import Response

from course_discovery.apps.api.mixins import ValidElasticSearchQueryRequiredMixin
from course_discovery.apps.course_metadata.models import Course, CourseRun, SearchAfterMixin
from course_discovery.apps.course_metadata.search_indexes.documents import CourseDocument, CourseRunDocument

log = logging.getLogger(__name__)


class CatalogQueryContainsViewSet(ValidElasticSearchQueryRequiredMixin, GenericAPIView, SearchAfterMixin):
permission_classes = (IsAuthenticated, DjangoModelPermissions)
queryset = Course.objects.all()

def get(self, request):
"""
Determine if a set of courses and/or course runs is found in the query results.
Returns
dict: mapping of course and run identifiers included in the request to boolean values
indicating whether the associated course or run is contained in the queryset
described by the query found in the request.
"""
query = request.GET.get('query')
course_run_ids = request.GET.get('course_run_ids', None)
course_uuids = request.GET.get('course_uuids', None)
partner = self.request.site.partner

if query and (course_run_ids or course_uuids):
log.info(
f"Attempting search against query {query} with course UUIDs {course_uuids} "
f"and course run IDs {course_run_ids}"
)
identified_course_ids = set()
specified_course_ids = []
if course_run_ids:
course_run_ids = course_run_ids.split(',')
specified_course_ids = course_run_ids
identified_course_ids.update(
i.key
for i in self.search(
query,
queryset=CourseRun.objects.all(),
partner=ESDSLQ('term', partner=partner.short_code),
identifiers=ESDSLQ('terms', **{'key.raw': course_run_ids}),
document=CourseRunDocument
)
)
if course_uuids:
course_uuids = [UUID(course_uuid) for course_uuid in course_uuids.split(',')]
specified_course_ids += course_uuids

log.info(f"Specified course ids: {specified_course_ids}")
identified_course_ids.update(
self.search(
query,
queryset=Course.objects.all(),
partner=ESDSLQ('term', partner=partner.short_code),
identifiers=ESDSLQ('terms', **{'uuid': course_uuids}),
document=CourseDocument
).values_list('uuid', flat=True)
)
log.info(f"Identified {len(identified_course_ids)} course ids: {identified_course_ids}")

contains = {str(identifier): identifier in identified_course_ids for identifier in specified_course_ids}
return Response(contains)
return Response(
'CatalogQueryContains endpoint requires query and identifiers list(s)', status=status.HTTP_400_BAD_REQUEST
)
69 changes: 69 additions & 0 deletions course_discovery/apps/course_metadata/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,75 @@ def search(cls, query, queryset=None):
return filtered_queryset


class SearchAfterMixin:
"""
Represents objects to query Elasticsearch with `search_after` pagination and load by primary key.
"""

@classmethod
def search(cls, query, queryset=None, page_size=settings.ELASTICSEARCH_DSL_QUERYSET_PAGINATION, partner=None,
identifiers=None, document=None):
"""
Queries the Elasticsearch index with optional pagination using `search_after`.
Args:
query (str) -- Elasticsearch querystring (e.g. `title:intro*`)
queryset (models.QuerySet) -- base queryset to search, defaults to objects.all()
page_size (int) -- Number of results per page.
partner (object) -- To be included in the ES query.
identifiers (object) -- UUID or key of a product.
Returns:
QuerySet
"""
query = clean_query(query)
queryset = queryset or cls.objects.all()

if query == '(*)':
# Early-exit optimization. Wildcard searching is very expensive in elasticsearch. And since we just
# want everything, we don't need to actually query elasticsearch at all.
return queryset

logger.info(f"Attempting Elasticsearch document search against query: {query}")
es_document = document or next(iter(registry.get_documents(models=(cls,))), None)

must_queries = [ESDSLQ('query_string', query=query, analyze_wildcard=True)]
if partner:
must_queries.append(partner)
if identifiers:
must_queries.append(identifiers)

dsl_query = ESDSLQ('bool', must=must_queries)

all_ids = set()
search_after = None

while True:
search = (
es_document.search()
.query(dsl_query)
.sort('id')
.extra(size=page_size)
)

search = search.extra(search_after=search_after) if search_after else search

results = search.execute()

ids = {result.pk for result in results}
if not ids:
logger.info("No more results found.")
break

all_ids.update(ids)
search_after = results[-1].meta.sort if results[-1] else None
logger.info(f"Fetched {len(ids)} records; total so far: {len(all_ids)}")

filtered_queryset = queryset.filter(pk__in=all_ids)
logger.info(f"Filtered queryset of size {len(filtered_queryset)} for query: {query}")
return filtered_queryset


class Collaborator(TimeStampedModel):
"""
Collaborator model, defining any collaborators who helped write course content.
Expand Down
12 changes: 12 additions & 0 deletions course_discovery/apps/course_metadata/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,3 +1011,15 @@ class Meta:

course_run = factory.SubFactory(CourseRunFactory)
restriction_type = FuzzyChoice([name for name, __ in CourseRunRestrictionType.choices])


class CourseProxy(SearchAfterMixin, Course):
"""Proxy model for testing SearchAfterMixin with Course."""
class Meta:
proxy = True


class CourseProxyFactory(CourseFactory):
"""Factory for the CourseProxy proxy model."""
class Meta:
model = CourseProxy
25 changes: 23 additions & 2 deletions course_discovery/apps/course_metadata/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from course_discovery.apps.api.v1.tests.test_views.mixins import OAuth2Mixin
from course_discovery.apps.core.models import Currency
from course_discovery.apps.core.tests.helpers import make_image_file
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.core.utils import SearchQuerySetWrapper
from course_discovery.apps.course_metadata.choices import (
CourseRunRestrictionType, CourseRunStatus, ExternalProductStatus, ProgramStatus
Expand All @@ -44,13 +45,15 @@
from course_discovery.apps.course_metadata.publishers import (
CourseRunMarketingSitePublisher, ProgramMarketingSitePublisher
)
from course_discovery.apps.course_metadata.search_indexes.documents import CourseDocument
from course_discovery.apps.course_metadata.signals import (
connect_course_data_modified_timestamp_related_models, disconnect_course_data_modified_timestamp_related_models
)
from course_discovery.apps.course_metadata.tests import factories
from course_discovery.apps.course_metadata.tests.factories import (
AdditionalMetadataFactory, CourseFactory, CourseRunFactory, CourseTypeFactory, CourseUrlSlugFactory, ImageFactory,
OrganizationFactory, PartnerFactory, ProgramFactory, SeatFactory, SeatTypeFactory, SourceFactory, SubjectFactory
AdditionalMetadataFactory, CourseFactory, CourseProxy, CourseRunFactory, CourseTypeFactory, CourseUrlSlugFactory,
ImageFactory, OrganizationFactory, PartnerFactory, ProgramFactory, SeatFactory, SeatTypeFactory, SourceFactory,
SubjectFactory
)
from course_discovery.apps.course_metadata.tests.mixins import MarketingSitePublisherTestMixin
from course_discovery.apps.course_metadata.toggles import (
Expand Down Expand Up @@ -4192,3 +4195,21 @@ def test_basic(self):
self.assertEqual(course_run.restricted_run, restricted_course_run)
self.assertEqual(restricted_course_run.restriction_type, 'custom-b2b-enterprise')
self.assertEqual(str(restricted_course_run), "course-v1:SC+BreadX+3T2015: <custom-b2b-enterprise>")


class TestSearchAfterMixin(ElasticsearchTestMixin, TestCase):
def setUp(self):
super().setUp()

self.total_courses = 5
for _ in range(self.total_courses):
CourseFactory()

@patch("course_discovery.apps.course_metadata.models.registry.get_documents")
def test_fetch_all_courses(self, mock_get_documents):
query = "Course*"
mock_get_documents.return_value = [CourseDocument]

queryset = CourseProxy.search(query=query, page_size=2)

self.assertEqual(len(queryset), self.total_courses)
Loading

0 comments on commit b566d5e

Please sign in to comment.