-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e94f38c
commit 591ac33
Showing
11 changed files
with
348 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from pathlib import Path | ||
|
||
from fake import ( | ||
FACTORY, | ||
FileSystemStorage, | ||
SubFactory, | ||
post_save, | ||
pre_save, | ||
trait, | ||
) | ||
from sqlalchemy_model_factory import SQLAlchemyModelFactory | ||
|
||
from article.models import Article, User | ||
|
||
__author__ = "Artur Barseghyan <artur.barseghyan@gmail.com>" | ||
__copyright__ = "2023 Artur Barseghyan" | ||
__license__ = "MIT" | ||
__all__ = ( | ||
"ArticleFactory", | ||
"UserFactory", | ||
) | ||
|
||
# Build paths inside the project like this: BASE_DIR / 'subdir'. | ||
BASE_DIR = Path(__file__).resolve().parent.parent | ||
MEDIA_ROOT = BASE_DIR / "media" | ||
|
||
STORAGE = FileSystemStorage(root_path=MEDIA_ROOT, rel_path="tmp") | ||
|
||
|
||
class UserFactory(SQLAlchemyModelFactory): | ||
"""User factory.""" | ||
|
||
username = FACTORY.username() | ||
first_name = FACTORY.first_name() | ||
last_name = FACTORY.last_name() | ||
email = FACTORY.email() | ||
last_login = FACTORY.date_time() | ||
is_superuser = False | ||
is_staff = False | ||
is_active = FACTORY.pybool() | ||
date_joined = FACTORY.date_time() | ||
|
||
class Meta: | ||
model = User | ||
get_or_create = ("username",) | ||
|
||
@trait | ||
def is_admin_user(self, instance: User) -> None: | ||
instance.is_superuser = True | ||
instance.is_staff = True | ||
instance.is_active = True | ||
|
||
@pre_save | ||
def _pre_save_method(self, instance): | ||
instance.pre_save_called = True | ||
|
||
@post_save | ||
def _post_save_method(self, instance): | ||
instance.post_save_called = True | ||
|
||
|
||
class ArticleFactory(SQLAlchemyModelFactory): | ||
"""Article factory.""" | ||
|
||
title = FACTORY.sentence() | ||
slug = FACTORY.slug() | ||
content = FACTORY.text() | ||
image = FACTORY.png_file(storage=STORAGE) | ||
pub_date = FACTORY.date() | ||
safe_for_work = FACTORY.pybool() | ||
minutes_to_read = FACTORY.pyint(min_value=1, max_value=10) | ||
author = SubFactory(UserFactory) | ||
|
||
class Meta: | ||
model = Article | ||
|
||
@pre_save | ||
def _pre_save_method(self, instance): | ||
instance.pre_save_called = True | ||
|
||
@post_save | ||
def _post_save_method(self, instance): | ||
instance.post_save_called = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from datetime import datetime | ||
|
||
from sqlalchemy import ( | ||
Boolean, | ||
Column, | ||
DateTime, | ||
ForeignKey, | ||
Integer, | ||
String, | ||
Text, | ||
) | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from sqlalchemy.orm import relationship | ||
|
||
__author__ = "Artur Barseghyan <artur.barseghyan@gmail.com>" | ||
__copyright__ = "2023 Artur Barseghyan" | ||
__license__ = "MIT" | ||
__all__ = ( | ||
"Article", | ||
"Base", | ||
"User", | ||
) | ||
|
||
|
||
Base = declarative_base() | ||
|
||
|
||
class User(Base): | ||
"""User model.""" | ||
|
||
__tablename__ = "users" | ||
|
||
id = Column(Integer, primary_key=True) | ||
username = Column(String(255), unique=True) | ||
first_name = Column(String(255)) | ||
last_name = Column(String(255)) | ||
email = Column(String(255)) | ||
password = Column(String(255), nullable=True) | ||
last_login = Column(DateTime, nullable=True) | ||
is_superuser = Column(Boolean, default=False) | ||
is_staff = Column(Boolean, default=False) | ||
is_active = Column(Boolean, default=True) | ||
date_joined = Column(DateTime, nullable=True) | ||
|
||
articles = relationship("Article", back_populates="author") | ||
|
||
def __repr__(self): | ||
return f"<User(username='{self.username}', email='{self.email}')>" | ||
|
||
|
||
class Article(Base): | ||
"""Article model.""" | ||
|
||
__tablename__ = "articles" | ||
|
||
id = Column(Integer, primary_key=True) | ||
title = Column(String(255)) | ||
slug = Column(String(255), unique=True) | ||
content = Column(Text) | ||
image = Column(Text, nullable=True) | ||
pub_date = Column(DateTime, default=datetime.utcnow) | ||
safe_for_work = Column(Boolean, default=False) | ||
minutes_to_read = Column(Integer, default=5) | ||
author_id = Column(Integer, ForeignKey("users.id")) | ||
|
||
author = relationship("User", back_populates="articles") | ||
|
||
def __repr__(self): | ||
return ( | ||
f"<Article(title='{self.title}', author='{self.author.username}')>" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import unittest | ||
|
||
from sqlalchemy_model_factory import SESSION | ||
|
||
from article.factories import ArticleFactory, UserFactory | ||
|
||
__author__ = "Artur Barseghyan <artur.barseghyan@gmail.com>" | ||
__copyright__ = "2023 Artur Barseghyan" | ||
__license__ = "MIT" | ||
__all__ = ("TestFactories",) | ||
|
||
|
||
class TestFactories(unittest.TestCase): | ||
def setUp(self): | ||
# Set up database session, if needed | ||
self.session = SESSION() | ||
|
||
def tearDown(self): | ||
# Clean up the session after each test | ||
self.session.rollback() | ||
self.session.close() | ||
|
||
def test_user_creation(self): | ||
user = UserFactory(username="testuser") | ||
self.assertIsNotNone(user.id) | ||
self.assertEqual(user.username, "testuser") | ||
|
||
def test_article_creation(self): | ||
user = UserFactory(username="authoruser") | ||
article = ArticleFactory(title="Test Article", author=user) | ||
self.assertIsNotNone(article.id) | ||
self.assertEqual(article.author.username, "authoruser") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/usr/bin/env python | ||
import argparse | ||
import os | ||
import sys | ||
import unittest | ||
|
||
import IPython | ||
|
||
|
||
def run_tests(): | ||
"""Function to run tests in the article directory.""" | ||
loader = unittest.TestLoader() | ||
suite = loader.discover(start_dir=".", pattern="tests.py") | ||
runner = unittest.TextTestRunner() | ||
runner.run(suite) | ||
|
||
|
||
def main(): | ||
"""Run administrative tasks based on command line arguments.""" | ||
sys.path.insert(0, os.path.abspath(os.path.join("..", ".."))) | ||
sys.path.insert(0, os.path.abspath(".")) | ||
parser = argparse.ArgumentParser( | ||
description="Management script for the project." | ||
) | ||
parser.add_argument("command", help="The command to run (test or shell)") | ||
|
||
args = parser.parse_args() | ||
|
||
from sqlalchemy_model_factory import ENGINE | ||
|
||
from article.models import Base | ||
|
||
Base.metadata.create_all(ENGINE) | ||
|
||
if args.command == "test": | ||
run_tests() | ||
elif args.command == "shell": | ||
IPython.embed() | ||
else: | ||
print("Unknown command. Use 'test' or 'shell'.") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
sqlalchemy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from fake import FactoryMethod, ModelFactory, SubFactory | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import scoped_session, sessionmaker | ||
|
||
DATABASE_URL = "sqlite:///test_database.db" | ||
|
||
ENGINE = create_engine(DATABASE_URL) | ||
SESSION = scoped_session(sessionmaker(bind=ENGINE)) | ||
|
||
__author__ = "Artur Barseghyan <artur.barseghyan@gmail.com>" | ||
__copyright__ = "2023 Artur Barseghyan" | ||
__license__ = "MIT" | ||
__all__ = ( | ||
"SQLAlchemyModelFactory", | ||
"ENGINE", | ||
"SESSION", | ||
) | ||
|
||
|
||
class SQLAlchemyModelFactory(ModelFactory): | ||
"""SQLAlchemy ModelFactory.""" | ||
|
||
@classmethod | ||
def get_session(cls): | ||
return SESSION() | ||
|
||
@classmethod | ||
def save(cls, instance): | ||
session = cls.get_session() | ||
session.add(instance) | ||
session.commit() | ||
|
||
@classmethod | ||
def create(cls, **kwargs): | ||
session = cls.get_session() | ||
|
||
model = cls.Meta.model | ||
unique_fields = cls._meta.get("get_or_create", ["id"]) | ||
|
||
# Check for existing instance | ||
if unique_fields: | ||
query_kwargs = {field: kwargs.get(field) for field in unique_fields} | ||
instance = session.query(model).filter_by(**query_kwargs).first() | ||
if instance: | ||
return instance | ||
|
||
# Construct model_data from class attributes | ||
model_data = { | ||
field: (value() if isinstance(value, FactoryMethod) else value) | ||
for field, value in cls.__dict__.items() | ||
if ( | ||
not field.startswith("_") | ||
and not field == "Meta" | ||
and not getattr(value, "is_trait", False) | ||
and not getattr(value, "is_pre_save", False) | ||
and not getattr(value, "is_post_save", False) | ||
) | ||
} | ||
|
||
# Separate nested attributes and direct attributes | ||
nested_attrs = {k: v for k, v in kwargs.items() if "__" in k} | ||
direct_attrs = {k: v for k, v in kwargs.items() if "__" not in k} | ||
|
||
# Update direct attributes with callable results | ||
for field, value in model_data.items(): | ||
if isinstance(value, (FactoryMethod, SubFactory)): | ||
model_data[field] = ( | ||
value() | ||
if field not in direct_attrs | ||
else direct_attrs[field] | ||
) | ||
|
||
# Create a new instance | ||
instance = model(**model_data) | ||
cls._apply_traits(instance, **kwargs) | ||
|
||
# Handle nested attributes | ||
for attr, value in nested_attrs.items(): | ||
field_name, nested_attr = attr.split("__", 1) | ||
if isinstance(getattr(cls, field_name, None), SubFactory): | ||
related_instance = getattr( | ||
cls, field_name | ||
).factory_class.create(**{nested_attr: value}) | ||
setattr(instance, field_name, related_instance) | ||
|
||
# Run pre-save hooks | ||
pre_save_hooks = [ | ||
method | ||
for method in dir(cls) | ||
if getattr(getattr(cls, method), "is_pre_save", False) | ||
] | ||
cls._run_hooks(pre_save_hooks, instance) | ||
|
||
# Save instance | ||
cls.save(instance) | ||
|
||
# Run post-save hooks | ||
post_save_hooks = [ | ||
method | ||
for method in dir(cls) | ||
if getattr(getattr(cls, method), "is_post_save", False) | ||
] | ||
cls._run_hooks(post_save_hooks, instance) | ||
|
||
return instance |
Oops, something went wrong.