Skip to content

Commit

Permalink
Merge branch 'main' into issue_822
Browse files Browse the repository at this point in the history
  • Loading branch information
gmuloc authored Jan 21, 2025
2 parents a2fe431 + 995943a commit 161d0e3
Show file tree
Hide file tree
Showing 30 changed files with 1,429 additions and 313 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ repos:
- '<!--| ~| -->'

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.1
rev: v0.9.2
hooks:
- id: ruff
name: Run Ruff linter
Expand Down
8 changes: 8 additions & 0 deletions anta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,11 @@
r"No source interface .*",
]
"""List of known EOS errors that should set a test status to 'failure' with the error message."""

UNSUPPORTED_PLATFORM_ERRORS = [
"not supported on this hardware platform",
"Invalid input (at token 2: 'trident')",
]
"""Error messages indicating platform or hardware unsupported commands.
Will set the test status to 'skipped'. Includes both general hardware
platform errors and specific ASIC family limitations."""
1 change: 1 addition & 0 deletions anta/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,4 @@ def validate_regex(value: str) -> str:
SnmpHashingAlgorithm = Literal["MD5", "SHA", "SHA-224", "SHA-256", "SHA-384", "SHA-512"]
SnmpEncryptionAlgorithm = Literal["AES-128", "AES-192", "AES-256", "DES"]
DynamicVlanSource = Literal["dmf", "dot1x", "dynvtep", "evpn", "mlag", "mlagsync", "mvpn", "swfwd", "vccbfd"]
LogSeverityLevel = Literal["alerts", "critical", "debugging", "emergencies", "errors", "informational", "notifications", "warnings"]
95 changes: 77 additions & 18 deletions anta/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from collections import OrderedDict, defaultdict
from time import monotonic
from typing import TYPE_CHECKING, Any, Literal

import asyncssh
import httpcore
from aiocache import Cache
from aiocache.plugins import HitMissRatioPlugin
from asyncssh import SSHClientConnection, SSHClientConnectionOptions
from httpx import ConnectError, HTTPError, TimeoutException

Expand All @@ -34,6 +33,67 @@
CLIENT_KEYS = asyncssh.public_key.load_default_keypairs()


class AntaCache:
"""Class to be used as cache.
Example
-------
```python
# Create cache
cache = AntaCache("device1")
with cache.locks[key]:
command_output = cache.get(key)
```
"""

def __init__(self, device: str, max_size: int = 128, ttl: int = 60) -> None:
"""Initialize the cache."""
self.device = device
self.cache: OrderedDict[str, Any] = OrderedDict()
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self.max_size = max_size
self.ttl = ttl

# Stats
self.stats: dict[str, int] = {}
self._init_stats()

def _init_stats(self) -> None:
"""Initialize the stats."""
self.stats["hits"] = 0
self.stats["total"] = 0

async def get(self, key: str) -> Any: # noqa: ANN401
"""Return the cached entry for key."""
self.stats["total"] += 1
if key in self.cache:
timestamp, value = self.cache[key]
if monotonic() - timestamp < self.ttl:
# checking the value is still valid
self.cache.move_to_end(key)
self.stats["hits"] += 1
return value
# Time expired
del self.cache[key]
del self.locks[key]
return None

async def set(self, key: str, value: Any) -> bool: # noqa: ANN401
"""Set the cached entry for key to value."""
timestamp = monotonic()
if len(self.cache) > self.max_size:
self.cache.popitem(last=False)
self.cache[key] = timestamp, value
return True

def clear(self) -> None:
"""Empty the cache."""
logger.debug("Clearing cache for device %s", self.device)
self.cache = OrderedDict()
self._init_stats()


class AntaDevice(ABC):
"""Abstract class representing a device in ANTA.
Expand All @@ -52,10 +112,11 @@ class AntaDevice(ABC):
Hardware model of the device.
tags : set[str]
Tags for this device.
cache : Cache | None
In-memory cache from aiocache library for this device (None if cache is disabled).
cache : AntaCache | None
In-memory cache for this device (None if cache is disabled).
cache_locks : dict
Dictionary mapping keys to asyncio locks to guarantee exclusive access to the cache if not disabled.
Deprecated, will be removed in ANTA v2.0.0, use self.cache.locks instead.
"""

Expand All @@ -79,7 +140,8 @@ def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bo
self.tags.add(self.name)
self.is_online: bool = False
self.established: bool = False
self.cache: Cache | None = None
self.cache: AntaCache | None = None
# Keeping cache_locks for backward compatibility.
self.cache_locks: defaultdict[str, asyncio.Lock] | None = None

# Initialize cache if not disabled
Expand All @@ -101,17 +163,16 @@ def __hash__(self) -> int:

