diff --git a/Makefile b/Makefile index f9d502e..b823e74 100644 --- a/Makefile +++ b/Makefile @@ -130,7 +130,7 @@ test-release: source $(VENV) && twine upload --repository testpypi dist/* mypy: - source $(VENV) && mypy . + source $(VENV) && mypy fake.py %: @: diff --git a/examples/sqlalchemy/article/factories.py b/examples/sqlalchemy/article/factories.py index d26a48e..92e4415 100644 --- a/examples/sqlalchemy/article/factories.py +++ b/examples/sqlalchemy/article/factories.py @@ -3,14 +3,15 @@ from fake import ( FACTORY, FileSystemStorage, + SQLAlchemyModelFactory, SubFactory, post_save, pre_save, trait, ) -from sqlalchemy_model_factory import SQLAlchemyModelFactory from article.models import Article, User +from config import SESSION __author__ = "Artur Barseghyan " __copyright__ = "2023 Artur Barseghyan" @@ -20,13 +21,16 @@ "UserFactory", ) -# Build paths inside the project like this: BASE_DIR / 'subdir'. +# Storage config. Build paths inside the project like this: BASE_DIR / 'subdir' BASE_DIR = Path(__file__).resolve().parent.parent MEDIA_ROOT = BASE_DIR / "media" - STORAGE = FileSystemStorage(root_path=MEDIA_ROOT, rel_path="tmp") +def get_session(): + return SESSION() + + class UserFactory(SQLAlchemyModelFactory): """User factory.""" @@ -44,6 +48,9 @@ class Meta: model = User get_or_create = ("username",) + class MetaSQLAlchemy: + get_session = get_session + @trait def is_admin_user(self, instance: User) -> None: instance.is_superuser = True @@ -74,6 +81,9 @@ class ArticleFactory(SQLAlchemyModelFactory): class Meta: model = Article + class MetaSQLAlchemy: + get_session = get_session + @pre_save def _pre_save_method(self, instance): instance.pre_save_called = True diff --git a/examples/sqlalchemy/article/tests.py b/examples/sqlalchemy/article/tests.py index 772378f..9f0f996 100644 --- a/examples/sqlalchemy/article/tests.py +++ b/examples/sqlalchemy/article/tests.py @@ -1,8 +1,7 @@ import unittest -from sqlalchemy_model_factory import SESSION - from article.factories import ArticleFactory, UserFactory +from config import SESSION __author__ = "Artur Barseghyan " __copyright__ = "2023 Artur Barseghyan" diff --git a/examples/sqlalchemy/config.py b/examples/sqlalchemy/config.py new file mode 100644 index 0000000..71c66b7 --- /dev/null +++ b/examples/sqlalchemy/config.py @@ -0,0 +1,16 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session, sessionmaker + +__author__ = "Artur Barseghyan " +__copyright__ = "2023 Artur Barseghyan" +__license__ = "MIT" +__all__ = ( + "DATABASE_URL", + "ENGINE", + "SESSION", +) + +# SQLAlchemy +DATABASE_URL = "sqlite:///test_database.db" +ENGINE = create_engine(DATABASE_URL) +SESSION = scoped_session(sessionmaker(bind=ENGINE)) diff --git a/examples/sqlalchemy/manage.py b/examples/sqlalchemy/manage.py index a0ca893..f3ec200 100644 --- a/examples/sqlalchemy/manage.py +++ b/examples/sqlalchemy/manage.py @@ -19,6 +19,12 @@ def main(): """Run administrative tasks based on command line arguments.""" sys.path.insert(0, os.path.abspath(os.path.join("..", ".."))) sys.path.insert(0, os.path.abspath(".")) + + from article.models import Base # noqa + from config import ENGINE # noqa + + Base.metadata.create_all(ENGINE) + parser = argparse.ArgumentParser( description="Management script for the project." ) @@ -26,12 +32,6 @@ def main(): args = parser.parse_args() - from sqlalchemy_model_factory import ENGINE - - from article.models import Base - - Base.metadata.create_all(ENGINE) - if args.command == "test": run_tests() elif args.command == "shell": diff --git a/examples/sqlalchemy/sqlalchemy_model_factory.py b/examples/sqlalchemy/sqlalchemy_model_factory.py deleted file mode 100644 index 863f6d6..0000000 --- a/examples/sqlalchemy/sqlalchemy_model_factory.py +++ /dev/null @@ -1,105 +0,0 @@ -from fake import FactoryMethod, ModelFactory, SubFactory -from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker - -DATABASE_URL = "sqlite:///test_database.db" - -ENGINE = create_engine(DATABASE_URL) -SESSION = scoped_session(sessionmaker(bind=ENGINE)) - -__author__ = "Artur Barseghyan " -__copyright__ = "2023 Artur Barseghyan" -__license__ = "MIT" -__all__ = ( - "SQLAlchemyModelFactory", - "ENGINE", - "SESSION", -) - - -class SQLAlchemyModelFactory(ModelFactory): - """SQLAlchemy ModelFactory.""" - - @classmethod - def get_session(cls): - return SESSION() - - @classmethod - def save(cls, instance): - session = cls.get_session() - session.add(instance) - session.commit() - - @classmethod - def create(cls, **kwargs): - session = cls.get_session() - - model = cls.Meta.model - unique_fields = cls._meta.get("get_or_create", ["id"]) - - # Check for existing instance - if unique_fields: - query_kwargs = {field: kwargs.get(field) for field in unique_fields} - instance = session.query(model).filter_by(**query_kwargs).first() - if instance: - return instance - - # Construct model_data from class attributes - model_data = { - field: (value() if isinstance(value, FactoryMethod) else value) - for field, value in cls.__dict__.items() - if ( - not field.startswith("_") - and not field == "Meta" - and not getattr(value, "is_trait", False) - and not getattr(value, "is_pre_save", False) - and not getattr(value, "is_post_save", False) - ) - } - - # Separate nested attributes and direct attributes - nested_attrs = {k: v for k, v in kwargs.items() if "__" in k} - direct_attrs = {k: v for k, v in kwargs.items() if "__" not in k} - - # Update direct attributes with callable results - for field, value in model_data.items(): - if isinstance(value, (FactoryMethod, SubFactory)): - model_data[field] = ( - value() - if field not in direct_attrs - else direct_attrs[field] - ) - - # Create a new instance - instance = model(**model_data) - cls._apply_traits(instance, **kwargs) - - # Handle nested attributes - for attr, value in nested_attrs.items(): - field_name, nested_attr = attr.split("__", 1) - if isinstance(getattr(cls, field_name, None), SubFactory): - related_instance = getattr( - cls, field_name - ).factory_class.create(**{nested_attr: value}) - setattr(instance, field_name, related_instance) - - # Run pre-save hooks - pre_save_hooks = [ - method - for method in dir(cls) - if getattr(getattr(cls, method), "is_pre_save", False) - ] - cls._run_hooks(pre_save_hooks, instance) - - # Save instance - cls.save(instance) - - # Run post-save hooks - post_save_hooks = [ - method - for method in dir(cls) - if getattr(getattr(cls, method), "is_post_save", False) - ] - cls._run_hooks(post_save_hooks, instance) - - return instance diff --git a/fake.py b/fake.py index 134da66..0906616 100644 --- a/fake.py +++ b/fake.py @@ -59,6 +59,7 @@ "ModelFactory", "PROVIDER_REGISTRY", "StringValue", + "SQLAlchemyModelFactory", "SubFactory", "TextPdfGenerator", "TortoiseModelFactory", @@ -2053,6 +2054,90 @@ async def async_related_instance(): return instance +class SQLAlchemyModelFactory(ModelFactory): + """SQLAlchemy ModelFactory.""" + + @classmethod + def save(cls, instance): + session = cls.MetaSQLAlchemy.get_session() + session.add(instance) + session.commit() + + @classmethod + def create(cls, **kwargs): + session = cls.MetaSQLAlchemy.get_session() + + model = cls.Meta.model + unique_fields = cls._meta.get("get_or_create", ["id"]) + + # Check for existing instance + if unique_fields: + query_kwargs = {field: kwargs.get(field) for field in unique_fields} + instance = session.query(model).filter_by(**query_kwargs).first() + if instance: + return instance + + # Construct model_data from class attributes + model_data = { + field: value + for field, value in cls.__dict__.items() + if ( + not field.startswith("_") + and not field.startswith("Meta") + and not getattr(value, "is_trait", False) + and not getattr(value, "is_pre_save", False) + and not getattr(value, "is_post_save", False) + ) + } + + # Separate nested attributes and direct attributes + nested_attrs = {k: v for k, v in kwargs.items() if "__" in k} + direct_attrs = {k: v for k, v in kwargs.items() if "__" not in k} + + # Update direct attributes with callable results + for field, value in model_data.items(): + if isinstance(value, (FactoryMethod, SubFactory)): + model_data[field] = ( + value() + if field not in direct_attrs + else direct_attrs[field] + ) + + # Create a new instance + instance = model(**model_data) + cls._apply_traits(instance, **kwargs) + + # Handle nested attributes + for attr, value in nested_attrs.items(): + field_name, nested_attr = attr.split("__", 1) + if isinstance(getattr(cls, field_name, None), SubFactory): + related_instance = getattr( + cls, field_name + ).factory_class.create(**{nested_attr: value}) + setattr(instance, field_name, related_instance) + + # Run pre-save hooks + pre_save_hooks = [ + method + for method in dir(cls) + if getattr(getattr(cls, method), "is_pre_save", False) + ] + cls._run_hooks(pre_save_hooks, instance) + + # Save instance + cls.save(instance) + + # Run post-save hooks + post_save_hooks = [ + method + for method in dir(cls) + if getattr(getattr(cls, method), "is_post_save", False) + ] + cls._run_hooks(post_save_hooks, instance) + + return instance + + # TODO: Remove once Python 3.8 support is dropped class ClassProperty(property): """ClassProperty.""" @@ -2763,10 +2848,10 @@ def objects(cls): # *********** Other ********** # **************************** - BASE_DIR = Path(__file__).resolve().parent.parent - MEDIA_ROOT = BASE_DIR / "media" + base_dir = Path(__file__).resolve().parent.parent + media_root = base_dir / "media" - STORAGE = FileSystemStorage(root_path=MEDIA_ROOT, rel_path="tmp") + storage = FileSystemStorage(root_path=media_root, rel_path="tmp") # **************************** # ******* ModelFactory ******* @@ -2806,7 +2891,7 @@ class ArticleFactory(ModelFactory): title = FACTORY.sentence() slug = FACTORY.slug() content = FACTORY.text() - image = FACTORY.png_file(storage=STORAGE) + image = FACTORY.png_file(storage=storage) pub_date = FACTORY.date() safe_for_work = FACTORY.pybool() minutes_to_read = FACTORY.pyint(min_value=1, max_value=10) @@ -2883,7 +2968,7 @@ class DjangoArticleFactory(DjangoModelFactory): title = FACTORY.sentence() slug = FACTORY.slug() content = FACTORY.text() - image = FACTORY.png_file(storage=STORAGE) + image = FACTORY.png_file(storage=storage) pub_date = FACTORY.date() safe_for_work = FACTORY.pybool() minutes_to_read = FACTORY.pyint(min_value=1, max_value=10) @@ -3059,7 +3144,7 @@ class TortoiseArticleFactory(TortoiseModelFactory): title = FACTORY.sentence() slug = FACTORY.slug() content = FACTORY.text() - image = FACTORY.png_file(storage=STORAGE) + image = FACTORY.png_file(storage=storage) pub_date = FACTORY.date() safe_for_work = FACTORY.pybool() minutes_to_read = FACTORY.pyint(min_value=1, max_value=10) @@ -3115,6 +3200,200 @@ def _post_save_method(self, instance): and tortoise_admin_user.is_active ) + # ********************************** + # ***** SQLAlchemyModelFactory ***** + # ********************************** + + class SQLAlchemySession: + return_instance_on_query_first: bool + + def __init__(self) -> None: + self.model = None + self.instance = None + + def query(self, model) -> "SQLAlchemySession": + self.model = model + return self + + def filter_by(self, **kwargs) -> "SQLAlchemySession": + return self + + def add(self, instance) -> None: + self.instance = instance + + def commit(self) -> None: + pass + + def first(self): + if not self.return_instance_on_query_first: + return None + + if self.model == SQLAlchemyUser: + return self.model( # noqa + id=FAKER.pyint(), + username=FAKER.username(), + first_name=FAKER.first_name(), + last_name=FAKER.last_name(), + email=FAKER.email(), + last_login=FAKER.date_time(), + date_joined=FAKER.date_time(), + ) + elif self.model == SQLAlchemyArticle: + return self.model( # noqa + id=FAKER.pyint(), + title=FAKER.word(), + slug=FAKER.slug(), + content=FAKER.text(), + author=TortoiseUser( + id=FAKER.pyint(), + username=FAKER.username(), + first_name=FAKER.first_name(), + last_name=FAKER.last_name(), + email=FAKER.email(), + last_login=FAKER.date_time(), + date_joined=FAKER.date_time(), + ), + ) + + class SQLAlchemySessionReturnNoneOnQueryFirst(SQLAlchemySession): + return_instance_on_query_first: bool = False + + class SQLAlchemySessionReturnInstanceOnQueryFirst(SQLAlchemySession): + return_instance_on_query_first: bool = True + + def get_session_return_instance_on_query_first(): + return SQLAlchemySessionReturnInstanceOnQueryFirst() + + def get_session_return_none_on_query_first(): + return SQLAlchemySessionReturnNoneOnQueryFirst() + + @dataclass + class SQLAlchemyUser: + """User model.""" + + id: int + username: str + first_name: str + last_name: str + email: str + last_login: Optional[datetime] + date_joined: Optional[datetime] + password: Optional[str] = None + is_superuser: bool = False + is_staff: bool = False + is_active: bool = True + + @dataclass + class SQLAlchemyArticle: + id: int + title: str + slug: str + content: str + author: User + image: Optional[ + str + ] = None # Use str to represent the image path or URL + pub_date: datetime = datetime.now() + safe_for_work: bool = False + minutes_to_read: int = 5 + + class SQLAlchemyUserFactory(SQLAlchemyModelFactory): + id = FACTORY.pyint() + username = FACTORY.username() + first_name = FACTORY.first_name() + last_name = FACTORY.last_name() + email = FACTORY.email() + last_login = FACTORY.date_time() + is_superuser = False + is_staff = False + is_active = FACTORY.pybool() + date_joined = FACTORY.date_time() + + class Meta: + model = SQLAlchemyUser + get_or_create = ("username",) + + class MetaSQLAlchemy: + get_session = get_session_return_none_on_query_first + + @trait + def is_admin_user(self, instance: SQLAlchemyUser) -> None: + instance.is_superuser = True + instance.is_staff = True + instance.is_active = True + + @pre_save + def _pre_save_method(self, instance): + instance.pre_save_called = True + + @post_save + def _post_save_method(self, instance): + instance.post_save_called = True + + class SQLAlchemyArticleFactory(SQLAlchemyModelFactory): + id = FACTORY.pyint() + title = FACTORY.sentence() + slug = FACTORY.slug() + content = FACTORY.text() + image = FACTORY.png_file(storage=storage) + pub_date = FACTORY.date() + safe_for_work = FACTORY.pybool() + minutes_to_read = FACTORY.pyint(min_value=1, max_value=10) + author = SubFactory(SQLAlchemyUserFactory) + + class Meta: + model = SQLAlchemyArticle + + class MetaSQLAlchemy: + get_session = get_session_return_none_on_query_first + + @pre_save + def _pre_save_method(self, instance): + instance.pre_save_called = True + + @post_save + def _post_save_method(self, instance): + instance.post_save_called = True + + sqlalchemy_article = SQLAlchemyArticleFactory(author__username="admin") + + # Testing SubFactory + self.assertIsInstance(sqlalchemy_article.author, SQLAlchemyUser) + self.assertIsInstance(sqlalchemy_article.author.id, int) + self.assertIsInstance(sqlalchemy_article.author.is_staff, bool) + self.assertIsInstance(sqlalchemy_article.author.date_joined, datetime) + # Since we're mimicking Tortoise's behaviour, the following line would + # fail on test, however would pass when testing against real Tortoise + # model (as done in the examples). + # self.assertEqual(tortoise_article.author.username, "admin") + + # Testing Factory + self.assertIsInstance(sqlalchemy_article.id, int) + self.assertIsInstance(sqlalchemy_article.slug, str) + + # Testing hooks + self.assertTrue( + hasattr(sqlalchemy_article, "pre_save_called") + and sqlalchemy_article.pre_save_called + ) + self.assertTrue( + hasattr(sqlalchemy_article, "post_save_called") + and sqlalchemy_article.post_save_called + ) + + # Testing batch creation + sqlalchemy_articles = SQLAlchemyArticleFactory.create_batch(5) + self.assertEqual(len(sqlalchemy_articles), 5) + self.assertIsInstance(sqlalchemy_articles[0], SQLAlchemyArticle) + + # Testing traits + sqlalchemy_admin_user = SQLAlchemyUserFactory(is_admin_user=True) + self.assertTrue( + sqlalchemy_admin_user.is_staff + and sqlalchemy_admin_user.is_superuser + and sqlalchemy_admin_user.is_active + ) + def test_registry_integration(self) -> None: """Test `add`.""" # Create a TXT file. diff --git a/pyproject.toml b/pyproject.toml index 8663231..3abe461 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,14 @@ force_grid_wrap = 0 use_parentheses = true ensure_newline_before_comments = true line_length = 80 -known_first_party = ["article", "address", "fake_address", "fake_band", "data"] +known_first_party = [ + "address", + "article", + "config", + "data", + "fake_address", + "fake_band", +] known_third_party = ["fake"] skip = ["wsgi.py", "builddocs/"]