diff --git a/microbootstrap/bootstrappers/fastapi.py b/microbootstrap/bootstrappers/fastapi.py index a3e8cd9..b4df53b 100644 --- a/microbootstrap/bootstrappers/fastapi.py +++ b/microbootstrap/bootstrappers/fastapi.py @@ -1,3 +1,4 @@ +import contextlib import typing import fastapi @@ -29,11 +30,32 @@ class FastApiBootstrapper( application_config = FastApiConfig() application_type = fastapi.FastAPI + @contextlib.asynccontextmanager + async def _lifespan_manager(self, _: fastapi.FastAPI) -> typing.AsyncIterator[dict[str, typing.Any]]: + try: + self.console_writer.print_bootstrap_table() + yield {} + finally: + self.teardown() + + @contextlib.asynccontextmanager + async def _wrapped_lifespan_manager(self, app: fastapi.FastAPI) -> typing.AsyncIterator[dict[str, typing.Any]]: + assert self.application_config.lifespan # noqa: S101 + async with self._lifespan_manager(app), self.application_config.lifespan(app): + yield {} + + def _choose_lifespan_manager( + self, + ) -> typing.Callable[[fastapi.FastAPI], typing.AsyncContextManager[dict[str, typing.Any]]]: + if self.application_config.lifespan: + return self._wrapped_lifespan_manager + + return self._lifespan_manager + def bootstrap_before(self) -> dict[str, typing.Any]: return { "debug": self.settings.service_debug, - "on_shutdown": [self.teardown], - "on_startup": [self.console_writer.print_bootstrap_table], + "lifespan": self._choose_lifespan_manager(), } diff --git a/tests/bootstrappers/test_fastapi.py b/tests/bootstrappers/test_fastapi.py index 121d1e2..902ea6e 100644 --- a/tests/bootstrappers/test_fastapi.py +++ b/tests/bootstrappers/test_fastapi.py @@ -49,14 +49,9 @@ def test_fastapi_configure_application() -> None: assert application.title == test_title -def test_fastapi_configure_application_add_startup_event(magic_mock: MagicMock) -> None: - def test_startup() -> None: - magic_mock() - +def test_fastapi_configure_application_lifespan(magic_mock: MagicMock) -> None: application: typing.Final = ( - FastApiBootstrapper(FastApiSettings()) - .configure_application(FastApiConfig(on_startup=[test_startup])) - .bootstrap() + FastApiBootstrapper(FastApiSettings()).configure_application(FastApiConfig(lifespan=magic_mock)).bootstrap() ) with TestClient(app=application):