def _init_cache(self) -> None:
"""Initialize cache for the device, can be overridden by subclasses to manipulate how it works."""
self.cache = Cache(cache_class=Cache.MEMORY, ttl=60, namespace=self.name, plugins=[HitMissRatioPlugin()])
self.cache_locks = defaultdict(asyncio.Lock)
self.cache = AntaCache(device=self.name, ttl=60)
self.cache_locks = self.cache.locks

@property
def cache_statistics(self) -> dict[str, Any] | None:
"""Return the device cache statistics for logging purposes."""
# Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough
# https://github.com/pylint-dev/pylint/issues/7258
if self.cache is not None:
stats = getattr(self.cache, "hit_miss_ratio", {"total": 0, "hits": 0, "hit_ratio": 0})
return {"total_commands_sent": stats["total"], "cache_hits": stats["hits"], "cache_hit_ratio": f"{stats['hit_ratio'] * 100:.2f}%"}
stats = self.cache.stats
ratio = stats["hits"] / stats["total"] if stats["total"] > 0 else 0
return {"total_commands_sent": stats["total"], "cache_hits": stats["hits"], "cache_hit_ratio": f"{ratio * 100:.2f}%"}
return None

def __rich_repr__(self) -> Iterator[tuple[str, Any]]:
Expand Down Expand Up @@ -177,18 +238,16 @@ async def collect(self, command: AntaCommand, *, collection_id: str | None = Non
collection_id
An identifier used to build the eAPI request ID.
"""
# Need to ignore pylint no-member as Cache is a proxy class and pylint is not smart enough
# https://github.com/pylint-dev/pylint/issues/7258
if self.cache is not None and self.cache_locks is not None and command.use_cache:
async with self.cache_locks[command.uid]:
cached_output = await self.cache.get(command.uid) # pylint: disable=no-member
if self.cache is not None and command.use_cache:
async with self.cache.locks[command.uid]:
cached_output = await self.cache.get(command.uid)

if cached_output is not None:
logger.debug("Cache hit for %s on %s", command.command, self.name)
command.output = cached_output
else:
await self._collect(command=command, collection_id=collection_id)
await self.cache.set(command.uid, command.output) # pylint: disable=no-member
await self.cache.set(command.uid, command.output)
else:
await self._collect(command=command, collection_id=collection_id)

Expand Down
47 changes: 45 additions & 2 deletions anta/input_models/routing/bgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from ipaddress import IPv4Address, IPv4Network, IPv6Address
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
from warnings import warn

from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator
Expand Down Expand Up @@ -149,7 +149,7 @@ class BgpPeer(BaseModel):
received_routes: list[IPv4Network] | None = None
"""List of received routes in CIDR format. Required field in the `VerifyBGPExchangedRoutes` test."""
capabilities: list[MultiProtocolCaps] | None = None
"""List of BGP multiprotocol capabilities. Required field in the `VerifyBGPPeerMPCaps` test."""
"""List of BGP multiprotocol capabilities. Required field in the `VerifyBGPPeerMPCaps`, `VerifyBGPNlriAcceptance` tests."""
strict: bool = False
"""If True, requires exact match of the provided BGP multiprotocol capabilities.
Expand Down Expand Up @@ -211,3 +211,46 @@ class VxlanEndpoint(BaseModel):
def __str__(self) -> str:
"""Return a human-readable string representation of the VxlanEndpoint for reporting."""
return f"Address: {self.address} VNI: {self.vni}"


class BgpRoute(BaseModel):
"""Model representing BGP routes.
Only IPv4 prefixes are supported for now.
"""

model_config = ConfigDict(extra="forbid")
prefix: IPv4Network
"""The IPv4 network address."""
vrf: str = "default"
"""Optional VRF for the BGP peer. Defaults to `default`."""
paths: list[BgpRoutePath]
"""A list of paths for the BGP route."""

def __str__(self) -> str:
"""Return a human-readable string representation of the BgpRoute for reporting.
Examples
--------
- Prefix: 192.168.66.100/24 VRF: default
"""
return f"Prefix: {self.prefix} VRF: {self.vrf}"


class BgpRoutePath(BaseModel):
"""Model representing a BGP route path."""

model_config = ConfigDict(extra="forbid")
nexthop: IPv4Address
"""The next-hop IPv4 address for the path."""
origin: Literal["Igp", "Egp", "Incomplete"]
"""The BGP origin attribute of the route."""

def __str__(self) -> str:
"""Return a human-readable string representation of the RoutePath for reporting.
Examples
--------
- Next-hop: 192.168.66.101 Origin: Igp
"""
return f"Next-hop: {self.nexthop} Origin: {self.origin}"
6 changes: 3 additions & 3 deletions anta/input_models/routing/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ class IPv4Routes(BaseModel):

model_config = ConfigDict(extra="forbid")
prefix: IPv4Network
"""The IPV4 network to validate the route type."""
"""IPv4 prefix in CIDR notation."""
vrf: str = "default"
"""VRF context. Defaults to `default` VRF."""
route_type: IPv4RouteType
"""List of IPV4 Route type to validate the valid rout type."""
route_type: IPv4RouteType | None = None
"""Expected route type. Required field in the `VerifyIPv4RouteType` test."""

def __str__(self) -> str:
"""Return a human-readable string representation of the IPv4RouteType for reporting."""
Expand Down
117 changes: 114 additions & 3 deletions anta/input_models/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,20 @@
from __future__ import annotations

from ipaddress import IPv4Address
from typing import Any
from typing import TYPE_CHECKING, Any, ClassVar, get_args
from warnings import warn

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field, model_validator

from anta.custom_types import EcdsaKeySize, EncryptionAlgorithm, RsaKeySize

if TYPE_CHECKING:
import sys

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class IPSecPeer(BaseModel):
Expand Down Expand Up @@ -43,6 +53,107 @@ class IPSecConn(BaseModel):
"""The IPv4 address of the destination in the security connection."""


class APISSLCertificate(BaseModel):
"""Model for an API SSL certificate."""

model_config = ConfigDict(extra="forbid")
certificate_name: str
"""The name of the certificate to be verified."""
expiry_threshold: int
"""The expiry threshold of the certificate in days."""
common_name: str
"""The Common Name of the certificate."""
encryption_algorithm: EncryptionAlgorithm
"""The encryption algorithm used by the certificate."""
key_size: RsaKeySize | EcdsaKeySize
"""The key size (in bits) of the encryption algorithm."""

def __str__(self) -> str:
"""Return a human-readable string representation of the APISSLCertificate for reporting.
Examples
--------
- Certificate: SIGNING_CA.crt
"""
return f"Certificate: {self.certificate_name}"

@model_validator(mode="after")
def validate_inputs(self) -> Self:
"""Validate the key size provided to the APISSLCertificates class.
If encryption_algorithm is RSA then key_size should be in {2048, 3072, 4096}.
If encryption_algorithm is ECDSA then key_size should be in {256, 384, 521}.
"""
if self.encryption_algorithm == "RSA" and self.key_size not in get_args(RsaKeySize):
msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for RSA encryption. Allowed sizes are {get_args(RsaKeySize)}."
raise ValueError(msg)

if self.encryption_algorithm == "ECDSA" and self.key_size not in get_args(EcdsaKeySize):
msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for ECDSA encryption. Allowed sizes are {get_args(EcdsaKeySize)}."
raise ValueError(msg)

return self


class ACLEntry(BaseModel):
"""Model for an Access Control List (ACL) entry."""

model_config = ConfigDict(extra="forbid")
sequence: int = Field(ge=1, le=4294967295)
"""Sequence number of the ACL entry, used to define the order of processing. Must be between 1 and 4294967295."""
action: str
"""Action of the ACL entry. Example: `deny ip any any`."""

def __str__(self) -> str:
"""Return a human-readable string representation of the ACLEntry for reporting.
Examples
--------
- Sequence: 10
"""
return f"Sequence: {self.sequence}"


class ACL(BaseModel):
"""Model for an Access Control List (ACL)."""

model_config = ConfigDict(extra="forbid")
name: str
"""Name of the ACL."""
entries: list[ACLEntry]
"""List of the ACL entries."""
IPv4ACLEntry: ClassVar[type[ACLEntry]] = ACLEntry
"""To maintain backward compatibility."""

def __str__(self) -> str:
"""Return a human-readable string representation of the ACL for reporting.
Examples
--------
- ACL name: Test
"""
return f"ACL name: {self.name}"


class IPv4ACL(ACL): # pragma: no cover
"""Alias for the ACL model to maintain backward compatibility.
When initialized, it will emit a deprecation warning and call the ACL model.
TODO: Remove this class in ANTA v2.0.0.
"""

def __init__(self, **data: Any) -> None: # noqa: ANN401
"""Initialize the IPv4ACL class, emitting a deprecation warning."""
warn(
message="IPv4ACL model is deprecated and will be removed in ANTA v2.0.0. Use the ACL model instead.",
category=DeprecationWarning,
stacklevel=2,
)
super().__init__(**data)


class IPSecPeers(IPSecPeer): # pragma: no cover
"""Alias for the IPSecPeers model to maintain backward compatibility.
Expand All @@ -52,7 +163,7 @@ class IPSecPeers(IPSecPeer): # pragma: no cover
"""

def __init__(self, **data: Any) -> None: # noqa: ANN401
"""Initialize the IPSecPeer class, emitting a deprecation warning."""
"""Initialize the IPSecPeers class, emitting a deprecation warning."""
warn(
message="IPSecPeers model is deprecated and will be removed in ANTA v2.0.0. Use the IPSecPeer model instead.",
category=DeprecationWarning,
Expand Down
Loading

0 comments on commit 161d0e3

Please sign in to comment.