From d2b9f6c63ec7d28473df770604be889f7bbc3214 Mon Sep 17 00:00:00 2001 From: Szymon Basan Date: Fri, 29 Nov 2024 11:10:26 +0100 Subject: [PATCH] default retry and timeout on vManageAuth requests --- catalystwan/request.py | 35 ++++++++++++++++++++++++++ catalystwan/tests/test_vmanage_auth.py | 9 ++++++- catalystwan/vmanage_auth.py | 35 +++++++++++++++++++++++--- 3 files changed, 74 insertions(+), 5 deletions(-) create mode 100644 catalystwan/request.py diff --git a/catalystwan/request.py b/catalystwan/request.py new file mode 100644 index 00000000..90e827c6 --- /dev/null +++ b/catalystwan/request.py @@ -0,0 +1,35 @@ +# Copyright 2024 Cisco Systems, Inc. and its affiliates +from logging import getLogger +from typing import Callable, Tuple, Type, TypeVar + +from requests import delete, get, head, options, patch, post, put, request +from requests.exceptions import ConnectionError, Timeout +from typing_extensions import Concatenate, ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") +logger = getLogger(__name__) + + +def retry(function: Callable[P, T], catch: Tuple[Type[Exception], ...]) -> Callable[Concatenate[int, P], T]: + def decorator(retries: int, *args: P.args, **kwargs: P.kwargs) -> T: + for _ in range(retries): + try: + return function(*args, **kwargs) + except catch as e: + logger.warning(f"Retrying: {e}") + return function(*args, **kwargs) + + return decorator + + +# retry decorators for request methods, retries count added as first positional argument +catch = (ConnectionError, Timeout) +retry_request = retry(request, catch) +retry_get = retry(get, catch) +retry_options = retry(options, catch) +retry_head = retry(head, catch) +retry_post = retry(post, catch) +retry_put = retry(put, catch) +retry_patch = retry(patch, catch) +retry_delete = retry(delete, catch) diff --git a/catalystwan/tests/test_vmanage_auth.py b/catalystwan/tests/test_vmanage_auth.py index 3e917d5a..3860032b 100644 --- a/catalystwan/tests/test_vmanage_auth.py +++ b/catalystwan/tests/test_vmanage_auth.py @@ -85,10 +85,12 @@ def test_get_cookie(self, mock_post): # Assert mock_post.assert_called_with( + vmanage_auth.request_retries, url="https://1.1.1.1:1111/j_security_check", data=security_payload, verify=False, headers={"Content-Type": "application/x-www-form-urlencoded", "User-Agent": USER_AGENT}, + timeout=vmanage_auth.request_timeout, ) @mock.patch("catalystwan.vmanage_auth.post", side_effect=mock_request_j_security_check) @@ -99,16 +101,19 @@ def test_get_cookie_invalid_username(self, mock_post): "j_username": username, "j_password": self.password, } + vmanage_auth = vManageAuth(username, self.password) # Act with self.assertRaises(UnauthorizedAccessError): - vManageAuth(username, self.password).get_jsessionid() + vmanage_auth.get_jsessionid() # Assert mock_post.assert_called_with( + vmanage_auth.request_retries, url="/j_security_check", data=security_payload, verify=False, headers={"Content-Type": "application/x-www-form-urlencoded", "User-Agent": USER_AGENT}, + timeout=vmanage_auth.request_timeout, ) @mock.patch("catalystwan.vmanage_auth.get", side_effect=mock_valid_token) @@ -128,10 +133,12 @@ def test_fetch_token(self, mock_get): self.assertEqual(token, "valid-token") mock_get.assert_called_with( + vmanage_auth.request_retries, url=valid_url, verify=False, headers={"Content-Type": "application/json", "User-Agent": USER_AGENT}, cookies=cookies, + timeout=vmanage_auth.request_timeout, ) @mock.patch("catalystwan.vmanage_auth.get", side_effect=mock_invalid_token_status) diff --git a/catalystwan/vmanage_auth.py b/catalystwan/vmanage_auth.py index 0813388c..41441866 100644 --- a/catalystwan/vmanage_auth.py +++ b/catalystwan/vmanage_auth.py @@ -7,7 +7,7 @@ from urllib.parse import urlparse from packaging.version import Version # type: ignore -from requests import PreparedRequest, Response, get, post +from requests import PreparedRequest, Response from requests.auth import AuthBase from requests.cookies import RequestsCookieJar, merge_cookies @@ -15,6 +15,8 @@ from catalystwan.abstractions import APIEndpointClient, AuthProtocol from catalystwan.exceptions import CatalystwanException, TenantSubdomainNotFound from catalystwan.models.tenant import Tenant +from catalystwan.request import retry_get as get +from catalystwan.request import retry_post as post from catalystwan.response import ManagerResponse, auth_response_debug from catalystwan.version import NullVersion @@ -81,6 +83,8 @@ def __init__(self, username: str, password: str, logger: Optional[logging.Logger self._base_url: str = "" self.session_count: int = 0 self.lock: RLock = RLock() + self.request_retries = 1 + self.request_timeout = 10 def __str__(self) -> str: return f"vManageAuth(username={self.username})" @@ -109,7 +113,14 @@ def get_jsessionid(self) -> str: } url = self._base_url + "/j_security_check" headers = {"Content-Type": "application/x-www-form-urlencoded", "User-Agent": USER_AGENT} - response: Response = post(url=url, headers=headers, data=security_payload, verify=self.verify) + response: Response = post( + self.request_retries, + url=url, + headers=headers, + data=security_payload, + verify=self.verify, + timeout=self.request_timeout, + ) self.sync_cookies(response.cookies) self.logger.debug(auth_response_debug(response, str(self))) if response.text != "" or not isinstance(self.jsessionid, str) or self.jsessionid == "": @@ -120,10 +131,12 @@ def get_xsrftoken(self) -> str: url = self._base_url + "/dataservice/client/token" headers = {"Content-Type": "application/json", "User-Agent": USER_AGENT} response: Response = get( + self.request_retries, url=url, cookies=self.cookies, headers=headers, verify=self.verify, + timeout=self.request_timeout, ) self.sync_cookies(response.cookies) self.logger.debug(auth_response_debug(response, str(self))) @@ -151,11 +164,21 @@ def logout(self, client: APIEndpointClient) -> None: headers = {"x-xsrf-token": self.xsrftoken, "User-Agent": USER_AGENT} if version >= Version("20.12"): response = post( - f"{self._base_url}/logout", headers=headers, cookies=self.cookies, verify=self.verify + self.request_retries, + url=f"{self._base_url}/logout", + headers=headers, + cookies=self.cookies, + verify=self.verify, + timeout=self.request_timeout, ) else: response = get( - f"{self._base_url}/logout", headers=headers, cookies=self.cookies, verify=self.verify + self.request_retries, + url=f"{self._base_url}/logout", + headers=headers, + cookies=self.cookies, + verify=self.verify, + timeout=self.request_timeout, ) self.logger.debug(auth_response_debug(response, str(self))) if response.status_code != 200: @@ -227,10 +250,12 @@ def get_tenantid(self) -> str: url = self._base_url + "/dataservice/tenant" headers = {"Content-Type": "application/json", "User-Agent": USER_AGENT, "x-xsrf-token": self.xsrftoken} response: Response = get( + self.request_retries, url=url, cookies=self.cookies, headers=headers, verify=self.verify, + timeout=self.request_timeout, ) self.sync_cookies(response.cookies) self.logger.debug(auth_response_debug(response, str(self))) @@ -244,10 +269,12 @@ def get_vsessionid(self, tenantid: str) -> str: url = self._base_url + f"/dataservice/tenant/{tenantid}/vsessionid" headers = {"Content-Type": "application/json", "User-Agent": USER_AGENT, "x-xsrf-token": self.xsrftoken} response: Response = post( + self.request_retries, url=url, cookies=self.cookies, headers=headers, verify=self.verify, + timeout=self.request_timeout, ) self.sync_cookies(response.cookies) self.logger.debug(auth_response_debug(response, str(self)))