Skip to content

Commit

Permalink
fix: Incorrect count of implicit tags when object had multiple tags
Browse files Browse the repository at this point in the history
  • Loading branch information
bradenmacdonald committed Dec 18, 2023
1 parent 43547b1 commit eb845c4
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
29 changes: 24 additions & 5 deletions openedx_tagging/core/tagging/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .data import TagDataQuerySet
from .models import ObjectTag, Tag, Taxonomy
from .models.utils import ConcatNull
from .models.utils import ConcatNull, StringAgg

# Export this as part of the API
TagDoesNotExist = Tag.DoesNotExist
Expand Down Expand Up @@ -218,13 +218,32 @@ def get_object_tag_counts(object_id_pattern: str, count_implicit=False) -> dict[
qs = qs.exclude(taxonomy__enabled=False) # The whole taxonomy is disabled
qs = qs.exclude(tag_id=None, taxonomy__allow_free_text=False) # The taxonomy exists but the tag is deleted
if count_implicit:
tags = Tag.annotate_depth(Tag.objects.filter(pk=models.OuterRef("tag_id")))
qs = qs.annotate(tag_depth=models.Subquery(tags.values('depth')))
# Counting the implicit tags is tricky, because if two "grandchild" tags have the same implicit parent tag, we
# need to count that parent tag only once. To do that, we collect all the ancestor tag IDs into an aggregate
# string, and then count the unique values using python
qs = qs.values("object_id").annotate(
num_tags=models.Count("id"),
num_implicit_tags=models.Sum("tag_depth"),
tag_ids_str_1=StringAgg("tag_id"),
tag_ids_str_2=StringAgg("tag__parent_id"),
tag_ids_str_3=StringAgg("tag__parent__parent_id"),
tag_ids_str_4=StringAgg("tag__parent__parent__parent_id"),
).order_by("object_id")
return {row["object_id"]: row["num_tags"] + (row["num_implicit_tags"] or 0) for row in qs}
result = {}
for row in qs:
# ObjectTags for free text taxonomies will be included in "num_tags" count, but not "tag_ids_str_1" since
# they have no tag ID. We can compute how many free text tags each object has now:
if row["tag_ids_str_1"]:
num_free_text_tags = row["num_tags"] - len(row["tag_ids_str_1"].split(","))
else:
num_free_text_tags = row["num_tags"]
# Then we count the total number of *unique* Tags for this object, both implicit and explicit:
other_tag_ids = set()
for field in ("tag_ids_str_1", "tag_ids_str_2", "tag_ids_str_3", "tag_ids_str_4"):
if row[field] is not None:
for tag_id in row[field].split(","):
other_tag_ids.add(int(tag_id))
result[row["object_id"]] = num_free_text_tags + len(other_tag_ids)
return result
else:
qs = qs.values("object_id").annotate(num_tags=models.Count("id")).order_by("object_id")
return {row["object_id"]: row["num_tags"] for row in qs}
Expand Down
22 changes: 21 additions & 1 deletion openedx_tagging/core/tagging/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Utilities for tagging and taxonomy models
"""

from django.db.models import Aggregate, CharField
from django.db.models.expressions import Func


Expand All @@ -22,3 +22,23 @@ def as_sqlite(self, compiler, connection, **extra_context):
arg_joiner=" || ",
**extra_context,
)


class StringAgg(Aggregate): # pylint: disable=abstract-method
"""
Aggregate function that collects the values of some column across all rows,
and creates a string by concatenating those values, with "," as a separator.
This is the same as Django's django.contrib.postgres.aggregates.StringAgg,
but this version works with MySQL and SQLite.
"""
function = 'GROUP_CONCAT'
template = '%(function)s(%(distinct)s%(expressions)s)'

def __init__(self, expression, distinct=False, **extra):
super().__init__(
expression,
distinct='DISTINCT ' if distinct else '',
output_field=CharField(),
**extra,
)
13 changes: 10 additions & 3 deletions tests/openedx_tagging/core/tagging/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,20 +711,27 @@ def test_get_object_tag_counts_implicit(self) -> None:
Note that:
- "DPANN" is "Archaea > DPANN" (2 tags, 1 implicit), and
- "Chordata" is "Eukaryota > Animalia > Chordata" (3 tags, 2 implicit)
- "Arthropoda" is "Eukaryota > Animalia > Arthropoda" (same)
"""
obj1 = "object_id1"
obj2 = "object_id2"
self.taxonomy.allow_multiple = True
self.taxonomy.save()
obj1, obj2, obj3 = "object_id1", "object_id2", "object_id3"
other = "other_object"
# Give each object 1-2 tags:
tagging_api.tag_object(object_id=obj1, taxonomy=self.taxonomy, tags=["DPANN"])
tagging_api.tag_object(object_id=obj2, taxonomy=self.taxonomy, tags=["Chordata"])
tagging_api.tag_object(object_id=obj2, taxonomy=self.free_text_taxonomy, tags=["has a notochord"])
tagging_api.tag_object(object_id=obj3, taxonomy=self.taxonomy, tags=["Chordata", "Arthropoda"])
tagging_api.tag_object(object_id=other, taxonomy=self.free_text_taxonomy, tags=["other"])

assert tagging_api.get_object_tag_counts(obj1, count_implicit=True) == {obj1: 2}
assert tagging_api.get_object_tag_counts(obj2, count_implicit=True) == {obj2: 4}
assert tagging_api.get_object_tag_counts(f"{obj1},{obj2}", count_implicit=True) == {obj1: 2, obj2: 4}
assert tagging_api.get_object_tag_counts("object_*", count_implicit=True) == {obj1: 2, obj2: 4}
assert tagging_api.get_object_tag_counts("object_*", count_implicit=True) == {
obj1: 2,
obj2: 4,
obj3: 4, # obj3 has 2 explicit tags and 2 implicit tags (not 4 because the implicit tags are the same)
}
assert tagging_api.get_object_tag_counts(other, count_implicit=True) == {other: 1}

def test_get_object_tag_counts_deleted_disabled(self) -> None:
Expand Down

0 comments on commit eb845c4

Please sign in to comment.