Skip to content

Commit

Permalink
More
Browse files Browse the repository at this point in the history
  • Loading branch information
barseghyanartur committed Dec 1, 2023
1 parent aad30a3 commit 21a7d76
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 24 deletions.
9 changes: 7 additions & 2 deletions examples/tortoise/article/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@


class UserFactory(TortoiseModelFactory):
id = FACTORY.pyint()
"""User factory."""

# id = FACTORY.pyint()
username = FACTORY.username()
first_name = FACTORY.first_name()
last_name = FACTORY.last_name()
Expand All @@ -30,10 +32,13 @@ class UserFactory(TortoiseModelFactory):

class Meta:
model = User
get_or_create = ("username",)


class ArticleFactory(TortoiseModelFactory):
id = FACTORY.pyint()
"""Article factory."""

# id = FACTORY.pyint()
title = FACTORY.sentence()
slug = FACTORY.slug()
content = FACTORY.text()
Expand Down
6 changes: 5 additions & 1 deletion examples/tortoise/article/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@


class User(Model):
"""User model."""

id = fields.IntField(pk=True)
username = fields.CharField(max_length=255)
username = fields.CharField(max_length=255, unique=True)
first_name = fields.CharField(max_length=255)
last_name = fields.CharField(max_length=255)
email = fields.CharField(max_length=255)
Expand All @@ -27,6 +29,8 @@ def __str__(self):


class Article(Model):
"""Article model."""

id = fields.IntField(pk=True)
title = fields.CharField(max_length=255)
slug = fields.CharField(max_length=255, unique=True)
Expand Down
66 changes: 45 additions & 21 deletions fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,27 @@ def post_save(func):
class ModelFactory:
"""ModelFactory."""

class Meta:
get_or_create = ("id",) # Default fields for get_or_create

def __init_subclass__(cls, **kwargs):
base_meta = getattr(
cls.__bases__[0],
"_meta",
{
attr: getattr(cls.__bases__[0].Meta, attr)
for attr in dir(cls.__bases__[0].Meta)
if not attr.startswith("_")
},
)
cls_meta = {
attr: getattr(cls.Meta, attr)
for attr in dir(cls.Meta)
if not attr.startswith("_")
}

cls._meta = {**base_meta, **cls_meta}

@classmethod
def _run_hooks(cls, hooks, instance):
for method in hooks:
Expand Down Expand Up @@ -1678,27 +1699,6 @@ def save(cls, instance):
class DjangoModelFactory(ModelFactory):
"""Django ModelFactory."""

class Meta:
get_or_create = ("id",) # Default fields for get_or_create

def __init_subclass__(cls, **kwargs):
base_meta = getattr(
cls.__bases__[0],
"_meta",
{
attr: getattr(cls.__bases__[0].Meta, attr)
for attr in dir(cls.__bases__[0].Meta)
if not attr.startswith("_")
},
)
cls_meta = {
attr: getattr(cls.Meta, attr)
for attr in dir(cls.Meta)
if not attr.startswith("_")
}

cls._meta = {**base_meta, **cls_meta}

@classmethod
def save(cls, instance):
instance.save()
Expand Down Expand Up @@ -1784,6 +1784,30 @@ async def async_save():

asyncio.run(async_save())

@classmethod
def create(cls, **kwargs):
model = cls.Meta.model
unique_fields = cls._meta.get("get_or_create", ["id"])

# Construct a query for unique fields
query = {
field: kwargs[field] for field in unique_fields if field in kwargs
}

# Try to get an existing instance
if query:

async def async_filter():
return await model.filter(**query).first()

instance = asyncio.run(async_filter())

if instance:
return instance

# Create a new instance if none found
return super().create(**kwargs)


class TestFaker(unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit 21a7d76

Please sign in to comment.