Skip to content

Commit

Permalink
Forward existing Authencation headers
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Jan 10, 2025
1 parent f3767e1 commit e9406e5
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/blueapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class TiledConfig(BaseModel):
Config for connecting to a tiled instance
"""

uri: str
api_key: str
host: str
port: int


class WorkerEventConfig(BlueapiBaseModel):
Expand Down
23 changes: 5 additions & 18 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
from functools import cache
from typing import Any

from bluesky.callbacks.tiled_writer import TiledWriter
from bluesky_stomp.messaging import StompClient
from bluesky_stomp.models import Broker, DestinationBase, MessageTopic
from tiled.client import from_uri

from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig, TiledConfig
from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig
from blueapi.core.context import BlueskyContext
from blueapi.core.event import EventStream
from blueapi.service.model import DeviceModel, PlanModel, WorkerTask
from blueapi.worker.event import TaskStatusEnum, WorkerState
from blueapi.worker.task import Task
from blueapi.worker.task_worker import TaskWorker, TrackableTask
from blueapi.worker.tiled import TiledConnection

"""This module provides interface between web application and underlying Bluesky
context and worker"""
Expand Down Expand Up @@ -42,27 +41,16 @@ def context() -> BlueskyContext:

@cache
def worker() -> TaskWorker:
conf = config()
worker = TaskWorker(
context(),
broadcast_statuses=config().env.events.broadcast_status_events,
broadcast_statuses=conf.env.events.broadcast_status_events,
tiled_inserter=TiledConnection(conf.tiled) if conf.tiled else None,
)
worker.start()
return worker


@cache
def tiled_inserter():
tiled_config: TiledConfig | None = config().tiled
if tiled_config is not None:
client = from_uri(tiled_config.uri, api_key=tiled_config.api_key)

ctx = context()
ctx.run_engine.subscribe(TiledWriter(client))
return client
else:
return None


@cache
def stomp_client() -> StompClient | None:
stomp_config: StompConfig | None = config().stomp
Expand Down Expand Up @@ -101,7 +89,6 @@ def setup(config: ApplicationConfig) -> None:
logging.basicConfig(format="%(asctime)s - %(message)s", level=config.logging.level)
worker()
stomp_client()
tiled_inserter()


def teardown() -> None:
Expand Down
31 changes: 25 additions & 6 deletions src/blueapi/worker/task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Generic, TypeVar

from bluesky.protocols import Status
from httpx import Headers
from observability_utils.tracing import (
add_span_attributes,
get_tracer,
Expand All @@ -32,6 +33,7 @@
from blueapi.core.bluesky_event_loop import configure_bluesky_event_loop
from blueapi.utils.base_model import BlueapiBaseModel
from blueapi.utils.thread_exception import handle_all_exceptions
from blueapi.worker.tiled import TiledConnection

from .event import (
ProgressEvent,
Expand Down Expand Up @@ -112,9 +114,11 @@ def __init__(
ctx: BlueskyContext,
start_stop_timeout: float = DEFAULT_START_STOP_TIMEOUT,
broadcast_statuses: bool = True,
tiled_inserter: TiledConnection | None = None,
) -> None:
self._ctx = ctx
self._start_stop_timeout = start_stop_timeout
self._tiled_inserter = tiled_inserter

self._tasks = {}

Expand Down Expand Up @@ -194,13 +198,25 @@ def get_active_task(self) -> TrackableTask[Task] | None:
return current

@start_as_current_span(TRACER, "task_id")
def begin_task(self, task_id: str) -> None:
def begin_task(self, task_id: str, headers: Headers | None) -> None:
task = self._tasks.get(task_id)
data_subs: list[int] = []
if task is not None:
self._submit_trackable_task(task)
if self._tiled_inserter:
data_subs.append(self._authorize_running_task(headers))
self._submit_trackable_task(task, data_subs)

else:
raise KeyError(f"No pending task with ID {task_id}")

def _authorize_running_task(self, headers: Headers | None) -> int:
assert self._tiled_inserter
# https://github.com/DiamondLightSource/blueapi/issues/774
# If users should only be able to run their own scans, pass headers
# as part of submitting a task, cache in TrackableTask field and check
# that token belongs to same user (but may be newer token!)
return self.data_events.subscribe(self._tiled_inserter(headers))

@start_as_current_span(TRACER, "task.name", "task.params")
def submit_task(self, task: Task) -> str:
task.prepare_params(self._ctx) # Will raise if parameters are invalid
Expand All @@ -218,7 +234,9 @@ def submit_task(self, task: Task) -> str:
"trackable_task.task.name",
"trackable_task.task.params",
)
def _submit_trackable_task(self, trackable_task: TrackableTask) -> None:
def _submit_trackable_task(
self, trackable_task: TrackableTask, data_subs: list[int] | None = None
) -> None:
if self.state is not WorkerState.IDLE:
raise WorkerBusyError(f"Worker is in state {self.state}")

Expand All @@ -235,17 +253,18 @@ def mark_task_as_started(event: WorkerEvent, _: str | None) -> None:
sub = self.worker_events.subscribe(mark_task_as_started)
try:
self._current_task_otel_context = get_current()
sub = self.worker_events.subscribe(mark_task_as_started)
""" Cache the current trace context as the one for this task id """
self._task_channel.put_nowait(trackable_task)
task_started.wait(timeout=5.0)
if not task_started.is_set():
if not task_started.wait(timeout=5.0):
raise TimeoutError("Failed to start plan within timeout")
except Full as f:
LOGGER.error("Cannot submit task while another is running")
raise WorkerBusyError("Cannot submit task while another is running") from f
finally:
self.worker_events.unsubscribe(sub)
if data_subs:
for data_sub in data_subs:
self.data_events.unsubscribe(data_sub)

@start_as_current_span(TRACER)
def start(self) -> None:
Expand Down
23 changes: 23 additions & 0 deletions src/blueapi/worker/tiled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from bluesky.callbacks.tiled_writer import TiledWriter
from httpx import Headers
from tiled.client import from_context
from tiled.client.context import Context as TiledContext

from blueapi.config import TiledConfig
from blueapi.core.bluesky_types import DataEvent


class TiledConverter:
def __init__(self, tiled_context: TiledContext):
self._writer: TiledWriter = TiledWriter(from_context(tiled_context))

def __call__(self, data: DataEvent, _: str | None = None) -> None:
self._writer(data.name, data.doc)


class TiledConnection:
def __init__(self, config: TiledConfig):
self.uri = f"{config.host}:{config.port}"

def __call__(self, headers: Headers | None):
return TiledConverter(TiledContext(self.uri, headers=headers))

0 comments on commit e9406e5

Please sign in to comment.