diff --git a/examples/tortoise/article/factories.py b/examples/tortoise/article/factories.py index 8fe3a84..2a4fd7e 100644 --- a/examples/tortoise/article/factories.py +++ b/examples/tortoise/article/factories.py @@ -17,7 +17,9 @@ class UserFactory(TortoiseModelFactory): - id = FACTORY.pyint() + """User factory.""" + + # id = FACTORY.pyint() username = FACTORY.username() first_name = FACTORY.first_name() last_name = FACTORY.last_name() @@ -30,10 +32,13 @@ class UserFactory(TortoiseModelFactory): class Meta: model = User + get_or_create = ("username",) class ArticleFactory(TortoiseModelFactory): - id = FACTORY.pyint() + """Article factory.""" + + # id = FACTORY.pyint() title = FACTORY.sentence() slug = FACTORY.slug() content = FACTORY.text() diff --git a/examples/tortoise/article/models.py b/examples/tortoise/article/models.py index 44bd3e9..047e487 100644 --- a/examples/tortoise/article/models.py +++ b/examples/tortoise/article/models.py @@ -10,8 +10,10 @@ class User(Model): + """User model.""" + id = fields.IntField(pk=True) - username = fields.CharField(max_length=255) + username = fields.CharField(max_length=255, unique=True) first_name = fields.CharField(max_length=255) last_name = fields.CharField(max_length=255) email = fields.CharField(max_length=255) @@ -27,6 +29,8 @@ def __str__(self): class Article(Model): + """Article model.""" + id = fields.IntField(pk=True) title = fields.CharField(max_length=255) slug = fields.CharField(max_length=255, unique=True) diff --git a/fake.py b/fake.py index b9bfda4..2c2d627 100644 --- a/fake.py +++ b/fake.py @@ -1625,6 +1625,27 @@ def post_save(func): class ModelFactory: """ModelFactory.""" + class Meta: + get_or_create = ("id",) # Default fields for get_or_create + + def __init_subclass__(cls, **kwargs): + base_meta = getattr( + cls.__bases__[0], + "_meta", + { + attr: getattr(cls.__bases__[0].Meta, attr) + for attr in dir(cls.__bases__[0].Meta) + if not attr.startswith("_") + }, + ) + cls_meta = { + attr: getattr(cls.Meta, attr) + for attr in dir(cls.Meta) + if not attr.startswith("_") + } + + cls._meta = {**base_meta, **cls_meta} + @classmethod def _run_hooks(cls, hooks, instance): for method in hooks: @@ -1678,27 +1699,6 @@ def save(cls, instance): class DjangoModelFactory(ModelFactory): """Django ModelFactory.""" - class Meta: - get_or_create = ("id",) # Default fields for get_or_create - - def __init_subclass__(cls, **kwargs): - base_meta = getattr( - cls.__bases__[0], - "_meta", - { - attr: getattr(cls.__bases__[0].Meta, attr) - for attr in dir(cls.__bases__[0].Meta) - if not attr.startswith("_") - }, - ) - cls_meta = { - attr: getattr(cls.Meta, attr) - for attr in dir(cls.Meta) - if not attr.startswith("_") - } - - cls._meta = {**base_meta, **cls_meta} - @classmethod def save(cls, instance): instance.save() @@ -1784,6 +1784,30 @@ async def async_save(): asyncio.run(async_save()) + @classmethod + def create(cls, **kwargs): + model = cls.Meta.model + unique_fields = cls._meta.get("get_or_create", ["id"]) + + # Construct a query for unique fields + query = { + field: kwargs[field] for field in unique_fields if field in kwargs + } + + # Try to get an existing instance + if query: + + async def async_filter(): + return await model.filter(**query).first() + + instance = asyncio.run(async_filter()) + + if instance: + return instance + + # Create a new instance if none found + return super().create(**kwargs) + class TestFaker(unittest.TestCase): def setUp(self) -> None: