Skip to content

Commit

Permalink
Refs django#373 - Add composite generic foreign key support
Browse files Browse the repository at this point in the history
  • Loading branch information
csirmazbendeguz committed Aug 12, 2024
1 parent 52f819e commit de85dc0
Show file tree
Hide file tree
Showing 14 changed files with 273 additions and 10 deletions.
37 changes: 32 additions & 5 deletions django/contrib/contenttypes/fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import itertools
import json
import warnings
from collections import defaultdict

Expand All @@ -8,7 +9,8 @@
from django.contrib.contenttypes.models import ContentType
from django.core import checks
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, models, router, transaction
from django.core.serializers.json import DjangoJSONEncoder
from django.db import DEFAULT_DB_ALIAS, connection, models, router, transaction
from django.db.models import DO_NOTHING, ForeignObject, ForeignObjectRel
from django.db.models.base import ModelBase, make_foreign_order_accessors
from django.db.models.fields import Field
Expand All @@ -25,6 +27,29 @@
from django.utils.functional import cached_property


def serialize_pk(obj):
if isinstance(obj.pk, tuple):
return json.dumps(
[
field.get_db_prep_value(value, connection)
for field, value in zip(obj._meta.pk, obj.pk)
],
cls=DjangoJSONEncoder,
)
else:
return obj.pk


def deserialize_pk(pk):
if isinstance(pk, str) and pk.startswith("[") and pk.endswith("]"):
try:
return json.loads(pk)
except json.JSONDecodeError:
return pk
else:
return pk


