Skip to content

Commit

Permalink
SQLAlchemy initial
Browse files Browse the repository at this point in the history
  • Loading branch information
barseghyanartur committed Dec 10, 2023
1 parent e94f38c commit 591ac33
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ fake.py.egg-info
matyan.log*
db.sqlite3
sample_db.sqlite
test_database.db
local_settings.py
/prof/
*.cast
Expand All @@ -40,3 +41,4 @@ examples/pydantic/media/
examples/tortoise/media/
examples/django/media/
examples/dataclasses/media/
examples/sqlalchemy/media/
6 changes: 3 additions & 3 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,16 @@
"filename": "Makefile",
"hashed_secret": "ee783f2421477b5483c23f47eca1f69a1f2bf4fb",
"is_verified": true,
"line_number": 80
"line_number": 86
},
{
"type": "Secret Keyword",
"filename": "Makefile",
"hashed_secret": "1457a35245051927fac6fa556074300f4162ed66",
"is_verified": true,
"line_number": 83
"line_number": 89
}
]
},
"generated_at": "2023-12-07T21:38:21Z"
"generated_at": "2023-12-10T22:39:54Z"
}
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ hypothesis-test:
pydantic-test:
source $(VENV) && cd examples/pydantic/ && python manage.py test

sqlalchemy-test:
source $(VENV) && cd examples/sqlalchemy/ && python manage.py test

tortoise-test:
source $(VENV) && cd examples/tortoise/ && python manage.py test

Expand All @@ -73,6 +76,9 @@ django-shell:
pydantic-shell:
source $(VENV) && cd examples/pydantic/ && python manage.py shell

sqlalchemy-shell:
source $(VENV) && cd examples/sqlalchemy/ && python manage.py shell

tortoise-shell:
source $(VENV) && cd examples/tortoise/ && python manage.py shell

Expand Down
Empty file.
83 changes: 83 additions & 0 deletions examples/sqlalchemy/article/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from pathlib import Path

from fake import (
FACTORY,
FileSystemStorage,
SubFactory,
post_save,
pre_save,
trait,
)
from sqlalchemy_model_factory import SQLAlchemyModelFactory

from article.models import Article, User

__author__ = "Artur Barseghyan <artur.barseghyan@gmail.com>"
__copyright__ = "2023 Artur Barseghyan"
__license__ = "MIT"
__all__ = (
"ArticleFactory",
"UserFactory",
)

# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent
MEDIA_ROOT = BASE_DIR / "media"

STORAGE = FileSystemStorage(root_path=MEDIA_ROOT, rel_path="tmp")


class UserFactory(SQLAlchemyModelFactory):
"""User factory."""

username = FACTORY.username()
first_name = FACTORY.first_name()
last_name = FACTORY.last_name()
email = FACTORY.email()
last_login = FACTORY.date_time()
is_superuser = False
is_staff = False
is_active = FACTORY.pybool()
date_joined = FACTORY.date_time()

class Meta:
model = User
get_or_create = ("username",)

@trait
def is_admin_user(self, instance: User) -> None:
instance.is_superuser = True
instance.is_staff = True
instance.is_active = True

@pre_save
def _pre_save_method(self, instance):
instance.pre_save_called = True

@post_save
def _post_save_method(self, instance):
instance.post_save_called = True


class ArticleFactory(SQLAlchemyModelFactory):
"""Article factory."""

title = FACTORY.sentence()
slug = FACTORY.slug()
content = FACTORY.text()
image = FACTORY.png_file(storage=STORAGE)
pub_date = FACTORY.date()
safe_for_work = FACTORY.pybool()
minutes_to_read = FACTORY.pyint(min_value=1, max_value=10)
author = SubFactory(UserFactory)

class Meta:
model = Article

@pre_save
def _pre_save_method(self, instance):
instance.pre_save_called = True

@post_save
def _post_save_method(self, instance):
instance.post_save_called = True
71 changes: 71 additions & 0 deletions examples/sqlalchemy/article/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from datetime import datetime

from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
Text,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship

__author__ = "Artur Barseghyan <artur.barseghyan@gmail.com>"
__copyright__ = "2023 Artur Barseghyan"
__license__ = "MIT"
__all__ = (
"Article",
"Base",
"User",
)


Base = declarative_base()


class User(Base):
"""User model."""

__tablename__ = "users"

id = Column(Integer, primary_key=True)
username = Column(String(255), unique=True)
first_name = Column(String(255))
last_name = Column(String(255))
email = Column(String(255))
password = Column(String(255), nullable=True)
last_login = Column(DateTime, nullable=True)
is_superuser = Column(Boolean, default=False)
is_staff = Column(Boolean, default=False)
is_active = Column(Boolean, default=True)
date_joined = Column(DateTime, nullable=True)

articles = relationship("Article", back_populates="author")

def __repr__(self):
return f"<User(username='{self.username}', email='{self.email}')>"


class Article(Base):
"""Article model."""

__tablename__ = "articles"

id = Column(Integer, primary_key=True)
title = Column(String(255))
slug = Column(String(255), unique=True)
content = Column(Text)
image = Column(Text, nullable=True)
pub_date = Column(DateTime, default=datetime.utcnow)
safe_for_work = Column(Boolean, default=False)
minutes_to_read = Column(Integer, default=5)
author_id = Column(Integer, ForeignKey("users.id"))

