diff --git a/django/contrib/contenttypes/fields.py b/django/contrib/contenttypes/fields.py index a3e87f6ed45e..3757276f8670 100644 --- a/django/contrib/contenttypes/fields.py +++ b/django/contrib/contenttypes/fields.py @@ -1,5 +1,6 @@ import functools import itertools +import json import warnings from collections import defaultdict @@ -8,7 +9,8 @@ from django.contrib.contenttypes.models import ContentType from django.core import checks from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist -from django.db import DEFAULT_DB_ALIAS, models, router, transaction +from django.core.serializers.json import DjangoJSONEncoder +from django.db import DEFAULT_DB_ALIAS, connection, models, router, transaction from django.db.models import DO_NOTHING, ForeignObject, ForeignObjectRel from django.db.models.base import ModelBase, make_foreign_order_accessors from django.db.models.fields import Field @@ -25,6 +27,29 @@ from django.utils.functional import cached_property +def serialize_pk(obj): + if isinstance(obj.pk, tuple): + return json.dumps( + [ + field.get_db_prep_value(value, connection) + for field, value in zip(obj._meta.pk, obj.pk) + ], + cls=DjangoJSONEncoder, + ) + else: + return obj.pk + + +def deserialize_pk(pk): + if isinstance(pk, str) and pk.startswith("[") and pk.endswith("]"): + try: + return tuple(json.loads(pk)) + except json.JSONDecodeError: + return pk + else: + return pk + + class GenericForeignKey(FieldCacheMixin, Field): """ Provide a generic many-to-one relation through the ``content_type`` and @@ -195,11 +220,12 @@ def get_prefetch_querysets(self, instances, querysets=None): if ct_id is not None: fk_val = getattr(instance, self.fk_field) if fk_val is not None: - fk_dict[ct_id].add(fk_val) + fk_dict[ct_id].add(deserialize_pk(fk_val)) instance_dict[ct_id] = instance ret_val = [] for ct_id, fkeys in fk_dict.items(): + fkeys = list(fkeys) if ct_id in custom_queryset_dict: # Return values from the custom queryset, if provided. ret_val.extend(custom_queryset_dict[ct_id].filter(pk__in=fkeys)) @@ -225,7 +251,7 @@ def gfk_key(obj): return ( ret_val, - lambda obj: (obj.pk, obj.__class__), + lambda obj: (serialize_pk(obj), obj.__class__), gfk_key, True, self.name, @@ -242,7 +268,8 @@ def __get__(self, instance, cls=None): # use ContentType.objects.get_for_id(), which has a global cache. f = self.model._meta.get_field(self.ct_field) ct_id = getattr(instance, f.attname, None) - pk_val = getattr(instance, self.fk_field) + fk_val = getattr(instance, self.fk_field) + pk_val = deserialize_pk(fk_val) rel_obj = self.get_cached_value(instance, default=None) if rel_obj is None and self.is_cached(instance): @@ -251,7 +278,9 @@ def __get__(self, instance, cls=None): ct_match = ( ct_id == self.get_content_type(obj=rel_obj, using=instance._state.db).id ) - pk_match = ct_match and rel_obj._meta.pk.to_python(pk_val) == rel_obj.pk + pk_match = ct_match and ( + rel_obj._meta.pk.to_python(pk_val) + ) == deserialize_pk(serialize_pk(rel_obj)) if pk_match: return rel_obj else: @@ -272,7 +301,7 @@ def __set__(self, instance, value): fk = None if value is not None: ct = self.get_content_type(obj=value) - fk = value.pk + fk = serialize_pk(value) setattr(instance, self.ct_field, ct) setattr(instance, self.fk_field, fk) @@ -541,7 +570,8 @@ def bulk_related_objects(self, objs, using=DEFAULT_DB_ALIAS): % self.content_type_field_name: ContentType.objects.db_manager(using) .get_for_model(self.model, for_concrete_model=self.for_concrete_model) .pk, - "%s__in" % self.object_id_field_name: [obj.pk for obj in objs], + "%s__in" + % self.object_id_field_name: [serialize_pk(obj) for obj in objs], } ) @@ -589,7 +619,7 @@ def __init__(self, instance=None): self.content_type_field_name = rel.field.content_type_field_name self.object_id_field_name = rel.field.object_id_field_name self.prefetch_cache_name = rel.field.attname - self.pk_val = instance.pk + self.pk_val = serialize_pk(instance) self.core_filters = { "%s__pk" % self.content_type_field_name: self.content_type.id, diff --git a/django/db/backends/base/operations.py b/django/db/backends/base/operations.py index 889e4d87b444..00159fcc7865 100644 --- a/django/db/backends/base/operations.py +++ b/django/db/backends/base/operations.py @@ -9,7 +9,7 @@ from django.conf import settings from django.db import NotSupportedError, transaction from django.db.backends import utils -from django.db.models.expressions import Col +from django.db.models.expressions import Col, ColPairs from django.utils import timezone from django.utils.deprecation import RemovedInDjango60Warning from django.utils.encoding import force_str @@ -800,7 +800,33 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel return "" def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field): - lhs_expr = Col(lhs_table, lhs_field) - rhs_expr = Col(rhs_table, rhs_field) + lhs_expr = lhs_field.get_col(lhs_table) + rhs_expr = rhs_field.get_col(rhs_table) + + if ( + isinstance(lhs_expr, Col) + and isinstance(rhs_expr, ColPairs) + and len(rhs_expr) > 1 + ): + return self.prepare_join_on_json_clause(lhs_expr, rhs_expr) + elif ( + isinstance(lhs_expr, ColPairs) + and isinstance(rhs_expr, Col) + and len(lhs_expr) > 1 + ): + return self.prepare_join_on_json_clause(rhs_expr, lhs_expr) return lhs_expr, rhs_expr + + def prepare_join_on_json_clause(self, lhs_expr, rhs_expr): + """ + If a generic foreign key refers to a composite primary key, the object_id is in + a JSON format backed by a CharField / TextField field. + + To support joining on generic foreign keys, use the backend-specific JSON + functions. + """ + raise NotImplementedError( + "subclasses of BaseDatabaseOperations may require a " + "prepare_join_on_json_clause() method." + ) diff --git a/django/db/backends/mysql/operations.py b/django/db/backends/mysql/operations.py index 9741e6a985fc..aff1e8f35491 100644 --- a/django/db/backends/mysql/operations.py +++ b/django/db/backends/mysql/operations.py @@ -3,7 +3,7 @@ from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.utils import split_tzname_delta -from django.db.models import Exists, ExpressionWrapper, Lookup +from django.db.models import Exists, ExpressionWrapper, Func, Lookup from django.db.models.constants import OnConflict from django.utils import timezone from django.utils.encoding import force_str @@ -456,3 +456,6 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel update_fields, unique_fields, ) + + def prepare_join_on_json_clause(self, lhs_expr, rhs_expr): + return lhs_expr, Func(rhs_expr, function="JSON_ARRAY") diff --git a/django/db/backends/oracle/operations.py b/django/db/backends/oracle/operations.py index 86340bbf4ac1..876e28ff0394 100644 --- a/django/db/backends/oracle/operations.py +++ b/django/db/backends/oracle/operations.py @@ -7,7 +7,7 @@ from django.db.backends.base.operations import BaseDatabaseOperations from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name from django.db.models import AutoField, Exists, ExpressionWrapper, Lookup -from django.db.models.expressions import RawSQL +from django.db.models.expressions import Func, RawSQL from django.db.models.sql.where import WhereNode from django.utils import timezone from django.utils.encoding import force_bytes, force_str @@ -729,3 +729,8 @@ def conditional_expression_supported_in_where_clause(self, expression): if isinstance(expression, RawSQL) and expression.conditional: return True return False + + def prepare_join_on_json_clause(self, lhs_expr, rhs_expr): + return Func(lhs_expr, function="JSON_SERIALIZE"), Func( + rhs_expr, function="JSON_ARRAY" + ) diff --git a/django/db/backends/postgresql/functions.py b/django/db/backends/postgresql/functions.py new file mode 100644 index 000000000000..e06a4b71074f --- /dev/null +++ b/django/db/backends/postgresql/functions.py @@ -0,0 +1,5 @@ +from django.db.models import Func + + +class JSONBBuildArray(Func): + function = "JSONB_BUILD_ARRAY" diff --git a/django/db/backends/postgresql/operations.py b/django/db/backends/postgresql/operations.py index 4b179ca83f3e..eef8c164248a 100644 --- a/django/db/backends/postgresql/operations.py +++ b/django/db/backends/postgresql/operations.py @@ -3,6 +3,7 @@ from django.conf import settings from django.db.backends.base.operations import BaseDatabaseOperations +from django.db.backends.postgresql.functions import JSONBBuildArray from django.db.backends.postgresql.psycopg_any import ( Inet, Jsonb, @@ -11,6 +12,7 @@ mogrify, ) from django.db.backends.utils import split_tzname_delta +from django.db.models import JSONField from django.db.models.constants import OnConflict from django.db.models.functions import Cast from django.utils.regex_helper import _lazy_re_compile @@ -407,7 +409,12 @@ def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field): lhs_table, lhs_field, rhs_table, rhs_field ) - if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection): + if not isinstance(rhs_expr, JSONBBuildArray) and lhs_field.db_type( + self.connection + ) != rhs_field.db_type(self.connection): rhs_expr = Cast(rhs_expr, lhs_field) return lhs_expr, rhs_expr + + def prepare_join_on_json_clause(self, lhs_expr, rhs_expr): + return Cast(lhs_expr, JSONField()), JSONBBuildArray(rhs_expr) diff --git a/django/db/backends/sqlite3/operations.py b/django/db/backends/sqlite3/operations.py index 0078cc077a97..d58c87a42582 100644 --- a/django/db/backends/sqlite3/operations.py +++ b/django/db/backends/sqlite3/operations.py @@ -9,7 +9,7 @@ from django.db import DatabaseError, NotSupportedError, models from django.db.backends.base.operations import BaseDatabaseOperations from django.db.models.constants import OnConflict -from django.db.models.expressions import Col +from django.db.models.expressions import Col, Func from django.utils import timezone from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.functional import cached_property @@ -431,3 +431,6 @@ def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fiel def force_group_by(self): return ["GROUP BY TRUE"] if Database.sqlite_version_info < (3, 39) else [] + + def prepare_join_on_json_clause(self, lhs_expr, rhs_expr): + return Func(lhs_expr, function="JSON"), Func(rhs_expr, function="JSON_ARRAY") diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 4c09613a8594..b4929b3549ff 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -1340,6 +1340,12 @@ def __len__(self): def __iter__(self): return iter(self.get_cols()) + def __repr__(self): + return "{}({})".format( + self.__class__.__name__, + ", ".join(repr(col) for col in self.get_cols()), + ) + def get_cols(self): return [ Col(self.alias, target, source) @@ -1374,11 +1380,6 @@ def relabeled_clone(self, relabels): def resolve_expression(self, *args, **kwargs): return self - @staticmethod - def db_converter(value, *_): - assert isinstance(value, list) - return (tuple(value),) - class Ref(Expression): """ diff --git a/tests/backends/base/test_operations.py b/tests/backends/base/test_operations.py index 8df02ee76b44..22cb1c8c7081 100644 --- a/tests/backends/base/test_operations.py +++ b/tests/backends/base/test_operations.py @@ -158,6 +158,7 @@ def test_datetime_extract_sql(self): def test_prepare_join_on_clause(self): author_table = Author._meta.db_table author_id_field = Author._meta.get_field("id") + author_name_field = Author._meta.get_field("name") book_table = Book._meta.db_table book_fk_field = Book._meta.get_field("author") lhs_expr, rhs_expr = self.ops.prepare_join_on_clause( @@ -167,7 +168,7 @@ def test_prepare_join_on_clause(self): book_fk_field, ) self.assertEqual(lhs_expr, Col(author_table, author_id_field)) - self.assertEqual(rhs_expr, Col(book_table, book_fk_field)) + self.assertEqual(rhs_expr, Col(book_table, book_fk_field, author_name_field)) class DatabaseOperationTests(TestCase): diff --git a/tests/composite_pk/models/__init__.py b/tests/composite_pk/models/__init__.py index 35c394371696..85c73ad35083 100644 --- a/tests/composite_pk/models/__init__.py +++ b/tests/composite_pk/models/__init__.py @@ -1,8 +1,9 @@ -from .tenant import Comment, Post, Tenant, Token, User +from .tenant import CharTag, Comment, Post, Tenant, Token, User __all__ = [ "Comment", "Post", + "CharTag", "Tenant", "Token", "User", diff --git a/tests/composite_pk/models/tenant.py b/tests/composite_pk/models/tenant.py index ac0b3d9715a1..ee2cd2a76c92 100644 --- a/tests/composite_pk/models/tenant.py +++ b/tests/composite_pk/models/tenant.py @@ -1,3 +1,5 @@ +from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation +from django.contrib.contenttypes.models import ContentType from django.db import models @@ -48,3 +50,13 @@ class Post(models.Model): pk = models.CompositePrimaryKey("tenant_id", "id") tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) id = models.UUIDField() + chartags = GenericRelation("CharTag", related_query_name="post") + + +class CharTag(models.Model): + name = models.CharField(max_length=5) + content_type = models.ForeignKey( + ContentType, on_delete=models.CASCADE, related_name="composite_pk_chartags" + ) + object_id = models.CharField(max_length=50) + content_object = GenericForeignKey("content_type", "object_id") diff --git a/tests/composite_pk/test_generic.py b/tests/composite_pk/test_generic.py new file mode 100644 index 000000000000..11da69cf5990 --- /dev/null +++ b/tests/composite_pk/test_generic.py @@ -0,0 +1,181 @@ +from uuid import UUID + +from django.contrib.contenttypes.models import ContentType +from django.contrib.contenttypes.prefetch import GenericPrefetch +from django.db import connection +from django.db.models import Count +from django.test import TestCase + +from .models import CharTag, Comment, Post, Tenant, User + + +class CompositePKGenericTests(TestCase): + POST_1_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb" + + @classmethod + def setUpTestData(cls): + cls.tenant_1 = Tenant.objects.create() + cls.tenant_2 = Tenant.objects.create() + cls.user_1 = User.objects.create( + tenant=cls.tenant_1, + id=1, + email="user0001@example.com", + ) + cls.user_2 = User.objects.create( + tenant=cls.tenant_1, + id=2, + email="user0002@example.com", + ) + cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1) + cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1) + cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=UUID(cls.POST_1_ID)) + cls.chartag_1 = CharTag.objects.create(name="a", content_object=cls.comment_1) + cls.chartag_2 = CharTag.objects.create(name="b", content_object=cls.post_1) + cls.comment_ct = ContentType.objects.get_for_model(Comment) + cls.post_ct = ContentType.objects.get_for_model(Post) + post_1_id = cls.POST_1_ID + if not connection.features.has_native_uuid_field: + post_1_id = cls.POST_1_ID.replace("-", "") + cls.post_1_fk = f'[{cls.tenant_1.id}, "{post_1_id}"]' + cls.comment_1_fk = f"[{cls.tenant_1.id}, {cls.comment_1.id}]" + + def test_fields(self): + tag_1 = CharTag.objects.get(pk=self.chartag_1.pk) + self.assertEqual(tag_1.content_type, self.comment_ct) + self.assertEqual(tag_1.object_id, self.comment_1_fk) + self.assertEqual(tag_1.content_object, self.comment_1) + + tag_2 = CharTag.objects.get(pk=self.chartag_2.pk) + self.assertEqual(tag_2.content_type, self.post_ct) + self.assertEqual(tag_2.object_id, self.post_1_fk) + self.assertEqual(tag_2.content_object, self.post_1) + + post_1 = Post.objects.get(pk=self.post_1.pk) + self.assertSequenceEqual(post_1.chartags.all(), (self.chartag_2,)) + + def test_cascade_delete_if_generic_relation(self): + Post.objects.get(pk=self.post_1.pk).delete() + self.assertFalse(CharTag.objects.filter(pk=self.chartag_2.pk).exists()) + + def test_no_cascade_delete_if_no_generic_relation(self): + Comment.objects.get(pk=self.comment_1.pk).delete() + tag_1 = CharTag.objects.get(pk=self.chartag_1.pk) + self.assertIsNone(tag_1.content_object) + + def test_tags_clear(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.chartags.clear() + self.assertEqual(post_1.chartags.count(), 0) + self.assertFalse(CharTag.objects.filter(pk=self.chartag_2.pk).exists()) + + def test_tags_remove(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.chartags.remove(self.chartag_2) + self.assertEqual(post_1.chartags.count(), 0) + self.assertFalse(CharTag.objects.filter(pk=self.chartag_2.pk).exists()) + + def test_tags_create(self): + tag_count = CharTag.objects.count() + + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.chartags.create(name="c") + self.assertEqual(post_1.chartags.count(), 2) + self.assertEqual(CharTag.objects.count(), tag_count + 1) + + tag_3 = CharTag.objects.get(name="c") + self.assertEqual(tag_3.content_type, self.post_ct) + self.assertEqual(tag_3.object_id, self.post_1_fk) + self.assertEqual(tag_3.content_object, post_1) + + def test_tags_add(self): + tag_count = CharTag.objects.count() + post_1 = Post.objects.get(pk=self.post_1.pk) + + tag_3 = CharTag(name="c") + post_1.chartags.add(tag_3, bulk=False) + self.assertEqual(post_1.chartags.count(), 2) + self.assertEqual(CharTag.objects.count(), tag_count + 1) + + tag_3 = CharTag.objects.get(name="c") + self.assertEqual(tag_3.content_type, self.post_ct) + self.assertEqual(tag_3.object_id, self.post_1_fk) + self.assertEqual(tag_3.content_object, post_1) + + tag_4 = CharTag.objects.create(name="d", content_object=self.comment_2) + post_1.chartags.add(tag_4) + self.assertEqual(post_1.chartags.count(), 3) + self.assertEqual(CharTag.objects.count(), tag_count + 2) + + tag_4 = CharTag.objects.get(name="d") + self.assertEqual(tag_4.content_type, self.post_ct) + self.assertEqual(tag_4.object_id, self.post_1_fk) + self.assertEqual(tag_4.content_object, post_1) + + def test_tags_set(self): + tag_count = CharTag.objects.count() + tag_1 = CharTag.objects.get(name="a") + post_1 = Post.objects.get(pk=self.post_1.pk) + post_1.chartags.set([tag_1]) + self.assertEqual(post_1.chartags.count(), 1) + self.assertEqual(CharTag.objects.count(), tag_count - 1) + self.assertFalse(CharTag.objects.filter(pk=self.chartag_2.pk).exists()) + + def test_tags_get_or_create(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + + tag_2, created = post_1.chartags.get_or_create(name="b") + self.assertFalse(created) + self.assertEqual(tag_2.pk, self.chartag_2.pk) + self.assertEqual(tag_2.content_type, self.post_ct) + self.assertEqual(tag_2.object_id, self.post_1_fk) + self.assertEqual(tag_2.content_object, post_1) + + tag_3, created = post_1.chartags.get_or_create(name="c") + self.assertTrue(created) + self.assertEqual(tag_3.content_type, self.post_ct) + self.assertEqual(tag_3.object_id, self.post_1_fk) + self.assertEqual(tag_3.content_object, post_1) + + def test_tags_update_or_create(self): + post_1 = Post.objects.get(pk=self.post_1.pk) + + tag_2, created = post_1.chartags.update_or_create( + name="b", defaults={"name": "b2"} + ) + self.assertFalse(created) + self.assertEqual(tag_2.pk, self.chartag_2.pk) + self.assertEqual(tag_2.name, "b2") + self.assertEqual(tag_2.content_type, self.post_ct) + self.assertEqual(tag_2.object_id, self.post_1_fk) + self.assertEqual(tag_2.content_object, post_1) + + tag_3, created = post_1.chartags.update_or_create(name="c") + self.assertTrue(created) + self.assertEqual(tag_3.content_type, self.post_ct) + self.assertEqual(tag_3.object_id, self.post_1_fk) + self.assertEqual(tag_3.content_object, post_1) + + def test_filter_by_related_query_name(self): + self.assertSequenceEqual( + CharTag.objects.filter(post__id=self.post_1.id), (self.chartag_2,) + ) + + def test_aggregate(self): + self.assertEqual( + Post.objects.aggregate(Count("chartags")), {"chartags__count": 1} + ) + + def test_generic_prefetch(self): + chartag_1, chartag_2 = CharTag.objects.prefetch_related( + GenericPrefetch( + "content_object", [Post.objects.all(), Comment.objects.all()] + ) + ).order_by("pk") + + self.assertEqual(chartag_1, self.chartag_1) + self.assertEqual(chartag_2, self.chartag_2) + + with self.assertNumQueries(0): + self.assertEqual(chartag_1.content_object, self.comment_1) + with self.assertNumQueries(0): + self.assertEqual(chartag_2.content_object, self.post_1)