class GenericForeignKey(FieldCacheMixin, Field):
"""
Provide a generic many-to-one relation through the ``content_type`` and
Expand Down Expand Up @@ -242,7 +267,8 @@ def __get__(self, instance, cls=None):
# use ContentType.objects.get_for_id(), which has a global cache.
f = self.model._meta.get_field(self.ct_field)
ct_id = getattr(instance, f.attname, None)
pk_val = getattr(instance, self.fk_field)
fk_val = getattr(instance, self.fk_field)
pk_val = deserialize_pk(fk_val)

rel_obj = self.get_cached_value(instance, default=None)
if rel_obj is None and self.is_cached(instance):
Expand Down Expand Up @@ -272,7 +298,7 @@ def __set__(self, instance, value):
fk = None
if value is not None:
ct = self.get_content_type(obj=value)
fk = value.pk
fk = serialize_pk(value)

setattr(instance, self.ct_field, ct)
setattr(instance, self.fk_field, fk)
Expand Down Expand Up @@ -541,7 +567,8 @@ def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS):
% self.content_type_field_name: ContentType.objects.db_manager(using)
.get_for_model(self.model, for_concrete_model=self.for_concrete_model)
.pk,
"%s__in" % self.object_id_field_name: [obj.pk for obj in objs],
"%s__in"
% self.object_id_field_name: [serialize_pk(obj) for obj in objs],
}
)

Expand Down Expand Up @@ -589,7 +616,7 @@ def __init__(self, instance=None):
self.content_type_field_name = rel.field.content_type_field_name
self.object_id_field_name = rel.field.object_id_field_name
self.prefetch_cache_name = rel.field.attname
self.pk_val = instance.pk
self.pk_val = serialize_pk(instance)

self.core_filters = {
"%s__pk" % self.content_type_field_name: self.content_type.id,
Expand Down
26 changes: 23 additions & 3 deletions django/db/backends/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from django.conf import settings
from django.db import NotSupportedError, transaction
from django.db.backends import utils
from django.db.models.expressions import Col
from django.db.models.expressions import Col, ColPairs
from django.utils import timezone
from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.encoding import force_str
Expand Down Expand Up @@ -800,7 +800,27 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel
return ""

def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
lhs_expr = Col(lhs_table, lhs_field)
rhs_expr = Col(rhs_table, rhs_field)
lhs_expr = lhs_field.get_col(lhs_table)
rhs_expr = rhs_field.get_col(rhs_table)

if (
isinstance(lhs_expr, Col)
and isinstance(rhs_expr, ColPairs)
and len(rhs_expr) > 1
):
return self.prepare_json_join(lhs_expr, rhs_expr)

return lhs_expr, rhs_expr

def prepare_json_join(self, lhs_expr, rhs_expr):
"""
If a generic foreign key refers to a composite primary key, the object_id is in
a JSON format and is backed by a CharField / TextField / JSONField.
To support joining on generic foreign keys, use the backend-specific JSON
functions.
"""
raise NotImplementedError(
"subclasses of BaseDatabaseOperations may require a prepare_json_join() "
"method."
)
5 changes: 5 additions & 0 deletions django/db/backends/mysql/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.db.models import expressions


class JSONArray(expressions.Func):
function = "JSON_ARRAY"
4 changes: 4 additions & 0 deletions django/db/backends/mysql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.mysql.functions import JSONArray
from django.db.backends.utils import split_tzname_delta
from django.db.models import Exists, ExpressionWrapper, Lookup
from django.db.models.constants import OnConflict
Expand Down Expand Up @@ -456,3 +457,6 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel
update_fields,
unique_fields,
)

def prepare_json_join(self, lhs_expr, rhs_expr):
return lhs_expr, JSONArray(rhs_expr)
8 changes: 8 additions & 0 deletions django/db/backends/oracle/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,11 @@ def __init__(self, expression, *, output_field=None, **extra):
super().__init__(
expression, output_field=output_field or DurationField(), **extra
)


class JSONSerialize(Func):
function = "JSON_SERIALIZE"


class JSONArray(Func):
function = "JSON_ARRAY"
4 changes: 4 additions & 0 deletions django/db/backends/oracle/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from django.utils.regex_helper import _lazy_re_compile

from .base import Database
from .functions import JSONArray, JSONSerialize
from .utils import BulkInsertMapper, InsertVar, Oracle_datetime


Expand Down Expand Up @@ -729,3 +730,6 @@ def conditional_expression_supported_in_where_clause(self, expression):
if isinstance(expression, RawSQL) and expression.conditional:
return True
return False

def prepare_json_join(self, lhs_expr, rhs_expr):
return JSONSerialize(lhs_expr), JSONArray(rhs_expr)
5 changes: 5 additions & 0 deletions django/db/backends/postgresql/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django.db.models import expressions


class JSONBBuildArray(expressions.Func):
function = "JSONB_BUILD_ARRAY"
9 changes: 8 additions & 1 deletion django/db/backends/postgresql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.backends.postgresql.functions import JSONBBuildArray
from django.db.backends.postgresql.psycopg_any import (
Inet,
Jsonb,
Expand All @@ -12,6 +13,7 @@
)
from django.db.backends.utils import split_tzname_delta
from django.db.models.constants import OnConflict
from django.db.models.expressions import ColPairs
from django.db.models.functions import Cast
from django.utils.regex_helper import _lazy_re_compile

Expand Down Expand Up @@ -407,7 +409,12 @@ def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
lhs_table, lhs_field, rhs_table, rhs_field
)

if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
if not isinstance(rhs_expr, ColPairs) and lhs_field.db_type(
self.connection
) != rhs_field.db_type(self.connection):
rhs_expr = Cast(rhs_expr, lhs_field)

return lhs_expr, rhs_expr

def prepare_json_join(self, lhs_expr, rhs_expr):
return lhs_expr, JSONBBuildArray(rhs_expr)
9 changes: 9 additions & 0 deletions django/db/backends/sqlite3/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from django.db.models import expressions


class JSON(expressions.Func):
function = "JSON"


class JSONArray(expressions.Func):
function = "JSON_ARRAY"
4 changes: 4 additions & 0 deletions django/db/backends/sqlite3/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from django.utils.functional import cached_property

from .base import Database
from .functions import JSON, JSONArray


class DatabaseOperations(BaseDatabaseOperations):
Expand Down Expand Up @@ -431,3 +432,6 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel

def force_group_by(self):
return ["GROUP BY TRUE"] if Database.sqlite_version_info < (3, 39) else []

def prepare_json_join(self, lhs_expr, rhs_expr):
return JSON(lhs_expr), JSONArray(rhs_expr)
6 changes: 6 additions & 0 deletions django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,12 @@ def __len__(self):
def __iter__(self):
return iter(self.get_cols())

def __repr__(self):
return "{}({})".format(
self.__class__.__name__,
", ".join(repr(col) for col in self.get_cols()),
)

def get_cols(self):
return [
Col(self.alias, target, source)
Expand Down
3 changes: 2 additions & 1 deletion tests/composite_pk/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .tenant import Comment, Post, Tenant, Token, User
from .tenant import CharTag, Comment, Post, Tenant, Token, User

__all__ = [
"Comment",
"Post",
"CharTag",
"Tenant",
"Token",
"User",
Expand Down
12 changes: 12 additions & 0 deletions tests/composite_pk/models/tenant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.db import models


Expand Down Expand Up @@ -48,3 +50,13 @@ class Post(models.Model):
pk = models.CompositePrimaryKey("tenant_id", "id")
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
id = models.UUIDField()
chartags = GenericRelation("CharTag", related_query_name="post")


class CharTag(models.Model):
name = models.CharField(max_length=5)
content_type = models.ForeignKey(
ContentType, on_delete=models.CASCADE, related_name="composite_pk_chartags"
)
object_id = models.CharField(max_length=50)
content_object = GenericForeignKey("content_type", "object_id")
Loading

0 comments on commit de85dc0

Please sign in to comment.