diff --git a/service.py b/service.py index 0a9bacc..f4e1ea0 100644 --- a/service.py +++ b/service.py @@ -23,14 +23,12 @@ def signal_handler(sig: int, _frame): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) - database = db.FAFDatabase(loop) - await database.connect( + database = db.FAFDatabase( host=config.DB_SERVER, port=int(config.DB_PORT), user=config.DB_LOGIN, password=config.DB_PASSWORD, - maxsize=10, - db=config.DB_NAME, + db=config.DB_NAME ) logger.info("Database connected.") diff --git a/service/db/__init__.py b/service/db/__init__.py index ec0f2d9..67269c5 100644 --- a/service/db/__init__.py +++ b/service/db/__init__.py @@ -1,42 +1,27 @@ -from aiomysql.sa import create_engine +from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import AsyncEngine class FAFDatabase: - def __init__(self, loop): - self._loop = loop - self.engine = None - - async def connect( + def __init__( self, - host="localhost", - port=3306, - user="root", - password="", - db="faf_test", - minsize=1, - maxsize=1, + host: str = "localhost", + port: int = 3306, + user: str = "root", + password: str = "", + db: str = "faf_test", + **kwargs ): - if self.engine is not None: - raise ValueError("DB is already connected!") - self.engine = await create_engine( - host=host, - port=port, - user=user, - password=password, - db=db, - autocommit=True, - loop=self._loop, - minsize=minsize, - maxsize=maxsize, + kwargs["future"] = True + sync_engine = create_engine( + f"mysql+aiomysql://{user}:{password}@{host}:{port}/{db}", + **kwargs ) + self.engine = AsyncEngine(sync_engine) + def acquire(self): - return self.engine.acquire() + return self.engine.begin() async def close(self): - if self.engine is None: - return - - self.engine.close() - await self.engine.wait_closed() - self.engine = None + await self.engine.dispose() \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 9fa3f6e..9c0660e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -78,7 +78,7 @@ async def test_data(request): await db.close() -async def global_database(request): +async def global_database(request) -> FAFDatabase: def opt(val): return request.config.getoption(val) @@ -89,11 +89,13 @@ def opt(val): opt("--mysql_database"), opt("--mysql_port"), ) - db = FAFDatabase(asyncio.get_running_loop()) - - await db.connect(host=host, user=user, password=pw or None, port=port, db=name) - - return db + return FAFDatabase( + host=host, + user=user, + password=pw or "", + port=port, + db=name + ) @pytest.fixture @@ -108,9 +110,14 @@ def opt(val): opt("--mysql_database"), opt("--mysql_port"), ) - db = MockDatabase(event_loop) - - await db.connect(host=host, user=user, password=pw or None, port=port, db=name) + db = MockDatabase( + host=host, + user=user, + password=pw or "", + port=port, + db=name + ) + await db.connect() yield db diff --git a/tests/utils.py b/tests/utils.py index b0fc814..d3c3ffb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,9 +1,9 @@ import asyncio import functools -from asyncio import Event, Lock import asynctest -from aiomysql.sa import create_engine + +from service.db import FAFDatabase # Copied over from PR #113 of pytest-asyncio. It will probably be available in @@ -84,7 +84,7 @@ async def __aexit__(self, exc_type, exc, tb): self._db._lock.release() -class MockDatabase: +class MockDatabase(FAFDatabase): """ This class mocks the FAFDatabase class, rolling back all transactions performed during tests. To do that, it proxies the real db engine, giving @@ -96,59 +96,38 @@ class MockDatabase: Any future manual commit() calls should be mocked here as well. """ - def __init__(self, loop): - self._loop = loop - self.engine = None - self._connection = None - self._conn_present = Event() - self._keep = None - self._lock = Lock() - self._done = Event() - - async def connect( + def __init__( self, - host="localhost", - port=3306, - user="root", - password="", - db="faf_test", - minsize=1, - maxsize=1, + host: str = "localhost", + port: int = 3306, + user: str = "root", + password: str = "", + db: str = "faf_test", + **kwargs ): - if self.engine is not None: - raise ValueError("DB is already connected!") - self.engine = await create_engine( - host=host, - port=port, - user=user, - password=password, - db=db, - autocommit=False, - loop=self._loop, - minsize=minsize, - maxsize=maxsize, - echo=True, - ) - self._keep = self._loop.create_task(self._keep_connection()) + super().__init__(host, port, user, password, db, **kwargs) + self._connection = None + self._conn_present = asyncio.Event() + self._lock = asyncio.Lock() + self._done = asyncio.Event() + self._keep = asyncio.create_task(self._keep_connection()) + + async def connect(self): await self._conn_present.wait() async def _keep_connection(self): - async with self.engine.acquire() as conn: + async with self.engine.begin() as conn: self._connection = conn self._conn_present.set() await self._done.wait() + await conn.rollback() self._connection = None def acquire(self): return MockConnectionContext(self) async def close(self): - if self.engine is None: - return - async with self._lock: self._done.set() await self._keep - self.engine.close() - await self.engine.wait_closed() - self.engine = None + await self.engine.dispose()