From a668f5820d4aae2f60f6fef1d18c4b44f96168e5 Mon Sep 17 00:00:00 2001
From: PrzeG <86780353+PrzeG@users.noreply.github.com>
Date: Mon, 28 Oct 2024 13:18:45 +0100
Subject: [PATCH] Multithreading Auth (#842)
---
README.md | 49 +++++++++++++++++
catalystwan/abstractions.py | 9 +++-
catalystwan/apigw_auth.py | 36 +++++++++++--
catalystwan/request_limiter.py | 18 +++++++
catalystwan/session.py | 86 +++++++++++++++++++++++-------
catalystwan/vmanage_auth.py | 97 +++++++++++++++++++++++++---------
pyproject.toml | 2 +-
7 files changed, 247 insertions(+), 50 deletions(-)
create mode 100644 catalystwan/request_limiter.py
diff --git a/README.md b/README.md
index 0ced7d6d..932b9b8d 100644
--- a/README.md
+++ b/README.md
@@ -125,6 +125,54 @@ with create_apigw_session(
```
+
+ Threading (click to expand)
+
+```python
+from threading import Thread
+from catalystwan.session import ManagerSession
+from catalystwan.vmanage_auth import vManageAuth
+from copy import copy
+
+def print_devices(manager: ManagerSession):
+ # using context manager (recommended)
+ with manager.login() as session:
+ print(session.api.devices.get())
+
+if __name__ =="__main__":
+
+ # 1. Create shared authentication handler for user session
+ auth = vManageAuth(username="username", password="password")
+ # 2. Configure session with base url and attach authentication handler
+ manager = ManagerSession(base_url="https://url:port", auth=auth)
+
+ # 3. Make sure each thread gets own copy of ManagerSession object
+ t1 = Thread(target=print_devices, args=(manager,))
+ t2 = Thread(target=print_devices, args=(copy(manager),))
+ t3 = Thread(target=print_devices, args=(copy(manager),))
+
+ t1.start()
+ t2.start()
+ t3.start()
+
+ t1.join()
+ t2.join()
+ t3.join()
+
+ print("Done!")
+```
+Threading can be achieved by using a shared auth object with sessions in each thread. As `ManagerSession` is not guaranteed to be thread-safe, it is recommended to create one session per thread. `ManagerSession` also comes in with a default `RequestLimiter`, which limits the number of concurrent requests to 50. It keeps `ManagerSession` from overloading the server and avoids HTTP 503 and HTTP 429 errors.
+If you wish to modify the limit, you can pass a modified `RequestLimiter` to `ManagerSession`:
+```python
+from catalystwan.session import ManagerSession
+from catalystwan.vmanage_auth import vManageAuth
+from catalystwan.request_limiter import RequestLimiter
+
+auth = vManageAuth(username="username", password="password")
+limiter = RequestLimiter(max_requests=30)
+manager = ManagerSession(base_url="https://url:port", auth=auth, request_limiter=limiter)
+```
+
## API usage examples
All examples below assumes `session` variable contains logged-in [Manager Session](#Manager-Session) instance.
@@ -413,6 +461,7 @@ migrate_task.wait_for_completed()
```
+
### Note:
To remove `InsecureRequestWarning`, you can include in your scripts (warning is suppressed when `catalystwan_devel` environment variable is set):
```Python
diff --git a/catalystwan/abstractions.py b/catalystwan/abstractions.py
index 1f6a98ac..2e4f3f9b 100644
--- a/catalystwan/abstractions.py
+++ b/catalystwan/abstractions.py
@@ -3,6 +3,7 @@
from typing import Optional, Protocol, Type, TypeVar
from packaging.version import Version # type: ignore
+from requests import PreparedRequest
from catalystwan.typed_list import DataSequence
from catalystwan.utils.session_type import SessionType
@@ -66,5 +67,11 @@ class AuthProtocol(Protocol):
def logout(self, client: APIEndpointClient) -> None:
...
- def clear(self) -> None:
+ def clear(self, last_request: Optional[PreparedRequest]) -> None:
+ ...
+
+ def increase_session_count(self) -> None:
+ ...
+
+ def decrease_session_count(self) -> None:
...
diff --git a/catalystwan/apigw_auth.py b/catalystwan/apigw_auth.py
index b8e19335..a8239e6b 100644
--- a/catalystwan/apigw_auth.py
+++ b/catalystwan/apigw_auth.py
@@ -1,4 +1,5 @@
import logging
+from threading import RLock
from typing import Literal, Optional
from urllib.parse import urlparse
@@ -37,13 +38,16 @@ def __init__(self, login: ApiGwLogin, logger: Optional[logging.Logger] = None, v
self.token = ""
self.logger = logger or logging.getLogger(__name__)
self.verify = verify
+ self.session_count: int = 0
+ self.lock: RLock = RLock()
def __str__(self) -> str:
return f"ApiGatewayAuth(mode={self.login.mode})"
def __call__(self, request: PreparedRequest) -> PreparedRequest:
- self.handle_auth(request)
- self.build_digest_header(request)
+ with self.lock:
+ self.handle_auth(request)
+ self.build_digest_header(request)
return request
def handle_auth(self, request: PreparedRequest) -> None:
@@ -92,5 +96,29 @@ def get_token(
def logout(self, client: APIEndpointClient) -> None:
return None
- def clear(self) -> None:
- self.token = ""
+ def _clear(self) -> None:
+ with self.lock:
+ self.token = ""
+
+ def increase_session_count(self) -> None:
+ with self.lock:
+ self.session_count += 1
+
+ def decrease_session_count(self) -> None:
+ with self.lock:
+ self.session_count -= 1
+
+ def clear(self, last_request: Optional[PreparedRequest]) -> None:
+ with self.lock:
+ # extract previously used jsessionid
+ if last_request is None:
+ token = None
+ else:
+ token = last_request.headers.get("Authorization")
+
+ if self.token == "" or f"Bearer {self.token}" == token:
+ # used auth was up-to-date, clear state
+ return self._clear()
+ else:
+ # used auth was out-of-date, repeat the request with a new one
+ return
diff --git a/catalystwan/request_limiter.py b/catalystwan/request_limiter.py
new file mode 100644
index 00000000..10a29fc2
--- /dev/null
+++ b/catalystwan/request_limiter.py
@@ -0,0 +1,18 @@
+from __future__ import annotations
+
+from contextlib import AbstractContextManager
+from threading import Semaphore
+
+
+class RequestLimiter(AbstractContextManager):
+ def __init__(self, max_requests: int = 49):
+ self._max_requests: int = max_requests
+ self._semaphore: Semaphore = Semaphore(value=self._max_requests)
+
+ def __enter__(self) -> RequestLimiter:
+ self._semaphore.acquire()
+ return self
+
+ def __exit__(self, *exc_info) -> None:
+ self._semaphore.release()
+ return
diff --git a/catalystwan/session.py b/catalystwan/session.py
index dcc502c3..95658cb8 100644
--- a/catalystwan/session.py
+++ b/catalystwan/session.py
@@ -27,6 +27,7 @@
TenantSubdomainNotFound,
)
from catalystwan.models.tenant import Tenant
+from catalystwan.request_limiter import RequestLimiter
from catalystwan.response import ManagerResponse, response_history_debug
from catalystwan.utils.session_type import SessionType
from catalystwan.version import NullVersion, parse_api_version
@@ -61,6 +62,8 @@ class ManagerSessionState(Enum):
WAIT_SERVER_READY_AFTER_RESTART = 1
LOGIN = 2
OPERATIVE = 3
+ LOGIN_IN_PROGRESS = 4
+ AUTH_SYNC = 5
def determine_session_type(
@@ -229,6 +232,7 @@ def __init__(
auth: Union[vManageAuth, ApiGwAuth],
subdomain: Optional[str] = None,
logger: Optional[logging.Logger] = None,
+ request_limiter: Optional[RequestLimiter] = None,
) -> None:
self.base_url = base_url
self.subdomain = subdomain
@@ -241,6 +245,7 @@ def __init__(
super(ManagerSession, self).__init__()
self.verify = False
self.headers.update({"User-Agent": USER_AGENT})
+ self._added_to_auth = False
self._auth = auth
self._platform_version: str = ""
self._api_version: Version = NullVersion # type: ignore
@@ -249,6 +254,8 @@ def __init__(
self.request_timeout: Optional[int] = None
self._validate_responses = True
self._state: ManagerSessionState = ManagerSessionState.OPERATIVE
+ self._last_request: Optional[PreparedRequest] = None
+ self._limiter: RequestLimiter = request_limiter or RequestLimiter()
@cached_property
def api(self) -> APIContainer:
@@ -286,8 +293,20 @@ def state(self, state: ManagerSessionState) -> None:
self.wait_server_ready(self.restart_timeout)
self.state = ManagerSessionState.LOGIN
elif state == ManagerSessionState.LOGIN:
- self.login()
+ self.state = ManagerSessionState.LOGIN_IN_PROGRESS
+ self._sync_auth()
+ server_info = self._fetch_server_info()
+ self._finalize_login(server_info)
self.state = ManagerSessionState.OPERATIVE
+ elif state == ManagerSessionState.LOGIN_IN_PROGRESS:
+ # nothing to be done, continue to login
+ return
+ elif state == ManagerSessionState.AUTH_SYNC:
+ # this state can be reached when using an expired auth during the login (most likely
+ # to happen when multithreading). To avoid fetching server info multiple times, we will
+ # only authenticate here and then return to the previous login flow
+ self._sync_auth()
+ self.state = ManagerSessionState.LOGIN_IN_PROGRESS
return
def restart_imminent(self, restart_timeout_override: Optional[int] = None):
@@ -301,23 +320,23 @@ def restart_imminent(self, restart_timeout_override: Optional[int] = None):
self.restart_timeout = restart_timeout_override
self.state = ManagerSessionState.RESTART_IMMINENT
- def login(self) -> ManagerSession:
- """Performs login to SDWAN Manager and fetches important server info to instance variables
-
- Raises:
- SessionNotCreatedError: indicates session configuration is not consistent
-
- Returns:
- ManagerSession: (self)
- """
+ def _sync_auth(self) -> None:
self.cookies.clear_session_cookies()
- self._auth.clear()
+ if not self._added_to_auth:
+ self._auth.increase_session_count()
+ self._added_to_auth = True
+ self._auth.clear(self._last_request)
self.auth = self._auth
+
+ def _fetch_server_info(self) -> ServerInfo:
try:
server_info = self.server()
except DefaultPasswordError:
server_info = ServerInfo.model_construct(**{})
+ return server_info
+
+ def _finalize_login(self, server_info: ServerInfo) -> None:
self.server_name = server_info.server
tenancy_mode = server_info.tenancy_mode
@@ -325,6 +344,7 @@ def login(self) -> ManagerSession:
view_mode = server_info.view_mode
self._session_type = determine_session_type(tenancy_mode, user_mode, view_mode)
+
if user_mode is UserMode.TENANT and self.subdomain:
raise SessionNotCreatedError(
f"Session not created. Subdomain {self.subdomain} passed to tenant session, "
@@ -339,6 +359,17 @@ def login(self) -> ManagerSession:
self.logger.info(
f"Logged to vManage({self.platform_version}) as {self.auth}. The session type is {self.session_type}"
)
+
+ def login(self) -> ManagerSession:
+ """Performs login to SDWAN Manager and fetches important server info to instance variables
+
+ Raises:
+ SessionNotCreatedError: indicates session configuration is not consistent
+
+ Returns:
+ ManagerSession: (self)
+ """
+ self.state = ManagerSessionState.LOGIN
return self
def wait_server_ready(self, timeout: int, poll_period: int = 10) -> None:
@@ -399,7 +430,8 @@ def request(self, method, url, *args, **kwargs) -> ManagerResponse:
if self.request_timeout is not None: # do not modify user provided kwargs unless property is set
_kwargs.update(timeout=self.request_timeout)
try:
- response = super(ManagerSession, self).request(method, full_url, *args, **_kwargs)
+ with self._limiter:
+ response = super(ManagerSession, self).request(method, full_url, *args, **_kwargs)
self.logger.debug(self.response_trace(response, None))
if self.state == ManagerSessionState.RESTART_IMMINENT and response.status_code == 503:
self.state = ManagerSessionState.WAIT_SERVER_READY_AFTER_RESTART
@@ -411,14 +443,29 @@ def request(self, method, url, *args, **kwargs) -> ManagerResponse:
self.logger.debug(exception)
raise ManagerRequestException(*exception.args, request=exception.request, response=exception.response)
- if response.jsessionid_expired and self.state == ManagerSessionState.OPERATIVE:
- self.logger.warning("Logging to session. Reason: expired JSESSIONID detected in response headers")
- self.state = ManagerSessionState.LOGIN
+ self._last_request = response.request
+ if response.jsessionid_expired and self.state in [
+ ManagerSessionState.OPERATIVE,
+ ManagerSessionState.LOGIN_IN_PROGRESS,
+ ]:
+ # detected expired auth during login, resync
+ if self.state == ManagerSessionState.LOGIN_IN_PROGRESS:
+ self.state = ManagerSessionState.AUTH_SYNC
+ else:
+ self.logger.warning("Logging to session. Reason: expired JSESSIONID detected in response headers")
+ self.state = ManagerSessionState.LOGIN
return self.request(method, url, *args, **_kwargs)
- if response.api_gw_unauthorized and self.state == ManagerSessionState.OPERATIVE:
- self.logger.warning("Logging to API GW session. Reason: unauthorized detected in response headers")
- self.state = ManagerSessionState.LOGIN
+ if response.api_gw_unauthorized and self.state in [
+ ManagerSessionState.OPERATIVE,
+ ManagerSessionState.LOGIN_IN_PROGRESS,
+ ]:
+ # detected expired auth during login, resync
+ if self.state == ManagerSessionState.LOGIN_IN_PROGRESS:
+ self.state = ManagerSessionState.AUTH_SYNC
+ else:
+ self.logger.warning("Logging to API GW session. Reason: unauthorized detected in response headers")
+ self.state = ManagerSessionState.LOGIN
return self.request(method, url, *args, **_kwargs)
if response.request.url and "passwordReset.html" in response.request.url:
@@ -485,6 +532,8 @@ def get_tenant_id(self) -> str:
return tenant.tenant_id
def logout(self) -> None:
+ if self._added_to_auth:
+ self._auth.decrease_session_count()
self._auth.logout(self)
def close(self) -> None:
@@ -532,6 +581,7 @@ def __copy__(self) -> ManagerSession:
auth=self._auth,
subdomain=self.subdomain,
logger=self.logger,
+ request_limiter=self._limiter,
)
def __str__(self) -> str:
diff --git a/catalystwan/vmanage_auth.py b/catalystwan/vmanage_auth.py
index cfcfdd76..0813388c 100644
--- a/catalystwan/vmanage_auth.py
+++ b/catalystwan/vmanage_auth.py
@@ -2,6 +2,7 @@
import logging
from http.cookies import SimpleCookie
+from threading import RLock
from typing import Optional
from urllib.parse import urlparse
@@ -78,14 +79,17 @@ def __init__(self, username: str, password: str, logger: Optional[logging.Logger
self.logger = logger or logging.getLogger(__name__)
self.cookies: RequestsCookieJar = RequestsCookieJar()
self._base_url: str = ""
+ self.session_count: int = 0
+ self.lock: RLock = RLock()
def __str__(self) -> str:
return f"vManageAuth(username={self.username})"
def __call__(self, request: PreparedRequest) -> PreparedRequest:
- self.handle_auth(request)
- update_headers(request, self.jsessionid, self.xsrftoken)
- return request
+ with self.lock:
+ self.handle_auth(request)
+ update_headers(request, self.jsessionid, self.xsrftoken)
+ return request
def sync_cookies(self, cookies: RequestsCookieJar) -> None:
self.cookies = merge_cookies(self.cookies, cookies)
@@ -133,24 +137,63 @@ def authenticate(self, request: PreparedRequest):
self.xsrftoken = self.get_xsrftoken()
def logout(self, client: APIEndpointClient) -> None:
- if isinstance((version := client.api_version), NullVersion):
- self.logger.warning("Cannot perform logout without known api version.")
- elif self._base_url is None:
- self.logger.warning("Cannot perform logout without known base url")
- else:
- headers = {"x-xsrf-token": self.xsrftoken, "User-Agent": USER_AGENT}
- if version >= Version("20.12"):
- response = post(f"{self._base_url}/logout", headers=headers, cookies=self.cookies, verify=self.verify)
+ with self.lock:
+ if self.session_count > 1:
+ # Other sessions still use the auth, unregister and return
+ return
+
+ # last session using the auth, logout
+ if isinstance((version := client.api_version), NullVersion):
+ self.logger.warning("Cannot perform logout without known api version.")
+ elif self._base_url is None:
+ self.logger.warning("Cannot perform logout without known base url")
else:
- response = get(f"{self._base_url}/logout", headers=headers, cookies=self.cookies, verify=self.verify)
- self.logger.debug(auth_response_debug(response, str(self)))
- if response.status_code != 200:
- self.logger.error("Unsuccessfull logout")
- self.clear()
-
- def clear(self) -> None:
- self.cookies.clear_session_cookies()
- self.xsrftoken = None
+ headers = {"x-xsrf-token": self.xsrftoken, "User-Agent": USER_AGENT}
+ if version >= Version("20.12"):
+ response = post(
+ f"{self._base_url}/logout", headers=headers, cookies=self.cookies, verify=self.verify
+ )
+ else:
+ response = get(
+ f"{self._base_url}/logout", headers=headers, cookies=self.cookies, verify=self.verify
+ )
+ self.logger.debug(auth_response_debug(response, str(self)))
+ if response.status_code != 200:
+ self.logger.error("Unsuccessfull logout")
+ self._clear()
+
+ def _clear(self) -> None:
+ with self.lock:
+ self.cookies.clear_session_cookies()
+ self.xsrftoken = None
+
+ def increase_session_count(self) -> None:
+ with self.lock:
+ self.session_count += 1
+
+ def decrease_session_count(self) -> None:
+ with self.lock:
+ self.session_count -= 1
+
+ def clear(self, last_request: Optional[PreparedRequest]) -> None:
+ with self.lock:
+ # extract previously used jsessionid
+ if last_request is None:
+ jsessionid = None
+ else:
+ cookie: SimpleCookie = SimpleCookie()
+ cookie.load(last_request.headers.get("Cookie", ""))
+ try:
+ jsessionid = cookie["JSESSIONID"].value
+ except KeyError:
+ jsessionid = None
+
+ if self.jsessionid is None or self.jsessionid == jsessionid:
+ # used auth was up-to-date, clear state
+ return self._clear()
+ else:
+ # used auth was out-of-date, repeat the request with a new one
+ return
class vSessionAuth(vManageAuth):
@@ -170,9 +213,10 @@ def __str__(self) -> str:
return f"vSessionAuth(username={self.username},subdomain={self.subdomain})" # noqa: E231
def __call__(self, request: PreparedRequest) -> PreparedRequest:
- self.handle_auth(request)
- update_headers(request, self.jsessionid, self.xsrftoken, self.vsessionid)
- return request
+ with self.lock:
+ self.handle_auth(request)
+ update_headers(request, self.jsessionid, self.xsrftoken, self.vsessionid)
+ return request
def authenticate(self, request: PreparedRequest):
super().authenticate(request)
@@ -209,9 +253,10 @@ def get_vsessionid(self, tenantid: str) -> str:
self.logger.debug(auth_response_debug(response, str(self)))
return response.json()["VSessionId"]
- def clear(self) -> None:
- super().clear()
- self.vsessionid = None
+ def _clear(self) -> None:
+ with self.lock:
+ super()._clear()
+ self.vsessionid = None
def create_vmanage_auth(
diff --git a/pyproject.toml b/pyproject.toml
index d992969c..739703f4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "catalystwan"
-version = "0.35.6"
+version = "0.36.0"
description = "Cisco Catalyst WAN SDK for Python"
authors = ["kagorski "]
readme = "README.md"