author = relationship("User", back_populates="articles")

def __repr__(self):
return (
f"<Article(title='{self.title}', author='{self.author.username}')>"
)
32 changes: 32 additions & 0 deletions examples/sqlalchemy/article/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest

from sqlalchemy_model_factory import SESSION

from article.factories import ArticleFactory, UserFactory

__author__ = "Artur Barseghyan <artur.barseghyan@gmail.com>"
__copyright__ = "2023 Artur Barseghyan"
__license__ = "MIT"
__all__ = ("TestFactories",)


class TestFactories(unittest.TestCase):
def setUp(self):
# Set up database session, if needed
self.session = SESSION()

def tearDown(self):
# Clean up the session after each test
self.session.rollback()
self.session.close()

def test_user_creation(self):
user = UserFactory(username="testuser")
self.assertIsNotNone(user.id)
self.assertEqual(user.username, "testuser")

def test_article_creation(self):
user = UserFactory(username="authoruser")
article = ArticleFactory(title="Test Article", author=user)
self.assertIsNotNone(article.id)
self.assertEqual(article.author.username, "authoruser")
44 changes: 44 additions & 0 deletions examples/sqlalchemy/manage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python
import argparse
import os
import sys
import unittest

import IPython


def run_tests():
"""Function to run tests in the article directory."""
loader = unittest.TestLoader()
suite = loader.discover(start_dir=".", pattern="tests.py")
runner = unittest.TextTestRunner()
runner.run(suite)


def main():
"""Run administrative tasks based on command line arguments."""
sys.path.insert(0, os.path.abspath(os.path.join("..", "..")))
sys.path.insert(0, os.path.abspath("."))
parser = argparse.ArgumentParser(
description="Management script for the project."
)
parser.add_argument("command", help="The command to run (test or shell)")

args = parser.parse_args()

from sqlalchemy_model_factory import ENGINE

from article.models import Base

Base.metadata.create_all(ENGINE)

if args.command == "test":
run_tests()
elif args.command == "shell":
IPython.embed()
else:
print("Unknown command. Use 'test' or 'shell'.")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/sqlalchemy/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sqlalchemy
105 changes: 105 additions & 0 deletions examples/sqlalchemy/sqlalchemy_model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from fake import FactoryMethod, ModelFactory, SubFactory
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker

DATABASE_URL = "sqlite:///test_database.db"

ENGINE = create_engine(DATABASE_URL)
SESSION = scoped_session(sessionmaker(bind=ENGINE))

__author__ = "Artur Barseghyan <artur.barseghyan@gmail.com>"
__copyright__ = "2023 Artur Barseghyan"
__license__ = "MIT"
__all__ = (
"SQLAlchemyModelFactory",
"ENGINE",
"SESSION",
)


class SQLAlchemyModelFactory(ModelFactory):
"""SQLAlchemy ModelFactory."""

@classmethod
def get_session(cls):
return SESSION()

@classmethod
def save(cls, instance):
session = cls.get_session()
session.add(instance)
session.commit()

@classmethod
def create(cls, **kwargs):
session = cls.get_session()

model = cls.Meta.model
unique_fields = cls._meta.get("get_or_create", ["id"])

# Check for existing instance
if unique_fields:
query_kwargs = {field: kwargs.get(field) for field in unique_fields}
instance = session.query(model).filter_by(**query_kwargs).first()
if instance:
return instance

# Construct model_data from class attributes
model_data = {
field: (value() if isinstance(value, FactoryMethod) else value)
for field, value in cls.__dict__.items()
if (
not field.startswith("_")
and not field == "Meta"
and not getattr(value, "is_trait", False)
and not getattr(value, "is_pre_save", False)
and not getattr(value, "is_post_save", False)
)
}

# Separate nested attributes and direct attributes
nested_attrs = {k: v for k, v in kwargs.items() if "__" in k}
direct_attrs = {k: v for k, v in kwargs.items() if "__" not in k}

# Update direct attributes with callable results
for field, value in model_data.items():
if isinstance(value, (FactoryMethod, SubFactory)):
model_data[field] = (
value()
if field not in direct_attrs
else direct_attrs[field]
)

# Create a new instance
instance = model(**model_data)
cls._apply_traits(instance, **kwargs)

# Handle nested attributes
for attr, value in nested_attrs.items():
field_name, nested_attr = attr.split("__", 1)
if isinstance(getattr(cls, field_name, None), SubFactory):
related_instance = getattr(
cls, field_name
).factory_class.create(**{nested_attr: value})
setattr(instance, field_name, related_instance)

# Run pre-save hooks
pre_save_hooks = [
method
for method in dir(cls)
if getattr(getattr(cls, method), "is_pre_save", False)
]
cls._run_hooks(pre_save_hooks, instance)

# Save instance
cls.save(instance)

# Run post-save hooks
post_save_hooks = [
method
for method in dir(cls)
if getattr(getattr(cls, method), "is_post_save", False)
]
cls._run_hooks(post_save_hooks, instance)

return instance
Loading

0 comments on commit 591ac33

Please sign in to comment.