diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 7e0c2c4..ea637c1 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -24,9 +24,9 @@ jobs: uses: actions/checkout@v2 - name: Setup Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: '3.10' - name: Install dependencies run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b5ae95f..11e6687 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,9 +16,9 @@ jobs: - uses: actions/checkout@v2 - name: Setup Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: '3.10' - name: Install dependencies run: pip install isort @@ -46,7 +46,7 @@ jobs: MYSQL_PASSWORD: banana faf-rabbitmq: - image: rabbitmq:3.8-management-alpine + image: rabbitmq:3.9-alpine ports: - 5672:5672 options: >- @@ -65,9 +65,9 @@ jobs: - uses: actions/checkout@v2 - name: Setup Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: '3.10' - name: Setup RabbitMQ run: ./.github/workflows/scripts/init-rabbitmq.sh @@ -84,8 +84,9 @@ jobs: run: pipenv run tests --cov-report=xml - name: Report coverage - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v4 with: + token: ${{ secrets.CODECOV_TOKEN }} files: coverage.xml fail_ci_if_error: true diff --git a/Dockerfile b/Dockerfile index 717a88a..f0e011e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9-slim +FROM python:3.10-slim # Need git for installing aiomysql RUN apt-get update diff --git a/Pipfile b/Pipfile index 9c95ce1..639342a 100644 --- a/Pipfile +++ b/Pipfile @@ -11,25 +11,24 @@ verify_ssl = true [packages] trueskill = "*" -sqlalchemy = "==1.4.*" +sqlalchemy = "*" aiomysql = "*" aio_pika = "*" aiocron = "*" prometheus_client = "*" yoyo-migrations = "*" -pamqp = "==2.3.0" +pamqp = "*" python-dateutil = "*" [dev-packages] pytest = "*" -pytest-asyncio = "==0.12.0" +pytest-asyncio = "*" pytest-cov = "*" asynctest = "*" python-coveralls = "*" -mock = "*" pytest-mock = "*" vulture = "*" freezegun = "*" [requires] -python_version = "3.9" +python_version = "3.10" diff --git a/Pipfile.lock b/Pipfile.lock index da4d0ce..24e021e 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,11 +1,11 @@ { "_meta": { "hash": { - "sha256": "7dc52ec6a1536adbe15eb8e0587e48d72ef1c41d3fc66a0d7e545314794e0ab7" + "sha256": "c8dc2ed7a59ed9841196f0cd72591996d1f2b0689d51ecd1ef4862c93c991374" }, "pipfile-spec": 6, "requires": { - "python_version": "3.9" + "python_version": "3.10" }, "sources": [ { @@ -18,12 +18,12 @@ "default": { "aio-pika": { "hashes": [ - "sha256:4bf23e54bceb86b789d4b4a72ed65f2d83ede429d5f343de838ca72e54f00475", - "sha256:d89658148def0d8b8d795868a753fe2906f8d8fccee53e4a1b5093ddd3d2dc5c" + "sha256:2d4cec5a5d32005a71e31faaf5be58ee993fa29b8243a5d04b8f4a73df272fe8", + "sha256:d177404a27aa797a2ac935f6aa2bbb287fe1f0ddf5628ca20a237b7db3ea6324" ], "index": "pypi", - "markers": "python_version >= '3.5' and python_version < '4'", - "version": "==6.8.2" + "markers": "python_version >= '3.8' and python_version < '4.0'", + "version": "==9.4.1" }, "aiocron": { "hashes": [ @@ -44,11 +44,11 @@ }, "aiormq": { "hashes": [ - "sha256:8218dd9f7198d6e7935855468326bbacf0089f926c70baa8dd92944cb2496573", - "sha256:e584dac13a242589aaf42470fd3006cb0dc5aed6506cbd20357c7ec8bbe4a89e" + "sha256:198f9c7430feb7bc491016099a06266dc45880b6b1de3925d410fde6541a66fb", + "sha256:9a16174dcae4078c957a773d2f02d3dfd6c2fcf12c909dc244333a458f2aeab0" ], - "markers": "python_version >= '3.6'", - "version": "==3.3.1" + "markers": "python_version >= '3.8' and python_version < '4.0'", + "version": "==6.8.0" }, "croniter": { "hashes": [ @@ -119,7 +119,7 @@ "sha256:fd096eb7ffef17c456cfa587523c5f92321ae02427ff955bebe9e3c63bc9f0da", "sha256:fe754d231288e1e64323cfad462fcee8f0288654c10bdf4f603a39ed923bef33" ], - "markers": "python_version >= '3' and platform_machine == 'aarch64' or (platform_machine == 'ppc64le' or (platform_machine == 'x86_64' or (platform_machine == 'amd64' or (platform_machine == 'AMD64' or (platform_machine == 'win32' or platform_machine == 'WIN32')))))", + "markers": "platform_machine == 'aarch64' or (platform_machine == 'ppc64le' or (platform_machine == 'x86_64' or (platform_machine == 'amd64' or (platform_machine == 'AMD64' or (platform_machine == 'win32' or platform_machine == 'WIN32')))))", "version": "==3.0.3" }, "idna": { @@ -236,11 +236,12 @@ }, "pamqp": { "hashes": [ - "sha256:2f81b5c186f668a67f165193925b6bfd83db4363a6222f599517f29ecee60b02", - "sha256:5cd0f5a85e89f20d5f8e19285a1507788031cfca4a9ea6f067e3cf18f5e294e8" + "sha256:40b8795bd4efcf2b0f8821c1de83d12ca16d5760f4507836267fd7a02b06763b", + "sha256:c901a684794157ae39b52cbf700db8c9aae7a470f13528b9d7b4e5f7202f8eb0" ], "index": "pypi", - "version": "==2.3.0" + "markers": "python_version >= '3.7'", + "version": "==3.3.0" }, "prometheus-client": { "hashes": [ @@ -285,56 +286,59 @@ }, "sqlalchemy": { "hashes": [ - "sha256:1296f2cdd6db09b98ceb3c93025f0da4835303b8ac46c15c2136e27ee4d18d94", - "sha256:1e135fff2e84103bc15c07edd8569612ce317d64bdb391f49ce57124a73f45c5", - "sha256:1f8e1c6a6b7f8e9407ad9afc0ea41c1f65225ce505b79bc0342159de9c890782", - "sha256:24bb0f81fbbb13d737b7f76d1821ec0b117ce8cbb8ee5e8641ad2de41aa916d3", - "sha256:29d4247313abb2015f8979137fe65f4eaceead5247d39603cc4b4a610936cd2b", - "sha256:2c286fab42e49db23c46ab02479f328b8bdb837d3e281cae546cc4085c83b680", - "sha256:2f251af4c75a675ea42766880ff430ac33291c8d0057acca79710f9e5a77383d", - "sha256:346ed50cb2c30f5d7a03d888e25744154ceac6f0e6e1ab3bc7b5b77138d37710", - "sha256:3491c85df263a5c2157c594f54a1a9c72265b75d3777e61ee13c556d9e43ffc9", - "sha256:427988398d2902de042093d17f2b9619a5ebc605bf6372f7d70e29bde6736842", - "sha256:427c282dd0deba1f07bcbf499cbcc9fe9a626743f5d4989bfdfd3ed3513003dd", - "sha256:49e3772eb3380ac88d35495843daf3c03f094b713e66c7d017e322144a5c6b7c", - "sha256:4dae6001457d4497736e3bc422165f107ecdd70b0d651fab7f731276e8b9e12d", - "sha256:5b5de6af8852500d01398f5047d62ca3431d1e29a331d0b56c3e14cb03f8094c", - "sha256:5bbce5dd7c7735e01d24f5a60177f3e589078f83c8a29e124a6521b76d825b85", - "sha256:5bed4f8c3b69779de9d99eb03fd9ab67a850d74ab0243d1be9d4080e77b6af12", - "sha256:618827c1a1c243d2540314c6e100aee7af09a709bd005bae971686fab6723554", - "sha256:6ab773f9ad848118df7a9bbabca53e3f1002387cdbb6ee81693db808b82aaab0", - "sha256:6e41cb5cda641f3754568d2ed8962f772a7f2b59403b95c60c89f3e0bd25f15e", - "sha256:7027be7930a90d18a386b25ee8af30514c61f3852c7268899f23fdfbd3107181", - "sha256:763bd97c4ebc74136ecf3526b34808c58945023a59927b416acebcd68d1fc126", - "sha256:7d0dbc56cb6af5088f3658982d3d8c1d6a82691f31f7b0da682c7b98fa914e91", - "sha256:80e63bbdc5217dad3485059bdf6f65a7d43f33c8bde619df5c220edf03d87296", - "sha256:80e7f697bccc56ac6eac9e2df5c98b47de57e7006d2e46e1a3c17c546254f6ef", - "sha256:84e10772cfc333eb08d0b7ef808cd76e4a9a30a725fb62a0495877a57ee41d81", - "sha256:853fcfd1f54224ea7aabcf34b227d2b64a08cbac116ecf376907968b29b8e763", - "sha256:99224d621affbb3c1a4f72b631f8393045f4ce647dd3262f12fe3576918f8bf3", - "sha256:a251146b921725547ea1735b060a11e1be705017b568c9f8067ca61e6ef85f20", - "sha256:a551d5f3dc63f096ed41775ceec72fdf91462bb95abdc179010dc95a93957800", - "sha256:a5d2e08d79f5bf250afb4a61426b41026e448da446b55e4770c2afdc1e200fce", - "sha256:a752bff4796bf22803d052d4841ebc3c55c26fb65551f2c96e90ac7c62be763a", - "sha256:afb1672b57f58c0318ad2cff80b384e816735ffc7e848d8aa51e0b0fc2f4b7bb", - "sha256:bcdfb4b47fe04967669874fb1ce782a006756fdbebe7263f6a000e1db969120e", - "sha256:bdb7b4d889631a3b2a81a3347c4c3f031812eb4adeaa3ee4e6b0d028ad1852b5", - "sha256:c124912fd4e1bb9d1e7dc193ed482a9f812769cb1e69363ab68e01801e859821", - "sha256:c294ae4e6bbd060dd79e2bd5bba8b6274d08ffd65b58d106394cb6abbf35cf45", - "sha256:ca5ce82b11731492204cff8845c5e8ca1a4bd1ade85e3b8fcf86e7601bfc6a39", - "sha256:cb8f9e4c4718f111d7b530c4e6fb4d28f9f110eb82e7961412955b3875b66de0", - "sha256:d2de46f5d5396d5331127cfa71f837cca945f9a2b04f7cb5a01949cf676db7d1", - "sha256:d913f8953e098ca931ad7f58797f91deed26b435ec3756478b75c608aa80d139", - "sha256:de9acf369aaadb71a725b7e83a5ef40ca3de1cf4cdc93fa847df6b12d3cd924b", - "sha256:e93983cc0d2edae253b3f2141b0a3fb07e41c76cd79c2ad743fc27eb79c3f6db", - "sha256:f12aaf94f4d9679ca475975578739e12cc5b461172e04d66f7a3c39dd14ffc64", - "sha256:f68016f9a5713684c1507cc37133c28035f29925c75c0df2f9d0f7571e23720a", - "sha256:f7ea11727feb2861deaa293c7971a4df57ef1c90e42cb53f0da40c3468388000", - "sha256:f98dbb8fcc6d1c03ae8ec735d3c62110949a3b8bc6e215053aa27096857afb45" + "sha256:0094c5dc698a5f78d3d1539853e8ecec02516b62b8223c970c86d44e7a80f6c7", + "sha256:0138c5c16be3600923fa2169532205d18891b28afa817cb49b50e08f62198bb8", + "sha256:0a089e218654e740a41388893e090d2e2c22c29028c9d1353feb38638820bbeb", + "sha256:0b3f4c438e37d22b83e640f825ef0f37b95db9aa2d68203f2c9549375d0b2260", + "sha256:16863e2b132b761891d6c49f0a0f70030e0bcac4fd208117f6b7e053e68668d0", + "sha256:1f9a727312ff6ad5248a4367358e2cf7e625e98b1028b1d7ab7b806b7d757513", + "sha256:2383146973a15435e4717f94c7509982770e3e54974c71f76500a0136f22810b", + "sha256:2753743c2afd061bb95a61a51bbb6a1a11ac1c44292fad898f10c9839a7f75b2", + "sha256:296230899df0b77dec4eb799bcea6fbe39a43707ce7bb166519c97b583cfcab3", + "sha256:2a4f4da89c74435f2bc61878cd08f3646b699e7d2eba97144030d1be44e27584", + "sha256:2b1708916730f4830bc69d6f49d37f7698b5bd7530aca7f04f785f8849e95255", + "sha256:2ecabd9ccaa6e914e3dbb2aa46b76dede7eadc8cbf1b8083c94d936bcd5ffb49", + "sha256:311710f9a2ee235f1403537b10c7687214bb1f2b9ebb52702c5aa4a77f0b3af7", + "sha256:37a4b4fb0dd4d2669070fb05b8b8824afd0af57587393015baee1cf9890242d9", + "sha256:3a365eda439b7a00732638f11072907c1bc8e351c7665e7e5da91b169af794af", + "sha256:3b48154678e76445c7ded1896715ce05319f74b1e73cf82d4f8b59b46e9c0ddc", + "sha256:3b69e934f0f2b677ec111b4d83f92dc1a3210a779f69bf905273192cf4ed433e", + "sha256:3cb5a646930c5123f8461f6468901573f334c2c63c795b9af350063a736d0134", + "sha256:408f8b0e2c04677e9c93f40eef3ab22f550fecb3011b187f66a096395ff3d9fd", + "sha256:40ad017c672c00b9b663fcfcd5f0864a0a97828e2ee7ab0c140dc84058d194cf", + "sha256:5a79d65395ac5e6b0c2890935bad892eabb911c4aa8e8015067ddb37eea3d56c", + "sha256:5a8e3b0a7e09e94be7510d1661339d6b52daf202ed2f5b1f9f48ea34ee6f2d57", + "sha256:69c9db1ce00e59e8dd09d7bae852a9add716efdc070a3e2068377e6ff0d6fdaa", + "sha256:7108d569d3990c71e26a42f60474b4c02c8586c4681af5fd67e51a044fdea86a", + "sha256:77d2edb1f54aff37e3318f611637171e8ec71472f1fdc7348b41dcb226f93d90", + "sha256:7d74336c65705b986d12a7e337ba27ab2b9d819993851b140efdf029248e818e", + "sha256:8409de825f2c3b62ab15788635ccaec0c881c3f12a8af2b12ae4910a0a9aeef6", + "sha256:955991a09f0992c68a499791a753523f50f71a6885531568404fa0f231832aa0", + "sha256:99650e9f4cf3ad0d409fed3eec4f071fadd032e9a5edc7270cd646a26446feeb", + "sha256:9a5baf9267b752390252889f0c802ea13b52dfee5e369527da229189b8bd592e", + "sha256:a0ef36b28534f2a5771191be6edb44cc2673c7b2edf6deac6562400288664221", + "sha256:a1429a4b0f709f19ff3b0cf13675b2b9bfa8a7e79990003207a011c0db880a13", + "sha256:a7bfc726d167f425d4c16269a9a10fe8630ff6d14b683d588044dcef2d0f6be7", + "sha256:a943d297126c9230719c27fcbbeab57ecd5d15b0bd6bfd26e91bfcfe64220621", + "sha256:ae8c62fe2480dd61c532ccafdbce9b29dacc126fe8be0d9a927ca3e699b9491a", + "sha256:b60203c63e8f984df92035610c5fb76d941254cf5d19751faab7d33b21e5ddc0", + "sha256:b6bf767d14b77f6a18b6982cbbf29d71bede087edae495d11ab358280f304d8e", + "sha256:b6c7ec2b1f4969fc19b65b7059ed00497e25f54069407a8701091beb69e591a5", + "sha256:bba002a9447b291548e8d66fd8c96a6a7ed4f2def0bb155f4f0a1309fd2735d5", + "sha256:bc0c53579650a891f9b83fa3cecd4e00218e071d0ba00c4890f5be0c34887ed3", + "sha256:c4f61ada6979223013d9ab83a3ed003ded6959eae37d0d685db2c147e9143797", + "sha256:c62d401223f468eb4da32627bffc0c78ed516b03bb8a34a58be54d618b74d472", + "sha256:e42203d8d20dc704604862977b1470a122e4892791fe3ed165f041e4bf447a1b", + "sha256:edc16a50f5e1b7a06a2dcc1f2205b0b961074c123ed17ebda726f376a5ab0953", + "sha256:efedba7e13aa9a6c8407c48facfdfa108a5a4128e35f4c68f20c3407e4376aa9", + "sha256:f1dc3eabd8c0232ee8387fbe03e0a62220a6f089e278b1f0aaf5e2d6210741ad", + "sha256:f69e4c756ee2686767eb80f94c0125c8b0a0b87ede03eacc5c8ae3b54b99dc46", + "sha256:f7703c2010355dd28f53deb644a05fc30f796bd8598b43f0ba678878780b6e4c", + "sha256:fa561138a64f949f3e889eb9ab8c58e1504ab351d6cf55259dc4c248eaa19da6" ], "index": "pypi", - "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'", - "version": "==1.4.52" + "markers": "python_version >= '3.7'", + "version": "==2.0.30" }, "sqlparse": { "hashes": [ @@ -360,6 +364,14 @@ "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'", "version": "==0.4.5" }, + "typing-extensions": { + "hashes": [ + "sha256:6024b58b69089e5a89c347397254e35f1bf02a907728ec7fee9bf0fe837d203a", + "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1" + ], + "markers": "python_version >= '3.8'", + "version": "==4.12.1" + }, "tzlocal": { "hashes": [ "sha256:49816ef2fe65ea8ac19d19aa7a1ae0551c834303d5014c6d5a62e4cbda8047b8", @@ -689,15 +701,6 @@ "markers": "python_version >= '3.7'", "version": "==2.0.0" }, - "mock": { - "hashes": [ - "sha256:18c694e5ae8a208cdb3d2c20a993ca1a7b0efa258c247a1e565150f477f83744", - "sha256:5e96aad5ccda4718e0a229ed94b2024df75cc2d55575ba5762d31f5767b8767d" - ], - "index": "pypi", - "markers": "python_version >= '3.6'", - "version": "==5.1.0" - }, "packaging": { "hashes": [ "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5", @@ -716,20 +719,21 @@ }, "pytest": { "hashes": [ - "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd", - "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1" + "sha256:c434598117762e2bd304e526244f67bf66bbd7b5d6cf22138be51ff661980343", + "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977" ], "index": "pypi", "markers": "python_version >= '3.8'", - "version": "==8.2.1" + "version": "==8.2.2" }, "pytest-asyncio": { "hashes": [ - "sha256:475bd2f3dc0bc11d2463656b3cbaafdbec5a47b47508ea0b329ee693040eebd2" + "sha256:009b48127fbe44518a547bddd25611551b0e43ccdbf1e67d12479f569832c20b", + "sha256:5f5c72948f4c49e7db4f29f2521d4031f1c27f86e57b046126654083d4770268" ], "index": "pypi", - "markers": "python_version >= '3.5'", - "version": "==0.12.0" + "markers": "python_version >= '3.8'", + "version": "==0.23.7" }, "pytest-cov": { "hashes": [ diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..2f4c80e --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/scripts/local_rabbitmq.sh b/scripts/local_rabbitmq.sh index 323a257..d6a171b 100644 --- a/scripts/local_rabbitmq.sh +++ b/scripts/local_rabbitmq.sh @@ -9,7 +9,7 @@ RABBITMQ_LEAGUE_SERVICE_USER=faf-league-service RABBITMQ_LEAGUE_SERVICE_PASS=banana RABBITMQ_LEAGUE_SERVICE_VHOST=/faf-lobby -docker run -d -p 5672:5672 --restart unless-stopped --name faf-rabbitmq rabbitmq:3.8-management-alpine +docker run -d -p 5672:5672 --restart unless-stopped --name faf-rabbitmq rabbitmq:3.9-alpine # This doesn't seem to pick up the pid file docker exec faf-rabbitmq rabbitmqctl wait --timeout ${MAX_WAIT_SECONDS} "${RABBITMQ_PID_FILE}" diff --git a/service.py b/service.py index 0a9bacc..a5b9fc4 100644 --- a/service.py +++ b/service.py @@ -23,13 +23,11 @@ def signal_handler(sig: int, _frame): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) - database = db.FAFDatabase(loop) - await database.connect( + database = db.FAFDatabase( host=config.DB_SERVER, port=int(config.DB_PORT), user=config.DB_LOGIN, password=config.DB_PASSWORD, - maxsize=10, db=config.DB_NAME, ) logger.info("Database connected.") diff --git a/service/db/__init__.py b/service/db/__init__.py index ec0f2d9..60b2949 100644 --- a/service/db/__init__.py +++ b/service/db/__init__.py @@ -1,42 +1,66 @@ -from aiomysql.sa import create_engine +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import AsyncConnection as _AsyncConnection +from sqlalchemy.ext.asyncio import AsyncEngine as _AsyncEngine +from sqlalchemy.util import EMPTY_DICT class FAFDatabase: - def __init__(self, loop): - self._loop = loop - self.engine = None - - async def connect( + def __init__( self, - host="localhost", - port=3306, - user="root", - password="", - db="faf_test", - minsize=1, - maxsize=1, + host: str = "localhost", + port: int = 3306, + user: str = "root", + password: str = "", + db: str = "faf_test", + **kwargs ): - if self.engine is not None: - raise ValueError("DB is already connected!") - self.engine = await create_engine( - host=host, - port=port, - user=user, - password=password, - db=db, - autocommit=True, - loop=self._loop, - minsize=minsize, - maxsize=maxsize, + kwargs["future"] = True + sync_engine = create_engine( + f"mysql+aiomysql://{user}:{password}@{host}:{port}/{db}", + **kwargs ) + self.engine = AsyncEngine(sync_engine) + def acquire(self): - return self.engine.acquire() + return self.engine.begin() async def close(self): - if self.engine is None: - return + await self.engine.dispose() + + +class AsyncEngine(_AsyncEngine): + """ + For overriding the connection class used to execute statements. + + This could also be done by changing engine._connection_cls, however this + is undocumented and probably more fragile so we subclass instead. + """ + + def connect(self): + return AsyncConnection(self) - self.engine.close() - await self.engine.wait_closed() - self.engine = None + +class AsyncConnection(_AsyncConnection): + async def execute( + self, + statement, + parameters=None, + execution_options=EMPTY_DICT, + **kwargs + ): + """ + Wrap strings in the text type automatically and allows bindparams to be + passed via kwargs. + """ + if isinstance(statement, str): + statement = text(statement) + + if kwargs and parameters is None: + parameters = kwargs + + return await super().execute( + statement, + parameters=parameters, + execution_options=execution_options + ) diff --git a/service/league_service/league_service.py b/service/league_service/league_service.py index a272063..e34f1c7 100644 --- a/service/league_service/league_service.py +++ b/service/league_service/league_service.py @@ -1,7 +1,6 @@ import asyncio from collections import defaultdict from datetime import datetime -from typing import Dict import aiocron from aio_pika import IncomingMessage @@ -54,14 +53,11 @@ async def update_data(self): async with self._db.acquire() as conn: sql = ( select( - [ - league_season, - league, - league_season_division, - league_season_division_subdivision, - leaderboard, - ], - use_labels=True, + league_season, + league, + league_season_division, + league_season_division_subdivision, + leaderboard, ) .select_from( league_season_division_subdivision.outerjoin(league_season_division) @@ -72,7 +68,7 @@ async def update_data(self): .where(between(datetime.now(), league_season.c.start_date, league_season.c.end_date)) ) result = await conn.execute(sql) - division_rows = await result.fetchall() + division_rows = result.fetchall() # The concept of subdivisions exists only in the database and client, # but not in the rating service. We therefore treat every subdivision @@ -80,18 +76,18 @@ async def update_data(self): # (division_index, subdivision_index) indices. divisions_by_league = defaultdict(list) for row in division_rows: - divisions_by_league[row[league.c.technical_name]].append(row) + divisions_by_league[row.league_technical_name].append(row) self._leagues_by_rating_type = defaultdict(list) for league_name, division_list in divisions_by_league.items(): - rating_type = division_list[0][leaderboard.c.technical_name] - placement_games = division_list[0][league_season.c.placement_games] - placement_games_returning_player = division_list[0][league_season.c.placement_games_returning_player] + rating_type = division_list[0].leaderboard_technical_name + placement_games = division_list[0].placement_games + placement_games_returning_player = division_list[0].placement_games_returning_player division_list.sort( key=lambda row: ( - row[league_season_division.c.division_index], - row[league_season_division_subdivision.c.subdivision_index], - row[league_season_division.c.id], + row.division_index, + row.subdivision_index, + row.league_season_division_id, ) ) self._leagues_by_rating_type[rating_type].append( @@ -99,14 +95,14 @@ async def update_data(self): league_name, [ LeagueDivision( - row[league_season_division_subdivision.c.id], - row[league_season_division_subdivision.c.min_rating], - row[league_season_division_subdivision.c.max_rating], - row[league_season_division_subdivision.c.highest_score], + row.league_season_division_subdivision_id, + row.min_rating, + row.max_rating, + row.highest_score, ) for row in division_list ], - division_list[0][league_season.c.id], + division_list[0].league_season_id, placement_games, placement_games_returning_player, rating_type, @@ -149,29 +145,29 @@ async def _rate_single_league(self, league: League, request: LeagueRatingRequest async def _load_score(self, player_id: PlayerID, league: League) -> LeagueScore: async with self._db.acquire() as conn: - sql = select([league_season_score]).where( + sql = select(league_season_score).where( and_( league_season_score.c.login_id == player_id, league_season_score.c.league_season_id == league.current_season_id, ) ) result = await conn.execute(sql) - row = await result.fetchone() + row = result.fetchone() if row is None: returning_player = await self.is_returning_player(player_id, league.rating_type) return LeagueScore(None, None, 0, returning_player) return LeagueScore( - row[league_season_score.c.subdivision_id], - row[league_season_score.c.score], - row[league_season_score.c.game_count], - row[league_season_score.c.returning_player], + row.subdivision_id, + row.score, + row.game_count, + row.returning_player, ) async def is_returning_player(self, player_id: PlayerID, rating_type: str) -> bool: async with self._db.acquire() as conn: sql = ( - select([league_season_score]) + select(league_season_score) .select_from( league_season_score.outerjoin(league_season) .outerjoin(leaderboard)) @@ -183,8 +179,8 @@ async def is_returning_player(self, player_id: PlayerID, rating_type: str) -> bo ) ) result = await conn.execute(sql) - row = await result.fetchone() - if row is None or row[league_season_score.c.subdivision_id] is None: + row = result.fetchone() + if row is None or row.subdivision_id is None: return False else: return True @@ -204,7 +200,7 @@ async def _persist_score( raise InvalidScoreError("Missing score for non-null division.") select_season_id = ( - select([league_season_division.c.league_season_id]) + select(league_season_division.c.league_season_id) .select_from( league_season_division_subdivision.outerjoin( league_season_division @@ -215,8 +211,8 @@ async def _persist_score( ) ) result = await conn.execute(select_season_id) - row = await result.fetchone() - season_id_of_division = row.get("league_season_id") + row = result.fetchone() + season_id_of_division = row.league_season_id if season_id != season_id_of_division: raise InvalidScoreError("Division id did not match season id.") @@ -276,7 +272,7 @@ def handle_message(self, message: IncomingMessage): else: asyncio.create_task(self.enqueue(parsed_dict)) - async def enqueue(self, rating_change_message: Dict) -> None: + async def enqueue(self, rating_change_message: dict) -> None: if not self._accept_input: self._logger.warning("Dropped league request %s", rating_change_message) raise ServiceNotReadyError( diff --git a/service/league_service/typedefs.py b/service/league_service/typedefs.py index aed6525..f77c409 100644 --- a/service/league_service/typedefs.py +++ b/service/league_service/typedefs.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Callable, Dict, List, NamedTuple, NewType, Optional, Tuple +from typing import Callable, NamedTuple, NewType, Optional from ..decorators import with_logger @@ -25,7 +25,7 @@ class InvalidScoreError(LeagueServiceError): GameID = NewType("GameId", int) PlayerID = NewType("PlayerID", int) RatingType = str # e.g. "ladder_1v1" -Rating = Tuple[float, float] +Rating = tuple[float, float] class GameOutcome(Enum): @@ -47,7 +47,7 @@ class LeagueDivision(NamedTuple): @with_logger class League(NamedTuple): name: str - divisions: List[LeagueDivision] + divisions: list[LeagueDivision] current_season_id: int placement_games: int placement_games_returning_player: int @@ -123,7 +123,7 @@ class LeagueRatingRequest(NamedTuple): callback: Optional[Callable] @classmethod - def from_rating_change_dict(cls, message: Dict): + def from_rating_change_dict(cls, message: dict): return cls( message["game_id"], message["player_id"], diff --git a/service/message_queue_service.py b/service/message_queue_service.py index b503cbf..6670fc3 100644 --- a/service/message_queue_service.py +++ b/service/message_queue_service.py @@ -1,11 +1,10 @@ import asyncio import json -from typing import Dict import aio_pika from aio_pika import DeliveryMode, ExchangeType from aio_pika.exceptions import ProbableAuthenticationError -from pamqp import specification +from pamqp import commands from service import config from service.decorators import with_logger @@ -46,14 +45,14 @@ async def _connect(self) -> None: ), loop=asyncio.get_running_loop(), ) - except ConnectionError as e: - self._logger.warning("Unable to connect to RabbitMQ. Is it running?") - raise ConnectionAttemptFailed from e except ProbableAuthenticationError as e: self._logger.warning( "Unable to connect to RabbitMQ. Incorrect credentials?" ) raise ConnectionAttemptFailed from e + except ConnectionError as e: + self._logger.warning("Unable to connect to RabbitMQ. Is it running?") + raise ConnectionAttemptFailed from e except Exception as e: self._logger.warning( "Unable to connect to RabbitMQ due to unhandled excpetion %s. Incorrect vhost?", @@ -96,7 +95,7 @@ async def publish( self, exchange_name: str, routing: str, - payload: Dict, + payload: dict, delivery_mode: DeliveryMode = DeliveryMode.PERSISTENT, ) -> None: if self._connection is None: @@ -114,7 +113,7 @@ async def publish( ) confirmation = await exchange.publish(message, routing_key=routing) - if not isinstance(confirmation, specification.Basic.Ack): + if not isinstance(confirmation, commands.Basic.Ack): self._logger.warning( "Message could not be delivered to %s, received %s", routing, @@ -147,7 +146,7 @@ async def reconnect(self) -> None: ) -def message_to_dict(message: aio_pika.IncomingMessage) -> Dict: +def message_to_dict(message: aio_pika.IncomingMessage) -> dict: decoded_dict = json.loads(message.body.decode()) decoded_dict.update( { diff --git a/service/season_generator.py b/service/season_generator.py index d85d996..bcd395d 100644 --- a/service/season_generator.py +++ b/service/season_generator.py @@ -26,11 +26,11 @@ def initialize(self): async def check_season_end(self): self._logger.debug("Checking if latest season ends soon.") async with self._db.acquire() as conn: - sql = select([league_season]) + sql = select(league_season) result = await conn.execute(sql) - rows = await result.fetchall() + rows = result.fetchall() - max_date = max(row[league_season.c.end_date] for row in rows) + max_date = max(row.end_date for row in rows) if max_date < datetime.now() + timedelta(days=SEASON_GENERATION_DAYS_BEFORE_SEASON_END): try: @@ -43,9 +43,9 @@ async def check_season_end(self): async def generate_season(self): self._logger.info("Generating new season...") async with self._db.acquire() as conn: - sql = select([league]).where(league.c.enabled == True) + sql = select(league).where(league.c.enabled == True) result = await conn.execute(sql) - rows = await result.fetchall() + rows = result.fetchall() next_month = datetime.now() + relativedelta(months=1) # season starts and ends at noon, so that all timezones see the same date in the client @@ -57,31 +57,31 @@ async def generate_season(self): async def update_db(self, conn, league_row, start_date, end_date): season_sql = ( - select([league_season]) - .where(league_season.c.league_id == league_row[league.c.id]) + select(league_season) + .where(league_season.c.league_id == league_row.id) .order_by(desc(league_season.c.season_number)) .limit(1) ) result = await conn.execute(season_sql) - season_row = await result.fetchone() + season_row = result.fetchone() if season_row is None: self._logger.warning( "No season found for league %s. Skipping this league", - league_row[league.c.technical_name] + league_row.technical_name ) return - result = await conn.execute(select([func.max(league_season.c.id)])) - season_id = await result.scalar() + 1 - season_number = season_row[league_season.c.season_number] + 1 + result = await conn.execute(select(func.max(league_season.c.id))) + season_id = result.scalar() + 1 + season_number = season_row.season_number + 1 season_insert_sql = ( insert(league_season) .values( id=season_id, - league_id=season_row[league_season.c.league_id], - leaderboard_id=season_row[league_season.c.leaderboard_id], - placement_games=season_row[league_season.c.placement_games], + league_id=season_row.league_id, + leaderboard_id=season_row.leaderboard_id, + placement_games=season_row.placement_games, season_number=season_number, - name_key=season_row[league_season.c.name_key], + name_key=season_row.name_key, start_date=start_date, end_date=end_date, ) @@ -89,23 +89,23 @@ async def update_db(self, conn, league_row, start_date, end_date): await conn.execute(season_insert_sql) division_sql = ( - select([league_season_division]) - .where(league_season_division.c.league_season_id == season_row[league_season.c.id]) + select(league_season_division) + .where(league_season_division.c.league_season_id == season_row.id) ) result = await conn.execute(division_sql) - season_division_rows = await result.fetchall() + season_division_rows = result.fetchall() if not season_division_rows: self._logger.warning( "No divisions found for season id %s. No divisions could be created. " "Now season id %s has no divisions as well. This needs to be fixed manually", - season_row[league_season.c.id], + season_row.id, season_id ) return - result = await conn.execute(select([func.max(league_season_division.c.id)])) - division_id = await result.scalar() + result = await conn.execute(select(func.max(league_season_division.c.id))) + division_id = result.scalar() for division_row in season_division_rows: - division_index = division_row[league_season_division.c.division_index] + division_index = division_row.division_index division_id += 1 division_insert_sql = ( insert(league_season_division) @@ -114,43 +114,43 @@ async def update_db(self, conn, league_row, start_date, end_date): league_season_id=season_id, division_index=division_index, description_key=( - f"{season_row[league_season.c.name_key]}_{season_number}.division.{division_index}" + f"{season_row.name_key}_{season_number}.division.{division_index}" ), - name_key=division_row[league_season_division.c.name_key], + name_key=division_row.name_key, ) ) await conn.execute(division_insert_sql) subdivision_sql = ( - select([league_season_division_subdivision]) + select(league_season_division_subdivision) .where(league_season_division_subdivision.c.league_season_division_id == - division_row[league_season_division.c.id]) + division_row.id) ) result = await conn.execute(subdivision_sql) - subdivision_rows = await result.fetchall() + subdivision_rows = result.fetchall() if not subdivision_rows: self._logger.warning( "No subdivisions found for division id %s. No subdivisions could be created. " "Now division id %s has no subdivisions as well. This needs to be fixed manually", - division_row[league_season_division.c.id], + division_row.id, division_id ) return for subdivision_row in subdivision_rows: - subdivision_index = subdivision_row[league_season_division_subdivision.c.subdivision_index] + subdivision_index = subdivision_row.subdivision_index subdivision_insert_sql = ( insert(league_season_division_subdivision) .values( league_season_division_id=division_id, subdivision_index=subdivision_index, description_key=( - f"{season_row[league_season.c.name_key]}_{season_number}" + f"{season_row.name_key}_{season_number}" f".subdivision.{division_index}.{subdivision_index}" ), - name_key=subdivision_row[league_season_division_subdivision.c.name_key], - min_rating=subdivision_row[league_season_division_subdivision.c.min_rating], - max_rating=subdivision_row[league_season_division_subdivision.c.max_rating], - highest_score=subdivision_row[league_season_division_subdivision.c.highest_score], + name_key=subdivision_row.name_key, + min_rating=subdivision_row.min_rating, + max_rating=subdivision_row.max_rating, + highest_score=subdivision_row.highest_score, ) ) await conn.execute(subdivision_insert_sql) diff --git a/tests/conftest.py b/tests/conftest.py index 9fa3f6e..de7de85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,12 +73,12 @@ async def test_data(request): with open("tests/data/test-data.sql") as f: async with db.acquire() as conn: - await conn.execute(f.read()) + await conn.execute(f.read().replace(":", r"\:")) await db.close() -async def global_database(request): +async def global_database(request) -> FAFDatabase: def opt(val): return request.config.getoption(val) @@ -89,11 +89,13 @@ def opt(val): opt("--mysql_database"), opt("--mysql_port"), ) - db = FAFDatabase(asyncio.get_running_loop()) - - await db.connect(host=host, user=user, password=pw or None, port=port, db=name) - - return db + return FAFDatabase( + host=host, + user=user, + password=pw or "", + port=port, + db=name, + ) @pytest.fixture @@ -108,9 +110,14 @@ def opt(val): opt("--mysql_database"), opt("--mysql_port"), ) - db = MockDatabase(event_loop) - - await db.connect(host=host, user=user, password=pw or None, port=port, db=name) + db = MockDatabase( + host=host, + user=user, + password=pw or "", + port=port, + db=name, + ) + await db.connect() yield db diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 982b7af..0457763 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -1,5 +1,6 @@ +from unittest import mock + import aio_pika -import mock import pytest from service import config diff --git a/tests/integration_tests/test_league_request_rating.py b/tests/integration_tests/test_league_request_rating.py index b01a76c..e68e1ae 100644 --- a/tests/integration_tests/test_league_request_rating.py +++ b/tests/integration_tests/test_league_request_rating.py @@ -12,9 +12,6 @@ async def league_service(database, message_queue_service): service.kill() -pytestmark = pytest.mark.asyncio - - async def test_rate_new_player(league_service): new_player_id = 50 rating_type = "global" diff --git a/tests/integration_tests/test_message_queue_service.py b/tests/integration_tests/test_message_queue_service.py index 37943d9..0d3ba89 100644 --- a/tests/integration_tests/test_message_queue_service.py +++ b/tests/integration_tests/test_message_queue_service.py @@ -8,8 +8,6 @@ MessageQueueService, message_to_dict) -pytestmark = pytest.mark.asyncio - @pytest.fixture async def mq_service(): @@ -118,7 +116,7 @@ async def test_parse_incoming_message(mq_service, consumer): await asyncio.sleep(0.1) - received_message = consumer.received_messages[0] + received_message = consumer.received_messages[-1] parsed_message = message_to_dict(received_message) for key, value in payload.items(): diff --git a/tests/unit_tests/test_league_service.py b/tests/unit_tests/test_league_service.py index 05c3988..c12caf1 100644 --- a/tests/unit_tests/test_league_service.py +++ b/tests/unit_tests/test_league_service.py @@ -1,5 +1,6 @@ +from unittest import mock + import pytest -from asynctest import CoroutineMock from sqlalchemy import select from service.db.models import league_score_journal @@ -8,8 +9,6 @@ from service.league_service.typedefs import (InvalidScoreError, League, LeagueScore) -pytestmark = pytest.mark.asyncio - @pytest.fixture async def league_service(database, message_queue_service): @@ -41,7 +40,7 @@ async def test_enqueue_manual_initialization( ): service = uninitialized_service await service.initialize() - service._rate_single_league = CoroutineMock() + service._rate_single_league = mock.AsyncMock() await service.enqueue(rating_change_message) await service.shutdown() @@ -58,7 +57,7 @@ async def test_double_initialization_does_not_start_second_worker(league_service async def test_enqueue_initialized(league_service, rating_change_message): service = league_service - service._rate_single_league = CoroutineMock() + service._rate_single_league = mock.AsyncMock() await service.enqueue(rating_change_message) await service.shutdown() @@ -184,18 +183,18 @@ async def test_persist_score_new_player(league_service, database): assert loaded_score == new_score async with database.acquire() as conn: - result = await conn.execute(select([league_score_journal])) - rows = await result.fetchall() + result = await conn.execute(select(league_score_journal)) + rows = result.fetchall() assert len(rows) == 1 for row in rows: - assert row["game_id"] == 1 - assert row["login_id"] == 5 - assert row["league_season_id"] == 2 - assert row["subdivision_id_before"] == 3 - assert row["subdivision_id_after"] == 3 - assert row["score_before"] == 6 - assert row["score_after"] == 5 - assert row["game_count"] == 43 + assert row.game_id == 1 + assert row.login_id == 5 + assert row.league_season_id == 2 + assert row.subdivision_id_before == 3 + assert row.subdivision_id_after == 3 + assert row.score_before == 6 + assert row.score_after == 5 + assert row.game_count == 43 async def test_persist_score_old_player(league_service, database): @@ -216,18 +215,18 @@ async def test_persist_score_old_player(league_service, database): assert loaded_score == new_score async with database.acquire() as conn: - result = await conn.execute(select([league_score_journal])) - rows = await result.fetchall() + result = await conn.execute(select(league_score_journal)) + rows = result.fetchall() assert len(rows) == 1 for row in rows: - assert row["game_id"] == 10 - assert row["login_id"] == 1 - assert row["league_season_id"] == 2 - assert row["subdivision_id_before"] == 3 - assert row["subdivision_id_after"] == 3 - assert row["score_before"] == 6 - assert row["score_after"] == 5 - assert row["game_count"] == 43 + assert row.game_id == 10 + assert row.login_id == 1 + assert row.league_season_id == 2 + assert row.subdivision_id_before == 3 + assert row.subdivision_id_after == 3 + assert row.score_before == 6 + assert row.score_after == 5 + assert row.game_count == 43 async def test_persist_score_season_id_mismatch(league_service): diff --git a/tests/unit_tests/test_message_queue.py b/tests/unit_tests/test_message_queue.py index 5643a9a..3e9e155 100644 --- a/tests/unit_tests/test_message_queue.py +++ b/tests/unit_tests/test_message_queue.py @@ -4,8 +4,6 @@ from service.message_queue_service import (ConnectionAttemptFailed, MessageQueueService) -pytestmark = pytest.mark.asyncio - @pytest.fixture async def mq_service(): diff --git a/tests/unit_tests/test_season_generator.py b/tests/unit_tests/test_season_generator.py index 89fd7bb..72b9135 100644 --- a/tests/unit_tests/test_season_generator.py +++ b/tests/unit_tests/test_season_generator.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta +from unittest import mock -import mock import pytest from freezegun import freeze_time from sqlalchemy import select @@ -9,8 +9,6 @@ league_season_division_subdivision) from service.season_generator import SeasonGenerator -pytestmark = pytest.mark.asyncio - @pytest.fixture def season_generator(database): @@ -43,43 +41,42 @@ async def test_season_check_after_season_end(season_generator): async def test_generate_season(season_generator, database): await season_generator.generate_season() async with database.acquire() as conn: - seasons = await conn.execute(select([league_season])) - rows = await seasons.fetchall() + seasons = await conn.execute(select(league_season)) + rows = seasons.fetchall() assert len(rows) == 6 - assert max(row[league_season.c.season_number] for row in rows) == 4 + assert max(row.season_number for row in rows) == 4 divisions = await conn.execute(select(league_season_division)) - rows = await divisions.fetchall() + rows = divisions.fetchall() assert len(rows) == 9 new_division_one = await conn.execute(select(league_season_division).where(league_season_division.c.id == 8)) - row = await new_division_one.fetchone() - assert row[league_season_division.c.league_season_id] == 6 - assert row[league_season_division.c.division_index] == 1 - assert row[league_season_division.c.name_key] == "L3D1" - assert row[league_season_division.c.description_key] == "second_test_league.season.1_2.division.1" + row = new_division_one.fetchone() + assert row.league_season_id == 6 + assert row.division_index == 1 + assert row.name_key == "L3D1" + assert row.description_key == "second_test_league.season.1_2.division.1" subdivisions = await conn.execute(select(league_season_division_subdivision)) - rows = await subdivisions.fetchall() + rows = subdivisions.fetchall() assert len(rows) == 10 new_subdivision = await conn.execute( select(league_season_division_subdivision) .where(league_season_division_subdivision.c.league_season_division_id == 8) ) - row = await new_subdivision.fetchone() - assert row[league_season_division_subdivision.c.subdivision_index] == 1 - assert row[league_season_division_subdivision.c.name_key] == "L3D1S1" - assert (row[league_season_division_subdivision.c.description_key] == - "second_test_league.season.1_2.subdivision.1.1") - assert row[league_season_division_subdivision.c.min_rating] == 0 - assert row[league_season_division_subdivision.c.max_rating] == 3000 - assert row[league_season_division_subdivision.c.highest_score] == 20 + row = new_subdivision.fetchone() + assert row.subdivision_index == 1 + assert row.name_key == "L3D1S1" + assert row.description_key == "second_test_league.season.1_2.subdivision.1.1" + assert row.min_rating == 0 + assert row.max_rating == 3000 + assert row.highest_score == 20 async def test_generate_season_only_once(season_generator, database): await season_generator.check_season_end() await season_generator.check_season_end() async with database.acquire() as conn: - seasons = await conn.execute(select([league_season])) - rows = await seasons.fetchall() + seasons = await conn.execute(select(league_season)) + rows = seasons.fetchall() assert len(rows) == 6 - assert max(row[league_season.c.season_number] for row in rows) == 4 + assert max(row.season_number for row in rows) == 4 diff --git a/tests/utils.py b/tests/utils.py index b0fc814..beb4f22 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,9 +1,9 @@ import asyncio import functools -from asyncio import Event, Lock import asynctest -from aiomysql.sa import create_engine + +from service.db import FAFDatabase # Copied over from PR #113 of pytest-asyncio. It will probably be available in @@ -84,7 +84,7 @@ async def __aexit__(self, exc_type, exc, tb): self._db._lock.release() -class MockDatabase: +class MockDatabase(FAFDatabase): """ This class mocks the FAFDatabase class, rolling back all transactions performed during tests. To do that, it proxies the real db engine, giving @@ -96,59 +96,38 @@ class MockDatabase: Any future manual commit() calls should be mocked here as well. """ - def __init__(self, loop): - self._loop = loop - self.engine = None - self._connection = None - self._conn_present = Event() - self._keep = None - self._lock = Lock() - self._done = Event() - - async def connect( + def __init__( self, - host="localhost", - port=3306, - user="root", - password="", - db="faf_test", - minsize=1, - maxsize=1, + host: str = "localhost", + port: int = 3306, + user: str = "root", + password: str = "", + db: str = "faf_test", + **kwargs, ): - if self.engine is not None: - raise ValueError("DB is already connected!") - self.engine = await create_engine( - host=host, - port=port, - user=user, - password=password, - db=db, - autocommit=False, - loop=self._loop, - minsize=minsize, - maxsize=maxsize, - echo=True, - ) - self._keep = self._loop.create_task(self._keep_connection()) + super().__init__(host, port, user, password, db, **kwargs) + self._connection = None + self._conn_present = asyncio.Event() + self._lock = asyncio.Lock() + self._done = asyncio.Event() + self._keep = asyncio.create_task(self._keep_connection()) + + async def connect(self): await self._conn_present.wait() async def _keep_connection(self): - async with self.engine.acquire() as conn: + async with self.engine.begin() as conn: self._connection = conn self._conn_present.set() await self._done.wait() + await conn.rollback() self._connection = None def acquire(self): return MockConnectionContext(self) async def close(self): - if self.engine is None: - return - async with self._lock: self._done.set() await self._keep - self.engine.close() - await self.engine.wait_closed() - self.engine = None + await self.engine.dispose()