Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OpalServer-EvenBroadcaster integration & other OpalServer refactors #632

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
12 changes: 11 additions & 1 deletion packages/opal-common/opal_common/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,23 @@ def __init__(self):
self._tasks: List[asyncio.Task] = []

def _cleanup_task(self, done_task):
self._tasks.remove(done_task)
try:
self._tasks.remove(done_task)
except KeyError:
...

def add_task(self, f):
t = asyncio.create_task(f)
self._tasks.append(t)
t.add_done_callback(self._cleanup_task)

async def join(self, cancel=False):
if cancel:
for t in self._tasks:
t.cancel()
await asyncio.gather(*self._tasks, return_exceptions=True)
self._tasks.clear()


async def repeated_call(
func: Coroutine,
Expand Down
2 changes: 1 addition & 1 deletion packages/opal-common/opal_common/confi/confi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def wrapped_cast(value, *args, **kwargs):
return wrapped_cast


def load_conf_if_none(variable, conf):
def load_conf_if_none(variable: Any, conf: Any):
if variable is None:
return conf
else:
Expand Down
165 changes: 4 additions & 161 deletions packages/opal-common/opal_common/topics/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Any, Optional, Set

from ddtrace import tracer
from fastapi_websocket_pubsub import PubSubClient, PubSubEndpoint, Topic, TopicList
from fastapi_websocket_pubsub import PubSubEndpoint, Topic, TopicList
from opal_common.async_utils import TasksPool
from opal_common.logger import logger


Expand All @@ -12,8 +13,7 @@ class TopicPublisher:

def __init__(self):
"""inits the publisher's asyncio tasks list."""
self._tasks: Set[asyncio.Task] = set()
self._tasks_lock = asyncio.Lock()
self._pool = TasksPool()

async def publish(self, topics: TopicList, data: Any = None):
raise NotImplementedError()
Expand All @@ -29,95 +29,10 @@ def start(self):
"""starts the publisher."""
logger.debug("started topic publisher")

async def _add_task(self, task: asyncio.Task):
async with self._tasks_lock:
self._tasks.add(task)
task.add_done_callback(self._cleanup_task)

async def wait(self):
async with self._tasks_lock:
await asyncio.gather(*self._tasks, return_exceptions=True)
self._tasks.clear()

async def stop(self):
"""stops the publisher (cancels any running publishing tasks)"""
logger.debug("stopping topic publisher")
await self.wait()

def _cleanup_task(self, task: asyncio.Task):
try:
self._tasks.remove(task)
except KeyError:
...


class PeriodicPublisher:
"""Wrapper for a task that publishes to topic on fixed interval
periodically."""

def __init__(
self,
publisher: TopicPublisher,
time_interval: int,
topic: Topic,
message: Any = None,
task_name: str = "periodic publish task",
):
"""inits the publisher.

Args:
publisher (TopicPublisher): can publish messages on the pub/sub channel
interval (int): the time interval between publishing consecutive messages
topic (Topic): the topic to publish on
message (Any): the message to publish
"""
self._publisher = publisher
self._interval = time_interval
self._topic = topic
self._message = message
self._task_name = task_name
self._task: Optional[asyncio.Task] = None

async def __aenter__(self):
self.start()
return self

async def __aexit__(self, exc_type, exc, tb):
await self.stop()

def start(self):
"""starts the periodic publisher task."""
if self._task is not None:
logger.warning(f"{self._task_name} already started")
return

logger.info(
f"started {self._task_name}: topic is '{self._topic}', interval is {self._interval} seconds"
)
self._task = asyncio.create_task(self._publish_task())

async def stop(self):
"""stops the publisher (cancels any running publishing tasks)"""
if self._task is not None:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self._task = None
logger.info(f"cancelled {self._task_name} to topic: {self._topic}")

async def wait_until_done(self):
await self._task

async def _publish_task(self):
while True:
await asyncio.sleep(self._interval)
logger.info(
f"{self._task_name}: publishing message on topic '{self._topic}', next publish is scheduled in {self._interval} seconds"
)
async with self._publisher:
await self._publisher.publish(topics=[self._topic], data=self._message)
await self._pool.join()


class ServerSideTopicPublisher(TopicPublisher):
Expand All @@ -132,77 +47,5 @@ def __init__(self, endpoint: PubSubEndpoint):
self._endpoint = endpoint
super().__init__()

async def _publish_impl(self, topics: TopicList, data: Any = None):
with tracer.trace("topic_publisher.publish", resource=str(topics)):
await self._endpoint.publish(topics=topics, data=data)

async def publish(self, topics: TopicList, data: Any = None):
await self._add_task(asyncio.create_task(self._publish_impl(topics, data)))


class ClientSideTopicPublisher(TopicPublisher):
"""A simple wrapper around a PubSubClient that exposes publish().

Provides start() and stop() shortcuts that helps treat this client
as a separate "process" or task that runs in the background.
"""

def __init__(self, client: PubSubClient, server_uri: str):
"""inits the publisher.

Args:
client (PubSubClient): a configured not-yet-started pub sub client
server_uri (str): the URI of the pub sub server we publish to
"""
self._client = client
self._server_uri = server_uri
super().__init__()

def start(self):
"""starts the pub/sub client as a background asyncio task.

the client will attempt to connect to the pubsub server until
successful.
"""
super().start()
self._client.start_client(f"{self._server_uri}")

async def stop(self):
"""stops the pubsub client, and cancels any publishing tasks."""
await self._client.disconnect()
await super().stop()

async def wait_until_done(self):
"""When the publisher is a used as a context manager, this method waits
until the client is done (i.e: terminated) to prevent exiting the
context."""
return await self._client.wait_until_done()

async def publish(self, topics: TopicList, data: Any = None):
"""publish a message by launching a background task on the event loop.

Args:
topics (TopicList): a list of topics to publish the message to
data (Any): optional data to publish as part of the message
"""
await self._add_task(
asyncio.create_task(self._publish(topics=topics, data=data))
)

async def _publish(self, topics: TopicList, data: Any = None) -> bool:
"""Do not trigger directly, must be triggered via publish() in order to
run as a monitored background asyncio task."""
await self._client.wait_until_ready()
logger.info("Publishing to topics: {topics}", topics=topics)
return await self._client.publish(topics, data)


class ScopedServerSideTopicPublisher(ServerSideTopicPublisher):
def __init__(self, endpoint: PubSubEndpoint, scope_id: str):
super().__init__(endpoint)
self._scope_id = scope_id

async def publish(self, topics: TopicList, data: Any = None):
scoped_topics = [f"{self._scope_id}:{topic}" for topic in topics]
logger.info("Publishing to topics: {topics}", topics=scoped_topics)
await super().publish(scoped_topics, data)
18 changes: 6 additions & 12 deletions packages/opal-server/opal_server/data/data_update_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,18 @@
import os
from typing import List

from fastapi_utils.tasks import repeat_every
from opal_common.logger import logger
from opal_common.schemas.data import (
DataSourceEntryWithPollingInterval,
DataUpdate,
ServerDataSourceConfig,
)
from opal_common.topics.publisher import TopicPublisher
from opal_common.schemas.data import DataUpdate
from opal_server.pubsub import PubSub
from opal_server.scopes.scoped_pubsub import ScopedPubSub

TOPIC_DELIMITER = "/"
PREFIX_DELIMITER = ":"


class DataUpdatePublisher:
def __init__(self, publisher: TopicPublisher) -> None:
self._publisher = publisher
def __init__(self, pubsub: PubSub | ScopedPubSub) -> None:
self._pubsub = pubsub

@staticmethod
def get_topic_combos(topic: str) -> List[str]:
Expand Down Expand Up @@ -108,6 +104,4 @@ async def publish_data_updates(self, update: DataUpdate):
entries=logged_entries,
)

