Skip to content

Commit

Permalink
Refs django#373 - Added composite GenericForeignKey support.
Browse files Browse the repository at this point in the history
  • Loading branch information
csirmazbendeguz committed Aug 21, 2024
1 parent 22c80e7 commit 28ae360
Show file tree
Hide file tree
Showing 12 changed files with 297 additions and 22 deletions.
46 changes: 38 additions & 8 deletions django/contrib/contenttypes/fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import itertools
import json
import warnings
from collections import defaultdict

Expand All @@ -8,7 +9,8 @@
from django.contrib.contenttypes.models import ContentType
from django.core import checks
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, models, router, transaction
from django.core.serializers.json import DjangoJSONEncoder
from django.db import DEFAULT_DB_ALIAS, connection, models, router, transaction
from django.db.models import DO_NOTHING, ForeignObject, ForeignObjectRel
from django.db.models.base import ModelBase, make_foreign_order_accessors
from django.db.models.fields import Field
Expand All @@ -25,6 +27,29 @@
from django.utils.functional import cached_property


def serialize_pk(obj):
if isinstance(obj.pk, tuple):
return json.dumps(
[
field.get_db_prep_value(value, connection)
for field, value in zip(obj._meta.pk, obj.pk)
],
cls=DjangoJSONEncoder,
)
else:
return obj.pk


def deserialize_pk(pk):
if isinstance(pk, str) and pk.startswith("[") and pk.endswith("]"):
try:
return tuple(json.loads(pk))
except json.JSONDecodeError:
return pk
else:
return pk


