diff --git a/documentation/docs/getting-started/configuration.mdx b/documentation/docs/getting-started/configuration.mdx index 5cd2ed4c9..d64d88433 100644 --- a/documentation/docs/getting-started/configuration.mdx +++ b/documentation/docs/getting-started/configuration.mdx @@ -74,7 +74,6 @@ Please use this table as a reference. | POLICY_BUNDLE_TMP_PATH | Path for temp policy file. It needs to be writable. | | | POLICY_BUNDLE_GIT_ADD_PATTERN | File pattern to add files to all the git default files. | | | REPO_WATCHER_ENABLED | | | -| PUBLISHER_ENABLED | | | | BROADCAST_KEEPALIVE_INTERVAL | The time to wait between sending two consecutive broadcaster keepalive messages. | | | BROADCAST_KEEPALIVE_TOPIC | The topic on which we should send broadcaster keepalive messages. | | | MAX_CHANNELS_PER_CLIENT | Max number of records per client, after this number it will not be added to statistics, relevant only if `STATISTICS_ENABLED`. | | diff --git a/packages/opal-client/opal_client/tests/server_to_client_intergation_test.py b/packages/opal-client/opal_client/tests/server_to_client_intergation_test.py index a3372c56f..e86cc99c6 100644 --- a/packages/opal-client/opal_client/tests/server_to_client_intergation_test.py +++ b/packages/opal-client/opal_client/tests/server_to_client_intergation_test.py @@ -52,7 +52,6 @@ def setup_server(event): # Server without git watcher and with a test specific url for data, and without broadcasting server = OpalServer( init_policy_watcher=False, - init_publisher=False, data_sources_config=DATA_SOURCES_CONFIG, broadcaster_uri=None, enable_jwks_endpoint=False, diff --git a/packages/opal-common/opal_common/async_utils.py b/packages/opal-common/opal_common/async_utils.py index a2df90c69..4a8383ea2 100644 --- a/packages/opal-common/opal_common/async_utils.py +++ b/packages/opal-common/opal_common/async_utils.py @@ -97,13 +97,24 @@ def __init__(self): self._tasks: List[asyncio.Task] = [] def _cleanup_task(self, done_task): - self._tasks.remove(done_task) + try: + self._tasks.remove(done_task) + except KeyError: + ... def add_task(self, f): t = asyncio.create_task(f) self._tasks.append(t) t.add_done_callback(self._cleanup_task) + async def join(self, cancel=False): + if cancel: + for t in self._tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*self._tasks, return_exceptions=True) + self._tasks.clear() + async def repeated_call( func: Coroutine, diff --git a/packages/opal-common/opal_common/confi/confi.py b/packages/opal-common/opal_common/confi/confi.py index cbaa9a587..86ce21ae7 100644 --- a/packages/opal-common/opal_common/confi/confi.py +++ b/packages/opal-common/opal_common/confi/confi.py @@ -75,7 +75,7 @@ def wrapped_cast(value, *args, **kwargs): return wrapped_cast -def load_conf_if_none(variable, conf): +def load_conf_if_none(variable: Any, conf: Any): if variable is None: return conf else: diff --git a/packages/opal-common/opal_common/topics/publisher.py b/packages/opal-common/opal_common/topics/publisher.py deleted file mode 100644 index b7b75a24f..000000000 --- a/packages/opal-common/opal_common/topics/publisher.py +++ /dev/null @@ -1,208 +0,0 @@ -import asyncio -from typing import Any, Optional, Set - -from ddtrace import tracer -from fastapi_websocket_pubsub import PubSubClient, PubSubEndpoint, Topic, TopicList -from opal_common.logger import logger - - -class TopicPublisher: - """abstract publisher, base class for client side and server side - publisher.""" - - def __init__(self): - """inits the publisher's asyncio tasks list.""" - self._tasks: Set[asyncio.Task] = set() - self._tasks_lock = asyncio.Lock() - - async def publish(self, topics: TopicList, data: Any = None): - raise NotImplementedError() - - async def __aenter__(self): - self.start() - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.stop() - - def start(self): - """starts the publisher.""" - logger.debug("started topic publisher") - - async def _add_task(self, task: asyncio.Task): - async with self._tasks_lock: - self._tasks.add(task) - task.add_done_callback(self._cleanup_task) - - async def wait(self): - async with self._tasks_lock: - await asyncio.gather(*self._tasks, return_exceptions=True) - self._tasks.clear() - - async def stop(self): - """stops the publisher (cancels any running publishing tasks)""" - logger.debug("stopping topic publisher") - await self.wait() - - def _cleanup_task(self, task: asyncio.Task): - try: - self._tasks.remove(task) - except KeyError: - ... - - -class PeriodicPublisher: - """Wrapper for a task that publishes to topic on fixed interval - periodically.""" - - def __init__( - self, - publisher: TopicPublisher, - time_interval: int, - topic: Topic, - message: Any = None, - task_name: str = "periodic publish task", - ): - """inits the publisher. - - Args: - publisher (TopicPublisher): can publish messages on the pub/sub channel - interval (int): the time interval between publishing consecutive messages - topic (Topic): the topic to publish on - message (Any): the message to publish - """ - self._publisher = publisher - self._interval = time_interval - self._topic = topic - self._message = message - self._task_name = task_name - self._task: Optional[asyncio.Task] = None - - async def __aenter__(self): - self.start() - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.stop() - - def start(self): - """starts the periodic publisher task.""" - if self._task is not None: - logger.warning(f"{self._task_name} already started") - return - - logger.info( - f"started {self._task_name}: topic is '{self._topic}', interval is {self._interval} seconds" - ) - self._task = asyncio.create_task(self._publish_task()) - - async def stop(self): - """stops the publisher (cancels any running publishing tasks)""" - if self._task is not None: - self._task.cancel() - try: - await self._task - except asyncio.CancelledError: - pass - self._task = None - logger.info(f"cancelled {self._task_name} to topic: {self._topic}") - - async def wait_until_done(self): - await self._task - - async def _publish_task(self): - while True: - await asyncio.sleep(self._interval) - logger.info( - f"{self._task_name}: publishing message on topic '{self._topic}', next publish is scheduled in {self._interval} seconds" - ) - async with self._publisher: - await self._publisher.publish(topics=[self._topic], data=self._message) - - -class ServerSideTopicPublisher(TopicPublisher): - """A simple wrapper around a PubSubEndpoint that exposes publish().""" - - def __init__(self, endpoint: PubSubEndpoint): - """inits the publisher. - - Args: - endpoint (PubSubEndpoint): a pub/sub endpoint - """ - self._endpoint = endpoint - super().__init__() - - async def _publish_impl(self, topics: TopicList, data: Any = None): - with tracer.trace("topic_publisher.publish", resource=str(topics)): - await self._endpoint.publish(topics=topics, data=data) - - async def publish(self, topics: TopicList, data: Any = None): - await self._add_task(asyncio.create_task(self._publish_impl(topics, data))) - - -class ClientSideTopicPublisher(TopicPublisher): - """A simple wrapper around a PubSubClient that exposes publish(). - - Provides start() and stop() shortcuts that helps treat this client - as a separate "process" or task that runs in the background. - """ - - def __init__(self, client: PubSubClient, server_uri: str): - """inits the publisher. - - Args: - client (PubSubClient): a configured not-yet-started pub sub client - server_uri (str): the URI of the pub sub server we publish to - """ - self._client = client - self._server_uri = server_uri - super().__init__() - - def start(self): - """starts the pub/sub client as a background asyncio task. - - the client will attempt to connect to the pubsub server until - successful. - """ - super().start() - self._client.start_client(f"{self._server_uri}") - - async def stop(self): - """stops the pubsub client, and cancels any publishing tasks.""" - await self._client.disconnect() - await super().stop() - - async def wait_until_done(self): - """When the publisher is a used as a context manager, this method waits - until the client is done (i.e: terminated) to prevent exiting the - context.""" - return await self._client.wait_until_done() - - async def publish(self, topics: TopicList, data: Any = None): - """publish a message by launching a background task on the event loop. - - Args: - topics (TopicList): a list of topics to publish the message to - data (Any): optional data to publish as part of the message - """ - await self._add_task( - asyncio.create_task(self._publish(topics=topics, data=data)) - ) - - async def _publish(self, topics: TopicList, data: Any = None) -> bool: - """Do not trigger directly, must be triggered via publish() in order to - run as a monitored background asyncio task.""" - await self._client.wait_until_ready() - logger.info("Publishing to topics: {topics}", topics=topics) - return await self._client.publish(topics, data) - - -class ScopedServerSideTopicPublisher(ServerSideTopicPublisher): - def __init__(self, endpoint: PubSubEndpoint, scope_id: str): - super().__init__(endpoint) - self._scope_id = scope_id - - async def publish(self, topics: TopicList, data: Any = None): - scoped_topics = [f"{self._scope_id}:{topic}" for topic in topics] - logger.info("Publishing to topics: {topics}", topics=scoped_topics) - await super().publish(scoped_topics, data) diff --git a/packages/opal-server/opal_server/config.py b/packages/opal-server/opal_server/config.py index b272915ad..8fa53d51c 100644 --- a/packages/opal-server/opal_server/config.py +++ b/packages/opal-server/opal_server/config.py @@ -146,9 +146,6 @@ class OpalServerConfig(Confi): REPO_WATCHER_ENABLED = confi.bool("REPO_WATCHER_ENABLED", True) - # publisher - PUBLISHER_ENABLED = confi.bool("PUBLISHER_ENABLED", True) - # broadcaster keepalive BROADCAST_KEEPALIVE_INTERVAL = confi.int( "BROADCAST_KEEPALIVE_INTERVAL", diff --git a/packages/opal-server/opal_server/data/data_update_publisher.py b/packages/opal-server/opal_server/data/data_update_publisher.py index 64bc32bbe..bf330eaa2 100644 --- a/packages/opal-server/opal_server/data/data_update_publisher.py +++ b/packages/opal-server/opal_server/data/data_update_publisher.py @@ -1,23 +1,18 @@ -import asyncio import os -from typing import List +from typing import List, Union -from fastapi_utils.tasks import repeat_every from opal_common.logger import logger -from opal_common.schemas.data import ( - DataSourceEntryWithPollingInterval, - DataUpdate, - ServerDataSourceConfig, -) -from opal_common.topics.publisher import TopicPublisher +from opal_common.schemas.data import DataUpdate +from opal_server.pubsub import PubSub +from opal_server.scopes.scoped_pubsub import ScopedPubSub TOPIC_DELIMITER = "/" PREFIX_DELIMITER = ":" class DataUpdatePublisher: - def __init__(self, publisher: TopicPublisher) -> None: - self._publisher = publisher + def __init__(self, pubsub: Union[PubSub, ScopedPubSub]) -> None: + self._pubsub = pubsub @staticmethod def get_topic_combos(topic: str) -> List[str]: @@ -108,6 +103,4 @@ async def publish_data_updates(self, update: DataUpdate): entries=logged_entries, ) - await self._publisher.publish( - list(all_topic_combos), update.dict(by_alias=True) - ) + await self._pubsub.publish(list(all_topic_combos), update.dict(by_alias=True)) diff --git a/packages/opal-server/opal_server/policy/watcher/callbacks.py b/packages/opal-server/opal_server/policy/watcher/callbacks.py index 1b5f65590..c0b168afc 100644 --- a/packages/opal-server/opal_server/policy/watcher/callbacks.py +++ b/packages/opal-server/opal_server/policy/watcher/callbacks.py @@ -16,8 +16,8 @@ PolicyUpdateMessage, PolicyUpdateMessageNotification, ) -from opal_common.topics.publisher import TopicPublisher from opal_common.topics.utils import policy_topics +from opal_server.pubsub import PubSub async def create_update_all_directories_in_repo( @@ -104,7 +104,7 @@ def is_path_affected(path: Path) -> bool: async def publish_changed_directories( old_commit: Commit, new_commit: Commit, - publisher: TopicPublisher, + pubsub: PubSub, file_extensions: Optional[List[str]] = None, bundle_ignore: Optional[List[str]] = None, ): @@ -116,7 +116,4 @@ async def publish_changed_directories( ) if notification: - async with publisher: - await publisher.publish( - topics=notification.topics, data=notification.update.dict() - ) + await pubsub.publish_sync(notification.topics, notification.update.dict()) diff --git a/packages/opal-server/opal_server/policy/watcher/factory.py b/packages/opal-server/opal_server/policy/watcher/factory.py index 6d94d6fc4..5c667244a 100644 --- a/packages/opal-server/opal_server/policy/watcher/factory.py +++ b/packages/opal-server/opal_server/policy/watcher/factory.py @@ -1,33 +1,31 @@ from functools import partial -from typing import Any, List, Optional +from typing import List, Optional -from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint from opal_common.confi.confi import load_conf_if_none from opal_common.git_utils.repo_cloner import RepoClonePathFinder from opal_common.logger import logger from opal_common.sources.api_policy_source import ApiPolicySource from opal_common.sources.git_policy_source import GitPolicySource -from opal_common.topics.publisher import TopicPublisher from opal_server.config import PolicySourceTypes, opal_server_config from opal_server.policy.watcher.callbacks import publish_changed_directories from opal_server.policy.watcher.task import BasePolicyWatcherTask, PolicyWatcherTask +from opal_server.pubsub import PubSub from opal_server.scopes.task import ScopesPolicyWatcherTask def setup_watcher_task( - publisher: TopicPublisher, - pubsub_endpoint: PubSubEndpoint, - source_type: str = None, - remote_source_url: str = None, - clone_path_finder: RepoClonePathFinder = None, - branch_name: str = None, + pubsub: PubSub, + source_type: Optional[str] = None, + remote_source_url: Optional[str] = None, + clone_path_finder: Optional[RepoClonePathFinder] = None, + branch_name: Optional[str] = None, ssh_key: Optional[str] = None, - polling_interval: int = None, - request_timeout: int = None, - policy_bundle_token: str = None, - policy_bundle_token_id: str = None, - policy_bundle_server_type: str = None, - policy_bundle_aws_region: str = None, + polling_interval: Optional[int] = None, + request_timeout: Optional[int] = None, + policy_bundle_token: Optional[str] = None, + policy_bundle_token_id: Optional[str] = None, + policy_bundle_server_type: Optional[str] = None, + policy_bundle_aws_region: Optional[str] = None, extensions: Optional[List[str]] = None, bundle_ignore: Optional[List[str]] = None, ) -> BasePolicyWatcherTask: @@ -35,7 +33,7 @@ def setup_watcher_task( vars Load all the defaults from config if called without params. Args: - publisher(TopicPublisher): server side publisher to publish changes in policy + pubsub(PubSub): server side pubsub client to publish changes in policy source_type(str): policy source type, can be Git / Api to opa bundle server remote_source_url(str): the base address to request the policy from clone_path_finder(RepoClonePathFinder): from which the local dir path for the repo clone would be retrieved @@ -46,11 +44,11 @@ def setup_watcher_task( policy_bundle_token(int): auth token to include in connections to OPAL server. Defaults to POLICY_BUNDLE_SERVER_TOKEN. policy_bundle_token_id(int): id token to include in connections to OPAL server. Defaults to POLICY_BUNDLE_SERVER_TOKEN_ID. policy_bundle_server_type (str): type of policy bundle server (HTTP S3). Defaults to POLICY_BUNDLE_SERVER_TYPE - extensions(list(str), optional): list of extantions to check when new policy arrive default is FILTER_FILE_EXTENSIONS + extensions(list(str), optional): list of extensions to check when new policy arrive default is FILTER_FILE_EXTENSIONS bundle_ignore(list(str), optional): list of glob paths to use for excluding files from bundle default is OPA_BUNDLE_IGNORE """ if opal_server_config.SCOPES: - return ScopesPolicyWatcherTask(pubsub_endpoint) + return ScopesPolicyWatcherTask(pubsub) # load defaults source_type = load_conf_if_none(source_type, opal_server_config.POLICY_SOURCE_TYPE) @@ -135,9 +133,9 @@ def setup_watcher_task( watcher.add_on_new_policy_callback( partial( publish_changed_directories, - publisher=publisher, + pubsub=pubsub, file_extensions=extensions, bundle_ignore=bundle_ignore, ) ) - return PolicyWatcherTask(watcher, pubsub_endpoint) + return PolicyWatcherTask(watcher, pubsub) diff --git a/packages/opal-server/opal_server/policy/watcher/task.py b/packages/opal-server/opal_server/policy/watcher/task.py index a2ba57558..0f7d8b397 100644 --- a/packages/opal-server/opal_server/policy/watcher/task.py +++ b/packages/opal-server/opal_server/policy/watcher/task.py @@ -1,107 +1,59 @@ import asyncio import os import signal -from typing import Any, Coroutine, List, Optional +from typing import Any, List, Optional from fastapi_websocket_pubsub import Topic -from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint from opal_common.logger import logger from opal_common.sources.base_policy_source import BasePolicySource +from opal_common.async_utils import TasksPool from opal_server.config import opal_server_config +from opal_server.pubsub import PubSub class BasePolicyWatcherTask: """Manages the asyncio tasks of the policy watcher.""" - def __init__(self, pubsub_endpoint: PubSubEndpoint): - self._tasks: List[asyncio.Task] = [] + def __init__(self, pubsub: PubSub): + self._tasks = TasksPool() self._should_stop: Optional[asyncio.Event] = None - self._pubsub_endpoint = pubsub_endpoint - self._webhook_tasks: List[asyncio.Task] = [] - - async def __aenter__(self): - await self.start() - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.stop() + self._pubsub = pubsub async def _on_webhook(self, topic: Topic, data: Any): - logger.info(f"Webhook listener triggered ({len(self._webhook_tasks)})") - for task in self._webhook_tasks: - if task.done(): - # Clean references to finished tasks - self._webhook_tasks.remove(task) - - self._webhook_tasks.append(asyncio.create_task(self.trigger(topic, data))) + logger.info("Webhook listener triggered") + self._tasks.add_task(self.trigger(topic, data)) async def _listen_to_webhook_notifications(self): # Webhook api route can be hit randomly in all workers, so it publishes a message to the webhook topic. # This listener, running in the leader's context, would actually trigger the repo pull - - async def _subscribe_internal(): - logger.info( - "listening on webhook topic: '{topic}'", - topic=opal_server_config.POLICY_REPO_WEBHOOK_TOPIC, - ) - await self._pubsub_endpoint.subscribe( - [opal_server_config.POLICY_REPO_WEBHOOK_TOPIC], - self._on_webhook, - ) - - if self._pubsub_endpoint.broadcaster is not None: - async with self._pubsub_endpoint.broadcaster.get_listening_context(): - await _subscribe_internal() - await self._pubsub_endpoint.broadcaster.get_reader_task() - - # Stop the watcher if broadcaster disconnects - self.signal_stop() - else: - # If no broadcaster is configured, just subscribe, no need to wait on anything - await _subscribe_internal() + logger.info( + "listening on webhook topic: '{topic}'", + topic=opal_server_config.POLICY_REPO_WEBHOOK_TOPIC, + ) + await self._pubsub.subscribe( + [opal_server_config.POLICY_REPO_WEBHOOK_TOPIC], + self._on_webhook, + ) async def start(self): """starts the policy watcher and registers a failure callback to terminate gracefully.""" logger.info("Launching policy watcher") - self._tasks.append(asyncio.create_task(self._listen_to_webhook_notifications())) - self._init_should_stop() + await self._listen_to_webhook_notifications() async def stop(self): """stops all policy watcher tasks.""" logger.info("Stopping policy watcher") - for task in self._tasks + self._webhook_tasks: - if not task.done(): - task.cancel() - await asyncio.gather(*self._tasks, return_exceptions=True) + await self._tasks.join() async def trigger(self, topic: Topic, data: Any): """triggers the policy watcher from outside to check for changes (git pull)""" raise NotImplementedError() - def wait_until_should_stop(self) -> Coroutine: - """waits until self.signal_stop() is called on the watcher. - - allows us to keep the repo watcher context alive until signalled - to stop from outside. - """ - self._init_should_stop() - return self._should_stop.wait() - - def signal_stop(self): - """signal the repo watcher it should stop.""" - self._init_should_stop() - self._should_stop.set() - - def _init_should_stop(self): - if self._should_stop is None: - self._should_stop = asyncio.Event() - async def _fail(self, exc: Exception): """called when the watcher fails, and stops all tasks gracefully.""" logger.error("policy watcher failed with exception: {err}", err=repr(exc)) - self.signal_stop() # trigger uvicorn graceful shutdown os.kill(os.getpid(), signal.SIGTERM) @@ -114,7 +66,7 @@ def __init__(self, policy_source: BasePolicySource, *args, **kwargs): async def start(self): await super().start() self._watcher.add_on_failure_callback(self._fail) - self._tasks.append(asyncio.create_task(self._watcher.run())) + self._tasks.add_task(self._watcher.run()) async def stop(self): await self._watcher.stop() diff --git a/packages/opal-server/opal_server/policy/webhook/api.py b/packages/opal-server/opal_server/policy/webhook/api.py index c19595ad2..780c53deb 100644 --- a/packages/opal-server/opal_server/policy/webhook/api.py +++ b/packages/opal-server/opal_server/policy/webhook/api.py @@ -2,7 +2,6 @@ from urllib.parse import SplitResult, urlparse from fastapi import APIRouter, Depends, Request, status -from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint from opal_common.authentication.deps import JWTAuthenticator from opal_common.logger import logger from opal_common.schemas.webhook import GitWebhookRequestParams @@ -12,11 +11,10 @@ extracted_git_changes, validate_git_secret_or_throw, ) +from opal_server.pubsub import PubSub -def init_git_webhook_router( - pubsub_endpoint: PubSubEndpoint, authenticator: JWTAuthenticator -): +def init_git_webhook_router(pubsub: PubSub, authenticator: JWTAuthenticator): async def dummy_affected_repo_urls(request: Request) -> List[str]: return [] @@ -32,7 +30,7 @@ async def dummy_affected_repo_urls(request: Request) -> List[str]: [Depends(route_dependency)], Depends(func_dependency), source_type, - pubsub_endpoint.publish, + pubsub.publish_sync, ) diff --git a/packages/opal-server/opal_server/publisher.py b/packages/opal-server/opal_server/publisher.py index 7d22fd86c..fa93977cf 100644 --- a/packages/opal-server/opal_server/publisher.py +++ b/packages/opal-server/opal_server/publisher.py @@ -1,41 +1,91 @@ -from fastapi_websocket_pubsub import PubSubClient, Topic -from opal_common.confi.confi import load_conf_if_none -from opal_common.topics.publisher import ( - ClientSideTopicPublisher, - PeriodicPublisher, - ServerSideTopicPublisher, - TopicPublisher, -) -from opal_common.utils import get_authorization_header -from opal_server.config import opal_server_config - - -def setup_publisher_task( - server_uri: str = None, - server_token: str = None, -) -> TopicPublisher: - server_uri = load_conf_if_none( - server_uri, - opal_server_config.OPAL_WS_LOCAL_URL, - ) - server_token = load_conf_if_none( - server_token, - opal_server_config.OPAL_WS_TOKEN, - ) - return ClientSideTopicPublisher( - client=PubSubClient(extra_headers=[get_authorization_header(server_token)]), - server_uri=server_uri, - ) - - -def setup_broadcaster_keepalive_task( - publisher: ServerSideTopicPublisher, - time_interval: int, - topic: Topic = "__broadcast_session_keepalive__", -) -> PeriodicPublisher: - """a periodic publisher with the intent to trigger messages on the - broadcast channel, so that the session to the backbone won't become idle - and close on the backbone end.""" - return PeriodicPublisher( - publisher, time_interval, topic, task_name="broadcaster keepalive task" - ) +import asyncio +from typing import Any, Optional + +from fastapi_websocket_pubsub import Topic, TopicList +from opal_common.logger import logger + + +class Publisher: + """abstract publisher, base class for client side and server side + publisher.""" + + async def publish(self, topics: TopicList, data: Any = None): + raise NotImplementedError() + + async def publish_sync(self, topics: TopicList, data: Any = None): + raise NotImplementedError() + + +class PeriodicPublisher: + """Wrapper for a task that publishes to topic on fixed interval + periodically.""" + + def __init__( + self, + publisher: Publisher, + time_interval: int, + topic: Topic, + message: Any = None, + task_name: str = "periodic publish task", + ): + """inits the publisher. + + Args: + publisher (Publisher): can publish messages on the pub/sub channel + time_interval (int): the time interval between publishing consecutive messages + topic (Topic): the topic to publish on + message (Any): the message to publish + """ + self._publisher = publisher + self._interval = time_interval + self._topic = topic + self._message = message + self._task_name = task_name + self._task: Optional[asyncio.Task] = None + + async def __aenter__(self): + self.start() + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.stop() + + def start(self): + """starts the periodic publisher task.""" + if self._task is not None: + logger.warning(f"{self._task_name} already started") + return + + logger.info( + f"started {self._task_name}: topic is '{self._topic}', interval is {self._interval} seconds" + ) + self._task = asyncio.create_task(self._publish_task()) + + async def stop(self): + """stops the publisher (cancels any running publishing tasks)""" + if self._task is not None: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + self._task = None + logger.info(f"cancelled {self._task_name} to topic: {self._topic}") + + async def _publish_task(self): + while True: + await asyncio.sleep(self._interval) + logger.info( + f"{self._task_name}: publishing message on topic '{self._topic}', next publish is scheduled in {self._interval} seconds" + ) + try: + await self._publisher.publish_sync([self._topic], self._message) + except asyncio.CancelledError: + logger.debug( + f"{self._task_name} for topic '{self._topic}' was cancelled" + ) + break + except Exception as e: + logger.error( + f"failed to publish periodic message on topic '{self._topic}': {e}" + ) diff --git a/packages/opal-server/opal_server/pubsub.py b/packages/opal-server/opal_server/pubsub.py index 26d47c422..2262e42f3 100644 --- a/packages/opal-server/opal_server/pubsub.py +++ b/packages/opal-server/opal_server/pubsub.py @@ -1,15 +1,18 @@ +import asyncio import time from contextlib import contextmanager from contextvars import ContextVar from threading import Lock -from typing import Dict, Generator, List, Optional, Set, Tuple, Union, cast -from uuid import UUID, uuid4 +from typing import Any, Coroutine, Dict, Generator, Optional, Set, Union +from uuid import uuid4 +from ddtrace import tracer from fastapi import APIRouter, Depends, WebSocket from fastapi_websocket_pubsub import ( ALL_TOPICS, EventBroadcaster, PubSubEndpoint, + Topic, TopicList, ) from fastapi_websocket_pubsub.event_notifier import ( @@ -21,16 +24,18 @@ WebSocketRpcEventNotifier, ) from fastapi_websocket_rpc import RpcChannel +from opal_common.async_utils import TasksPool from opal_common.authentication.deps import WebsocketJWTAuthenticator from opal_common.authentication.signer import JWTSigner from opal_common.authentication.types import JWTClaims from opal_common.authentication.verifier import Unauthorized -from opal_common.confi.confi import load_conf_if_none from opal_common.config import opal_common_config from opal_common.logger import logger from opal_server.config import opal_server_config +from opal_server.publisher import PeriodicPublisher, Publisher from pydantic import BaseModel from starlette.datastructures import QueryParams +from tenacity import retry, wait_fixed OPAL_CLIENT_INFO_PARAM_PREFIX = "__opal_" OPAL_CLIENT_INFO_CLIENT_ID = f"{OPAL_CLIENT_INFO_PARAM_PREFIX}client_id" @@ -117,19 +122,37 @@ async def on_unsubscribe( client_info.subscribed_topics.difference_update(topics) -class PubSub: +def setup_broadcaster_keepalive_task( + pubsub: Publisher, + time_interval: int, + topic: Topic = "__broadcast_session_keepalive__", +) -> PeriodicPublisher: + """a periodic publisher with the intent to trigger messages on the + broadcast channel, so that the session to the backbone won't become idle + and close on the backbone end.""" + return PeriodicPublisher( + pubsub, time_interval, topic, task_name="broadcaster keepalive task" + ) + + +BROADCASTER_CONNECT_RETRY_INTERVAL = 2 + + +class PubSub(Publisher): """Wrapper for the Pub/Sub channel used for both policy and data updates.""" - def __init__(self, signer: JWTSigner, broadcaster_uri: str = None): + def __init__( + self, + signer: JWTSigner, + broadcaster_uri: str = None, + disconnect_callback: Coroutine = None, + ): """ Args: broadcaster_uri (str, optional): Which server/medium should the PubSub use for broadcasting. Defaults to BROADCAST_URI. None means no broadcasting. """ - broadcaster_uri = load_conf_if_none( - broadcaster_uri, opal_server_config.BROADCAST_URI - ) self.pubsub_router = APIRouter() self.api_router = APIRouter() # Pub/Sub Internals @@ -138,8 +161,8 @@ def __init__(self, signer: JWTSigner, broadcaster_uri: str = None): self.client_tracker = ClientTracker() self.notifier.register_subscribe_event(self.client_tracker.on_subscribe) self.notifier.register_unsubscribe_event(self.client_tracker.on_unsubscribe) + self._publish_pool = TasksPool() - self.broadcaster = None if broadcaster_uri is not None: logger.info(f"Initializing broadcaster for server<->server communication") self.broadcaster = EventBroadcaster( @@ -147,8 +170,22 @@ def __init__(self, signer: JWTSigner, broadcaster_uri: str = None): notifier=self.notifier, channel=opal_server_config.BROADCAST_CHANNEL_NAME, ) + if opal_server_config.BROADCAST_KEEPALIVE_INTERVAL > 0: + self.broadcast_keepalive = setup_broadcaster_keepalive_task( + self, + time_interval=opal_server_config.BROADCAST_KEEPALIVE_INTERVAL, + topic=opal_server_config.BROADCAST_KEEPALIVE_TOPIC, + ) + else: logger.info("Pub/Sub broadcaster is off") + self.broadcaster = None + self.broadcast_keepalive = None + + self._wait_for_broadcaster_closed: Optional[asyncio.Task] = None + self._disconnect_callbacks: Set[Coroutine] = set() + if disconnect_callback is not None: + self._disconnect_callbacks.add(disconnect_callback) # The server endpoint self.endpoint = PubSubEndpoint( @@ -202,6 +239,53 @@ async def websocket_rpc_endpoint( finally: await websocket.close() + async def start(self): + if self.broadcaster is not None: + logger.info("Waiting for successful broadcaster connection") + await retry(wait=wait_fixed(BROADCASTER_CONNECT_RETRY_INTERVAL))( + self.broadcaster.connect + )() + logger.info("Broadcaster connected") + self._wait_for_broadcaster_closed = asyncio.create_task( + self.wait_until_done() + ) + if self.broadcast_keepalive is not None: + self.broadcast_keepalive.start() + + async def stop(self): + stop_tasks = [self._publish_pool.join()] + if self.broadcast_keepalive is not None: + stop_tasks.append(self.broadcast_keepalive.stop()) + if self.broadcaster is not None: + self._wait_for_broadcaster_closed.cancel() + stop_tasks.append(self._wait_for_broadcaster_closed) + + await asyncio.gather(*stop_tasks, return_exceptions=True) + if self.broadcaster is not None: + await self.broadcaster.close() + self.broadcaster = None + + async def wait_until_done(self): + if self.broadcaster is not None: + await self.broadcaster.wait_until_done() + + for callback in self._disconnect_callbacks: + await callback + + async def publish_sync(self, topics: TopicList, data: Any = None): + with tracer.trace("topic_publisher.publish", resource=str(topics)): + await self.endpoint.publish(topics=topics, data=data) + + async def publish(self, topics: TopicList, data: Any = None): + self._publish_pool.add_task(self.publish_sync(topics, data)) + + async def subscribe( + self, + topics: Union[TopicList, ALL_TOPICS], + callback: EventCallback, + ) -> list[Subscription]: + return await self.endpoint.subscribe(topics, callback) + @staticmethod async def _verify_permitted_topics( topics: Union[TopicList, ALL_TOPICS], channel: RpcChannel diff --git a/packages/opal-server/opal_server/scopes/api.py b/packages/opal-server/opal_server/scopes/api.py index 95181866a..1ee3b0173 100644 --- a/packages/opal-server/opal_server/scopes/api.py +++ b/packages/opal-server/opal_server/scopes/api.py @@ -13,7 +13,6 @@ status, ) from fastapi.responses import RedirectResponse -from fastapi_websocket_pubsub import PubSubEndpoint from git import InvalidGitRepositoryError from opal_common.async_utils import run_sync from opal_common.authentication.authz import ( @@ -31,19 +30,17 @@ DataUpdate, ServerDataSourceConfig, ) -from opal_common.schemas.policy import PolicyBundle, PolicyUpdateMessageNotification +from opal_common.schemas.policy import PolicyBundle from opal_common.schemas.policy_source import GitPolicyScopeSource, SSHAuthData from opal_common.schemas.scopes import Scope from opal_common.schemas.security import PeerType -from opal_common.topics.publisher import ( - ScopedServerSideTopicPublisher, - ServerSideTopicPublisher, -) from opal_common.urls import set_url_query_param from opal_server.config import opal_server_config from opal_server.data.data_update_publisher import DataUpdatePublisher from opal_server.git_fetcher import GitPolicyFetcher +from opal_server.pubsub import PubSub from opal_server.scopes.scope_repository import ScopeNotFoundError, ScopeRepository +from opal_server.scopes.scoped_pubsub import ScopedPubSub def verify_private_key(private_key: str, key_format: EncryptionKeyFormat) -> bool: @@ -79,7 +76,7 @@ def verify_private_key_or_throw(scope_in: Scope): def init_scope_router( scopes: ScopeRepository, authenticator: JWTAuthenticator, - pubsub_endpoint: PubSubEndpoint, + pubsub: PubSub, ): router = APIRouter() @@ -117,7 +114,7 @@ async def put_scope( logger.info(f"Sync scope: {scope_in.scope_id}{force_fetch_str}") # All server replicas (leaders) should sync the scope. - await pubsub_endpoint.publish( + await pubsub.publish_sync( opal_server_config.POLICY_REPO_WEBHOOK_TOPIC, {"scope_id": scope_in.scope_id, "force_fetch": force_fetch}, ) @@ -203,7 +200,7 @@ async def refresh_scope( force_fetch = hinted_hash is None # All server replicas (leaders) should sync the scope. - await pubsub_endpoint.publish( + await pubsub.publish_sync( opal_server_config.POLICY_REPO_WEBHOOK_TOPIC, { "scope_id": scope_id, @@ -229,7 +226,7 @@ async def sync_all_scopes(claims: JWTClaims = Depends(authenticator)): raise # All server replicas (leaders) should sync all scopes. - await pubsub_endpoint.publish(opal_server_config.POLICY_REPO_WEBHOOK_TOPIC) + await pubsub.publish_sync(opal_server_config.POLICY_REPO_WEBHOOK_TOPIC) return Response(status_code=status.HTTP_200_OK) @@ -350,7 +347,7 @@ async def publish_data_update_event( entry.topics = [f"data:{topic}" for topic in entry.topics] await DataUpdatePublisher( - ScopedServerSideTopicPublisher(pubsub_endpoint, scope_id) + ScopedPubSub(pubsub, scope_id) ).publish_data_updates(update) except Unauthorized as ex: logger.error(f"Unauthorized to publish update: {repr(ex)}") diff --git a/packages/opal-server/opal_server/scopes/scoped_pubsub.py b/packages/opal-server/opal_server/scopes/scoped_pubsub.py new file mode 100644 index 000000000..ce55e062b --- /dev/null +++ b/packages/opal-server/opal_server/scopes/scoped_pubsub.py @@ -0,0 +1,22 @@ +from typing import Any + +from fastapi_websocket_pubsub import TopicList +from opal_common.logger import logger +from opal_server.pubsub import PubSub + + +class ScopedPubSub: + def __init__(self, pubsub: PubSub, scope_id: str): + self._pubsub = pubsub + self._scope_id = scope_id + + def scope_topics(self, topics: TopicList) -> TopicList: + topics = [f"{self._scope_id}:{topic}" for topic in topics] + logger.debug("Publishing to topics: {topics}", topics=topics) + return topics + + async def publish(self, topics: TopicList, data: Any = None): + await self._pubsub.publish(self.scope_topics(topics), data) + + async def publish_sync(self, topics: TopicList, data: Any = None): + await self._pubsub.publish_sync(self.scope_topics(topics), data) diff --git a/packages/opal-server/opal_server/scopes/service.py b/packages/opal-server/opal_server/scopes/service.py index f0104e7bf..fad2a5b5e 100644 --- a/packages/opal-server/opal_server/scopes/service.py +++ b/packages/opal-server/opal_server/scopes/service.py @@ -6,18 +6,18 @@ import git from ddtrace import tracer -from fastapi_websocket_pubsub import PubSubEndpoint from opal_common.git_utils.commit_viewer import VersionedFile from opal_common.logger import logger from opal_common.schemas.policy import PolicyUpdateMessageNotification from opal_common.schemas.policy_source import GitPolicyScopeSource -from opal_common.topics.publisher import ScopedServerSideTopicPublisher from opal_server.git_fetcher import GitPolicyFetcher, PolicyFetcherCallbacks from opal_server.policy.watcher.callbacks import ( create_policy_update, create_update_all_directories_in_repo, ) +from opal_server.pubsub import PubSub from opal_server.scopes.scope_repository import Scope, ScopeRepository +from opal_server.scopes.scoped_pubsub import ScopedPubSub def is_rego_source_file( @@ -41,12 +41,12 @@ def __init__( base_dir: Path, scope_id: str, source: GitPolicyScopeSource, - pubsub_endpoint: PubSubEndpoint, + pubsub: PubSub, ): self._scope_repo_dir = GitPolicyFetcher.repo_clone_path(base_dir, source) self._scope_id = scope_id self._source = source - self._pubsub_endpoint = pubsub_endpoint + self._pubsub = pubsub async def on_update(self, previous_head: str, head: str): if previous_head == head: @@ -93,10 +93,9 @@ async def trigger_notification(self, notification: PolicyUpdateMessageNotificati logger.info( f"Triggering policy update for scope {self._scope_id}: {notification.dict()}" ) - async with ScopedServerSideTopicPublisher( - self._pubsub_endpoint, self._scope_id - ) as publisher: - await publisher.publish(notification.topics, notification.update) + await ScopedPubSub(self._pubsub, self._scope_id).publish_sync( + notification.topics, notification.update + ) class ScopesService: @@ -104,11 +103,11 @@ def __init__( self, base_dir: Path, scopes: ScopeRepository, - pubsub_endpoint: PubSubEndpoint, + pubsub: Optional[PubSub], ): self._base_dir = base_dir self._scopes = scopes - self._pubsub_endpoint = pubsub_endpoint + self._pubsub = pubsub async def sync_scope( self, @@ -139,7 +138,7 @@ async def sync_scope( base_dir=self._base_dir, scope_id=scope.scope_id, source=source, - pubsub_endpoint=self._pubsub_endpoint, + pubsub=self._pubsub, ) fetcher = GitPolicyFetcher( diff --git a/packages/opal-server/opal_server/scopes/task.py b/packages/opal-server/opal_server/scopes/task.py index 83b2b10f0..868fd0617 100644 --- a/packages/opal-server/opal_server/scopes/task.py +++ b/packages/opal-server/opal_server/scopes/task.py @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs): self._service = ScopesService( base_dir=Path(opal_server_config.BASE_DIR), scopes=ScopeRepository(RedisDB(opal_server_config.REDIS_URL)), - pubsub_endpoint=self._pubsub_endpoint, + pubsub=self._pubsub, ) async def start(self): @@ -78,7 +78,7 @@ def preload_scopes(): service = ScopesService( base_dir=Path(opal_server_config.BASE_DIR), scopes=ScopeRepository(RedisDB(opal_server_config.REDIS_URL)), - pubsub_endpoint=None, + pubsub=None, ) asyncio.run(service.sync_scopes(notify_on_changes=False)) diff --git a/packages/opal-server/opal_server/server.py b/packages/opal-server/opal_server/server.py index 34d9905c3..ee690f189 100644 --- a/packages/opal-server/opal_server/server.py +++ b/packages/opal-server/opal_server/server.py @@ -2,12 +2,9 @@ import os import signal import sys -import traceback -from functools import partial from typing import List, Optional from fastapi import Depends, FastAPI -from fastapi_websocket_pubsub.event_broadcaster import EventBroadcasterContextManager from opal_common.authentication.deps import JWTAuthenticator, StaticBearerAuthenticator from opal_common.authentication.signer import JWTSigner from opal_common.confi.confi import load_conf_if_none @@ -17,11 +14,6 @@ from opal_common.monitoring import apm, metrics from opal_common.schemas.data import ServerDataSourceConfig from opal_common.synchronization.named_lock import NamedLock -from opal_common.topics.publisher import ( - PeriodicPublisher, - ServerSideTopicPublisher, - TopicPublisher, -) from opal_server.config import opal_server_config from opal_server.data.api import init_data_updates_router from opal_server.data.data_update_publisher import DataUpdatePublisher @@ -30,7 +22,6 @@ from opal_server.policy.watcher.factory import setup_watcher_task from opal_server.policy.watcher.task import PolicyWatcherTask from opal_server.policy.webhook.api import init_git_webhook_router -from opal_server.publisher import setup_broadcaster_keepalive_task from opal_server.pubsub import PubSub from opal_server.redis_utils import RedisDB from opal_server.scopes.api import init_scope_router @@ -46,7 +37,6 @@ def __init__( self, init_policy_watcher: bool = None, policy_remote_url: str = None, - init_publisher: bool = None, data_sources_config: Optional[ServerDataSourceConfig] = None, broadcaster_uri: str = None, signer: Optional[JWTSigner] = None, @@ -59,8 +49,6 @@ def __init__( """ Args: policy_remote_url (str, optional): the url of the repo watched by policy watcher. - init_publisher (bool, optional): whether or not to launch a publisher pub/sub client. - this publisher is used by the server processes to publish data to the client. data_sources_config (ServerDataSourceConfig, optional): base data configuration, that opal clients should get the data from. broadcaster_uri (str, optional): Which server/medium should the PubSub use for broadcasting. @@ -84,9 +72,6 @@ def __init__( will update the opal client via pubsub. """ # load defaults - init_publisher: bool = load_conf_if_none( - init_publisher, opal_server_config.PUBLISHER_ENABLED - ) broadcaster_uri: str = load_conf_if_none( broadcaster_uri, opal_server_config.BROADCAST_URI ) @@ -143,39 +128,17 @@ def __init__( else: self.jwks_endpoint = None - self.pubsub = PubSub(signer=self.signer, broadcaster_uri=broadcaster_uri) - - self.publisher: Optional[TopicPublisher] = None - self.broadcast_keepalive: Optional[PeriodicPublisher] = None - if init_publisher: - self.publisher = ServerSideTopicPublisher(self.pubsub.endpoint) - - if ( - opal_server_config.BROADCAST_KEEPALIVE_INTERVAL > 0 - and self.broadcaster_uri is not None - ): - self.broadcast_keepalive = setup_broadcaster_keepalive_task( - self.publisher, - time_interval=opal_server_config.BROADCAST_KEEPALIVE_INTERVAL, - topic=opal_server_config.BROADCAST_KEEPALIVE_TOPIC, - ) + self.pubsub = PubSub( + signer=self.signer, + broadcaster_uri=broadcaster_uri, + disconnect_callback=self._graceful_shutdown(), # TODO: a better approach might be to have each component (e.g statistics) register a callback on broadcaster reconnect that resets its own state (shouldn't rely on state if broadcaster was down) + ) if opal_common_config.STATISTICS_ENABLED: - self.opal_statistics = OpalStatistics(self.pubsub.endpoint) + self.opal_statistics = OpalStatistics(self.pubsub) else: self.opal_statistics = None - # if stats are enabled, the server workers must be listening on the broadcast - # channel for their own synchronization, not just for their clients. therefore - # we need a "global" listening context - self.broadcast_listening_context: Optional[ - EventBroadcasterContextManager - ] = None - if self.broadcaster_uri is not None and opal_common_config.STATISTICS_ENABLED: - self.broadcast_listening_context = ( - self.pubsub.endpoint.broadcaster.get_listening_context() - ) - self.watcher: PolicyWatcherTask = None self.leadership_lock: Optional[NamedLock] = None @@ -184,6 +147,8 @@ def __init__( self._scopes = ScopeRepository(self._redis_db) logger.info("OPAL Scopes: server is connected to scopes repository") + self._leadership_task: Optional[asyncio.Task] = None + # init fastapi app self.app: FastAPI = self._init_fast_api_app() @@ -221,15 +186,15 @@ def _configure_api_routes(self, app: FastAPI): """mounts the api routes on the app object.""" authenticator = JWTAuthenticator(self.signer) - data_update_publisher: Optional[DataUpdatePublisher] = None - if self.publisher is not None: - data_update_publisher = DataUpdatePublisher(self.publisher) + data_update_publisher: Optional[DataUpdatePublisher] = DataUpdatePublisher( + self.pubsub + ) # Init api routers with required dependencies data_updates_router = init_data_updates_router( data_update_publisher, self.data_sources_config, authenticator ) - webhook_router = init_git_webhook_router(self.pubsub.endpoint, authenticator) + webhook_router = init_git_webhook_router(self.pubsub, authenticator) security_router = init_security_router( self.signer, StaticBearerAuthenticator(self.master_token) ) @@ -264,7 +229,7 @@ def _configure_api_routes(self, app: FastAPI): if opal_server_config.SCOPES: app.include_router( - init_scope_router(self._scopes, authenticator, self.pubsub.endpoint), + init_scope_router(self._scopes, authenticator, self.pubsub), tags=["Scopes"], prefix="/scopes", ) @@ -294,12 +259,9 @@ async def startup_event(): logger.info("*** OPAL Server Startup ***") try: - self._task = asyncio.create_task(self.start_server_background_tasks()) - + await self.start_server_background_tasks() except Exception: - logger.critical("Exception while starting OPAL") - traceback.print_exc() - + logger.exception("Exception while starting OPAL") sys.exit(1) @app.on_event("shutdown") @@ -309,6 +271,27 @@ async def shutdown_event(): return app + async def _wait_for_leadership(self): + # We want only one worker to run repo watchers + # (otherwise for each new commit, we will publish multiple updates via pub/sub). + # leadership is determined by the first worker to obtain a lock + self.leadership_lock = NamedLock(opal_server_config.LEADER_LOCK_FILE_PATH) + await self.leadership_lock.acquire() + # only one worker gets here, the others block. in case the leader worker + # is terminated, another one will obtain the lock and become leader. + logger.info( + "leadership lock acquired, leader pid: {pid}", + pid=os.getpid(), + ) + + if opal_server_config.SCOPES: + await load_scopes(self._scopes) + + if self._init_policy_watcher: + # TODO: Should somehow refresh scopes if broadcaster connection was lost (maybe webhook msgs got lost?) + self.watcher = setup_watcher_task(self.pubsub) + await self.watcher.start() + async def start_server_background_tasks(self): """starts the background processes (as asyncio tasks) if such are configured. @@ -319,85 +302,36 @@ async def start_server_background_tasks(self): only the leader worker (first to obtain leadership lock) will start these tasks: - (repo) watcher: monitors the policy git repository for changes. """ - if self.publisher is not None: - async with self.publisher: - if self.opal_statistics is not None: - if self.broadcast_listening_context is not None: - logger.info( - "listening on broadcast channel for statistics events..." - ) - await self.broadcast_listening_context.__aenter__() - # if the broadcast channel is closed, we want to restart worker process because statistics can't be reliable anymore - self.broadcast_listening_context._event_broadcaster.get_reader_task().add_done_callback( - lambda _: self._graceful_shutdown() - ) - asyncio.create_task(self.opal_statistics.run()) - self.pubsub.endpoint.notifier.register_unsubscribe_event( - self.opal_statistics.remove_client - ) - - # We want only one worker to run repo watchers - # (otherwise for each new commit, we will publish multiple updates via pub/sub). - # leadership is determined by the first worker to obtain a lock - self.leadership_lock = NamedLock( - opal_server_config.LEADER_LOCK_FILE_PATH - ) - async with self.leadership_lock: - # only one worker gets here, the others block. in case the leader worker - # is terminated, another one will obtain the lock and become leader. - logger.info( - "leadership lock acquired, leader pid: {pid}", - pid=os.getpid(), - ) - - if opal_server_config.SCOPES: - await load_scopes(self._scopes) - - if self.broadcast_keepalive is not None: - self.broadcast_keepalive.start() - if not self._init_policy_watcher: - # Wait on keepalive instead to keep leadership lock acquired - await self.broadcast_keepalive.wait_until_done() - - if self._init_policy_watcher: - self.watcher = setup_watcher_task( - self.publisher, self.pubsub.endpoint - ) - # running the watcher, and waiting until it stops (until self.watcher.signal_stop() is called) - async with self.watcher: - await self.watcher.wait_until_should_stop() - - # Worker should restart when watcher stops - self._graceful_shutdown() - - if ( - self.opal_statistics is not None - and self.broadcast_listening_context is not None - ): - await self.broadcast_listening_context.__aexit__() - logger.info( - "stopped listening for statistics events on the broadcast channel" - ) + await self.pubsub.start() + + if self.opal_statistics is not None: + await self.opal_statistics.start() + + self._leadership_task = asyncio.create_task(self._wait_for_leadership()) async def stop_server_background_tasks(self): logger.info("stopping background tasks...") tasks: List[asyncio.Task] = [] - if self.watcher is not None: - tasks.append(asyncio.create_task(self.watcher.stop())) - if self.publisher is not None: - tasks.append(asyncio.create_task(self.publisher.stop())) - if self.broadcast_keepalive is not None: - tasks.append(asyncio.create_task(self.broadcast_keepalive.stop())) if self.opal_statistics is not None: tasks.append(asyncio.create_task(self.opal_statistics.stop())) + if self.watcher is not None: + tasks.append(asyncio.create_task(self.watcher.stop())) + if self._leadership_task is not None: + self._leadership_task.cancel() + tasks.append(self._leadership_task) + tasks.append(asyncio.create_task(self.pubsub.stop())) try: await asyncio.gather(*tasks) except Exception: logger.exception("exception while shutting down background tasks") - def _graceful_shutdown(self): + if self.leadership_lock.is_locked: + await self.leadership_lock.release() + + @staticmethod + async def _graceful_shutdown(): logger.info("Trigger worker graceful shutdown") os.kill(os.getpid(), signal.SIGTERM) diff --git a/packages/opal-server/opal_server/statistics.py b/packages/opal-server/opal_server/statistics.py index 14ea97f0a..49e4c49b7 100644 --- a/packages/opal-server/opal_server/statistics.py +++ b/packages/opal-server/opal_server/statistics.py @@ -10,12 +10,11 @@ import pydantic from fastapi import APIRouter, HTTPException, status from fastapi_websocket_pubsub.event_notifier import Subscription, TopicList -from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint from opal_common.async_utils import TasksPool from opal_common.config import opal_common_config from opal_common.logger import get_logger -from opal_common.topics.publisher import PeriodicPublisher from opal_server.config import opal_server_config +from opal_server.pubsub import PubSub from pydantic import BaseModel, Field @@ -75,8 +74,8 @@ class OpalStatistics: The pub/sub server endpoint that allows us to subscribe to the stats channel on the server side """ - def __init__(self, endpoint): - self._endpoint: PubSubEndpoint = endpoint + def __init__(self, pubsub): + self._pubsub: PubSub = pubsub self._uptime = datetime.utcnow() self._workers_count = (lambda envar: int(envar) if envar.isdigit() else 1)( os.environ.get("UVICORN_NUM_WORKERS", "1") @@ -102,7 +101,6 @@ def __init__(self, endpoint): self._lock = asyncio.Lock() self._synced_after_wakeup = asyncio.Event() self._received_sync_messages: Set[str] = set() - self._publish_tasks = TasksPool() self._seen_servers: Dict[str, datetime] = {} self._periodic_keepalive_task: asyncio.Task | None = None @@ -135,7 +133,7 @@ async def _periodic_server_keepalive(self): while True: try: await self._expire_old_servers() - self._publish( + await self._publish( opal_server_config.STATISTICS_SERVER_KEEPALIVE_CHANNEL, ServerKeepalive(worker_id=self._worker_id).dict(), ) @@ -147,33 +145,35 @@ async def _periodic_server_keepalive(self): return except Exception as e: logger.exception("Statistics: periodic server keepalive failed") - logger.exception("Statistics: periodic server keepalive failed") - def _publish(self, channel: str, message: Any): - self._publish_tasks.add_task(self._endpoint.publish([channel], message)) + async def _publish(self, channel: str, message: Any): + await self._pubsub.publish([channel], message) - async def run(self): + async def start(self): """subscribe to two channels to be able to sync add and delete of clients.""" - await self._endpoint.subscribe( + await self._pubsub.subscribe( [opal_server_config.STATISTICS_WAKEUP_CHANNEL], self._receive_other_worker_wakeup_message, ) - await self._endpoint.subscribe( + await self._pubsub.subscribe( [opal_server_config.STATISTICS_STATE_SYNC_CHANNEL], self._receive_other_worker_synced_state, ) - await self._endpoint.subscribe( + await self._pubsub.subscribe( [opal_server_config.STATISTICS_SERVER_KEEPALIVE_CHANNEL], self._receive_other_worker_keepalive_message, ) - await self._endpoint.subscribe( + await self._pubsub.subscribe( [opal_common_config.STATISTICS_ADD_CLIENT_CHANNEL], self._add_client ) - await self._endpoint.subscribe( + await self._pubsub.subscribe( [opal_common_config.STATISTICS_REMOVE_CLIENT_CHANNEL], self._sync_remove_client, ) + self._pubsub.endpoint.notifier.register_unsubscribe_event( + self.remove_client + ) # TODO: Should have a better way to handle this # wait before publishing the wakeup message, due to the fact we are # counting on the broadcaster to listen and to replicate the message @@ -183,7 +183,7 @@ async def run(self): await asyncio.sleep(SLEEP_TIME_FOR_BROADCASTER_READER_TO_START) # Let all the other opal servers know that new opal server started logger.info(f"sending stats wakeup message: {self._worker_id}") - self._publish( + await self._publish( opal_server_config.STATISTICS_WAKEUP_CHANNEL, SyncRequest(requesting_worker_id=self._worker_id).dict(), ) @@ -242,7 +242,7 @@ async def _receive_other_worker_wakeup_message( logger.info( f"[{request.requesting_worker_id}] respond with my own stats" ) - self._publish( + await self._publish( opal_server_config.STATISTICS_STATE_SYNC_CHANNEL, SyncResponse( requesting_worker_id=request.requesting_worker_id, @@ -363,7 +363,7 @@ async def remove_client(self, rpc_id: str, topics: TopicList, publish=True): "Publish rpc_id={rpc_id} to be removed from statistics", rpc_id=rpc_id, ) - self._publish( + await self._publish( opal_common_config.STATISTICS_REMOVE_CLIENT_CHANNEL, rpc_id, ) diff --git a/packages/requires.txt b/packages/requires.txt index 7c586c798..60b068bc9 100644 --- a/packages/requires.txt +++ b/packages/requires.txt @@ -1,7 +1,7 @@ idna>=3.3,<4 typer>=0.4.1,<1 fastapi>=0.109.1,<1 -fastapi_websocket_pubsub==0.3.7 +fastapi_websocket_pubsub==0.3.9 fastapi_websocket_rpc==0.1.27 websockets>=10.3,<14 gunicorn>=22.0.0,<23