Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Commit

Permalink
Multithreading Auth (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
PrzeG authored Oct 28, 2024
1 parent 7624812 commit a668f58
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 50 deletions.
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,54 @@ with create_apigw_session(
```
</details>

<details>
<summary> <b>Threading</b> <i>(click to expand)</i></summary>

```python
from threading import Thread
from catalystwan.session import ManagerSession
from catalystwan.vmanage_auth import vManageAuth
from copy import copy

def print_devices(manager: ManagerSession):
# using context manager (recommended)
with manager.login() as session:
print(session.api.devices.get())

if __name__ =="__main__":

# 1. Create shared authentication handler for user session
auth = vManageAuth(username="username", password="password")
# 2. Configure session with base url and attach authentication handler
manager = ManagerSession(base_url="https://url:port", auth=auth)

# 3. Make sure each thread gets own copy of ManagerSession object
t1 = Thread(target=print_devices, args=(manager,))
t2 = Thread(target=print_devices, args=(copy(manager),))
t3 = Thread(target=print_devices, args=(copy(manager),))

t1.start()
t2.start()
t3.start()

t1.join()
t2.join()
t3.join()

print("Done!")
```
Threading can be achieved by using a shared auth object with sessions in each thread. As `ManagerSession` is not guaranteed to be thread-safe, it is recommended to create one session per thread. `ManagerSession` also comes in with a default `RequestLimiter`, which limits the number of concurrent requests to 50. It keeps `ManagerSession` from overloading the server and avoids HTTP 503 and HTTP 429 errors.
If you wish to modify the limit, you can pass a modified `RequestLimiter` to `ManagerSession`:
```python
from catalystwan.session import ManagerSession
from catalystwan.vmanage_auth import vManageAuth
from catalystwan.request_limiter import RequestLimiter

auth = vManageAuth(username="username", password="password")
limiter = RequestLimiter(max_requests=30)
manager = ManagerSession(base_url="https://url:port", auth=auth, request_limiter=limiter)
```
</details>

## API usage examples
All examples below assumes `session` variable contains logged-in [Manager Session](#Manager-Session) instance.
Expand Down Expand Up @@ -413,6 +461,7 @@ migrate_task.wait_for_completed()
```
</details>


### Note:
To remove `InsecureRequestWarning`, you can include in your scripts (warning is suppressed when `catalystwan_devel` environment variable is set):
```Python
Expand Down
9 changes: 8 additions & 1 deletion catalystwan/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Protocol, Type, TypeVar

from packaging.version import Version # type: ignore
from requests import PreparedRequest

from catalystwan.typed_list import DataSequence
from catalystwan.utils.session_type import SessionType
Expand Down Expand Up @@ -66,5 +67,11 @@ class AuthProtocol(Protocol):
def logout(self, client: APIEndpointClient) -> None:
...

def clear(self) -> None:
def clear(self, last_request: Optional[PreparedRequest]) -> None:
...

def increase_session_count(self) -> None:
...

def decrease_session_count(self) -> None:
...
36 changes: 32 additions & 4 deletions catalystwan/apigw_auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from threading import RLock
from typing import Literal, Optional
from urllib.parse import urlparse

Expand Down Expand Up @@ -37,13 +38,16 @@ def __init__(self, login: ApiGwLogin, logger: Optional[logging.Logger] = None, v
self.token = ""
self.logger = logger or logging.getLogger(__name__)
self.verify = verify
self.session_count: int = 0
self.lock: RLock = RLock()

def __str__(self) -> str:
return f"ApiGatewayAuth(mode={self.login.mode})"

def __call__(self, request: PreparedRequest) -> PreparedRequest:
self.handle_auth(request)
self.build_digest_header(request)
with self.lock:
self.handle_auth(request)
self.build_digest_header(request)
return request

def handle_auth(self, request: PreparedRequest) -> None:
Expand Down Expand Up @@ -92,5 +96,29 @@ def get_token(
def logout(self, client: APIEndpointClient) -> None:
return None

def clear(self) -> None:
self.token = ""
def _clear(self) -> None:
with self.lock:
self.token = ""

def increase_session_count(self) -> None:
with self.lock:
self.session_count += 1

def decrease_session_count(self) -> None:
with self.lock:
self.session_count -= 1

def clear(self, last_request: Optional[PreparedRequest]) -> None:
with self.lock:
# extract previously used jsessionid
if last_request is None:
token = None
else:
token = last_request.headers.get("Authorization")

if self.token == "" or f"Bearer {self.token}" == token:
# used auth was up-to-date, clear state
return self._clear()
else:
# used auth was out-of-date, repeat the request with a new one
return
18 changes: 18 additions & 0 deletions catalystwan/request_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from contextlib import AbstractContextManager
from threading import Semaphore


class RequestLimiter(AbstractContextManager):
def __init__(self, max_requests: int = 49):
self._max_requests: int = max_requests
self._semaphore: Semaphore = Semaphore(value=self._max_requests)

def __enter__(self) -> RequestLimiter:
self._semaphore.acquire()
return self

def __exit__(self, *exc_info) -> None:
self._semaphore.release()
return
86 changes: 68 additions & 18 deletions catalystwan/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TenantSubdomainNotFound,
)
from catalystwan.models.tenant import Tenant
from catalystwan.request_limiter import RequestLimiter
from catalystwan.response import ManagerResponse, response_history_debug
from catalystwan.utils.session_type import SessionType
from catalystwan.version import NullVersion, parse_api_version
Expand Down Expand Up @@ -61,6 +62,8 @@ class ManagerSessionState(Enum):
WAIT_SERVER_READY_AFTER_RESTART = 1
LOGIN = 2
OPERATIVE = 3
LOGIN_IN_PROGRESS = 4
AUTH_SYNC = 5


def determine_session_type(
Expand Down Expand Up @@ -229,6 +232,7 @@ def __init__(
auth: Union[vManageAuth, ApiGwAuth],
subdomain: Optional[str] = None,
logger: Optional[logging.Logger] = None,
request_limiter: Optional[RequestLimiter] = None,
) -> None:
self.base_url = base_url
self.subdomain = subdomain
Expand All @@ -241,6 +245,7 @@ def __init__(
super(ManagerSession, self).__init__()
self.verify = False
self.headers.update({"User-Agent": USER_AGENT})
self._added_to_auth = False
self._auth = auth
self._platform_version: str = ""
self._api_version: Version = NullVersion # type: ignore
Expand All @@ -249,6 +254,8 @@ def __init__(
self.request_timeout: Optional[int] = None
self._validate_responses = True
self._state: ManagerSessionState = ManagerSessionState.OPERATIVE
self._last_request: Optional[PreparedRequest] = None
self._limiter: RequestLimiter = request_limiter or RequestLimiter()

@cached_property
def api(self) -> APIContainer:
Expand Down Expand Up @@ -286,8 +293,20 @@ def state(self, state: ManagerSessionState) -> None:
self.wait_server_ready(self.restart_timeout)
self.state = ManagerSessionState.LOGIN
elif state == ManagerSessionState.LOGIN:
self.login()
self.state = ManagerSessionState.LOGIN_IN_PROGRESS
self._sync_auth()
server_info = self._fetch_server_info()
self._finalize_login(server_info)
self.state = ManagerSessionState.OPERATIVE
elif state == ManagerSessionState.LOGIN_IN_PROGRESS:
# nothing to be done, continue to login
return
elif state == ManagerSessionState.AUTH_SYNC:
# this state can be reached when using an expired auth during the login (most likely
# to happen when multithreading). To avoid fetching server info multiple times, we will
# only authenticate here and then return to the previous login flow
self._sync_auth()
self.state = ManagerSessionState.LOGIN_IN_PROGRESS
return

def restart_imminent(self, restart_timeout_override: Optional[int] = None):
Expand All @@ -301,30 +320,31 @@ def restart_imminent(self, restart_timeout_override: Optional[int] = None):
self.restart_timeout = restart_timeout_override
self.state = ManagerSessionState.RESTART_IMMINENT

def login(self) -> ManagerSession:
"""Performs login to SDWAN Manager and fetches important server info to instance variables
Raises:
SessionNotCreatedError: indicates session configuration is not consistent
Returns:
ManagerSession: (self)
"""
def _sync_auth(self) -> None:
self.cookies.clear_session_cookies()
self._auth.clear()
if not self._added_to_auth:
self._auth.increase_session_count()
self._added_to_auth = True
self._auth.clear(self._last_request)
self.auth = self._auth

def _fetch_server_info(self) -> ServerInfo:
try:
server_info = self.server()
except DefaultPasswordError:
server_info = ServerInfo.model_construct(**{})

return server_info

def _finalize_login(self, server_info: ServerInfo) -> None:
self.server_name = server_info.server

tenancy_mode = server_info.tenancy_mode
user_mode = server_info.user_mode
view_mode = server_info.view_mode

self._session_type = determine_session_type(tenancy_mode, user_mode, view_mode)

if user_mode is UserMode.TENANT and self.subdomain:
raise SessionNotCreatedError(
f"Session not created. Subdomain {self.subdomain} passed to tenant session, "
Expand All @@ -339,6 +359,17 @@ def login(self) -> ManagerSession:
self.logger.info(
f"Logged to vManage({self.platform_version}) as {self.auth}. The session type is {self.session_type}"
)

def login(self) -> ManagerSession:
"""Performs login to SDWAN Manager and fetches important server info to instance variables
Raises:
SessionNotCreatedError: indicates session configuration is not consistent
Returns:
ManagerSession: (self)
"""
self.state = ManagerSessionState.LOGIN
return self

def wait_server_ready(self, timeout: int, poll_period: int = 10) -> None:
Expand Down Expand Up @@ -399,7 +430,8 @@ def request(self, method, url, *args, **kwargs) -> ManagerResponse:
if self.request_timeout is not None: # do not modify user provided kwargs unless property is set
_kwargs.update(timeout=self.request_timeout)
try:
response = super(ManagerSession, self).request(method, full_url, *args, **_kwargs)
with self._limiter:
response = super(ManagerSession, self).request(method, full_url, *args, **_kwargs)
self.logger.debug(self.response_trace(response, None))
if self.state == ManagerSessionState.RESTART_IMMINENT and response.status_code == 503:
self.state = ManagerSessionState.WAIT_SERVER_READY_AFTER_RESTART
Expand All @@ -411,14 +443,29 @@ def request(self, method, url, *args, **kwargs) -> ManagerResponse:
self.logger.debug(exception)
raise ManagerRequestException(*exception.args, request=exception.request, response=exception.response)

if response.jsessionid_expired and self.state == ManagerSessionState.OPERATIVE:
self.logger.warning("Logging to session. Reason: expired JSESSIONID detected in response headers")
self.state = ManagerSessionState.LOGIN
self._last_request = response.request
if response.jsessionid_expired and self.state in [
ManagerSessionState.OPERATIVE,
ManagerSessionState.LOGIN_IN_PROGRESS,
]:
# detected expired auth during login, resync
if self.state == ManagerSessionState.LOGIN_IN_PROGRESS:
self.state = ManagerSessionState.AUTH_SYNC
else:
self.logger.warning("Logging to session. Reason: expired JSESSIONID detected in response headers")
self.state = ManagerSessionState.LOGIN
return self.request(method, url, *args, **_kwargs)

if response.api_gw_unauthorized and self.state == ManagerSessionState.OPERATIVE:
self.logger.warning("Logging to API GW session. Reason: unauthorized detected in response headers")
self.state = ManagerSessionState.LOGIN
if response.api_gw_unauthorized and self.state in [
ManagerSessionState.OPERATIVE,
ManagerSessionState.LOGIN_IN_PROGRESS,
]:
# detected expired auth during login, resync
if self.state == ManagerSessionState.LOGIN_IN_PROGRESS:
self.state = ManagerSessionState.AUTH_SYNC
else:
self.logger.warning("Logging to API GW session. Reason: unauthorized detected in response headers")
self.state = ManagerSessionState.LOGIN
return self.request(method, url, *args, **_kwargs)

if response.request.url and "passwordReset.html" in response.request.url:
Expand Down Expand Up @@ -485,6 +532,8 @@ def get_tenant_id(self) -> str:
return tenant.tenant_id

def logout(self) -> None:
if self._added_to_auth:
self._auth.decrease_session_count()
self._auth.logout(self)

def close(self) -> None:
Expand Down Expand Up @@ -532,6 +581,7 @@ def __copy__(self) -> ManagerSession:
auth=self._auth,
subdomain=self.subdomain,
logger=self.logger,
request_limiter=self._limiter,
)

def __str__(self) -> str:
Expand Down
Loading

0 comments on commit a668f58

Please sign in to comment.