class GenericForeignKey(FieldCacheMixin, Field):
"""
Provide a generic many-to-one relation through the ``content_type`` and
Expand Down Expand Up @@ -195,11 +220,12 @@ def get_prefetch_querysets(self, instances, querysets=None):
if ct_id is not None:
fk_val = getattr(instance, self.fk_field)
if fk_val is not None:
fk_dict[ct_id].add(fk_val)
fk_dict[ct_id].add(deserialize_pk(fk_val))
instance_dict[ct_id] = instance

ret_val = []
for ct_id, fkeys in fk_dict.items():
fkeys = list(fkeys)
if ct_id in custom_queryset_dict:
# Return values from the custom queryset, if provided.
ret_val.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys))
Expand All @@ -225,7 +251,7 @@ def gfk_key(obj):

return (
ret_val,
lambda obj: (obj.pk, obj.__class__),
lambda obj: (serialize_pk(obj), obj.__class__),
gfk_key,
True,
self.name,
Expand All @@ -242,7 +268,8 @@ def __get__(self, instance, cls=None):
# use ContentType.objects.get_for_id(), which has a global cache.
f = self.model._meta.get_field(self.ct_field)
ct_id = getattr(instance, f.attname, None)
pk_val = getattr(instance, self.fk_field)
fk_val = getattr(instance, self.fk_field)
pk_val = deserialize_pk(fk_val)

rel_obj = self.get_cached_value(instance, default=None)
if rel_obj is None and self.is_cached(instance):
Expand All @@ -251,7 +278,9 @@ def __get__(self, instance, cls=None):
ct_match = (
ct_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id
)
pk_match = ct_match and rel_obj._meta.pk.to_python(pk_val) == rel_obj.pk
pk_match = ct_match and (
rel_obj._meta.pk.to_python(pk_val)
) == deserialize_pk(serialize_pk(rel_obj))
if pk_match:
return rel_obj
else:
Expand All @@ -272,7 +301,7 @@ def __set__(self, instance, value):
fk = None
if value is not None:
ct = self.get_content_type(obj=value)
fk = value.pk
fk = serialize_pk(value)

setattr(instance, self.ct_field, ct)
setattr(instance, self.fk_field, fk)
Expand Down Expand Up @@ -541,7 +570,8 @@ def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
% self.content_type_field_name: ContentType.objects.db_manager(using)
.get_for_model(self.model, for_concrete_model=self.for_concrete_model)
.pk,
"%s__in" % self.object_id_field_name: [obj.pk for obj in objs],
"%s__in"
% self.object_id_field_name: [serialize_pk(obj) for obj in objs],
}
)

Expand Down Expand Up @@ -589,7 +619,7 @@ def __init__(self, instance=None):
self.content_type_field_name = rel.field.content_type_field_name
self.object_id_field_name = rel.field.object_id_field_name
self.prefetch_cache_name = rel.field.attname
self.pk_val = instance.pk
self.pk_val = serialize_pk(instance)

self.core_filters = {
"%s__pk" % self.content_type_field_name: self.content_type.id,
Expand Down
32 changes: 29 additions & 3 deletions django/db/backends/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from django.conf import settings
from django.db import NotSupportedError, transaction
from django.db.backends import utils
from django.db.models.expressions import Col
from django.db.models.expressions import Col, ColPairs
from django.utils import timezone
from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.encoding import force_str
Expand Down Expand Up @@ -800,7 +800,33 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel
return ""

def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
lhs_expr = Col(lhs_table, lhs_field)
rhs_expr = Col(rhs_table, rhs_field)
lhs_expr = lhs_field.get_col(lhs_table)
rhs_expr = rhs_field.get_col(rhs_table)

if (
isinstance(lhs_expr, Col)
and isinstance(rhs_expr, ColPairs)
and len(rhs_expr) > 1
):
return self.prepare_join_on_json_clause(lhs_expr, rhs_expr)
elif (
isinstance(lhs_expr, ColPairs)
and isinstance(rhs_expr, Col)
and len(lhs_expr) > 1
):
return self.prepare_join_on_json_clause(rhs_expr, lhs_expr)

return lhs_expr, rhs_expr

def prepare_join_on_json_clause(self, lhs_expr, rhs_expr):
"""
If a generic foreign key refers to a composite primary key, the object_id is in
a JSON format backed by a CharField / TextField field.
To support joining on generic foreign keys, use the backend-specific JSON
functions.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a "
"prepare_join_on_json_clause() method."
)
5 changes: 4 additions & 1 deletion django/db/backends/mysql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta
from django.db.models import Exists, ExpressionWrapper, Lookup
from django.db.models import Exists, ExpressionWrapper, Func, Lookup
from django.db.models.constants import OnConflict
from django.utils import timezone
from django.utils.encoding import force_str
Expand Down Expand Up @@ -456,3 +456,6 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel
update_fields,
unique_fields,
)

def prepare_join_on_json_clause(self, lhs_expr, rhs_expr):
return lhs_expr, Func(rhs_expr, function="JSON_ARRAY")
7 changes: 6 additions & 1 deletion django/db/backends/oracle/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup
from django.db.models.expressions import RawSQL
from django.db.models.expressions import Func, RawSQL
from django.db.models.sql.where import WhereNode
from django.utils import timezone
from django.utils.encoding import force_bytes, force_str
Expand Down Expand Up @@ -729,3 +729,8 @@ def conditional_expression_supported_in_where_clause(self, expression):
if isinstance(expression, RawSQL) and expression.conditional:
return True
return False

def prepare_join_on_json_clause(self, lhs_expr, rhs_expr):
return Func(lhs_expr, function="JSON_SERIALIZE"), Func(
rhs_expr, function="JSON_ARRAY"
)
5 changes: 5 additions & 0 deletions django/db/backends/postgresql/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.db.models import Func


class JSONBBuildArray(Func):
function = "JSONB_BUILD_ARRAY"
9 changes: 8 additions & 1 deletion django/db/backends/postgresql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.postgresql.functions import JSONBBuildArray
from django.db.backends.postgresql.psycopg_any import (
Inet,
Jsonb,
Expand All @@ -11,6 +12,7 @@
mogrify,
)
from django.db.backends.utils import split_tzname_delta
from django.db.models import JSONField
from django.db.models.constants import OnConflict
from django.db.models.functions import Cast
from django.utils.regex_helper import _lazy_re_compile
Expand Down Expand Up @@ -407,7 +409,12 @@ def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
lhs_table, lhs_field, rhs_table, rhs_field
)

