diff --git a/anta/cli/nrfu/__init__.py b/anta/cli/nrfu/__init__.py index 6a67e609f..27f8588e7 100644 --- a/anta/cli/nrfu/__init__.py +++ b/anta/cli/nrfu/__init__.py @@ -5,7 +5,6 @@ from __future__ import annotations -import asyncio from typing import TYPE_CHECKING, get_args import click @@ -13,11 +12,7 @@ from anta.cli.nrfu import commands from anta.cli.utils import AliasedGroup, catalog_options, inventory_options from anta.custom_types import TestStatus -from anta.models import AntaTest from anta.result_manager import ResultManager -from anta.runner import main - -from .utils import anta_progress_bar, print_settings if TYPE_CHECKING: from anta.catalog import AntaCatalog @@ -37,6 +32,7 @@ def parse_args(self, ctx: click.Context, args: list[str]) -> list[str]: """Ignore MissingParameter exception when parsing arguments if `--help` is present for a subcommand.""" # Adding a flag for potential callbacks ctx.ensure_object(dict) + ctx.obj["args"] = args if "--help" in args: ctx.obj["_anta_help"] = True @@ -125,29 +121,22 @@ def nrfu( # If help is invoke somewhere, skip the command if ctx.obj.get("_anta_help"): return + # We use ctx.obj to pass stuff to the next Click functions ctx.ensure_object(dict) ctx.obj["result_manager"] = ResultManager() ctx.obj["ignore_status"] = ignore_status ctx.obj["ignore_error"] = ignore_error ctx.obj["hide"] = set(hide) if hide else None - print_settings(inventory, catalog) - with anta_progress_bar() as AntaTest.progress: - asyncio.run( - main( - ctx.obj["result_manager"], - inventory, - catalog, - tags=tags, - devices=set(device) if device else None, - tests=set(test) if test else None, - dry_run=dry_run, - ) - ) - if dry_run: - return + ctx.obj["catalog"] = catalog + ctx.obj["inventory"] = inventory + ctx.obj["tags"] = tags + ctx.obj["device"] = device + ctx.obj["test"] = test + ctx.obj["dry_run"] = dry_run + # Invoke `anta nrfu table` if no command is passed - if ctx.invoked_subcommand is None: + if not ctx.invoked_subcommand: ctx.invoke(commands.table) diff --git a/anta/cli/nrfu/commands.py b/anta/cli/nrfu/commands.py index 4dd779b41..7581116c6 100644 --- a/anta/cli/nrfu/commands.py +++ b/anta/cli/nrfu/commands.py @@ -13,7 +13,7 @@ from anta.cli.utils import exit_with_code -from .utils import print_jinja, print_json, print_table, print_text +from .utils import print_jinja, print_json, print_table, print_text, run_tests logger = logging.getLogger(__name__) @@ -32,6 +32,7 @@ def table( group_by: Literal["device", "test"] | None, ) -> None: """ANTA command to check network states with table result.""" + run_tests(ctx) print_table(ctx, group_by=group_by) exit_with_code(ctx) @@ -48,6 +49,7 @@ def table( ) def json(ctx: click.Context, output: pathlib.Path | None) -> None: """ANTA command to check network state with JSON result.""" + run_tests(ctx) print_json(ctx, output=output) exit_with_code(ctx) @@ -56,6 +58,7 @@ def json(ctx: click.Context, output: pathlib.Path | None) -> None: @click.pass_context def text(ctx: click.Context) -> None: """ANTA command to check network states with text result.""" + run_tests(ctx) print_text(ctx) exit_with_code(ctx) @@ -80,5 +83,6 @@ def text(ctx: click.Context) -> None: ) def tpl_report(ctx: click.Context, template: pathlib.Path, output: pathlib.Path | None) -> None: """ANTA command to check network state with templated report.""" + run_tests(ctx) print_jinja(results=ctx.obj["result_manager"], template=template, output=output) exit_with_code(ctx) diff --git a/anta/cli/nrfu/utils.py b/anta/cli/nrfu/utils.py index 2eeeacb76..d4cd1317d 100644 --- a/anta/cli/nrfu/utils.py +++ b/anta/cli/nrfu/utils.py @@ -5,6 +5,7 @@ from __future__ import annotations +import asyncio import json import logging from typing import TYPE_CHECKING, Literal @@ -14,7 +15,9 @@ from rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn from anta.cli.console import console +from anta.models import AntaTest from anta.reporter import ReportJinja, ReportTable +from anta.runner import main if TYPE_CHECKING: import pathlib @@ -28,6 +31,37 @@ logger = logging.getLogger(__name__) +def run_tests(ctx: click.Context) -> None: + """Run the tests.""" + # Digging up the parameters from the parent context + if ctx.parent is None: + ctx.exit() + nrfu_ctx_params = ctx.parent.params + tags = nrfu_ctx_params["tags"] + device = nrfu_ctx_params["device"] or None + test = nrfu_ctx_params["test"] or None + dry_run = nrfu_ctx_params["dry_run"] + + catalog = ctx.obj["catalog"] + inventory = ctx.obj["inventory"] + + print_settings(inventory, catalog) + with anta_progress_bar() as AntaTest.progress: + asyncio.run( + main( + ctx.obj["result_manager"], + inventory, + catalog, + tags=tags, + devices=set(device) if device else None, + tests=set(test) if test else None, + dry_run=dry_run, + ) + ) + if dry_run: + ctx.exit() + + def _get_result_manager(ctx: click.Context) -> ResultManager: """Get a ResultManager instance based on Click context.""" return ctx.obj["result_manager"].filter(ctx.obj.get("hide")) if ctx.obj.get("hide") is not None else ctx.obj["result_manager"]