diff --git a/dnstapir/key_resolver.py b/dnstapir/key_resolver.py index 577b9eb..807c214 100644 --- a/dnstapir/key_resolver.py +++ b/dnstapir/key_resolver.py @@ -76,7 +76,7 @@ class UrlKeyResolver(CacheKeyResolver): def __init__(self, client_database_base_url: str, key_cache: KeyCache | None = None): super().__init__(key_cache=key_cache) self.client_database_base_url = client_database_base_url - self.httpx_client = httpx.Client() + self._httpx_client: httpx.Client | None = None def get_public_key_pem(self, key_id: str) -> bytes: with tracer.start_as_current_span("get_public_key_pem_from_url"): @@ -87,3 +87,26 @@ def get_public_key_pem(self, key_id: str) -> bytes: return response.content except httpx.HTTPError as exc: raise KeyError(key_id) from exc + + @property + def httpx_client(self) -> httpx.Client: + if self._httpx_client is None: + self._httpx_client = httpx.Client() + return self._httpx_client + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __del__(self): + self.close() + + def close(self): + """Explicitly close the client and free resources.""" + if self._httpx_client is not None: + try: + self._httpx_client.close() + finally: + self._httpx_client = None diff --git a/tests/test_key_resolver.py b/tests/test_key_resolver.py index 656c7c6..48528d9 100644 --- a/tests/test_key_resolver.py +++ b/tests/test_key_resolver.py @@ -15,6 +15,19 @@ def test_url_key_resolver(httpx_mock: HTTPXMock): httpx_mock.add_response(url=f"https://keys/{key_id}.pem", content=public_key_pem) resolver = UrlKeyResolver(client_database_base_url="https://keys") - res = resolver.resolve_public_key(key_id) assert res == public_key + + +def test_url_key_resolver_contextlib(httpx_mock: HTTPXMock): + key_id = "xyzzy" + public_key = ed25519.Ed25519PrivateKey.generate().public_key() + public_key_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + httpx_mock.add_response(url=f"https://keys/{key_id}.pem", content=public_key_pem) + + with UrlKeyResolver(client_database_base_url="https://keys") as resolver: + res = resolver.resolve_public_key(key_id) + assert res == public_key