From f36adf8a6b8e555d6fc67ef2f100dbd6a82a56b7 Mon Sep 17 00:00:00 2001 From: Kira Miller Date: Wed, 23 Oct 2024 20:26:54 +0000 Subject: [PATCH] fix: in progress commit --- enterprise/api/v1/serializers.py | 32 +++++ enterprise/api/v1/urls.py | 8 ++ .../v1/views/enterprise_customer_members.py | 111 ++++++++++++++++++ tests/test_enterprise/api/test_serializers.py | 43 ++++++- tests/test_enterprise/api/test_views.py | 38 ++++++ 5 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 enterprise/api/v1/views/enterprise_customer_members.py diff --git a/enterprise/api/v1/serializers.py b/enterprise/api/v1/serializers.py index 4bef00173..0b0140663 100644 --- a/enterprise/api/v1/serializers.py +++ b/enterprise/api/v1/serializers.py @@ -1905,3 +1905,35 @@ def get_role_assignments(self, obj): return role_assignments_by_ecu_id else: return None + +class EnterpriseMemberSerializer(serializers.Serializer): + """ + Serializer for EnterpriseCustomerUser model with additions. + """ + class Meta: + model = models.EnterpriseCustomerUser + + fields = ( + 'enterprise_customer_user', + 'user_email', + 'enrollments', + 'created', + ) + enterprise_customer_user = UserSerializer(source="user", required=False, default=None) + user_email = serializers.EmailField() + enrollments = serializers.SerializerMethodField() + + + def get_enrollments(self, obj): + """ + Fetch all of user's enterprise enrollments + """ + import pdb; + pdb.set_trace() + print('in here!') + if hasattr(obj, 'user_id'): + user_id = obj.user_id + enrollments = models.EnterpriseCourseEnrollment.objects.filter( + enterprise_customer_user=user_id, + ) + return len(enrollments) diff --git a/enterprise/api/v1/urls.py b/enterprise/api/v1/urls.py index 866ed4292..168f5a066 100644 --- a/enterprise/api/v1/urls.py +++ b/enterprise/api/v1/urls.py @@ -16,6 +16,7 @@ enterprise_customer_branding_configuration, enterprise_customer_catalog, enterprise_customer_invite_key, + enterprise_customer_members, enterprise_customer_reporting, enterprise_customer_sso_configuration, enterprise_customer_support, @@ -205,6 +206,13 @@ ), name='enterprise-customer-support' ), + re_path( + r'^enterprise-customer-members/(?P[A-Za-z0-9-]+)$', + enterprise_customer_members.EnterpriseCustomerMembersViewSet.as_view( + {'get': 'get_members'} + ), + name='enterprise-customer-members' + ), ] urlpatterns += router.urls diff --git a/enterprise/api/v1/views/enterprise_customer_members.py b/enterprise/api/v1/views/enterprise_customer_members.py new file mode 100644 index 000000000..ddbba015f --- /dev/null +++ b/enterprise/api/v1/views/enterprise_customer_members.py @@ -0,0 +1,111 @@ +""" +Views for the ``enterprise-customer-members`` API endpoint. +""" + +from collections import OrderedDict + +from django_filters.rest_framework import DjangoFilterBackend +from rest_framework import filters, permissions, response, status +from rest_framework.pagination import PageNumberPagination + +from django.contrib import auth +from django.core.exceptions import ValidationError +from django.db.models import Q + +from enterprise import models +from enterprise.api.v1 import serializers +from enterprise.api.v1.views.base_views import EnterpriseReadOnlyModelViewSet +from enterprise.logging import getEnterpriseLogger + +User = auth.get_user_model() + +LOGGER = getEnterpriseLogger(__name__) + +class EnterpriseCustomerMembersPaginator(PageNumberPagination): + """Custom paginator for the enterprise customer members.""" + + page_size = 6 + + def get_paginated_response(self, data): + """Return a paginated style `Response` object for the given output data.""" + return response.Response( + OrderedDict( + [ + ("count", self.page.paginator.count), + ("num_pages", self.page.paginator.num_pages), + ("next", self.get_next_link()), + ("previous", self.get_previous_link()), + ("results", data), + ] + ) + ) + + def paginate_queryset(self, queryset, request, view=None): + """ + Paginate a queryset if required, either returning a page object, + or `None` if pagination is not configured for this view. + + """ + if isinstance(queryset, filter): + queryset = list(queryset) + + return super().paginate_queryset(queryset, request, view) + + +class EnterpriseCustomerMembersViewSet(EnterpriseReadOnlyModelViewSet): + """ + API views for the ``enterprise-customer-members`` API endpoint. + """ + + queryset = models.PendingEnterpriseCustomerUser.objects.all() + filter_backends = (DjangoFilterBackend, filters.OrderingFilter) + permission_classes = (permissions.IsAuthenticated,) + paginator = EnterpriseCustomerMembersPaginator() + + def filter_queryset_by_user_query(self, queryset, is_pending_user=False): + """ + Filter queryset based on user provided query + """ + user_query = self.request.query_params.get("user_query", None) + if user_query: + queryset = models.EnterpriseCustomerUser.objects.filter( + user_id__in=User.objects.filter( + Q(email__icontains=user_query) | Q(username__icontains=user_query) + ) + ) + return queryset + + def get_members(self, request, *args, **kwargs): + """ + Filter down the queryset of groups available to the requesting uuid. + """ + enterprise_uuid = kwargs.get("enterprise_uuid", None) + users = [] + + try: + enterprise_customer_queryset = models.EnterpriseCustomerUser.objects.filter( + enterprise_customer__uuid=enterprise_uuid, + ) + enterprise_customer_queryset = self.filter_queryset_by_user_query( + enterprise_customer_queryset + ) + users.extend(enterprise_customer_queryset) + + except ValidationError: + # did not find UUID match in either EnterpriseCustomerUser + return response.Response( + {"detail": "Could not find enterprise uuid {}".format(enterprise_uuid)}, + status=status.HTTP_404_NOT_FOUND, + ) + + # default sort criteria + is_reversed = False + + # paginate the queryset + users_page = self.paginator.paginate_queryset(users, request, view=self) + + # serialize the paged dataset + serializer = serializers.EnterpriseMemberSerializer(users_page, many=True) + serializer_data = serializer.data + + return self.paginator.get_paginated_response(serializer_data) diff --git a/tests/test_enterprise/api/test_serializers.py b/tests/test_enterprise/api/test_serializers.py index aa330a356..77a76f168 100644 --- a/tests/test_enterprise/api/test_serializers.py +++ b/tests/test_enterprise/api/test_serializers.py @@ -19,6 +19,7 @@ EnterpriseCustomerReportingConfigurationSerializer, EnterpriseCustomerSerializer, EnterpriseCustomerUserReadOnlySerializer, + EnterpriseMemberSerializer, EnterpriseUserSerializer, ImmutableStateSerializer, ) @@ -470,7 +471,7 @@ def setUp(self): super().setUp() - # setup Enteprise Customer + # setup Enterprise Customer self.user_1 = factories.UserFactory() self.user_2 = factories.UserFactory() self.enterprise_customer_user_1 = factories.EnterpriseCustomerUserFactory(user_id=self.user_1.id) @@ -573,3 +574,43 @@ def test_serialize_pending_users(self): serialized_pending_admin_user = serializer.data self.assertEqual(expected_pending_admin_user, serialized_pending_admin_user) + +@mark.django_db +class TestEnterpriseUserSerializer(TestCase): + def setUp(self): + """ + Perform operations common for all tests. + """ + super().setUp() + + # setup Enterprise Customer + self.user_1 = factories.UserFactory() + self.user_2 = factories.UserFactory() + self.enterprise_customer_user_1 = factories.EnterpriseCustomerUserFactory(user_id=self.user_1.id) + self.enterprise_customer_user_2 = factories.EnterpriseCustomerUserFactory(user_id=self.user_2.id) + self.enterprise_customer_1 = self.enterprise_customer_user_1.enterprise_customer + self.enterprise_customer_2 = self.enterprise_customer_user_2.enterprise_customer + + self.enrollment_1 = factories.EnterpriseCourseEnrollmentFactory( + enterprise_customer_user=self.enterprise_customer_user_1, + ) + self.enrollment_1 = factories.EnterpriseCourseEnrollmentFactory( + enterprise_customer_user=self.enterprise_customer_user_1, + ) + self.enrollment_1 = factories.EnterpriseCourseEnrollmentFactory( + enterprise_customer_user=self.enterprise_customer_user_2, + ) + + def test_serialize_users(self): + for customer_user in [ + (self.enterprise_customer_user_1), + (self.enterprise_customer_user_2), + ]: + user = customer_user.user + serializer = EnterpriseMemberSerializer(customer_user) + print("serializer ", serializer) + print("data ", serializer.data) + serialized_user = serializer.data + + self.assertEqual('', serialized_user) + diff --git a/tests/test_enterprise/api/test_views.py b/tests/test_enterprise/api/test_views.py index 8c7fd276b..d709fac28 100644 --- a/tests/test_enterprise/api/test_views.py +++ b/tests/test_enterprise/api/test_views.py @@ -9768,3 +9768,41 @@ def test_list_users_filtered(self): assert expected_json == response.json().get('results') assert response.json().get('count') == 1 + + +@ddt.ddt +@mark.django_db +class TestEnterpriseCustomerMembers(BaseTestEnterpriseAPIViews): + """ + Test enterprise customer members list endpoint + """ + ECM_ENDPOINT = 'enterprise-customer-members' + ECM_KWARG = 'enterprise_uuid' + + def test_get_enterprise_org_members(self): + """ + Assert whether the response is valid. + """ + user = factories.UserFactory() + enterprise_customer = factories.EnterpriseCustomerFactory(uuid=FAKE_UUIDS[0]) + user = factories.EnterpriseCustomerUserFactory( + user_id=user.id, + enterprise_customer=enterprise_customer + ) + factories.EnterpriseCourseEnrollment( + enterprise_customer_user=user, + ) + + expected_json = {'this-is-going-to-fail': False} + # Test valid UUID + url = reverse(self.ECM_ENDPOINT, kwargs={self.ECM_KWARG: enterprise_customer.uuid}) + response = self.client.get(settings.TEST_SERVER + url) + + assert expected_json == response.json().get('results')[0] + + # Test invalid UUID + url = reverse(self.ECM_ENDPOINT, kwargs={self.ECM_KWARG: 123}) + response = self.client.get(settings.TEST_SERVER + url) + self.assertEqual(response.status_code, 404) + + \ No newline at end of file