From de85dc02bca4898e7808eec2c58e2250c4e32256 Mon Sep 17 00:00:00 2001 From: Bendeguz Csirmaz Date: Tue, 6 Aug 2024 23:21:27 +0800 Subject: [PATCH] Refs #373 - Add composite generic foreign key support --- django/contrib/contenttypes/fields.py | 37 ++++- django/db/backends/base/operations.py | 26 +++- django/db/backends/mysql/functions.py | 5 + django/db/backends/mysql/operations.py | 4 + django/db/backends/oracle/functions.py | 8 ++ django/db/backends/oracle/operations.py | 4 + django/db/backends/postgresql/functions.py | 5 + django/db/backends/postgresql/operations.py | 9 +- django/db/backends/sqlite3/functions.py | 9 ++ django/db/backends/sqlite3/operations.py | 4 + django/db/models/expressions.py | 6 + tests/composite_pk/models/__init__.py | 3 +- tests/composite_pk/models/tenant.py | 12 ++ tests/composite_pk/test_generic.py | 151 ++++++++++++++++++++ 14 files changed, 273 insertions(+), 10 deletions(-) create mode 100644 django/db/backends/mysql/functions.py create mode 100644 django/db/backends/postgresql/functions.py create mode 100644 django/db/backends/sqlite3/functions.py create mode 100644 tests/composite_pk/test_generic.py diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index a3e87f6ed45ee..7557141bfe4c2 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -1,5 +1,6 @@ import functools import itertools +import json import warnings from collections import defaultdict @@ -8,7 +9,8 @@ from django.contrib.contenttypes.models import ContentType from django.core import checks from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist -from django.db import DEFAULT_DB_ALIAS, models, router, transaction +from django.core.serializers.json import DjangoJSONEncoder +from django.db import DEFAULT_DB_ALIAS, connection, models, router, transaction from django.db.models import DO_NOTHING, ForeignObject, ForeignObjectRel from django.db.models.base import ModelBase, make_foreign_order_accessors from django.db.models.fields import Field @@ -25,6 +27,29 @@ from django.utils.functional import cached_property +def serialize_pk(obj): + if isinstance(obj.pk, tuple): + return json.dumps( + [ + field.get_db_prep_value(value, connection) + for field, value in zip(obj._meta.pk, obj.pk) + ], + cls=DjangoJSONEncoder, + ) + else: + return obj.pk + + +def deserialize_pk(pk): + if isinstance(pk, str) and pk.startswith("[") and pk.endswith("]"): + try: + return json.loads(pk) + except json.JSONDecodeError: + return pk + else: + return pk + + class GenericForeignKey(FieldCacheMixin, Field): """ Provide a generic many-to-one relation through the ``content_type`` and @@ -242,7 +267,8 @@ def __get__(self, instance, cls=None): # use ContentType.objects.get_for_id(), which has a global cache. f = self.model._meta.get_field(self.ct_field) ct_id = getattr(instance, f.attname, None) - pk_val = getattr(instance, self.fk_field) + fk_val = getattr(instance, self.fk_field) + pk_val = deserialize_pk(fk_val) rel_obj = self.get_cached_value(instance, default=None) if rel_obj is None and self.is_cached(instance): @@ -272,7 +298,7 @@ def __set__(self, instance, value): fk = None if value is not None: ct = self.get_content_type(obj=value) - fk = value.pk + fk = serialize_pk(value) setattr(instance, self.ct_field, ct) setattr(instance, self.fk_field, fk) @@ -541,7 +567,8 @@ def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS): % self.content_type_field_name: ContentType.objects.db_manager(using) .get_for_model(self.model, for_concrete_model=self.for_concrete_model) .pk, - "%s__in" % self.object_id_field_name: [obj.pk for obj in objs], + "%s__in" + % self.object_id_field_name: [serialize_pk(obj) for obj in objs], } ) @@ -589,7 +616,7 @@ def __init__(self, instance=None): self.content_type_field_name = rel.field.content_type_field_name self.object_id_field_name = rel.field.object_id_field_name self.prefetch_cache_name = rel.field.attname - self.pk_val = instance.pk + self.pk_val = serialize_pk(instance) self.core_filters = { "%s__pk" % self.content_type_field_name: self.content_type.id, diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 889e4d87b444f..f50231adccd18 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -9,7 +9,7 @@ from django.conf import settings from django.db import NotSupportedError, transaction from django.db.backends import utils -from django.db.models.expressions import Col +from django.db.models.expressions import Col, ColPairs from django.utils import timezone from django.utils.deprecation import RemovedInDjango60Warning from django.utils.encoding import force_str @@ -800,7 +800,27 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel return "" def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field): - lhs_expr = Col(lhs_table, lhs_field) - rhs_expr = Col(rhs_table, rhs_field) + lhs_expr = lhs_field.get_col(lhs_table) + rhs_expr = rhs_field.get_col(rhs_table) + + if ( + isinstance(lhs_expr, Col) + and isinstance(rhs_expr, ColPairs) + and len(rhs_expr) > 1 + ): + return self.prepare_json_join(lhs_expr, rhs_expr) return lhs_expr, rhs_expr + + def prepare_json_join(self, lhs_expr, rhs_expr): + """ + If a generic foreign key refers to a composite primary key, the object_id is in + a JSON format and is backed by a CharField / TextField / JSONField. + + To support joining on generic foreign keys, use the backend-specific JSON + functions. + """ + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a prepare_json_join() " + "method." + ) diff --git a/django/db/backends/mysql/functions.py b/django/db/backends/mysql/functions.py new file mode 100644 index 0000000000000..acffb559b8722 --- /dev/null +++ b/django/db/backends/mysql/functions.py @@ -0,0 +1,5 @@ +from django.db.models import expressions + + +class JSONArray(expressions.Func): + function = "JSON_ARRAY" diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index 9741e6a985fc0..267a4c8fb0893 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -2,6 +2,7 @@ from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations +from django.db.backends.mysql.functions import JSONArray from django.db.backends.utils import split_tzname_delta from django.db.models import Exists, ExpressionWrapper, Lookup from django.db.models.constants import OnConflict @@ -456,3 +457,6 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel update_fields, unique_fields, ) + + def prepare_json_join(self, lhs_expr, rhs_expr): + return lhs_expr, JSONArray(rhs_expr) diff --git a/django/db/backends/oracle/functions.py b/django/db/backends/oracle/functions.py index 936cc9e73f19d..984f4dacb0161 100644 --- a/django/db/backends/oracle/functions.py +++ b/django/db/backends/oracle/functions.py @@ -24,3 +24,11 @@ def __init__(self, expression, *, output_field=None, **extra): super().__init__( expression, output_field=output_field or DurationField(), **extra ) + + +class JSONSerialize(Func): + function = "JSON_SERIALIZE" + + +class JSONArray(Func): + function = "JSON_ARRAY" diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 86340bbf4ac1e..fa453e719d36f 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -15,6 +15,7 @@ from django.utils.regex_helper import _lazy_re_compile from .base import Database +from .functions import JSONArray, JSONSerialize from .utils import BulkInsertMapper, InsertVar, Oracle_datetime @@ -729,3 +730,6 @@ def conditional_expression_supported_in_where_clause(self, expression): if isinstance(expression, RawSQL) and expression.conditional: return True return False + + def prepare_json_join(self, lhs_expr, rhs_expr): + return JSONSerialize(lhs_expr), JSONArray(rhs_expr) diff --git a/django/db/backends/postgresql/functions.py b/django/db/backends/postgresql/functions.py new file mode 100644 index 0000000000000..59a7a9891874f --- /dev/null +++ b/django/db/backends/postgresql/functions.py @@ -0,0 +1,5 @@ +from django.db.models import expressions + + +class JSONBBuildArray(expressions.Func): + function = "JSONB_BUILD_ARRAY" diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 4b179ca83f3e6..1b38ae63c998d 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -3,6 +3,7 @@ from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations +from django.db.backends.postgresql.functions import JSONBBuildArray from django.db.backends.postgresql.psycopg_any import ( Inet, Jsonb, @@ -12,6 +13,7 @@ ) from django.db.backends.utils import split_tzname_delta from django.db.models.constants import OnConflict +from django.db.models.expressions import ColPairs from django.db.models.functions import Cast from django.utils.regex_helper import _lazy_re_compile @@ -407,7 +409,12 @@ def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field): lhs_table, lhs_field, rhs_table, rhs_field ) - if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection): + if not isinstance(rhs_expr, ColPairs) and lhs_field.db_type( + self.connection + ) != rhs_field.db_type(self.connection): rhs_expr = Cast(rhs_expr, lhs_field) return lhs_expr, rhs_expr + + def prepare_json_join(self, lhs_expr, rhs_expr): + return lhs_expr, JSONBBuildArray(rhs_expr) diff --git a/django/db/backends/sqlite3/functions.py b/django/db/backends/sqlite3/functions.py new file mode 100644 index 0000000000000..00ff7b5e59645 --- /dev/null +++ b/django/db/backends/sqlite3/functions.py @@ -0,0 +1,9 @@ +from django.db.models import expressions + + +class JSON(expressions.Func): + function = "JSON" + + +class JSONArray(expressions.Func): + function = "JSON_ARRAY" diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 0078cc077a97e..a70b19506b146 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -15,6 +15,7 @@ from django.utils.functional import cached_property from .base import Database +from .functions import JSON, JSONArray class DatabaseOperations(BaseDatabaseOperations): @@ -431,3 +432,6 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel def force_group_by(self): return ["GROUP BY TRUE"] if Database.sqlite_version_info < (3, 39) else [] + + def prepare_json_join(self, lhs_expr, rhs_expr): + return JSON(lhs_expr), JSONArray(rhs_expr) diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 89b7838e6ce5e..7cb4b05a77be8 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1338,6 +1338,12 @@ def __len__(self): def __iter__(self): return iter(self.get_cols()) + def __repr__(self): + return "{}({})".format( + self.__class__.__name__, + ", ".join(repr(col) for col in self.get_cols()), + ) + def get_cols(self): return [ Col(self.alias, target, source) diff --git a/tests/composite_pk/models/__init__.py b/tests/composite_pk/models/__init__.py index 35c394371696d..85c73ad350837 100644 --- a/tests/composite_pk/models/__init__.py +++ b/tests/composite_pk/models/__init__.py @@ -1,8 +1,9 @@ -from .tenant import Comment, Post, Tenant, Token, User +from .tenant import CharTag, Comment, Post, Tenant, Token, User __all__ = [ "Comment", "Post", + "CharTag", "Tenant", "Token", "User", diff --git a/tests/composite_pk/models/tenant.py b/tests/composite_pk/models/tenant.py index 01f6e0aea36ad..230bc4d32b7c7 100644 --- a/tests/composite_pk/models/tenant.py +++ b/tests/composite_pk/models/tenant.py @@ -1,3 +1,5 @@ +from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation +from django.contrib.contenttypes.models import ContentType from django.db import models @@ -48,3 +50,13 @@ class Post(models.Model): pk = models.CompositePrimaryKey("tenant_id", "id") tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) id = models.UUIDField() + chartags = GenericRelation("CharTag", related_query_name="post") + + +class CharTag(models.Model): + name = models.CharField(max_length=5) + content_type = models.ForeignKey( + ContentType, on_delete=models.CASCADE, related_name="composite_pk_chartags" + ) + object_id = models.CharField(max_length=50) + content_object = GenericForeignKey("content_type", "object_id") diff --git a/tests/composite_pk/test_generic.py b/tests/composite_pk/test_generic.py new file mode 100644 index 0000000000000..96cfe68da4974 --- /dev/null +++ b/tests/composite_pk/test_generic.py @@ -0,0 +1,151 @@ +from uuid import UUID + +from django.contrib.contenttypes.models import ContentType +from django.db import connection +from django.test import TestCase + +from .models import CharTag, Comment, Post, Tenant, User + + +class CompositePKGenericTests(TestCase): + POST_1_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" + + @classmethod + def setUpTestData(cls): + cls.tenant_1 = Tenant.objects.create() + cls.tenant_2 = Tenant.objects.create() + cls.user_1 = User.objects.create(tenant=cls.tenant_1, id=1) + cls.user_2 = User.objects.create(tenant=cls.tenant_1, id=2) + cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1) + cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1) + cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=UUID(cls.POST_1_ID)) + cls.chartag_1 = CharTag.objects.create(name="a", content_object=cls.comment_1) + cls.chartag_2 = CharTag.objects.create(name="b", content_object=cls.post_1) + cls.comment_ct = ContentType.objects.get_for_model(Comment) + cls.post_ct = ContentType.objects.get_for_model(Post) + post_1_id = cls.POST_1_ID + if not connection.features.has_native_uuid_field: + post_1_id = cls.POST_1_ID.replace("-", "") + cls.post_1_fk = f'[{cls.tenant_1.id}, "{post_1_id}"]' + cls.comment_1_fk = f"[{cls.tenant_1.id}, {cls.comment_1.id}]" + + def test_fields(self): + tag_1 = CharTag.objects.get(pk=self.chartag_1.pk) + self.assertEqual(tag_1.content_type, self.comment_ct) + self.assertEqual(tag_1.object_id, self.comment_1_fk) + self.assertEqual(tag_1.content_object, self.comment_1) + + tag_2 = CharTag.objects.get(pk=self.chartag_2.pk) + self.assertEqual(tag_2.content_type, self.post_ct) + self.assertEqual(tag_2.object_id, self.post_1_fk) + self.assertEqual(tag_2.content_object, self.post_1) + + post_1 = Post.objects.get(pk=self.post_1.pk) + self.assertSequenceEqual(post_1.chartags.all(), (self.chartag_2,)) + + def test_cascade_delete_if_generic_relation(self): + Post.objects.get(pk=self.post_1.pk).delete() + self.assertFalse(CharTag.objects.filter(pk=self.chartag_2.pk).exists()) + + def test_no_cascade_delete_if_no_generic_relation(self): + Comment.objects.get(pk=self.comment_1.pk).delete() + tag_1 = CharTag.objects.get(pk=self.chartag_1.pk) + self.assertIsNone(tag_1.content_object) + + def test_tags_clear(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.chartags.clear() + self.assertEqual(post_1.chartags.count(), 0) + self.assertFalse(CharTag.objects.filter(pk=self.chartag_2.pk).exists()) + + def test_tags_remove(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.chartags.remove(self.chartag_2) + self.assertEqual(post_1.chartags.count(), 0) + self.assertFalse(CharTag.objects.filter(pk=self.chartag_2.pk).exists()) + + def test_tags_create(self): + tag_count = CharTag.objects.count() + + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.chartags.create(name="c") + self.assertEqual(post_1.chartags.count(), 2) + self.assertEqual(CharTag.objects.count(), tag_count + 1) + + tag_3 = CharTag.objects.get(name="c") + self.assertEqual(tag_3.content_type, self.post_ct) + self.assertEqual(tag_3.object_id, self.post_1_fk) + self.assertEqual(tag_3.content_object, post_1) + + def test_tags_add(self): + tag_count = CharTag.objects.count() + post_1 = Post.objects.get(pk=self.post_1.pk) + + tag_3 = CharTag(name="c") + post_1.chartags.add(tag_3, bulk=False) + self.assertEqual(post_1.chartags.count(), 2) + self.assertEqual(CharTag.objects.count(), tag_count + 1) + + tag_3 = CharTag.objects.get(name="c") + self.assertEqual(tag_3.content_type, self.post_ct) + self.assertEqual(tag_3.object_id, self.post_1_fk) + self.assertEqual(tag_3.content_object, post_1) + + tag_4 = CharTag.objects.create(name="d", content_object=self.comment_2) + post_1.chartags.add(tag_4) + self.assertEqual(post_1.chartags.count(), 3) + self.assertEqual(CharTag.objects.count(), tag_count + 2) + + tag_4 = CharTag.objects.get(name="d") + self.assertEqual(tag_4.content_type, self.post_ct) + self.assertEqual(tag_4.object_id, self.post_1_fk) + self.assertEqual(tag_4.content_object, post_1) + + def test_tags_set(self): + tag_count = CharTag.objects.count() + tag_1 = CharTag.objects.get(name="a") + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.chartags.set([tag_1]) + self.assertEqual(post_1.chartags.count(), 1) + self.assertEqual(CharTag.objects.count(), tag_count - 1) + self.assertFalse(CharTag.objects.filter(pk=self.chartag_2.pk).exists()) + + def test_tags_get_or_create(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + + tag_2, created = post_1.chartags.get_or_create(name="b") + self.assertFalse(created) + self.assertEqual(tag_2.pk, self.chartag_2.pk) + self.assertEqual(tag_2.content_type, self.post_ct) + self.assertEqual(tag_2.object_id, self.post_1_fk) + self.assertEqual(tag_2.content_object, post_1) + + tag_3, created = post_1.chartags.get_or_create(name="c") + self.assertTrue(created) + self.assertEqual(tag_3.content_type, self.post_ct) + self.assertEqual(tag_3.object_id, self.post_1_fk) + self.assertEqual(tag_3.content_object, post_1) + + def test_tags_update_or_create(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + + tag_2, created = post_1.chartags.update_or_create( + name="b", defaults={"name": "b2"} + ) + self.assertFalse(created) + self.assertEqual(tag_2.pk, self.chartag_2.pk) + self.assertEqual(tag_2.name, "b2") + self.assertEqual(tag_2.content_type, self.post_ct) + self.assertEqual(tag_2.object_id, self.post_1_fk) + self.assertEqual(tag_2.content_object, post_1) + + tag_3, created = post_1.chartags.update_or_create(name="c") + self.assertTrue(created) + self.assertEqual(tag_3.content_type, self.post_ct) + self.assertEqual(tag_3.object_id, self.post_1_fk) + self.assertEqual(tag_3.content_object, post_1) + + def test_filter_by_related_query_name(self): + self.assertSequenceEqual( + CharTag.objects.filter(post__id=self.post_1.id), (self.chartag_2,) + )