diff --git a/anta/tests/bfd.py b/anta/tests/bfd.py index 861a6a2e4..2361a4221 100644 --- a/anta/tests/bfd.py +++ b/anta/tests/bfd.py @@ -8,9 +8,9 @@ from __future__ import annotations from datetime import datetime, timezone -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING, ClassVar, TypeVar -from pydantic import Field +from pydantic import Field, field_validator from anta.input_models.bfd import BFDPeer from anta.models import AntaCommand, AntaTest @@ -19,6 +19,9 @@ if TYPE_CHECKING: from anta.models import AntaTemplate +# Using a TypeVar for the BFDPeer model since mypy thinks it's a ClassVar and not a valid type when used in field validators +T = TypeVar("T", bound=BFDPeer) + class VerifyBFDSpecificPeers(AntaTest): """Verifies the state of IPv4 BFD peer sessions. @@ -143,6 +146,23 @@ class Input(AntaTest.Input): BFDPeer: ClassVar[type[BFDPeer]] = BFDPeer """To maintain backward compatibility""" + @field_validator("bfd_peers") + @classmethod + def validate_bfd_peers(cls, bfd_peers: list[T]) -> list[T]: + """Validate that 'tx_interval', 'rx_interval' and 'multiplier' fields are provided in each BFD peer.""" + for peer in bfd_peers: + missing_fileds = [] + if peer.tx_interval is None: + missing_fileds.append("tx_interval") + if peer.rx_interval is None: + missing_fileds.append("rx_interval") + if peer.multiplier is None: + missing_fileds.append("multiplier") + if missing_fileds: + msg = f"{peer} {', '.join(missing_fileds)} field(s) are missing in the input." + raise ValueError(msg) + return bfd_peers + @AntaTest.anta_test def test(self) -> None: """Main test function for VerifyBFDPeersIntervals.""" @@ -308,6 +328,16 @@ class Input(AntaTest.Input): BFDPeer: ClassVar[type[BFDPeer]] = BFDPeer """To maintain backward compatibility""" + @field_validator("bfd_peers") + @classmethod + def validate_bfd_peers(cls, bfd_peers: list[T]) -> list[T]: + """Validate that 'protocols' field is provided in each BFD peer.""" + for peer in bfd_peers: + if peer.protocols is None: + msg = f"{peer} 'protocols' field missing in the input." + raise ValueError(msg) + return bfd_peers + @AntaTest.anta_test def test(self) -> None: """Main test function for VerifyBFDPeersRegProtocols.""" diff --git a/tests/units/input_models/test_bfd.py b/tests/units/input_models/test_bfd.py new file mode 100644 index 000000000..e179f39fe --- /dev/null +++ b/tests/units/input_models/test_bfd.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023-2025 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Tests for anta.input_models.bfd.py.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from pydantic import ValidationError + +from anta.tests.bfd import VerifyBFDPeersIntervals, VerifyBFDPeersRegProtocols + +if TYPE_CHECKING: + from anta.input_models.bfd import BFDPeer + + +class TestVerifyBFDPeersIntervalsInput: + """Test anta.tests.bfd.VerifyBFDPeersIntervals.Input.""" + + @pytest.mark.parametrize( + ("bfd_peers"), + [ + pytest.param([{"peer_address": "10.0.0.1", "vrf": "default", "tx_interval": 1200, "rx_interval": 1200, "multiplier": 3}], id="valid"), + ], + ) + def test_valid(self, bfd_peers: list[BFDPeer]) -> None: + """Test VerifyBFDPeersIntervals.Input valid inputs.""" + VerifyBFDPeersIntervals.Input(bfd_peers=bfd_peers) + + @pytest.mark.parametrize( + ("bfd_peers"), + [ + pytest.param([{"peer_address": "10.0.0.1", "vrf": "default", "tx_interval": 1200}], id="invalid-tx-interval"), + pytest.param([{"peer_address": "10.0.0.1", "vrf": "default", "rx_interval": 1200}], id="invalid-rx-interval"), + pytest.param([{"peer_address": "10.0.0.1", "vrf": "default", "tx_interval": 1200, "rx_interval": 1200}], id="invalid-multiplier"), + ], + ) + def test_invalid(self, bfd_peers: list[BFDPeer]) -> None: + """Test VerifyBFDPeersIntervals.Input invalid inputs.""" + with pytest.raises(ValidationError): + VerifyBFDPeersIntervals.Input(bfd_peers=bfd_peers) + + +class TestVerifyBFDPeersRegProtocolsInput: + """Test anta.tests.bfd.VerifyBFDPeersRegProtocols.Input.""" + + @pytest.mark.parametrize( + ("bfd_peers"), + [ + pytest.param([{"peer_address": "10.0.0.1", "vrf": "default", "protocols": ["bgp"]}], id="valid"), + ], + ) + def test_valid(self, bfd_peers: list[BFDPeer]) -> None: + """Test VerifyBFDPeersRegProtocols.Input valid inputs.""" + VerifyBFDPeersRegProtocols.Input(bfd_peers=bfd_peers) + + @pytest.mark.parametrize( + ("bfd_peers"), + [ + pytest.param([{"peer_address": "10.0.0.1", "vrf": "default"}], id="invalid"), + ], + ) + def test_invalid(self, bfd_peers: list[BFDPeer]) -> None: + """Test VerifyBFDPeersRegProtocols.Input invalid inputs.""" + with pytest.raises(ValidationError): + VerifyBFDPeersRegProtocols.Input(bfd_peers=bfd_peers)