Skip to content

Commit

Permalink
Refactor database wrapper to use sqlalchemy engine instead of aiomysql
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackYps committed Jun 4, 2024
1 parent 4c1b6cc commit fbfcd2a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 87 deletions.
6 changes: 2 additions & 4 deletions service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ def signal_handler(sig: int, _frame):
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)

database = db.FAFDatabase(loop)
await database.connect(
database = db.FAFDatabase(
host=config.DB_SERVER,
port=int(config.DB_PORT),
user=config.DB_LOGIN,
password=config.DB_PASSWORD,
maxsize=10,
db=config.DB_NAME,
db=config.DB_NAME
)
logger.info("Database connected.")

Expand Down
49 changes: 17 additions & 32 deletions service/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,27 @@
from aiomysql.sa import create_engine
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncEngine


class FAFDatabase:
def __init__(self, loop):
self._loop = loop
self.engine = None

async def connect(
def __init__(
self,
host="localhost",
port=3306,
user="root",
password="",
db="faf_test",
minsize=1,
maxsize=1,
host: str = "localhost",
port: int = 3306,
user: str = "root",
password: str = "",
db: str = "faf_test",
**kwargs
):
if self.engine is not None:
raise ValueError("DB is already connected!")
self.engine = await create_engine(
host=host,
port=port,
user=user,
password=password,
db=db,
autocommit=True,
loop=self._loop,
minsize=minsize,
maxsize=maxsize,
kwargs["future"] = True
sync_engine = create_engine(
f"mysql+aiomysql://{user}:{password}@{host}:{port}/{db}",
**kwargs
)

self.engine = AsyncEngine(sync_engine)

def acquire(self):
return self.engine.acquire()
return self.engine.begin()

async def close(self):
if self.engine is None:
return

self.engine.close()
await self.engine.wait_closed()
self.engine = None
await self.engine.dispose()
25 changes: 16 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def test_data(request):
await db.close()


async def global_database(request):
async def global_database(request) -> FAFDatabase:
def opt(val):
return request.config.getoption(val)

Expand All @@ -89,11 +89,13 @@ def opt(val):
opt("--mysql_database"),
opt("--mysql_port"),
)
db = FAFDatabase(asyncio.get_running_loop())

await db.connect(host=host, user=user, password=pw or None, port=port, db=name)

return db
return FAFDatabase(
host=host,
user=user,
password=pw or "",
port=port,
db=name
)


@pytest.fixture
Expand All @@ -108,9 +110,14 @@ def opt(val):
opt("--mysql_database"),
opt("--mysql_port"),
)
db = MockDatabase(event_loop)

await db.connect(host=host, user=user, password=pw or None, port=port, db=name)
db = MockDatabase(
host=host,
user=user,
password=pw or "",
port=port,
db=name
)
await db.connect()

yield db

Expand Down
63 changes: 21 additions & 42 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import functools
from asyncio import Event, Lock

import asynctest
from aiomysql.sa import create_engine

from service.db import FAFDatabase


# Copied over from PR #113 of pytest-asyncio. It will probably be available in
Expand Down Expand Up @@ -84,7 +84,7 @@ async def __aexit__(self, exc_type, exc, tb):
self._db._lock.release()


class MockDatabase:
class MockDatabase(FAFDatabase):
"""
This class mocks the FAFDatabase class, rolling back all transactions
performed during tests. To do that, it proxies the real db engine, giving
Expand All @@ -96,59 +96,38 @@ class MockDatabase:
Any future manual commit() calls should be mocked here as well.
"""

def __init__(self, loop):
self._loop = loop
self.engine = None
self._connection = None
self._conn_present = Event()
self._keep = None
self._lock = Lock()
self._done = Event()

async def connect(
def __init__(
self,
host="localhost",
port=3306,
user="root",
password="",
db="faf_test",
minsize=1,
maxsize=1,
host: str = "localhost",
port: int = 3306,
user: str = "root",
password: str = "",
db: str = "faf_test",
**kwargs
):
if self.engine is not None:
raise ValueError("DB is already connected!")
self.engine = await create_engine(
host=host,
port=port,
user=user,
password=password,
db=db,
autocommit=False,
loop=self._loop,
minsize=minsize,
maxsize=maxsize,
echo=True,
)
self._keep = self._loop.create_task(self._keep_connection())
super().__init__(host, port, user, password, db, **kwargs)
self._connection = None
self._conn_present = asyncio.Event()
self._lock = asyncio.Lock()
self._done = asyncio.Event()
self._keep = asyncio.create_task(self._keep_connection())

async def connect(self):
await self._conn_present.wait()

async def _keep_connection(self):
async with self.engine.acquire() as conn:
async with self.engine.begin() as conn:
self._connection = conn
self._conn_present.set()
await self._done.wait()
await conn.rollback()
self._connection = None

def acquire(self):
return MockConnectionContext(self)

async def close(self):
if self.engine is None:
return

async with self._lock:
self._done.set()
await self._keep
self.engine.close()
await self.engine.wait_closed()
self.engine = None
await self.engine.dispose()

0 comments on commit fbfcd2a

Please sign in to comment.