Skip to content

Commit

Permalink
feat: Change health connection to SSE
Browse files Browse the repository at this point in the history
  • Loading branch information
Threated committed Jan 21, 2025
1 parent 607d7bc commit 8cd208c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 40 deletions.
36 changes: 22 additions & 14 deletions broker/src/serve_health.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::{sync::Arc, time::{Duration, SystemTime}};
use std::{convert::Infallible, marker::PhantomData, sync::Arc, time::{Duration, SystemTime}};

use axum::{extract::{State, Path}, http::StatusCode, routing::get, Json, Router, response::Response};
use axum::{extract::{Path, State}, http::StatusCode, response::{sse::{Event, KeepAlive}, Response, Sse}, routing::get, Json, Router};
use axum_extra::{headers::{authorization::Basic, Authorization}, TypedHeader};
use beam_lib::ProxyId;
use futures_core::Stream;
use serde::{Serialize, Deserialize};
use shared::{crypto_jwt::Authorized, Msg, config::CONFIG_CENTRAL};
use tokio::sync::RwLock;
Expand Down Expand Up @@ -46,7 +47,7 @@ async fn handler(
}

async fn get_all_proxies(State(state): State<Arc<RwLock<Health>>>) -> Json<Vec<ProxyId>> {
Json(state.read().await.proxies.keys().cloned().collect())
Json(state.read().await.proxies.iter().filter(|(_, v)| v.online()).map(|(k, _)| k).cloned().collect())
}

async fn proxy_health(
Expand Down Expand Up @@ -76,25 +77,32 @@ async fn proxy_health(
async fn get_control_tasks(
State(state): State<Arc<RwLock<Health>>>,
proxy_auth: Authorized,
) -> StatusCode {
) -> Sse<ForeverStream> {
let proxy_id = proxy_auth.get_from().proxy_id();
// Once this is freed the connection will be removed from the map of connected proxies again
// This ensures that when the connection is dropped and therefore this response future the status of this proxy will be updated
let _connection_remover = ConnectedGuard::connect(&proxy_id, &state).await;
let connect_guard = ConnectedGuard::connect(proxy_id, state).await;

// In the future, this will wait for control tasks for the given proxy
tokio::time::sleep(Duration::from_secs(60 * 60)).await;
Sse::new(ForeverStream(connect_guard)).keep_alive(KeepAlive::new())
}

struct ForeverStream(#[allow(dead_code)] ConnectedGuard);

StatusCode::OK
impl Stream for ForeverStream {
type Item = Result<Event, Infallible>;

fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
std::task::Poll::Pending
}
}

struct ConnectedGuard<'a> {
proxy: &'a ProxyId,
state: &'a Arc<RwLock<Health>>
struct ConnectedGuard {
proxy: ProxyId,
state: Arc<RwLock<Health>>
}

impl<'a> ConnectedGuard<'a> {
async fn connect(proxy: &'a ProxyId, state: &'a Arc<RwLock<Health>>) -> ConnectedGuard<'a> {
impl ConnectedGuard {
async fn connect(proxy: ProxyId, state: Arc<RwLock<Health>>) -> ConnectedGuard {
{
state.write().await.proxies
.entry(proxy.clone())
Expand All @@ -105,7 +113,7 @@ impl<'a> ConnectedGuard<'a> {
}
}

impl<'a> Drop for ConnectedGuard<'a> {
impl Drop for ConnectedGuard {
fn drop(&mut self) {
let proxy_id = self.proxy.clone();
let map = self.state.clone();
Expand Down
65 changes: 39 additions & 26 deletions proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::time::Duration;
use axum::http::{header, HeaderValue, StatusCode};
use beam_lib::AppOrProxyId;
use futures::future::Ready;
use futures::{StreamExt, TryStreamExt};
use shared::{reqwest, EncryptedMessage, MsgEmpty, PlainMessage};
use shared::crypto::CryptoPublicPortion;
use shared::errors::SamplyBeamError;
Expand Down Expand Up @@ -132,8 +133,12 @@ fn spawn_controller_polling(client: SamplyHttpClient, config: Config) {
const RETRY_INTERVAL: Duration = Duration::from_secs(60);
tokio::spawn(async move {
let mut retries_this_min = 0;
let mut reset_interval = std::pin::pin!(tokio::time::sleep(Duration::from_secs(60)));
let mut reset_interval = Instant::now() + RETRY_INTERVAL;
loop {
if reset_interval < Instant::now() {
retries_this_min = 0;
reset_interval = Instant::now() + RETRY_INTERVAL;
}
let body = EncryptedMessage::MsgEmpty(MsgEmpty {
from: AppOrProxyId::Proxy(config.proxy_id.clone()),
});
Expand All @@ -145,39 +150,47 @@ fn spawn_controller_polling(client: SamplyHttpClient, config: Config) {

let req = sign_request(body, parts, &config, None).await.expect("Unable to sign request; this should always work");
// In the future this will poll actual control related tasks
match client.execute(req).await {
Ok(res) => {
match res.status() {
StatusCode::OK => {
// Process control task
},
status @ (StatusCode::GATEWAY_TIMEOUT | StatusCode::BAD_GATEWAY) => {
if retries_this_min < 10 {
retries_this_min += 1;
debug!("Connection to broker timed out; retrying.");
} else {
warn!("Retried more then 10 times in one minute getting status code: {status}");
tokio::time::sleep(RETRY_INTERVAL).await;
continue;
}
},
other => {
warn!("Got unexpected status getting control tasks from broker: {other}");
tokio::time::sleep(RETRY_INTERVAL).await;
}
};
},
let res = match client.execute(req).await {
Ok(res) if res.status() != StatusCode::OK => {
if retries_this_min < 10 {
retries_this_min += 1;
debug!("Unexpected status code getting control tasks from broker: {}", res.status());
} else {
warn!("Retried more then 10 times in one minute getting status code: {}", res.status());
tokio::time::sleep(RETRY_INTERVAL).await;
}
continue;
}
Ok(res) => res,
Err(e) if e.is_timeout() => {
debug!("Connection to broker timed out; retrying: {e}");
continue;
},
Err(e) => {
warn!("Error getting control tasks from broker; retrying in {}s: {e}", RETRY_INTERVAL.as_secs());
tokio::time::sleep(RETRY_INTERVAL).await;
continue;
}
};
if reset_interval.is_elapsed() {
retries_this_min = 0;
reset_interval.as_mut().reset(Instant::now() + Duration::from_secs(60));
let incoming = res
.bytes_stream()
.map(|result| result.map_err(|error| {
let kind = error.is_timeout().then_some(std::io::ErrorKind::TimedOut).unwrap_or(std::io::ErrorKind::Other);
std::io::Error::new(kind, format!("IO Error: {error}"))
}))
.into_async_read();
let mut reader = async_sse::decode(incoming);
while let Some(ev) = reader.next().await {
match ev {
Ok(_)=> (),
Err(e) if e.downcast_ref::<std::io::Error>().unwrap().kind() == std::io::ErrorKind::TimedOut => {
debug!("SSE connection timed out");
break;
},
Err(err) => {
error!("Got error reading SSE stream: {err}");
}
};
}
}
});
Expand Down

0 comments on commit 8cd208c

Please sign in to comment.