Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Motor GenericReference field #352

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ class ToRef2(Document):

@self.instance.register
class MyDoc(Document):
gref = fields.GenericReferenceField(attribute='in_mongo_gref', allow_none=True)
gref = fields.GenericReferenceField(attribute='in_mongo_gref', reference_cls=Reference, allow_none=True)

MySchema = MyDoc.Schema

Expand Down
3 changes: 2 additions & 1 deletion tests/test_marshmallow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test marshmallow-related features"""
import datetime as dt
from umongo.data_objects import Reference

import pytest

Expand Down Expand Up @@ -317,7 +318,7 @@ def test_marshmallow_bonus_fields(self):
class Doc(Document):
id = fields.ObjectIdField(attribute='_id')
ref = fields.ReferenceField('Doc')
gen_ref = fields.GenericReferenceField()
gen_ref = fields.GenericReferenceField(reference_cls=Reference)

for name, field_cls in (
('id', ma_bonus_fields.ObjectId),
Expand Down
2 changes: 1 addition & 1 deletion umongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def _deserialize_from_mongo(self, value):

class GenericReferenceField(BaseField, ma_bonus_fields.GenericReference):

def __init__(self, *args, reference_cls=Reference, **kwargs):
def __init__(self, *args, reference_cls=None, **kwargs):
super().__init__(*args, **kwargs)
self.reference_cls = reference_cls
self._document_implementation_cls = DocumentImplementation
Expand Down
6 changes: 5 additions & 1 deletion umongo/frameworks/motor_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..document import DocumentImplementation
from ..data_objects import Reference
from ..exceptions import NotCreatedError, UpdateError, DeleteError, NoneReferenceError
from ..fields import ReferenceField, ListField, DictField, EmbeddedField
from ..fields import ReferenceField, GenericReferenceField, ListField, DictField, EmbeddedField
from ..query_mapper import map_query

from .tools import cook_find_filter, remove_cls_field_from_embedded_docs
Expand Down Expand Up @@ -431,6 +431,10 @@ def _patch_field(self, field):
if isinstance(field, ReferenceField):
field.io_validate.append(_reference_io_validate)
field.reference_cls = MotorAsyncIOReference
if isinstance(field, GenericReferenceField):
field.io_validate.append(_reference_io_validate)
if field.reference_cls is None:
field.reference_cls = MotorAsyncIOReference
if isinstance(field, EmbeddedField):
field.io_validate_recursive = _embedded_document_io_validate

Expand Down