diff --git a/.gitignore b/.gitignore index ce0b4d8..9e048bf 100755 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ fake.py.egg-info matyan.log* db.sqlite3 sample_db.sqlite +test_database.db local_settings.py /prof/ *.cast @@ -40,3 +41,4 @@ examples/pydantic/media/ examples/tortoise/media/ examples/django/media/ examples/dataclasses/media/ +examples/sqlalchemy/media/ diff --git a/.secrets.baseline b/.secrets.baseline index 53d2dab..39a309e 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -118,16 +118,16 @@ "filename": "Makefile", "hashed_secret": "ee783f2421477b5483c23f47eca1f69a1f2bf4fb", "is_verified": true, - "line_number": 80 + "line_number": 86 }, { "type": "Secret Keyword", "filename": "Makefile", "hashed_secret": "1457a35245051927fac6fa556074300f4162ed66", "is_verified": true, - "line_number": 83 + "line_number": 89 } ] }, - "generated_at": "2023-12-07T21:38:21Z" + "generated_at": "2023-12-10T22:39:54Z" } diff --git a/Makefile b/Makefile index 1a2c7e6..f9d502e 100644 --- a/Makefile +++ b/Makefile @@ -55,6 +55,9 @@ hypothesis-test: pydantic-test: source $(VENV) && cd examples/pydantic/ && python manage.py test +sqlalchemy-test: + source $(VENV) && cd examples/sqlalchemy/ && python manage.py test + tortoise-test: source $(VENV) && cd examples/tortoise/ && python manage.py test @@ -73,6 +76,9 @@ django-shell: pydantic-shell: source $(VENV) && cd examples/pydantic/ && python manage.py shell +sqlalchemy-shell: + source $(VENV) && cd examples/sqlalchemy/ && python manage.py shell + tortoise-shell: source $(VENV) && cd examples/tortoise/ && python manage.py shell diff --git a/examples/sqlalchemy/article/__init__.py b/examples/sqlalchemy/article/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/sqlalchemy/article/factories.py b/examples/sqlalchemy/article/factories.py new file mode 100644 index 0000000..d26a48e --- /dev/null +++ b/examples/sqlalchemy/article/factories.py @@ -0,0 +1,83 @@ +from pathlib import Path + +from fake import ( + FACTORY, + FileSystemStorage, + SubFactory, + post_save, + pre_save, + trait, +) +from sqlalchemy_model_factory import SQLAlchemyModelFactory + +from article.models import Article, User + +__author__ = "Artur Barseghyan " +__copyright__ = "2023 Artur Barseghyan" +__license__ = "MIT" +__all__ = ( + "ArticleFactory", + "UserFactory", +) + +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent +MEDIA_ROOT = BASE_DIR / "media" + +STORAGE = FileSystemStorage(root_path=MEDIA_ROOT, rel_path="tmp") + + +class UserFactory(SQLAlchemyModelFactory): + """User factory.""" + + username = FACTORY.username() + first_name = FACTORY.first_name() + last_name = FACTORY.last_name() + email = FACTORY.email() + last_login = FACTORY.date_time() + is_superuser = False + is_staff = False + is_active = FACTORY.pybool() + date_joined = FACTORY.date_time() + + class Meta: + model = User + get_or_create = ("username",) + + @trait + def is_admin_user(self, instance: User) -> None: + instance.is_superuser = True + instance.is_staff = True + instance.is_active = True + + @pre_save + def _pre_save_method(self, instance): + instance.pre_save_called = True + + @post_save + def _post_save_method(self, instance): + instance.post_save_called = True + + +class ArticleFactory(SQLAlchemyModelFactory): + """Article factory.""" + + title = FACTORY.sentence() + slug = FACTORY.slug() + content = FACTORY.text() + image = FACTORY.png_file(storage=STORAGE) + pub_date = FACTORY.date() + safe_for_work = FACTORY.pybool() + minutes_to_read = FACTORY.pyint(min_value=1, max_value=10) + author = SubFactory(UserFactory) + + class Meta: + model = Article + + @pre_save + def _pre_save_method(self, instance): + instance.pre_save_called = True + + @post_save + def _post_save_method(self, instance): + instance.post_save_called = True diff --git a/examples/sqlalchemy/article/models.py b/examples/sqlalchemy/article/models.py new file mode 100644 index 0000000..e3a046e --- /dev/null +++ b/examples/sqlalchemy/article/models.py @@ -0,0 +1,71 @@ +from datetime import datetime + +from sqlalchemy import ( + Boolean, + Column, + DateTime, + ForeignKey, + Integer, + String, + Text, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + +__author__ = "Artur Barseghyan " +__copyright__ = "2023 Artur Barseghyan" +__license__ = "MIT" +__all__ = ( + "Article", + "Base", + "User", +) + + +Base = declarative_base() + + +class User(Base): + """User model.""" + + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + username = Column(String(255), unique=True) + first_name = Column(String(255)) + last_name = Column(String(255)) + email = Column(String(255)) + password = Column(String(255), nullable=True) + last_login = Column(DateTime, nullable=True) + is_superuser = Column(Boolean, default=False) + is_staff = Column(Boolean, default=False) + is_active = Column(Boolean, default=True) + date_joined = Column(DateTime, nullable=True) + + articles = relationship("Article", back_populates="author") + + def __repr__(self): + return f"" + + +class Article(Base): + """Article model.""" + + __tablename__ = "articles" + + id = Column(Integer, primary_key=True) + title = Column(String(255)) + slug = Column(String(255), unique=True) + content = Column(Text) + image = Column(Text, nullable=True) + pub_date = Column(DateTime, default=datetime.utcnow) + safe_for_work = Column(Boolean, default=False) + minutes_to_read = Column(Integer, default=5) + author_id = Column(Integer, ForeignKey("users.id")) + + author = relationship("User", back_populates="articles") + + def __repr__(self): + return ( + f"" + ) diff --git a/examples/sqlalchemy/article/tests.py b/examples/sqlalchemy/article/tests.py new file mode 100644 index 0000000..772378f --- /dev/null +++ b/examples/sqlalchemy/article/tests.py @@ -0,0 +1,32 @@ +import unittest + +from sqlalchemy_model_factory import SESSION + +from article.factories import ArticleFactory, UserFactory + +__author__ = "Artur Barseghyan " +__copyright__ = "2023 Artur Barseghyan" +__license__ = "MIT" +__all__ = ("TestFactories",) + + +class TestFactories(unittest.TestCase): + def setUp(self): + # Set up database session, if needed + self.session = SESSION() + + def tearDown(self): + # Clean up the session after each test + self.session.rollback() + self.session.close() + + def test_user_creation(self): + user = UserFactory(username="testuser") + self.assertIsNotNone(user.id) + self.assertEqual(user.username, "testuser") + + def test_article_creation(self): + user = UserFactory(username="authoruser") + article = ArticleFactory(title="Test Article", author=user) + self.assertIsNotNone(article.id) + self.assertEqual(article.author.username, "authoruser") diff --git a/examples/sqlalchemy/manage.py b/examples/sqlalchemy/manage.py new file mode 100644 index 0000000..a0ca893 --- /dev/null +++ b/examples/sqlalchemy/manage.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import unittest + +import IPython + + +def run_tests(): + """Function to run tests in the article directory.""" + loader = unittest.TestLoader() + suite = loader.discover(start_dir=".", pattern="tests.py") + runner = unittest.TextTestRunner() + runner.run(suite) + + +def main(): + """Run administrative tasks based on command line arguments.""" + sys.path.insert(0, os.path.abspath(os.path.join("..", ".."))) + sys.path.insert(0, os.path.abspath(".")) + parser = argparse.ArgumentParser( + description="Management script for the project." + ) + parser.add_argument("command", help="The command to run (test or shell)") + + args = parser.parse_args() + + from sqlalchemy_model_factory import ENGINE + + from article.models import Base + + Base.metadata.create_all(ENGINE) + + if args.command == "test": + run_tests() + elif args.command == "shell": + IPython.embed() + else: + print("Unknown command. Use 'test' or 'shell'.") + + +if __name__ == "__main__": + main() diff --git a/examples/sqlalchemy/requirements.in b/examples/sqlalchemy/requirements.in new file mode 100644 index 0000000..39fb2be --- /dev/null +++ b/examples/sqlalchemy/requirements.in @@ -0,0 +1 @@ +sqlalchemy diff --git a/examples/sqlalchemy/sqlalchemy_model_factory.py b/examples/sqlalchemy/sqlalchemy_model_factory.py new file mode 100644 index 0000000..863f6d6 --- /dev/null +++ b/examples/sqlalchemy/sqlalchemy_model_factory.py @@ -0,0 +1,105 @@ +from fake import FactoryMethod, ModelFactory, SubFactory +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session, sessionmaker + +DATABASE_URL = "sqlite:///test_database.db" + +ENGINE = create_engine(DATABASE_URL) +SESSION = scoped_session(sessionmaker(bind=ENGINE)) + +__author__ = "Artur Barseghyan " +__copyright__ = "2023 Artur Barseghyan" +__license__ = "MIT" +__all__ = ( + "SQLAlchemyModelFactory", + "ENGINE", + "SESSION", +) + + +class SQLAlchemyModelFactory(ModelFactory): + """SQLAlchemy ModelFactory.""" + + @classmethod + def get_session(cls): + return SESSION() + + @classmethod + def save(cls, instance): + session = cls.get_session() + session.add(instance) + session.commit() + + @classmethod + def create(cls, **kwargs): + session = cls.get_session() + + model = cls.Meta.model + unique_fields = cls._meta.get("get_or_create", ["id"]) + + # Check for existing instance + if unique_fields: + query_kwargs = {field: kwargs.get(field) for field in unique_fields} + instance = session.query(model).filter_by(**query_kwargs).first() + if instance: + return instance + + # Construct model_data from class attributes + model_data = { + field: (value() if isinstance(value, FactoryMethod) else value) + for field, value in cls.__dict__.items() + if ( + not field.startswith("_") + and not field == "Meta" + and not getattr(value, "is_trait", False) + and not getattr(value, "is_pre_save", False) + and not getattr(value, "is_post_save", False) + ) + } + + # Separate nested attributes and direct attributes + nested_attrs = {k: v for k, v in kwargs.items() if "__" in k} + direct_attrs = {k: v for k, v in kwargs.items() if "__" not in k} + + # Update direct attributes with callable results + for field, value in model_data.items(): + if isinstance(value, (FactoryMethod, SubFactory)): + model_data[field] = ( + value() + if field not in direct_attrs + else direct_attrs[field] + ) + + # Create a new instance + instance = model(**model_data) + cls._apply_traits(instance, **kwargs) + + # Handle nested attributes + for attr, value in nested_attrs.items(): + field_name, nested_attr = attr.split("__", 1) + if isinstance(getattr(cls, field_name, None), SubFactory): + related_instance = getattr( + cls, field_name + ).factory_class.create(**{nested_attr: value}) + setattr(instance, field_name, related_instance) + + # Run pre-save hooks + pre_save_hooks = [ + method + for method in dir(cls) + if getattr(getattr(cls, method), "is_pre_save", False) + ] + cls._run_hooks(pre_save_hooks, instance) + + # Save instance + cls.save(instance) + + # Run post-save hooks + post_save_hooks = [ + method + for method in dir(cls) + if getattr(getattr(cls, method), "is_post_save", False) + ] + cls._run_hooks(post_save_hooks, instance) + + return instance diff --git a/fake.py b/fake.py index ab8eaba..134da66 100644 --- a/fake.py +++ b/fake.py @@ -50,6 +50,7 @@ "FAKER", "FILE_REGISTRY", "Factory", + "FactoryMethod", "Faker", "FileRegistry", "FileSystemStorage",