From 71b6a621ea3b36e4bc934693119d103621dbbbe7 Mon Sep 17 00:00:00 2001 From: Jakob Schlyter Date: Mon, 13 Jan 2025 15:13:15 +0100 Subject: [PATCH] Allow node search with tags (#64) --- nodeman/nodes.py | 14 +++++++++++--- tests/test_api.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/nodeman/nodes.py b/nodeman/nodes.py index c18fc3e..6f0196b 100644 --- a/nodeman/nodes.py +++ b/nodeman/nodes.py @@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status from jwcrypto.jwk import JWK from jwcrypto.jws import JWS, InvalidJWSSignature +from mongoengine import Q from opentelemetry import metrics, trace from pydantic_core import ValidationError @@ -205,11 +206,18 @@ def get_node_information(name: str, username: Annotated[str, Depends(get_current status.HTTP_200_OK: {"model": NodeCollection}, }, tags=["backend"], + response_model_exclude_none=False, ) -def get_all_nodes(username: Annotated[str, Depends(get_current_username)]) -> NodeCollection: +def get_all_nodes(username: Annotated[str, Depends(get_current_username)], tags: str | None = None) -> NodeCollection: """Get all nodes""" - logging.info("%s queried for all nodes", username, extra={"username": username}) - return NodeCollection(nodes=[NodeInformation.from_db_model(node) for node in TapirNode.objects(deleted=None)]) + query = Q(deleted=None) + if tags: + query_tags = sorted(set(tags.split(","))) + logging.info("%s queried for nodes with tags %s", username, query_tags, extra={"username": username}) + query &= Q(tags__all=sorted(set(query_tags))) + else: + logging.info("%s queried for all nodes", username, extra={"username": username}) + return NodeCollection(nodes=[NodeInformation.from_db_model(node) for node in TapirNode.objects(query)]) @router.get( diff --git a/tests/test_api.py b/tests/test_api.py index d1c9376..85eb7c7 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -26,6 +26,7 @@ from nodeman.x509 import RSA_EXPONENT, CertificateAuthorityClient, generate_ca_certificate, generate_x509_csr ADMIN_TEST_NODE_COUNT = 100 +ADMIN_TEST_NODE_COUNT_TAGS = 10 BACKEND_CREDENTIALS = ("username", "password") PrivateKey = ec.EllipticCurvePrivateKey | rsa.RSAPublicKey | Ed25519PrivateKey | Ed448PrivateKey @@ -407,6 +408,35 @@ def test_admin() -> None: assert response.status_code == status.HTTP_204_NO_CONTENT +def test_admin_tags() -> None: + client = get_test_client() + client.auth = BACKEND_CREDENTIALS + + server = "" + + for node_number in range(ADMIN_TEST_NODE_COUNT_TAGS): + tags = ["odd"] if node_number % 2 else ["even"] + if node_number == 0: + tags.append("zero") + response = client.post(urljoin(server, "/api/v1/node"), json={"tags": tags}) + assert response.status_code == status.HTTP_201_CREATED + + # half of the nodes should have tag even + response = client.get(urljoin(server, "/api/v1/nodes"), params={"tags": "even"}) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["nodes"]) == ADMIN_TEST_NODE_COUNT_TAGS // 2 + + # exactly one node should have tags even & zero + response = client.get(urljoin(server, "/api/v1/nodes"), params={"tags": "even,zero"}) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["nodes"]) == 1 + + # no nodes should have both tags even & odd + response = client.get(urljoin(server, "/api/v1/nodes"), params={"tags": "even,odd"}) + assert response.status_code == status.HTTP_200_OK + assert len(response.json()["nodes"]) == 0 + + def test_backend_authentication() -> None: client = get_test_client() server = ""