From a668f5820d4aae2f60f6fef1d18c4b44f96168e5 Mon Sep 17 00:00:00 2001 From: PrzeG <86780353+PrzeG@users.noreply.github.com> Date: Mon, 28 Oct 2024 13:18:45 +0100 Subject: [PATCH] Multithreading Auth (#842) --- README.md | 49 +++++++++++++++++ catalystwan/abstractions.py | 9 +++- catalystwan/apigw_auth.py | 36 +++++++++++-- catalystwan/request_limiter.py | 18 +++++++ catalystwan/session.py | 86 +++++++++++++++++++++++------- catalystwan/vmanage_auth.py | 97 +++++++++++++++++++++++++--------- pyproject.toml | 2 +- 7 files changed, 247 insertions(+), 50 deletions(-) create mode 100644 catalystwan/request_limiter.py diff --git a/README.md b/README.md index 0ced7d6d..932b9b8d 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,54 @@ with create_apigw_session( ``` +
+ Threading (click to expand) + +```python +from threading import Thread +from catalystwan.session import ManagerSession +from catalystwan.vmanage_auth import vManageAuth +from copy import copy + +def print_devices(manager: ManagerSession): + # using context manager (recommended) + with manager.login() as session: + print(session.api.devices.get()) + +if __name__ =="__main__": + + # 1. Create shared authentication handler for user session + auth = vManageAuth(username="username", password="password") + # 2. Configure session with base url and attach authentication handler + manager = ManagerSession(base_url="https://url:port", auth=auth) + + # 3. Make sure each thread gets own copy of ManagerSession object + t1 = Thread(target=print_devices, args=(manager,)) + t2 = Thread(target=print_devices, args=(copy(manager),)) + t3 = Thread(target=print_devices, args=(copy(manager),)) + + t1.start() + t2.start() + t3.start() + + t1.join() + t2.join() + t3.join() + + print("Done!") +``` +Threading can be achieved by using a shared auth object with sessions in each thread. As `ManagerSession` is not guaranteed to be thread-safe, it is recommended to create one session per thread. `ManagerSession` also comes in with a default `RequestLimiter`, which limits the number of concurrent requests to 50. It keeps `ManagerSession` from overloading the server and avoids HTTP 503 and HTTP 429 errors. +If you wish to modify the limit, you can pass a modified `RequestLimiter` to `ManagerSession`: +```python +from catalystwan.session import ManagerSession +from catalystwan.vmanage_auth import vManageAuth +from catalystwan.request_limiter import RequestLimiter + +auth = vManageAuth(username="username", password="password") +limiter = RequestLimiter(max_requests=30) +manager = ManagerSession(base_url="https://url:port", auth=auth, request_limiter=limiter) +``` +
## API usage examples All examples below assumes `session` variable contains logged-in [Manager Session](#Manager-Session) instance. @@ -413,6 +461,7 @@ migrate_task.wait_for_completed() ``` + ### Note: To remove `InsecureRequestWarning`, you can include in your scripts (warning is suppressed when `catalystwan_devel` environment variable is set): ```Python diff --git a/catalystwan/abstractions.py b/catalystwan/abstractions.py index 1f6a98ac..2e4f3f9b 100644 --- a/catalystwan/abstractions.py +++ b/catalystwan/abstractions.py @@ -3,6 +3,7 @@ from typing import Optional, Protocol, Type, TypeVar from packaging.version import Version # type: ignore +from requests import PreparedRequest from catalystwan.typed_list import DataSequence from catalystwan.utils.session_type import SessionType @@ -66,5 +67,11 @@ class AuthProtocol(Protocol): def logout(self, client: APIEndpointClient) -> None: ... - def clear(self) -> None: + def clear(self, last_request: Optional[PreparedRequest]) -> None: + ... + + def increase_session_count(self) -> None: + ... + + def decrease_session_count(self) -> None: ... diff --git a/catalystwan/apigw_auth.py b/catalystwan/apigw_auth.py index b8e19335..a8239e6b 100644 --- a/catalystwan/apigw_auth.py +++ b/catalystwan/apigw_auth.py @@ -1,4 +1,5 @@ import logging +from threading import RLock from typing import Literal, Optional from urllib.parse import urlparse @@ -37,13 +38,16 @@ def __init__(self, login: ApiGwLogin, logger: Optional[logging.Logger] = None, v self.token = "" self.logger = logger or logging.getLogger(__name__) self.verify = verify + self.session_count: int = 0 + self.lock: RLock = RLock() def __str__(self) -> str: return f"ApiGatewayAuth(mode={self.login.mode})" def __call__(self, request: PreparedRequest) -> PreparedRequest: - self.handle_auth(request) - self.build_digest_header(request) + with self.lock: + self.handle_auth(request) + self.build_digest_header(request) return request def handle_auth(self, request: PreparedRequest) -> None: @@ -92,5 +96,29 @@ def get_token( def logout(self, client: APIEndpointClient) -> None: return None - def clear(self) -> None: - self.token = "" + def _clear(self) -> None: + with self.lock: + self.token = "" + + def increase_session_count(self) -> None: + with self.lock: + self.session_count += 1 + + def decrease_session_count(self) -> None: + with self.lock: + self.session_count -= 1 + + def clear(self, last_request: Optional[PreparedRequest]) -> None: + with self.lock: + # extract previously used jsessionid + if last_request is None: + token = None + else: + token = last_request.headers.get("Authorization") + + if self.token == "" or f"Bearer {self.token}" == token: + # used auth was up-to-date, clear state + return self._clear() + else: + # used auth was out-of-date, repeat the request with a new one + return diff --git a/catalystwan/request_limiter.py b/catalystwan/request_limiter.py new file mode 100644 index 00000000..10a29fc2 --- /dev/null +++ b/catalystwan/request_limiter.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from contextlib import AbstractContextManager +from threading import Semaphore + + +class RequestLimiter(AbstractContextManager): + def __init__(self, max_requests: int = 49): + self._max_requests: int = max_requests + self._semaphore: Semaphore = Semaphore(value=self._max_requests) + + def __enter__(self) -> RequestLimiter: + self._semaphore.acquire() + return self + + def __exit__(self, *exc_info) -> None: + self._semaphore.release() + return diff --git a/catalystwan/session.py b/catalystwan/session.py index dcc502c3..95658cb8 100644 --- a/catalystwan/session.py +++ b/catalystwan/session.py @@ -27,6 +27,7 @@ TenantSubdomainNotFound, ) from catalystwan.models.tenant import Tenant +from catalystwan.request_limiter import RequestLimiter from catalystwan.response import ManagerResponse, response_history_debug from catalystwan.utils.session_type import SessionType from catalystwan.version import NullVersion, parse_api_version @@ -61,6 +62,8 @@ class ManagerSessionState(Enum): WAIT_SERVER_READY_AFTER_RESTART = 1 LOGIN = 2 OPERATIVE = 3 + LOGIN_IN_PROGRESS = 4 + AUTH_SYNC = 5 def determine_session_type( @@ -229,6 +232,7 @@ def __init__( auth: Union[vManageAuth, ApiGwAuth], subdomain: Optional[str] = None, logger: Optional[logging.Logger] = None, + request_limiter: Optional[RequestLimiter] = None, ) -> None: self.base_url = base_url self.subdomain = subdomain @@ -241,6 +245,7 @@ def __init__( super(ManagerSession, self).__init__() self.verify = False self.headers.update({"User-Agent": USER_AGENT}) + self._added_to_auth = False self._auth = auth self._platform_version: str = "" self._api_version: Version = NullVersion # type: ignore @@ -249,6 +254,8 @@ def __init__( self.request_timeout: Optional[int] = None self._validate_responses = True self._state: ManagerSessionState = ManagerSessionState.OPERATIVE + self._last_request: Optional[PreparedRequest] = None + self._limiter: RequestLimiter = request_limiter or RequestLimiter() @cached_property def api(self) -> APIContainer: @@ -286,8 +293,20 @@ def state(self, state: ManagerSessionState) -> None: self.wait_server_ready(self.restart_timeout) self.state = ManagerSessionState.LOGIN elif state == ManagerSessionState.LOGIN: - self.login() + self.state = ManagerSessionState.LOGIN_IN_PROGRESS + self._sync_auth() + server_info = self._fetch_server_info() + self._finalize_login(server_info) self.state = ManagerSessionState.OPERATIVE + elif state == ManagerSessionState.LOGIN_IN_PROGRESS: + # nothing to be done, continue to login + return + elif state == ManagerSessionState.AUTH_SYNC: + # this state can be reached when using an expired auth during the login (most likely + # to happen when multithreading). To avoid fetching server info multiple times, we will + # only authenticate here and then return to the previous login flow + self._sync_auth() + self.state = ManagerSessionState.LOGIN_IN_PROGRESS return def restart_imminent(self, restart_timeout_override: Optional[int] = None): @@ -301,23 +320,23 @@ def restart_imminent(self, restart_timeout_override: Optional[int] = None): self.restart_timeout = restart_timeout_override self.state = ManagerSessionState.RESTART_IMMINENT - def login(self) -> ManagerSession: - """Performs login to SDWAN Manager and fetches important server info to instance variables - - Raises: - SessionNotCreatedError: indicates session configuration is not consistent - - Returns: - ManagerSession: (self) - """ + def _sync_auth(self) -> None: self.cookies.clear_session_cookies() - self._auth.clear() + if not self._added_to_auth: + self._auth.increase_session_count() + self._added_to_auth = True + self._auth.clear(self._last_request) self.auth = self._auth + + def _fetch_server_info(self) -> ServerInfo: try: server_info = self.server() except DefaultPasswordError: server_info = ServerInfo.model_construct(**{}) + return server_info + + def _finalize_login(self, server_info: ServerInfo) -> None: self.server_name = server_info.server tenancy_mode = server_info.tenancy_mode @@ -325,6 +344,7 @@ def login(self) -> ManagerSession: view_mode = server_info.view_mode self._session_type = determine_session_type(tenancy_mode, user_mode, view_mode) + if user_mode is UserMode.TENANT and self.subdomain: raise SessionNotCreatedError( f"Session not created. Subdomain {self.subdomain} passed to tenant session, " @@ -339,6 +359,17 @@ def login(self) -> ManagerSession: self.logger.info( f"Logged to vManage({self.platform_version}) as {self.auth}. The session type is {self.session_type}" ) + + def login(self) -> ManagerSession: + """Performs login to SDWAN Manager and fetches important server info to instance variables + + Raises: + SessionNotCreatedError: indicates session configuration is not consistent + + Returns: + ManagerSession: (self) + """ + self.state = ManagerSessionState.LOGIN return self def wait_server_ready(self, timeout: int, poll_period: int = 10) -> None: @@ -399,7 +430,8 @@ def request(self, method, url, *args, **kwargs) -> ManagerResponse: if self.request_timeout is not None: # do not modify user provided kwargs unless property is set _kwargs.update(timeout=self.request_timeout) try: - response = super(ManagerSession, self).request(method, full_url, *args, **_kwargs) + with self._limiter: + response = super(ManagerSession, self).request(method, full_url, *args, **_kwargs) self.logger.debug(self.response_trace(response, None)) if self.state == ManagerSessionState.RESTART_IMMINENT and response.status_code == 503: self.state = ManagerSessionState.WAIT_SERVER_READY_AFTER_RESTART @@ -411,14 +443,29 @@ def request(self, method, url, *args, **kwargs) -> ManagerResponse: self.logger.debug(exception) raise ManagerRequestException(*exception.args, request=exception.request, response=exception.response) - if response.jsessionid_expired and self.state == ManagerSessionState.OPERATIVE: - self.logger.warning("Logging to session. Reason: expired JSESSIONID detected in response headers") - self.state = ManagerSessionState.LOGIN + self._last_request = response.request + if response.jsessionid_expired and self.state in [ + ManagerSessionState.OPERATIVE, + ManagerSessionState.LOGIN_IN_PROGRESS, + ]: + # detected expired auth during login, resync + if self.state == ManagerSessionState.LOGIN_IN_PROGRESS: + self.state = ManagerSessionState.AUTH_SYNC + else: + self.logger.warning("Logging to session. Reason: expired JSESSIONID detected in response headers") + self.state = ManagerSessionState.LOGIN return self.request(method, url, *args, **_kwargs) - if response.api_gw_unauthorized and self.state == ManagerSessionState.OPERATIVE: - self.logger.warning("Logging to API GW session. Reason: unauthorized detected in response headers") - self.state = ManagerSessionState.LOGIN + if response.api_gw_unauthorized and self.state in [ + ManagerSessionState.OPERATIVE, + ManagerSessionState.LOGIN_IN_PROGRESS, + ]: + # detected expired auth during login, resync + if self.state == ManagerSessionState.LOGIN_IN_PROGRESS: + self.state = ManagerSessionState.AUTH_SYNC + else: + self.logger.warning("Logging to API GW session. Reason: unauthorized detected in response headers") + self.state = ManagerSessionState.LOGIN return self.request(method, url, *args, **_kwargs) if response.request.url and "passwordReset.html" in response.request.url: @@ -485,6 +532,8 @@ def get_tenant_id(self) -> str: return tenant.tenant_id def logout(self) -> None: + if self._added_to_auth: + self._auth.decrease_session_count() self._auth.logout(self) def close(self) -> None: @@ -532,6 +581,7 @@ def __copy__(self) -> ManagerSession: auth=self._auth, subdomain=self.subdomain, logger=self.logger, + request_limiter=self._limiter, ) def __str__(self) -> str: diff --git a/catalystwan/vmanage_auth.py b/catalystwan/vmanage_auth.py index cfcfdd76..0813388c 100644 --- a/catalystwan/vmanage_auth.py +++ b/catalystwan/vmanage_auth.py @@ -2,6 +2,7 @@ import logging from http.cookies import SimpleCookie +from threading import RLock from typing import Optional from urllib.parse import urlparse @@ -78,14 +79,17 @@ def __init__(self, username: str, password: str, logger: Optional[logging.Logger self.logger = logger or logging.getLogger(__name__) self.cookies: RequestsCookieJar = RequestsCookieJar() self._base_url: str = "" + self.session_count: int = 0 + self.lock: RLock = RLock() def __str__(self) -> str: return f"vManageAuth(username={self.username})" def __call__(self, request: PreparedRequest) -> PreparedRequest: - self.handle_auth(request) - update_headers(request, self.jsessionid, self.xsrftoken) - return request + with self.lock: + self.handle_auth(request) + update_headers(request, self.jsessionid, self.xsrftoken) + return request def sync_cookies(self, cookies: RequestsCookieJar) -> None: self.cookies = merge_cookies(self.cookies, cookies) @@ -133,24 +137,63 @@ def authenticate(self, request: PreparedRequest): self.xsrftoken = self.get_xsrftoken() def logout(self, client: APIEndpointClient) -> None: - if isinstance((version := client.api_version), NullVersion): - self.logger.warning("Cannot perform logout without known api version.") - elif self._base_url is None: - self.logger.warning("Cannot perform logout without known base url") - else: - headers = {"x-xsrf-token": self.xsrftoken, "User-Agent": USER_AGENT} - if version >= Version("20.12"): - response = post(f"{self._base_url}/logout", headers=headers, cookies=self.cookies, verify=self.verify) + with self.lock: + if self.session_count > 1: + # Other sessions still use the auth, unregister and return + return + + # last session using the auth, logout + if isinstance((version := client.api_version), NullVersion): + self.logger.warning("Cannot perform logout without known api version.") + elif self._base_url is None: + self.logger.warning("Cannot perform logout without known base url") else: - response = get(f"{self._base_url}/logout", headers=headers, cookies=self.cookies, verify=self.verify) - self.logger.debug(auth_response_debug(response, str(self))) - if response.status_code != 200: - self.logger.error("Unsuccessfull logout") - self.clear() - - def clear(self) -> None: - self.cookies.clear_session_cookies() - self.xsrftoken = None + headers = {"x-xsrf-token": self.xsrftoken, "User-Agent": USER_AGENT} + if version >= Version("20.12"): + response = post( + f"{self._base_url}/logout", headers=headers, cookies=self.cookies, verify=self.verify + ) + else: + response = get( + f"{self._base_url}/logout", headers=headers, cookies=self.cookies, verify=self.verify + ) + self.logger.debug(auth_response_debug(response, str(self))) + if response.status_code != 200: + self.logger.error("Unsuccessfull logout") + self._clear() + + def _clear(self) -> None: + with self.lock: + self.cookies.clear_session_cookies() + self.xsrftoken = None + + def increase_session_count(self) -> None: + with self.lock: + self.session_count += 1 + + def decrease_session_count(self) -> None: + with self.lock: + self.session_count -= 1 + + def clear(self, last_request: Optional[PreparedRequest]) -> None: + with self.lock: + # extract previously used jsessionid + if last_request is None: + jsessionid = None + else: + cookie: SimpleCookie = SimpleCookie() + cookie.load(last_request.headers.get("Cookie", "")) + try: + jsessionid = cookie["JSESSIONID"].value + except KeyError: + jsessionid = None + + if self.jsessionid is None or self.jsessionid == jsessionid: + # used auth was up-to-date, clear state + return self._clear() + else: + # used auth was out-of-date, repeat the request with a new one + return class vSessionAuth(vManageAuth): @@ -170,9 +213,10 @@ def __str__(self) -> str: return f"vSessionAuth(username={self.username},subdomain={self.subdomain})" # noqa: E231 def __call__(self, request: PreparedRequest) -> PreparedRequest: - self.handle_auth(request) - update_headers(request, self.jsessionid, self.xsrftoken, self.vsessionid) - return request + with self.lock: + self.handle_auth(request) + update_headers(request, self.jsessionid, self.xsrftoken, self.vsessionid) + return request def authenticate(self, request: PreparedRequest): super().authenticate(request) @@ -209,9 +253,10 @@ def get_vsessionid(self, tenantid: str) -> str: self.logger.debug(auth_response_debug(response, str(self))) return response.json()["VSessionId"] - def clear(self) -> None: - super().clear() - self.vsessionid = None + def _clear(self) -> None: + with self.lock: + super()._clear() + self.vsessionid = None def create_vmanage_auth( diff --git a/pyproject.toml b/pyproject.toml index d992969c..739703f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "catalystwan" -version = "0.35.6" +version = "0.36.0" description = "Cisco Catalyst WAN SDK for Python" authors = ["kagorski "] readme = "README.md"