await self._publisher.publish(
list(all_topic_combos), update.dict(by_alias=True)
)
await self._pubsub.publish(list(all_topic_combos), update.dict(by_alias=True))
9 changes: 3 additions & 6 deletions packages/opal-server/opal_server/policy/watcher/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
PolicyUpdateMessage,
PolicyUpdateMessageNotification,
)
from opal_common.topics.publisher import TopicPublisher
from opal_common.topics.utils import policy_topics
from opal_server.pubsub import PubSub


async def create_update_all_directories_in_repo(
Expand Down Expand Up @@ -104,7 +104,7 @@ def is_path_affected(path: Path) -> bool:
async def publish_changed_directories(
old_commit: Commit,
new_commit: Commit,
publisher: TopicPublisher,
pubsub: PubSub,
file_extensions: Optional[List[str]] = None,
bundle_ignore: Optional[List[str]] = None,
):
Expand All @@ -116,7 +116,4 @@ async def publish_changed_directories(
)

if notification:
async with publisher:
await publisher.publish(
topics=notification.topics, data=notification.update.dict()
)
await pubsub.publish_sync(notification.topics, notification.update.dict())
38 changes: 18 additions & 20 deletions packages/opal-server/opal_server/policy/watcher/factory.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,39 @@
from functools import partial
from typing import Any, List, Optional
from typing import List, Optional

from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint
from opal_common.confi.confi import load_conf_if_none
from opal_common.git_utils.repo_cloner import RepoClonePathFinder
from opal_common.logger import logger
from opal_common.sources.api_policy_source import ApiPolicySource
from opal_common.sources.git_policy_source import GitPolicySource
from opal_common.topics.publisher import TopicPublisher
from opal_server.config import PolicySourceTypes, opal_server_config
from opal_server.policy.watcher.callbacks import publish_changed_directories
from opal_server.policy.watcher.task import BasePolicyWatcherTask, PolicyWatcherTask
from opal_server.pubsub import PubSub
from opal_server.scopes.task import ScopesPolicyWatcherTask


def setup_watcher_task(
publisher: TopicPublisher,
pubsub_endpoint: PubSubEndpoint,
source_type: str = None,
remote_source_url: str = None,
clone_path_finder: RepoClonePathFinder = None,
branch_name: str = None,
pubsub: PubSub,
source_type: Optional[str] = None,
remote_source_url: Optional[str] = None,
clone_path_finder: Optional[RepoClonePathFinder] = None,
branch_name: Optional[str] = None,
ssh_key: Optional[str] = None,
polling_interval: int = None,
request_timeout: int = None,
policy_bundle_token: str = None,
policy_bundle_token_id: str = None,
policy_bundle_server_type: str = None,
policy_bundle_aws_region: str = None,
polling_interval: Optional[int] = None,
request_timeout: Optional[int] = None,
policy_bundle_token: Optional[str] = None,
policy_bundle_token_id: Optional[str] = None,
policy_bundle_server_type: Optional[str] = None,
policy_bundle_aws_region: Optional[str] = None,
extensions: Optional[List[str]] = None,
bundle_ignore: Optional[List[str]] = None,
) -> BasePolicyWatcherTask:
"""Create a PolicyWatcherTask with Git / API policy source defined by env
vars Load all the defaults from config if called without params.

Args:
publisher(TopicPublisher): server side publisher to publish changes in policy
pubsub(PubSub): server side pubsub client to publish changes in policy
source_type(str): policy source type, can be Git / Api to opa bundle server
remote_source_url(str): the base address to request the policy from
clone_path_finder(RepoClonePathFinder): from which the local dir path for the repo clone would be retrieved
Expand All @@ -46,11 +44,11 @@ def setup_watcher_task(
policy_bundle_token(int): auth token to include in connections to OPAL server. Defaults to POLICY_BUNDLE_SERVER_TOKEN.
policy_bundle_token_id(int): id token to include in connections to OPAL server. Defaults to POLICY_BUNDLE_SERVER_TOKEN_ID.
policy_bundle_server_type (str): type of policy bundle server (HTTP S3). Defaults to POLICY_BUNDLE_SERVER_TYPE
extensions(list(str), optional): list of extantions to check when new policy arrive default is FILTER_FILE_EXTENSIONS
extensions(list(str), optional): list of extensions to check when new policy arrive default is FILTER_FILE_EXTENSIONS
bundle_ignore(list(str), optional): list of glob paths to use for excluding files from bundle default is OPA_BUNDLE_IGNORE
"""
if opal_server_config.SCOPES:
return ScopesPolicyWatcherTask(pubsub_endpoint)
return ScopesPolicyWatcherTask(pubsub)

# load defaults
source_type = load_conf_if_none(source_type, opal_server_config.POLICY_SOURCE_TYPE)
Expand Down Expand Up @@ -135,9 +133,9 @@ def setup_watcher_task(
watcher.add_on_new_policy_callback(
partial(
publish_changed_directories,
publisher=publisher,
pubsub=pubsub,
file_extensions=extensions,
bundle_ignore=bundle_ignore,
)
)
return PolicyWatcherTask(watcher, pubsub_endpoint)
return PolicyWatcherTask(watcher, pubsub)
Loading
Loading