From e0545a8451eb71dcc2853489e713b749140241a2 Mon Sep 17 00:00:00 2001 From: vitthalmagadum <122079046+vitthalmagadum@users.noreply.github.com> Date: Tue, 21 Jan 2025 00:43:28 +0530 Subject: [PATCH] fix(anta.tests): AntaTest.Input subclasses using common input models should have validators for their required fields. (#1013) * Added validators for required fields * added unit tests for input validators * Fix docstrings --------- Co-authored-by: Carl Baillargeon --- anta/input_models/routing/bgp.py | 4 +- anta/input_models/routing/generic.py | 6 +-- anta/tests/interfaces.py | 27 +++++++++- anta/tests/routing/generic.py | 20 +++++-- .../input_models/routing/test_generic.py | 41 +++++++++++++++ tests/units/input_models/test_interfaces.py | 52 +++++++++++++++++++ 6 files changed, 139 insertions(+), 11 deletions(-) create mode 100644 tests/units/input_models/routing/test_generic.py diff --git a/anta/input_models/routing/bgp.py b/anta/input_models/routing/bgp.py index a34227a1b..8b1256e18 100644 --- a/anta/input_models/routing/bgp.py +++ b/anta/input_models/routing/bgp.py @@ -224,8 +224,8 @@ class BgpRoute(BaseModel): """The IPv4 network address.""" vrf: str = "default" """Optional VRF for the BGP peer. Defaults to `default`.""" - paths: list[BgpRoutePath] | None = None - """A list of paths for the BGP route. Required field in the `VerifyBGPRouteOrigin` test.""" + paths: list[BgpRoutePath] + """A list of paths for the BGP route.""" def __str__(self) -> str: """Return a human-readable string representation of the BgpRoute for reporting. diff --git a/anta/input_models/routing/generic.py b/anta/input_models/routing/generic.py index b683a4582..277611237 100644 --- a/anta/input_models/routing/generic.py +++ b/anta/input_models/routing/generic.py @@ -17,11 +17,11 @@ class IPv4Routes(BaseModel): model_config = ConfigDict(extra="forbid") prefix: IPv4Network - """The IPV4 network to validate the route type.""" + """IPv4 prefix in CIDR notation.""" vrf: str = "default" """VRF context. Defaults to `default` VRF.""" - route_type: IPv4RouteType - """List of IPV4 Route type to validate the valid rout type.""" + route_type: IPv4RouteType | None = None + """Expected route type. Required field in the `VerifyIPv4RouteType` test.""" def __str__(self) -> str: """Return a human-readable string representation of the IPv4RouteType for reporting.""" diff --git a/anta/tests/interfaces.py b/anta/tests/interfaces.py index 3191053fc..551c416b6 100644 --- a/anta/tests/interfaces.py +++ b/anta/tests/interfaces.py @@ -9,9 +9,9 @@ import re from ipaddress import IPv4Interface -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeVar -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from pydantic_extra_types.mac_address import MacAddress from anta import GITHUB_SUGGESTION @@ -23,6 +23,9 @@ BPS_GBPS_CONVERSIONS = 1000000000 +# Using a TypeVar for the InterfaceState model since mypy thinks it's a ClassVar and not a valid type when used in field validators +T = TypeVar("T", bound=InterfaceState) + class VerifyInterfaceUtilization(AntaTest): """Verifies that the utilization of interfaces is below a certain threshold. @@ -226,6 +229,16 @@ class Input(AntaTest.Input): """List of interfaces with their expected state.""" InterfaceState: ClassVar[type[InterfaceState]] = InterfaceState + @field_validator("interfaces") + @classmethod + def validate_interfaces(cls, interfaces: list[T]) -> list[T]: + """Validate that 'status' field is provided in each interface.""" + for interface in interfaces: + if interface.status is None: + msg = f"{interface} 'status' field missing in the input" + raise ValueError(msg) + return interfaces + @AntaTest.anta_test def test(self) -> None: """Main test function for VerifyInterfacesStatus.""" @@ -891,6 +904,16 @@ class Input(AntaTest.Input): """List of interfaces with their expected state.""" InterfaceState: ClassVar[type[InterfaceState]] = InterfaceState + @field_validator("interfaces") + @classmethod + def validate_interfaces(cls, interfaces: list[T]) -> list[T]: + """Validate that 'portchannel' field is provided in each interface.""" + for interface in interfaces: + if interface.portchannel is None: + msg = f"{interface} 'portchannel' field missing in the input" + raise ValueError(msg) + return interfaces + @AntaTest.anta_test def test(self) -> None: """Main test function for VerifyLACPInterfacesStatus.""" diff --git a/anta/tests/routing/generic.py b/anta/tests/routing/generic.py index 97709aa40..c3cd76008 100644 --- a/anta/tests/routing/generic.py +++ b/anta/tests/routing/generic.py @@ -11,7 +11,7 @@ from ipaddress import IPv4Address, IPv4Interface from typing import TYPE_CHECKING, ClassVar, Literal -from pydantic import model_validator +from pydantic import field_validator, model_validator from anta.custom_types import PositiveInteger from anta.input_models.routing.generic import IPv4Routes @@ -189,9 +189,10 @@ class VerifyIPv4RouteType(AntaTest): """Verifies the route-type of the IPv4 prefixes. This test performs the following checks for each IPv4 route: - 1. Verifies that the specified VRF is configured. - 2. Verifies that the specified IPv4 route is exists in the configuration. - 3. Verifies that the the specified IPv4 route is of the expected type. + + 1. Verifies that the specified VRF is configured. + 2. Verifies that the specified IPv4 route is exists in the configuration. + 3. Verifies that the the specified IPv4 route is of the expected type. Expected Results ---------------- @@ -230,6 +231,17 @@ class Input(AntaTest.Input): """Input model for the VerifyIPv4RouteType test.""" routes_entries: list[IPv4Routes] + """List of IPv4 route(s).""" + + @field_validator("routes_entries") + @classmethod + def validate_routes_entries(cls, routes_entries: list[IPv4Routes]) -> list[IPv4Routes]: + """Validate that 'route_type' field is provided in each BGP route entry.""" + for entry in routes_entries: + if entry.route_type is None: + msg = f"{entry} 'route_type' field missing in the input" + raise ValueError(msg) + return routes_entries @AntaTest.anta_test def test(self) -> None: diff --git a/tests/units/input_models/routing/test_generic.py b/tests/units/input_models/routing/test_generic.py new file mode 100644 index 000000000..59d069e08 --- /dev/null +++ b/tests/units/input_models/routing/test_generic.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023-2025 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Tests for anta.input_models.routing.generic.py.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from pydantic import ValidationError + +from anta.tests.routing.generic import VerifyIPv4RouteType + +if TYPE_CHECKING: + from anta.input_models.routing.generic import IPv4Routes + + +class TestVerifyIPv4RouteTypeInput: + """Test anta.tests.routing.bgp.VerifyIPv4RouteType.Input.""" + + @pytest.mark.parametrize( + ("routes_entries"), + [ + pytest.param([{"prefix": "192.168.0.0/24", "vrf": "default", "route_type": "eBGP"}], id="valid"), + ], + ) + def test_valid(self, routes_entries: list[IPv4Routes]) -> None: + """Test VerifyIPv4RouteType.Input valid inputs.""" + VerifyIPv4RouteType.Input(routes_entries=routes_entries) + + @pytest.mark.parametrize( + ("routes_entries"), + [ + pytest.param([{"prefix": "192.168.0.0/24", "vrf": "default"}], id="invalid"), + ], + ) + def test_invalid(self, routes_entries: list[IPv4Routes]) -> None: + """Test VerifyIPv4RouteType.Input invalid inputs.""" + with pytest.raises(ValidationError): + VerifyIPv4RouteType.Input(routes_entries=routes_entries) diff --git a/tests/units/input_models/test_interfaces.py b/tests/units/input_models/test_interfaces.py index ee850ee7a..aefa31941 100644 --- a/tests/units/input_models/test_interfaces.py +++ b/tests/units/input_models/test_interfaces.py @@ -9,8 +9,10 @@ from typing import TYPE_CHECKING import pytest +from pydantic import ValidationError from anta.input_models.interfaces import InterfaceState +from anta.tests.interfaces import VerifyInterfacesStatus, VerifyLACPInterfacesStatus if TYPE_CHECKING: from anta.custom_types import Interface, PortChannelInterface @@ -31,3 +33,53 @@ class TestInterfaceState: def test_valid__str__(self, name: Interface, portchannel: PortChannelInterface | None, expected: str) -> None: """Test InterfaceState __str__.""" assert str(InterfaceState(name=name, portchannel=portchannel)) == expected + + +class TestVerifyInterfacesStatusInput: + """Test anta.tests.interfaces.VerifyInterfacesStatus.Input.""" + + @pytest.mark.parametrize( + ("interfaces"), + [ + pytest.param([{"name": "Ethernet1", "status": "up"}], id="valid"), + ], + ) + def test_valid(self, interfaces: list[InterfaceState]) -> None: + """Test VerifyInterfacesStatus.Input valid inputs.""" + VerifyInterfacesStatus.Input(interfaces=interfaces) + + @pytest.mark.parametrize( + ("interfaces"), + [ + pytest.param([{"name": "Ethernet1"}], id="invalid"), + ], + ) + def test_invalid(self, interfaces: list[InterfaceState]) -> None: + """Test VerifyInterfacesStatus.Input invalid inputs.""" + with pytest.raises(ValidationError): + VerifyInterfacesStatus.Input(interfaces=interfaces) + + +class TestVerifyLACPInterfacesStatusInput: + """Test anta.tests.interfaces.VerifyLACPInterfacesStatus.Input.""" + + @pytest.mark.parametrize( + ("interfaces"), + [ + pytest.param([{"name": "Ethernet1", "portchannel": "Port-Channel100"}], id="valid"), + ], + ) + def test_valid(self, interfaces: list[InterfaceState]) -> None: + """Test VerifyLACPInterfacesStatus.Input valid inputs.""" + VerifyLACPInterfacesStatus.Input(interfaces=interfaces) + + @pytest.mark.parametrize( + ("interfaces"), + [ + pytest.param([{"name": "Ethernet1"}], id="invalid"), + ], + ) + def test_invalid(self, interfaces: list[InterfaceState]) -> None: + """Test VerifyLACPInterfacesStatus.Input invalid inputs.""" + with pytest.raises(ValidationError): + VerifyLACPInterfacesStatus.Input(interfaces=interfaces)