diff --git a/openedx_tagging/core/tagging/api.py b/openedx_tagging/core/tagging/api.py index 52e4f84b..9cde39a4 100644 --- a/openedx_tagging/core/tagging/api.py +++ b/openedx_tagging/core/tagging/api.py @@ -21,7 +21,7 @@ from .data import TagDataQuerySet from .models import ObjectTag, Tag, Taxonomy -from .models.utils import ConcatNull +from .models.utils import ConcatNull, StringAgg # Export this as part of the API TagDoesNotExist = Tag.DoesNotExist @@ -218,13 +218,32 @@ def get_object_tag_counts(object_id_pattern: str, count_implicit=False) -> dict[ qs = qs.exclude(taxonomy__enabled=False) # The whole taxonomy is disabled qs = qs.exclude(tag_id=None, taxonomy__allow_free_text=False) # The taxonomy exists but the tag is deleted if count_implicit: - tags = Tag.annotate_depth(Tag.objects.filter(pk=models.OuterRef("tag_id"))) - qs = qs.annotate(tag_depth=models.Subquery(tags.values('depth'))) + # Counting the implicit tags is tricky, because if two "grandchild" tags have the same implicit parent tag, we + # need to count that parent tag only once. To do that, we collect all the ancestor tag IDs into an aggregate + # string, and then count the unique values using python qs = qs.values("object_id").annotate( num_tags=models.Count("id"), - num_implicit_tags=models.Sum("tag_depth"), + tag_ids_str_1=StringAgg("tag_id"), + tag_ids_str_2=StringAgg("tag__parent_id"), + tag_ids_str_3=StringAgg("tag__parent__parent_id"), + tag_ids_str_4=StringAgg("tag__parent__parent__parent_id"), ).order_by("object_id") - return {row["object_id"]: row["num_tags"] + (row["num_implicit_tags"] or 0) for row in qs} + result = {} + for row in qs: + # ObjectTags for free text taxonomies will be included in "num_tags" count, but not "tag_ids_str_1" since + # they have no tag ID. We can compute how many free text tags each object has now: + if row["tag_ids_str_1"]: + num_free_text_tags = row["num_tags"] - len(row["tag_ids_str_1"].split(",")) + else: + num_free_text_tags = row["num_tags"] + # Then we count the total number of *unique* Tags for this object, both implicit and explicit: + other_tag_ids = set() + for field in ("tag_ids_str_1", "tag_ids_str_2", "tag_ids_str_3", "tag_ids_str_4"): + if row[field] is not None: + for tag_id in row[field].split(","): + other_tag_ids.add(int(tag_id)) + result[row["object_id"]] = num_free_text_tags + len(other_tag_ids) + return result else: qs = qs.values("object_id").annotate(num_tags=models.Count("id")).order_by("object_id") return {row["object_id"]: row["num_tags"] for row in qs} diff --git a/openedx_tagging/core/tagging/models/utils.py b/openedx_tagging/core/tagging/models/utils.py index 1f7dcb92..c2e2201e 100644 --- a/openedx_tagging/core/tagging/models/utils.py +++ b/openedx_tagging/core/tagging/models/utils.py @@ -1,7 +1,7 @@ """ Utilities for tagging and taxonomy models """ - +from django.db.models import Aggregate, CharField from django.db.models.expressions import Func @@ -22,3 +22,23 @@ def as_sqlite(self, compiler, connection, **extra_context): arg_joiner=" || ", **extra_context, ) + + +class StringAgg(Aggregate): # pylint: disable=abstract-method + """ + Aggregate function that collects the values of some column across all rows, + and creates a string by concatenating those values, with "," as a separator. + + This is the same as Django's django.contrib.postgres.aggregates.StringAgg, + but this version works with MySQL and SQLite. + """ + function = 'GROUP_CONCAT' + template = '%(function)s(%(distinct)s%(expressions)s)' + + def __init__(self, expression, distinct=False, **extra): + super().__init__( + expression, + distinct='DISTINCT ' if distinct else '', + output_field=CharField(), + **extra, + ) diff --git a/tests/openedx_tagging/core/tagging/test_api.py b/tests/openedx_tagging/core/tagging/test_api.py index 400073cd..5befef5c 100644 --- a/tests/openedx_tagging/core/tagging/test_api.py +++ b/tests/openedx_tagging/core/tagging/test_api.py @@ -711,20 +711,27 @@ def test_get_object_tag_counts_implicit(self) -> None: Note that: - "DPANN" is "Archaea > DPANN" (2 tags, 1 implicit), and - "Chordata" is "Eukaryota > Animalia > Chordata" (3 tags, 2 implicit) + - "Arthropoda" is "Eukaryota > Animalia > Arthropoda" (same) """ - obj1 = "object_id1" - obj2 = "object_id2" + self.taxonomy.allow_multiple = True + self.taxonomy.save() + obj1, obj2, obj3 = "object_id1", "object_id2", "object_id3" other = "other_object" # Give each object 1-2 tags: tagging_api.tag_object(object_id=obj1, taxonomy=self.taxonomy, tags=["DPANN"]) tagging_api.tag_object(object_id=obj2, taxonomy=self.taxonomy, tags=["Chordata"]) tagging_api.tag_object(object_id=obj2, taxonomy=self.free_text_taxonomy, tags=["has a notochord"]) + tagging_api.tag_object(object_id=obj3, taxonomy=self.taxonomy, tags=["Chordata", "Arthropoda"]) tagging_api.tag_object(object_id=other, taxonomy=self.free_text_taxonomy, tags=["other"]) assert tagging_api.get_object_tag_counts(obj1, count_implicit=True) == {obj1: 2} assert tagging_api.get_object_tag_counts(obj2, count_implicit=True) == {obj2: 4} assert tagging_api.get_object_tag_counts(f"{obj1},{obj2}", count_implicit=True) == {obj1: 2, obj2: 4} - assert tagging_api.get_object_tag_counts("object_*", count_implicit=True) == {obj1: 2, obj2: 4} + assert tagging_api.get_object_tag_counts("object_*", count_implicit=True) == { + obj1: 2, + obj2: 4, + obj3: 4, # obj3 has 2 explicit tags and 2 implicit tags (not 4 because the implicit tags are the same) + } assert tagging_api.get_object_tag_counts(other, count_implicit=True) == {other: 1} def test_get_object_tag_counts_deleted_disabled(self) -> None: