Skip to content
This repository has been archived by the owner on Nov 21, 2024. It is now read-only.

Multithreading Auth #842

Merged
merged 8 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,54 @@ with create_apigw_session(
```
</details>

<details>
<summary> <b>Threading</b> <i>(click to expand)</i></summary>

```python
from threading import Thread
from catalystwan.session import ManagerSession
from catalystwan.vmanage_auth import vManageAuth
from copy import copy

def print_devices(manager: ManagerSession):
# using context manager (recommended)
with manager.login() as session:
print(session.api.devices.get())

if __name__ =="__main__":

# 1. Create shared authentication handler for user session
auth = vManageAuth(username="username", password="password")
sbasan marked this conversation as resolved.
Show resolved Hide resolved
# 2. Configure session with base url and attach authentication handler
manager = ManagerSession(base_url="https://url:port", auth=auth)

sbasan marked this conversation as resolved.
Show resolved Hide resolved
# 3. Make sure each thread gets own copy of ManagerSession object
t1 = Thread(target=print_devices, args=(manager,))
t2 = Thread(target=print_devices, args=(copy(manager),))
t3 = Thread(target=print_devices, args=(copy(manager),))

t1.start()
t2.start()
t3.start()

t1.join()
t2.join()
t3.join()

print("Done!")
```
Threading can be achieved by using a shared auth object with sessions in each thread. As `ManagerSession` is not guaranteed to be thread-safe, it is recommended to create one session per thread. `ManagerSession` also comes in with a default `RequestLimiter`, which limits the number of concurrent requests to 50. It keeps `ManagerSession` from overloading the server and avoids HTTP 503 and HTTP 429 errors.
If you wish to modify the limit, you can pass a modified `RequestLimiter` to `ManagerSession`:
```python
from catalystwan.session import ManagerSession
from catalystwan.vmanage_auth import vManageAuth
from catalystwan.request_limiter import RequestLimiter

auth = vManageAuth(username="username", password="password")
limiter = RequestLimiter(max_requests=30)
manager = ManagerSession(base_url="https://url:port", auth=auth, request_limiter=limiter)
```
</details>

## API usage examples
All examples below assumes `session` variable contains logged-in [Manager Session](#Manager-Session) instance.
Expand Down Expand Up @@ -413,6 +461,7 @@ migrate_task.wait_for_completed()
```
</details>


### Note:
To remove `InsecureRequestWarning`, you can include in your scripts (warning is suppressed when `catalystwan_devel` environment variable is set):
```Python
Expand Down
9 changes: 8 additions & 1 deletion catalystwan/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Protocol, Type, TypeVar

from packaging.version import Version # type: ignore
from requests import PreparedRequest

from catalystwan.typed_list import DataSequence
from catalystwan.utils.session_type import SessionType
Expand Down Expand Up @@ -66,5 +67,11 @@ class AuthProtocol(Protocol):
def logout(self, client: APIEndpointClient) -> None:
...

def clear(self) -> None:
def clear(self, last_request: Optional[PreparedRequest]) -> None:
...

def increase_session_count(self) -> None:
...

def decrease_session_count(self) -> None:
...
36 changes: 32 additions & 4 deletions catalystwan/apigw_auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from threading import RLock
from typing import Literal, Optional
from urllib.parse import urlparse

Expand Down Expand Up @@ -37,13 +38,16 @@ def __init__(self, login: ApiGwLogin, logger: Optional[logging.Logger] = None, v
self.token = ""
self.logger = logger or logging.getLogger(__name__)
self.verify = verify
self.session_count: int = 0
self.lock: RLock = RLock()

def __str__(self) -> str:
return f"ApiGatewayAuth(mode={self.login.mode})"

def __call__(self, request: PreparedRequest) -> PreparedRequest:
self.handle_auth(request)
self.build_digest_header(request)
with self.lock:
self.handle_auth(request)
self.build_digest_header(request)
return request

def handle_auth(self, request: PreparedRequest) -> None:
Expand Down Expand Up @@ -92,5 +96,29 @@ def get_token(
def logout(self, client: APIEndpointClient) -> None:
return None

def clear(self) -> None:
self.token = ""
def _clear(self) -> None:
with self.lock:
self.token = ""

def increase_session_count(self) -> None:
with self.lock:
self.session_count += 1

def decrease_session_count(self) -> None:
with self.lock:
self.session_count -= 1

def clear(self, last_request: Optional[PreparedRequest]) -> None:
with self.lock:
# extract previously used jsessionid
if last_request is None:
token = None
else:
token = last_request.headers.get("Authorization")

if self.token == "" or f"Bearer {self.token}" == token:
# used auth was up-to-date, clear state
return self._clear()
else:
# used auth was out-of-date, repeat the request with a new one
return
18 changes: 18 additions & 0 deletions catalystwan/request_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from contextlib import AbstractContextManager
from threading import Semaphore


class RequestLimiter(AbstractContextManager):
def __init__(self, max_requests: int = 49):
self._max_requests: int = max_requests
self._semaphore: Semaphore = Semaphore(value=self._max_requests)

def __enter__(self) -> RequestLimiter:
self._semaphore.acquire()
return self

def __exit__(self, *exc_info) -> None:
self._semaphore.release()
return
86 changes: 68 additions & 18 deletions catalystwan/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TenantSubdomainNotFound,
)
from catalystwan.models.tenant import Tenant
from catalystwan.request_limiter import RequestLimiter
from catalystwan.response import ManagerResponse, response_history_debug
from catalystwan.utils.session_type import SessionType
from catalystwan.version import NullVersion, parse_api_version
Expand Down Expand Up @@ -61,6 +62,8 @@ class ManagerSessionState(Enum):
WAIT_SERVER_READY_AFTER_RESTART = 1
LOGIN = 2
OPERATIVE = 3
LOGIN_IN_PROGRESS = 4
AUTH_SYNC = 5


def determine_session_type(
Expand Down Expand Up @@ -229,6 +232,7 @@ def __init__(
auth: Union[vManageAuth, ApiGwAuth],
subdomain: Optional[str] = None,
logger: Optional[logging.Logger] = None,
request_limiter: Optional[RequestLimiter] = None,
) -> None:
self.base_url = base_url
self.subdomain = subdomain
Expand All @@ -241,6 +245,7 @@ def __init__(
super(ManagerSession, self).__init__()
self.verify = False
self.headers.update({"User-Agent": USER_AGENT})
self._added_to_auth = False
self._auth = auth
self._platform_version: str = ""
self._api_version: Version = NullVersion # type: ignore
Expand All @@ -249,6 +254,8 @@ def __init__(
self.request_timeout: Optional[int] = None
self._validate_responses = True
self._state: ManagerSessionState = ManagerSessionState.OPERATIVE
self._last_request: Optional[PreparedRequest] = None
self._limiter: RequestLimiter = request_limiter or RequestLimiter()

@cached_property
def api(self) -> APIContainer:
Expand Down Expand Up @@ -286,8 +293,20 @@ def state(self, state: ManagerSessionState) -> None:
self.wait_server_ready(self.restart_timeout)
self.state = ManagerSessionState.LOGIN
elif state == ManagerSessionState.LOGIN:
self.login()
self.state = ManagerSessionState.LOGIN_IN_PROGRESS
self._sync_auth()
server_info = self._fetch_server_info()
self._finalize_login(server_info)
self.state = ManagerSessionState.OPERATIVE
elif state == ManagerSessionState.LOGIN_IN_PROGRESS:
# nothing to be done, continue to login
return
elif state == ManagerSessionState.AUTH_SYNC:
# this state can be reached when using an expired auth during the login (most likely
# to happen when multithreading). To avoid fetching server info multiple times, we will
# only authenticate here and then return to the previous login flow
self._sync_auth()
self.state = ManagerSessionState.LOGIN_IN_PROGRESS
return

def restart_imminent(self, restart_timeout_override: Optional[int] = None):
Expand All @@ -301,30 +320,31 @@ def restart_imminent(self, restart_timeout_override: Optional[int] = None):
self.restart_timeout = restart_timeout_override
self.state = ManagerSessionState.RESTART_IMMINENT

def login(self) -> ManagerSession:
"""Performs login to SDWAN Manager and fetches important server info to instance variables

Raises:
SessionNotCreatedError: indicates session configuration is not consistent

Returns:
ManagerSession: (self)
"""
def _sync_auth(self) -> None:
self.cookies.clear_session_cookies()
self._auth.clear()
if not self._added_to_auth:
self._auth.increase_session_count()
self._added_to_auth = True
self._auth.clear(self._last_request)
self.auth = self._auth
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this line. Why we want make auth public after login. Do we need later?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need it. Auth is called with every request and requests library expects it in that place.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need it for requests :/


def _fetch_server_info(self) -> ServerInfo:
try:
server_info = self.server()
except DefaultPasswordError:
server_info = ServerInfo.model_construct(**{})

return server_info

def _finalize_login(self, server_info: ServerInfo) -> None:
self.server_name = server_info.server

tenancy_mode = server_info.tenancy_mode
user_mode = server_info.user_mode
view_mode = server_info.view_mode

self._session_type = determine_session_type(tenancy_mode, user_mode, view_mode)

if user_mode is UserMode.TENANT and self.subdomain:
raise SessionNotCreatedError(
f"Session not created. Subdomain {self.subdomain} passed to tenant session, "
Expand All @@ -339,6 +359,17 @@ def login(self) -> ManagerSession:
self.logger.info(
f"Logged to vManage({self.platform_version}) as {self.auth}. The session type is {self.session_type}"
)

def login(self) -> ManagerSession:
"""Performs login to SDWAN Manager and fetches important server info to instance variables

Raises:
SessionNotCreatedError: indicates session configuration is not consistent

Returns:
ManagerSession: (self)
"""
self.state = ManagerSessionState.LOGIN
return self

def wait_server_ready(self, timeout: int, poll_period: int = 10) -> None:
Expand Down Expand Up @@ -399,7 +430,8 @@ def request(self, method, url, *args, **kwargs) -> ManagerResponse:
if self.request_timeout is not None: # do not modify user provided kwargs unless property is set
_kwargs.update(timeout=self.request_timeout)
try:
response = super(ManagerSession, self).request(method, full_url, *args, **_kwargs)
with self._limiter:
response = super(ManagerSession, self).request(method, full_url, *args, **_kwargs)
self.logger.debug(self.response_trace(response, None))
if self.state == ManagerSessionState.RESTART_IMMINENT and response.status_code == 503:
self.state = ManagerSessionState.WAIT_SERVER_READY_AFTER_RESTART
Expand All @@ -411,14 +443,29 @@ def request(self, method, url, *args, **kwargs) -> ManagerResponse:
self.logger.debug(exception)
raise ManagerRequestException(*exception.args, request=exception.request, response=exception.response)

if response.jsessionid_expired and self.state == ManagerSessionState.OPERATIVE:
self.logger.warning("Logging to session. Reason: expired JSESSIONID detected in response headers")
self.state = ManagerSessionState.LOGIN
self._last_request = response.request
if response.jsessionid_expired and self.state in [
ManagerSessionState.OPERATIVE,
ManagerSessionState.LOGIN_IN_PROGRESS,
]:
# detected expired auth during login, resync
if self.state == ManagerSessionState.LOGIN_IN_PROGRESS:
self.state = ManagerSessionState.AUTH_SYNC
else:
self.logger.warning("Logging to session. Reason: expired JSESSIONID detected in response headers")
self.state = ManagerSessionState.LOGIN
return self.request(method, url, *args, **_kwargs)

if response.api_gw_unauthorized and self.state == ManagerSessionState.OPERATIVE:
self.logger.warning("Logging to API GW session. Reason: unauthorized detected in response headers")
self.state = ManagerSessionState.LOGIN
if response.api_gw_unauthorized and self.state in [
ManagerSessionState.OPERATIVE,
ManagerSessionState.LOGIN_IN_PROGRESS,
]:
# detected expired auth during login, resync
if self.state == ManagerSessionState.LOGIN_IN_PROGRESS:
self.state = ManagerSessionState.AUTH_SYNC
else:
self.logger.warning("Logging to API GW session. Reason: unauthorized detected in response headers")
self.state = ManagerSessionState.LOGIN
return self.request(method, url, *args, **_kwargs)

if response.request.url and "passwordReset.html" in response.request.url:
Expand Down Expand Up @@ -485,6 +532,8 @@ def get_tenant_id(self) -> str:
return tenant.tenant_id

def logout(self) -> None:
if self._added_to_auth:
self._auth.decrease_session_count()
self._auth.logout(self)

def close(self) -> None:
Expand Down Expand Up @@ -532,6 +581,7 @@ def __copy__(self) -> ManagerSession:
auth=self._auth,
subdomain=self.subdomain,
logger=self.logger,
request_limiter=self._limiter,
)

def __str__(self) -> str:
Expand Down
Loading