if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
if not isinstance(rhs_expr, JSONBBuildArray) and lhs_field.db_type(
self.connection
) != rhs_field.db_type(self.connection):
rhs_expr = Cast(rhs_expr, lhs_field)

return lhs_expr, rhs_expr

def prepare_join_on_json_clause(self, lhs_expr, rhs_expr):
return Cast(lhs_expr, JSONField()), JSONBBuildArray(rhs_expr)
5 changes: 4 additions & 1 deletion django/db/backends/sqlite3/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from django.db import DatabaseError, NotSupportedError, models
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models.constants import OnConflict
from django.db.models.expressions import Col
from django.db.models.expressions import Col, Func
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.functional import cached_property
Expand Down Expand Up @@ -431,3 +431,6 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel

def force_group_by(self):
return ["GROUP BY TRUE"] if Database.sqlite_version_info < (3, 39) else []

def prepare_join_on_json_clause(self, lhs_expr, rhs_expr):
return Func(lhs_expr, function="JSON"), Func(rhs_expr, function="JSON_ARRAY")
11 changes: 6 additions & 5 deletions django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,12 @@ def __len__(self):
def __iter__(self):
return iter(self.get_cols())

def __repr__(self):
return "{}({})".format(
self.__class__.__name__,
", ".join(repr(col) for col in self.get_cols()),
)

def get_cols(self):
return [
Col(self.alias, target, source)
Expand Down Expand Up @@ -1374,11 +1380,6 @@ def relabeled_clone(self, relabels):
def resolve_expression(self, *args, **kwargs):
return self

@staticmethod
def db_converter(value, *_):
assert isinstance(value, list)
return (tuple(value),)


class Ref(Expression):
"""
Expand Down
3 changes: 2 additions & 1 deletion tests/backends/base/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def test_datetime_extract_sql(self):
def test_prepare_join_on_clause(self):
author_table = Author._meta.db_table
author_id_field = Author._meta.get_field("id")
author_name_field = Author._meta.get_field("name")
book_table = Book._meta.db_table
book_fk_field = Book._meta.get_field("author")
lhs_expr, rhs_expr = self.ops.prepare_join_on_clause(
Expand All @@ -167,7 +168,7 @@ def test_prepare_join_on_clause(self):
book_fk_field,
)
self.assertEqual(lhs_expr, Col(author_table, author_id_field))
self.assertEqual(rhs_expr, Col(book_table, book_fk_field))
self.assertEqual(rhs_expr, Col(book_table, book_fk_field, author_name_field))


class DatabaseOperationTests(TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tests/composite_pk/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .tenant import Comment, Post, Tenant, Token, User
from .tenant import CharTag, Comment, Post, Tenant, Token, User

__all__ = [
"Comment",
"Post",
"CharTag",
"Tenant",
"Token",
"User",
Expand Down
12 changes: 12 additions & 0 deletions tests/composite_pk/models/tenant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.db import models


Expand Down Expand Up @@ -48,3 +50,13 @@ class Post(models.Model):
pk = models.CompositePrimaryKey("tenant_id", "id")
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
id = models.UUIDField()
chartags = GenericRelation("CharTag", related_query_name="post")


class CharTag(models.Model):
name = models.CharField(max_length=5)
content_type = models.ForeignKey(
ContentType, on_delete=models.CASCADE, related_name="composite_pk_chartags"
)
object_id = models.CharField(max_length=50)
content_object = GenericForeignKey("content_type", "object_id")
Loading

0 comments on commit 28ae360

Please sign in to comment.