Skip to content

Commit

Permalink
fix(anta.tests): Cleaning up security tests module (VerifyAPISSLCerti…
Browse files Browse the repository at this point in the history
…ficate, VerifyIPv4ACL) (#957)

* refactor VerifyAPISSLCertificate, VerifyIPv4ACL tests for input model

* Updated test docstring

* updated the unit test for no acl found

* addressed review comments: updated docs

* addressed review comments: updated input model docstring

* Add previous models for backward compatibility

---------

Co-authored-by: Carl Baillargeon <carl.baillargeon@arista.com>
  • Loading branch information
vitthalmagadum and carl-baillargeon authored Jan 15, 2025
1 parent e82c1a5 commit d57b53e
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 176 deletions.
117 changes: 114 additions & 3 deletions anta/input_models/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,20 @@
from __future__ import annotations

from ipaddress import IPv4Address
from typing import Any
from typing import TYPE_CHECKING, Any, ClassVar, get_args
from warnings import warn

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field, model_validator

from anta.custom_types import EcdsaKeySize, EncryptionAlgorithm, RsaKeySize

if TYPE_CHECKING:
import sys

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class IPSecPeer(BaseModel):
Expand Down Expand Up @@ -43,6 +53,107 @@ class IPSecConn(BaseModel):
"""The IPv4 address of the destination in the security connection."""


class APISSLCertificate(BaseModel):
"""Model for an API SSL certificate."""

model_config = ConfigDict(extra="forbid")
certificate_name: str
"""The name of the certificate to be verified."""
expiry_threshold: int
"""The expiry threshold of the certificate in days."""
common_name: str
"""The Common Name of the certificate."""
encryption_algorithm: EncryptionAlgorithm
"""The encryption algorithm used by the certificate."""
key_size: RsaKeySize | EcdsaKeySize
"""The key size (in bits) of the encryption algorithm."""

def __str__(self) -> str:
"""Return a human-readable string representation of the APISSLCertificate for reporting.
Examples
--------
- Certificate: SIGNING_CA.crt
"""
return f"Certificate: {self.certificate_name}"

@model_validator(mode="after")
def validate_inputs(self) -> Self:
"""Validate the key size provided to the APISSLCertificates class.
If encryption_algorithm is RSA then key_size should be in {2048, 3072, 4096}.
If encryption_algorithm is ECDSA then key_size should be in {256, 384, 521}.
"""
if self.encryption_algorithm == "RSA" and self.key_size not in get_args(RsaKeySize):
msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for RSA encryption. Allowed sizes are {get_args(RsaKeySize)}."
raise ValueError(msg)

if self.encryption_algorithm == "ECDSA" and self.key_size not in get_args(EcdsaKeySize):
msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for ECDSA encryption. Allowed sizes are {get_args(EcdsaKeySize)}."
raise ValueError(msg)

return self


class ACLEntry(BaseModel):
"""Model for an Access Control List (ACL) entry."""

model_config = ConfigDict(extra="forbid")
sequence: int = Field(ge=1, le=4294967295)
"""Sequence number of the ACL entry, used to define the order of processing. Must be between 1 and 4294967295."""
action: str
"""Action of the ACL entry. Example: `deny ip any any`."""

def __str__(self) -> str:
"""Return a human-readable string representation of the ACLEntry for reporting.
Examples
--------
- Sequence: 10
"""
return f"Sequence: {self.sequence}"


class ACL(BaseModel):
"""Model for an Access Control List (ACL)."""

model_config = ConfigDict(extra="forbid")
name: str
"""Name of the ACL."""
entries: list[ACLEntry]
"""List of the ACL entries."""
IPv4ACLEntry: ClassVar[type[ACLEntry]] = ACLEntry
"""To maintain backward compatibility."""

def __str__(self) -> str:
"""Return a human-readable string representation of the ACL for reporting.
Examples
--------
- ACL name: Test
"""
return f"ACL name: {self.name}"


class IPv4ACL(ACL): # pragma: no cover
"""Alias for the ACL model to maintain backward compatibility.
When initialized, it will emit a deprecation warning and call the ACL model.
TODO: Remove this class in ANTA v2.0.0.
"""

def __init__(self, **data: Any) -> None: # noqa: ANN401
"""Initialize the IPv4ACL class, emitting a deprecation warning."""
warn(
message="IPv4ACL model is deprecated and will be removed in ANTA v2.0.0. Use the ACL model instead.",
category=DeprecationWarning,
stacklevel=2,
)
super().__init__(**data)


class IPSecPeers(IPSecPeer): # pragma: no cover
"""Alias for the IPSecPeers model to maintain backward compatibility.
Expand All @@ -52,7 +163,7 @@ class IPSecPeers(IPSecPeer): # pragma: no cover
"""

def __init__(self, **data: Any) -> None: # noqa: ANN401
"""Initialize the IPSecPeer class, emitting a deprecation warning."""
"""Initialize the IPSecPeers class, emitting a deprecation warning."""
warn(
message="IPSecPeers model is deprecated and will be removed in ANTA v2.0.0. Use the IPSecPeer model instead.",
category=DeprecationWarning,
Expand Down
181 changes: 68 additions & 113 deletions anta/tests/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,12 @@
# Mypy does not understand AntaTest.Input typing
# mypy: disable-error-code=attr-defined
from datetime import datetime, timezone
from typing import TYPE_CHECKING, ClassVar, get_args
from typing import ClassVar

from pydantic import BaseModel, Field, model_validator

from anta.custom_types import EcdsaKeySize, EncryptionAlgorithm, PositiveInteger, RsaKeySize
from anta.input_models.security import IPSecPeer, IPSecPeers
from anta.custom_types import PositiveInteger
from anta.input_models.security import ACL, APISSLCertificate, IPSecPeer, IPSecPeers
from anta.models import AntaCommand, AntaTemplate, AntaTest
from anta.tools import get_failed_logs, get_item, get_value

if TYPE_CHECKING:
import sys

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
from anta.tools import get_item, get_value


class VerifySSHStatus(AntaTest):
Expand Down Expand Up @@ -354,14 +344,27 @@ def test(self) -> None:


class VerifyAPISSLCertificate(AntaTest):
"""Verifies the eAPI SSL certificate expiry, common subject name, encryption algorithm and key size.
"""Verifies the eAPI SSL certificate.
This test performs the following checks for each certificate:
1. Validates that the certificate is not expired and meets the configured expiry threshold.
2. Validates that the certificate Common Name matches the expected one.
3. Ensures the certificate uses the specified encryption algorithm.
4. Verifies the certificate key matches the expected key size.
Expected Results
----------------
* Success: The test will pass if the certificate's expiry date is greater than the threshold,
and the certificate has the correct name, encryption algorithm, and key size.
* Failure: The test will fail if the certificate is expired or is going to expire,
or if the certificate has an incorrect name, encryption algorithm, or key size.
* Success: If all of the following occur:
- The certificate's expiry date exceeds the configured threshold.
- The certificate's Common Name matches the input configuration.
- The encryption algorithm used by the certificate is as expected.
- The key size of the certificate matches the input configuration.
* Failure: If any of the following occur:
- The certificate is expired or set to expire within the defined threshold.
- The certificate's common name does not match the expected input.
- The encryption algorithm is incorrect.
- The key size does not match the expected input.
Examples
--------
Expand Down Expand Up @@ -393,38 +396,7 @@ class Input(AntaTest.Input):

certificates: list[APISSLCertificate]
"""List of API SSL certificates."""

class APISSLCertificate(BaseModel):
"""Model for an API SSL certificate."""

certificate_name: str
"""The name of the certificate to be verified."""
expiry_threshold: int
"""The expiry threshold of the certificate in days."""
common_name: str
"""The common subject name of the certificate."""
encryption_algorithm: EncryptionAlgorithm
"""The encryption algorithm of the certificate."""
key_size: RsaKeySize | EcdsaKeySize
"""The encryption algorithm key size of the certificate."""

@model_validator(mode="after")
def validate_inputs(self) -> Self:
"""Validate the key size provided to the APISSLCertificates class.
If encryption_algorithm is RSA then key_size should be in {2048, 3072, 4096}.
If encryption_algorithm is ECDSA then key_size should be in {256, 384, 521}.
"""
if self.encryption_algorithm == "RSA" and self.key_size not in get_args(RsaKeySize):
msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for RSA encryption. Allowed sizes are {get_args(RsaKeySize)}."
raise ValueError(msg)

if self.encryption_algorithm == "ECDSA" and self.key_size not in get_args(EcdsaKeySize):
msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for ECDSA encryption. Allowed sizes are {get_args(EcdsaKeySize)}."
raise ValueError(msg)

return self
APISSLCertificate: ClassVar[type[APISSLCertificate]] = APISSLCertificate

@AntaTest.anta_test
def test(self) -> None:
Expand All @@ -442,32 +414,33 @@ def test(self) -> None:
# Collecting certificate expiry time and current EOS time.
# These times are used to calculate the number of days until the certificate expires.
if not (certificate_data := get_value(certificate_output, f"certificates..{certificate.certificate_name}", separator="..")):
self.result.is_failure(f"SSL certificate '{certificate.certificate_name}', is not configured.\n")
self.result.is_failure(f"{certificate} - Not found")
continue

expiry_time = certificate_data["notAfter"]
day_difference = (datetime.fromtimestamp(expiry_time, tz=timezone.utc) - datetime.fromtimestamp(current_timestamp, tz=timezone.utc)).days

# Verify certificate expiry
if 0 < day_difference < certificate.expiry_threshold:
self.result.is_failure(f"SSL certificate `{certificate.certificate_name}` is about to expire in {day_difference} days.\n")
self.result.is_failure(
f"{certificate} - set to expire within the threshold - Threshold: {certificate.expiry_threshold} days Actual: {day_difference} days"
)
elif day_difference < 0:
self.result.is_failure(f"SSL certificate `{certificate.certificate_name}` is expired.\n")
self.result.is_failure(f"{certificate} - certificate expired")

# Verify certificate common subject name, encryption algorithm and key size
keys_to_verify = ["subject.commonName", "publicKey.encryptionAlgorithm", "publicKey.size"]
actual_certificate_details = {key: get_value(certificate_data, key) for key in keys_to_verify}
common_name = get_value(certificate_data, "subject.commonName", default="Not found")
encryp_algo = get_value(certificate_data, "publicKey.encryptionAlgorithm", default="Not found")
key_size = get_value(certificate_data, "publicKey.size", default="Not found")

expected_certificate_details = {
"subject.commonName": certificate.common_name,
"publicKey.encryptionAlgorithm": certificate.encryption_algorithm,
"publicKey.size": certificate.key_size,
}
if common_name != certificate.common_name:
self.result.is_failure(f"{certificate} - incorrect common name - Expected: {certificate.common_name} Actual: {common_name}")

if encryp_algo != certificate.encryption_algorithm:
self.result.is_failure(f"{certificate} - incorrect encryption algorithm - Expected: {certificate.encryption_algorithm} Actual: {encryp_algo}")

if actual_certificate_details != expected_certificate_details:
failed_log = f"SSL certificate `{certificate.certificate_name}` is not configured properly:"
failed_log += get_failed_logs(expected_certificate_details, actual_certificate_details)
self.result.is_failure(f"{failed_log}\n")
if key_size != certificate.key_size:
self.result.is_failure(f"{certificate} - incorrect public key - Expected: {certificate.key_size} Actual: {key_size}")


class VerifyBannerLogin(AntaTest):
Expand Down Expand Up @@ -555,12 +528,22 @@ def test(self) -> None:


class VerifyIPv4ACL(AntaTest):
"""Verifies the configuration of IPv4 ACLs.
"""Verifies the IPv4 ACLs.
This test performs the following checks for each IPv4 ACL:
1. Validates that the IPv4 ACL is properly configured.
2. Validates that the sequence entries in the ACL are correctly ordered.
Expected Results
----------------
* Success: The test will pass if an IPv4 ACL is configured with the correct sequence entries.
* Failure: The test will fail if an IPv4 ACL is not configured or entries are not in sequence.
* Success: If all of the following occur:
- Any IPv4 ACL entry is not configured.
- The sequency entries are correctly configured.
* Failure: If any of the following occur:
- The IPv4 ACL is not configured.
- The any IPv4 ACL entry is not configured.
- The action for any entry does not match the expected input.
Examples
--------
Expand All @@ -586,65 +569,37 @@ class VerifyIPv4ACL(AntaTest):
"""

categories: ClassVar[list[str]] = ["security"]
commands: ClassVar[list[AntaCommand | AntaTemplate]] = [AntaTemplate(template="show ip access-lists {acl}", revision=1)]
commands: ClassVar[list[AntaCommand | AntaTemplate]] = [AntaCommand(command="show ip access-lists", revision=1)]

class Input(AntaTest.Input):
"""Input model for the VerifyIPv4ACL test."""

ipv4_access_lists: list[IPv4ACL]
ipv4_access_lists: list[ACL]
"""List of IPv4 ACLs to verify."""

class IPv4ACL(BaseModel):
"""Model for an IPv4 ACL."""

name: str
"""Name of IPv4 ACL."""

entries: list[IPv4ACLEntry]
"""List of IPv4 ACL entries."""

class IPv4ACLEntry(BaseModel):
"""Model for an IPv4 ACL entry."""

sequence: int = Field(ge=1, le=4294967295)
"""Sequence number of an ACL entry."""
action: str
"""Action of an ACL entry."""

def render(self, template: AntaTemplate) -> list[AntaCommand]:
"""Render the template for each input ACL."""
return [template.render(acl=acl.name) for acl in self.inputs.ipv4_access_lists]
IPv4ACL: ClassVar[type[ACL]] = ACL
"""To maintain backward compatibility."""

@AntaTest.anta_test
def test(self) -> None:
"""Main test function for VerifyIPv4ACL."""
self.result.is_success()
for command_output, acl in zip(self.instance_commands, self.inputs.ipv4_access_lists):
# Collecting input ACL details
acl_name = command_output.params.acl
# Retrieve the expected entries from the inputs
acl_entries = acl.entries

# Check if ACL is configured
ipv4_acl_list = command_output.json_output["aclList"]
if not ipv4_acl_list:
self.result.is_failure(f"{acl_name}: Not found")

if not (command_output := self.instance_commands[0].json_output["aclList"]):
self.result.is_failure("No Access Control List (ACL) configured")
return

for access_list in self.inputs.ipv4_access_lists:
if not (access_list_output := get_item(command_output, "name", access_list.name)):
self.result.is_failure(f"{access_list} - Not configured")
continue

# Check if the sequence number is configured and has the correct action applied
failed_log = f"{acl_name}:\n"
for acl_entry in acl_entries:
acl_seq = acl_entry.sequence
acl_action = acl_entry.action
if (actual_entry := get_item(ipv4_acl_list[0]["sequence"], "sequenceNumber", acl_seq)) is None:
failed_log += f"Sequence number `{acl_seq}` is not found.\n"
for entry in access_list.entries:
if not (actual_entry := get_item(access_list_output["sequence"], "sequenceNumber", entry.sequence)):
self.result.is_failure(f"{access_list} {entry} - Not configured")
continue

if actual_entry["text"] != acl_action:
failed_log += f"Expected `{acl_action}` as sequence number {acl_seq} action but found `{actual_entry['text']}` instead.\n"

if failed_log != f"{acl_name}:\n":
self.result.is_failure(f"{failed_log}")
if (act_action := actual_entry["text"]) != entry.action:
self.result.is_failure(f"{access_list} {entry} - action mismatch - Expected: {entry.action} Actual: {act_action}")


class VerifyIPSecConnHealth(AntaTest):
Expand Down
Loading

0 comments on commit d57b53e

Please sign in to comment.