From 7fa313276c5dbe46f038a6c752f08ad77b7fc50e Mon Sep 17 00:00:00 2001 From: Xiliang Chen Date: Tue, 14 May 2024 21:35:00 +1200 Subject: [PATCH] Refactor endpoint (#178) * refactor endpoint and more * update health * update logging * clippy --- src/extensions/api/tests.rs | 51 ++- src/extensions/client/endpoint.rs | 477 +++++++++++++++++------ src/extensions/client/health.rs | 103 +---- src/extensions/client/mod.rs | 87 +++-- src/extensions/client/tests.rs | 56 ++- src/middlewares/methods/block_tag.rs | 9 +- src/middlewares/methods/inject_params.rs | 9 +- src/tests/merge_subscription.rs | 11 +- src/tests/upstream.rs | 11 +- 9 files changed, 544 insertions(+), 270 deletions(-) diff --git a/src/extensions/api/tests.rs b/src/extensions/api/tests.rs index 1c0ff0c..a82b730 100644 --- a/src/extensions/api/tests.rs +++ b/src/extensions/api/tests.rs @@ -1,6 +1,6 @@ use jsonrpsee::server::ServerHandle; use serde_json::json; -use std::{net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use tokio::sync::mpsc; use super::eth::EthApi; @@ -61,7 +61,14 @@ async fn create_client() -> ( ) { let (addr, server, head_rx, finalized_head_rx, block_hash_rx) = create_server().await; - let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); + let client = Client::new( + [format!("ws://{addr}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(); (client, server, head_rx, finalized_head_rx, block_hash_rx) } @@ -168,7 +175,14 @@ async fn rotate_endpoint_on_stale() { let (addr, server, mut head_rx, _, mut block_rx) = create_server().await; let (addr2, server2, mut head_rx2, _, mut block_rx2) = create_server().await; - let client = Client::with_endpoints([format!("ws://{addr}"), format!("ws://{addr2}")]).unwrap(); + let client = Client::new( + [format!("ws://{addr}"), format!("ws://{addr2}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(); let api = SubstrateApi::new(Arc::new(client), std::time::Duration::from_millis(100)); let head = api.get_head(); @@ -231,7 +245,14 @@ async fn rotate_endpoint_on_head_mismatch() { let (addr1, server1, mut head_rx1, mut finalized_head_rx1, mut block_rx1) = create_server().await; let (addr2, server2, mut head_rx2, mut finalized_head_rx2, mut block_rx2) = create_server().await; - let client = Client::with_endpoints([format!("ws://{addr1}"), format!("ws://{addr2}")]).unwrap(); + let client = Client::new( + [format!("ws://{addr1}"), format!("ws://{addr2}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(); let client = Arc::new(client); let api = SubstrateApi::new(client.clone(), std::time::Duration::from_millis(100)); @@ -332,7 +353,16 @@ async fn rotate_endpoint_on_head_mismatch() { #[tokio::test] async fn substrate_background_tasks_abort_on_drop() { let (addr, _server, mut head_rx, mut finalized_head_rx, _) = create_server().await; - let client = Arc::new(Client::with_endpoints([format!("ws://{addr}")]).unwrap()); + let client = Arc::new( + Client::new( + [format!("ws://{addr}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(), + ); let api = SubstrateApi::new(client, std::time::Duration::from_millis(100)); // background tasks started @@ -352,7 +382,16 @@ async fn substrate_background_tasks_abort_on_drop() { #[tokio::test] async fn eth_background_tasks_abort_on_drop() { let (addr, _server, mut subscription_rx, mut block_rx) = create_eth_server().await; - let client = Arc::new(Client::with_endpoints([format!("ws://{addr}")]).unwrap()); + let client = Arc::new( + Client::new( + [format!("ws://{addr}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(), + ); let api = EthApi::new(client, std::time::Duration::from_millis(100)); diff --git a/src/extensions/client/endpoint.rs b/src/extensions/client/endpoint.rs index f266d29..f675253 100644 --- a/src/extensions/client/endpoint.rs +++ b/src/extensions/client/endpoint.rs @@ -1,14 +1,13 @@ -use super::health::{Event, Health}; -use crate::{ - extensions::client::{get_backoff_time, HealthCheckConfig}, - utils::errors, -}; +use super::health::{self, Event, Health}; +use crate::extensions::client::{get_backoff_time, HealthCheckConfig}; use jsonrpsee::{ async_client::Client, core::client::{ClientT, Subscription, SubscriptionClientT}, + core::JsonValue, ws_client::WsClientBuilder, }; use std::{ + fmt::{Debug, Formatter}, sync::{ atomic::{AtomicU32, Ordering}, Arc, @@ -16,12 +15,68 @@ use std::{ time::Duration, }; +enum Message { + Request { + method: String, + params: Vec, + response: tokio::sync::oneshot::Sender>, + timeout: Duration, + }, + Subscribe { + subscribe: String, + params: Vec, + unsubscribe: String, + response: tokio::sync::oneshot::Sender, jsonrpsee::core::Error>>, + timeout: Duration, + }, + Reconnect, +} + +impl Debug for Message { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Message::Request { + method, + params, + response: _, + timeout, + } => write!(f, "Request({method}, {params:?}, _, {timeout:?})"), + Message::Subscribe { + subscribe, + params, + unsubscribe, + response: _, + timeout, + } => write!(f, "Subscribe({subscribe}, {params:?}, {unsubscribe}, _, {timeout:?})"), + Message::Reconnect => write!(f, "Reconnect"), + } + } +} + +enum State { + Initial, + OnError(health::Event), + Connect(Option), + HandleMessage(Arc, Message), + WaitForMessage(Arc), +} + +impl Debug for State { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + State::Initial => write!(f, "Initial"), + State::OnError(e) => write!(f, "OnError({e:?})"), + State::Connect(m) => write!(f, "Connect({m:?})"), + State::HandleMessage(_c, m) => write!(f, "HandleMessage(_, {m:?})"), + State::WaitForMessage(_c) => write!(f, "WaitForMessage(_)"), + } + } +} + pub struct Endpoint { url: String, health: Arc, - client_rx: tokio::sync::watch::Receiver>>, - reconnect_tx: tokio::sync::mpsc::Sender<()>, - on_client_ready: Arc, + message_tx: tokio::sync::mpsc::Sender, background_tasks: Vec>, connect_counter: Arc, } @@ -35,78 +90,279 @@ impl Drop for Endpoint { impl Endpoint { pub fn new( url: String, - request_timeout: Option, - connection_timeout: Option, + request_timeout: Duration, + connection_timeout: Duration, health_config: Option, ) -> Self { - let (client_tx, client_rx) = tokio::sync::watch::channel(None); - let (reconnect_tx, mut reconnect_rx) = tokio::sync::mpsc::channel(1); - let on_client_ready = Arc::new(tokio::sync::Notify::new()); - let health = Arc::new(Health::new(url.clone(), health_config)); + tracing::info!("New endpoint: {url}"); + + let health = Arc::new(Health::new(url.clone())); let connect_counter = Arc::new(AtomicU32::new(0)); + let (message_tx, message_rx) = tokio::sync::mpsc::channel::(4096); - let url_ = url.clone(); - let health_ = health.clone(); - let on_client_ready_ = on_client_ready.clone(); - let connect_counter_ = connect_counter.clone(); + let mut endpoint = Self { + url: url.clone(), + health: health.clone(), + message_tx, + background_tasks: vec![], + connect_counter: connect_counter.clone(), + }; - // This task will try to connect to the endpoint and keep the connection alive - let connection_task = tokio::spawn(async move { + endpoint.start_background_task( + url, + request_timeout, + connection_timeout, + connect_counter, + message_rx, + health, + ); + if let Some(config) = health_config { + endpoint.start_health_monitor_task(config); + } + + endpoint + } + + fn start_background_task( + &mut self, + url: String, + request_timeout: Duration, + connection_timeout: Duration, + connect_counter: Arc, + mut message_rx: tokio::sync::mpsc::Receiver, + health: Arc, + ) { + let handler = tokio::spawn(async move { let connect_backoff_counter = Arc::new(AtomicU32::new(0)); + let mut state = State::Initial; + loop { - tracing::info!("Connecting endpoint: {url_}"); - connect_counter_.fetch_add(1, Ordering::Relaxed); - - let client = WsClientBuilder::default() - .request_timeout(request_timeout.unwrap_or(Duration::from_secs(30))) - .connection_timeout(connection_timeout.unwrap_or(Duration::from_secs(30))) - .max_buffer_capacity_per_subscription(2048) - .max_concurrent_requests(2048) - .max_response_size(20 * 1024 * 1024) - .build(&url_); - - match client.await { - Ok(client) => { - let client = Arc::new(client); - health_.update(Event::ConnectionSuccessful); - _ = client_tx.send(Some(client.clone())); - on_client_ready_.notify_waiters(); - tracing::info!("Endpoint connected: {url_}"); - connect_backoff_counter.store(0, Ordering::Relaxed); + tracing::trace!("{url} {state:?}"); + + let new_state = match state { + State::Initial => { + connect_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + // wait for messages before connecting + let msg = match message_rx.recv().await { + Some(Message::Reconnect) => None, + Some(msg @ Message::Request { .. } | msg @ Message::Subscribe { .. }) => Some(msg), + None => { + let url = url.clone(); + // channel is closed? exit + tracing::debug!("Endpoint {url} channel closed"); + return; + } + }; + State::Connect(msg) + } + State::OnError(evt) => { + health.update(evt); + tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; + State::Initial + } + State::Connect(msg) => { + // TODO: make the params configurable + let client = WsClientBuilder::default() + .request_timeout(request_timeout) + .connection_timeout(connection_timeout) + .max_buffer_capacity_per_subscription(2048) + .max_concurrent_requests(2048) + .max_response_size(20 * 1024 * 1024) + .build(url.clone()) + .await; + + match client { + Ok(client) => { + connect_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); + health.update(Event::ConnectionSuccessful); + if let Some(msg) = msg { + State::HandleMessage(Arc::new(client), msg) + } else { + State::WaitForMessage(Arc::new(client)) + } + } + Err(err) => { + tracing::debug!("Endpoint {url} connection error: {err}"); + State::OnError(health::Event::ConnectionClosed) + } + } + } + State::HandleMessage(client, msg) => match msg { + Message::Request { + method, + params, + response, + timeout, + } => { + // don't block on making the request + let url = url.clone(); + let health = health.clone(); + let client2 = client.clone(); + tokio::spawn(async move { + let resp = match tokio::time::timeout( + timeout, + client2.request::>(&method, params), + ) + .await + { + Ok(resp) => resp, + Err(_) => { + tracing::warn!("Endpoint {url} request timeout: {method} timeout: {timeout:?}"); + health.update(Event::RequestTimeout); + Err(jsonrpsee::core::Error::RequestTimeout) + } + }; + if let Err(err) = &resp { + health.on_error(err); + } + + if response.send(resp).is_err() { + tracing::error!("Unable to send response to message channel"); + } + }); + + State::WaitForMessage(client) + } + Message::Subscribe { + subscribe, + params, + unsubscribe, + response, + timeout, + } => { + // don't block on making the request + let url = url.clone(); + let health = health.clone(); + let client2 = client.clone(); + tokio::spawn(async move { + let resp = match tokio::time::timeout( + timeout, + client2.subscribe::>( + &subscribe, + params, + &unsubscribe, + ), + ) + .await + { + Ok(resp) => resp, + Err(_) => { + tracing::warn!("Endpoint {url} subscription timeout: {subscribe}"); + health.update(Event::RequestTimeout); + Err(jsonrpsee::core::Error::RequestTimeout) + } + }; + if let Err(err) = &resp { + health.on_error(err); + } + + if response.send(resp).is_err() { + tracing::error!("Unable to send response to message channel"); + } + }); + + State::WaitForMessage(client) + } + Message::Reconnect => State::Initial, + }, + State::WaitForMessage(client) => { tokio::select! { - _ = reconnect_rx.recv() => { - tracing::debug!("Endpoint reconnect requested: {url_}"); + msg = message_rx.recv() => { + match msg { + Some(msg) => State::HandleMessage(client, msg), + None => { + // channel is closed? exit + tracing::debug!("Endpoint {url} channel closed"); + return + } + } + }, - _ = client.on_disconnect() => { - tracing::debug!("Endpoint disconnected: {url_}"); + () = client.on_disconnect() => { + tracing::debug!("Endpoint {url} disconnected"); + State::OnError(health::Event::ConnectionClosed) } } } + }; + + state = new_state; + } + }); + + self.background_tasks.push(handler); + } + + fn start_health_monitor_task(&mut self, config: HealthCheckConfig) { + let message_tx = self.message_tx.clone(); + let health = self.health.clone(); + let url = self.url.clone(); + + let handler = tokio::spawn(async move { + let health_response = config.response.clone(); + let interval = Duration::from_secs(config.interval_sec); + let healthy_response_time = Duration::from_millis(config.healthy_response_time_ms); + let max_response_time: Duration = Duration::from_millis(config.healthy_response_time_ms * 2); + + loop { + // Wait for the next interval + tokio::time::sleep(interval).await; + + let request_start = std::time::Instant::now(); + + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let res = message_tx + .send(Message::Request { + method: config.health_method.clone(), + params: vec![], + response: response_tx, + timeout: max_response_time, + }) + .await; + + if let Err(err) = res { + tracing::error!("{url} Unexpected error in message channel: {err}"); + } + + let res = match response_rx.await { + Ok(resp) => resp, Err(err) => { - health_.on_error(&err); - _ = client_tx.send(None); - tracing::warn!("Unable to connect to endpoint: {url_} error: {err}"); + tracing::error!("{url} Unexpected error in response channel: {err}"); + Err(jsonrpsee::core::Error::Custom("Internal server error".into())) + } + }; + + match res { + Ok(response) => { + let duration = request_start.elapsed(); + + // Check response + if let Some(ref health_response) = health_response { + if !health_response.validate(&response) { + health.update(Event::Unhealthy); + continue; + } + } + + // Check response time + if duration > healthy_response_time { + tracing::warn!("{url} response time is too long: {duration:?}"); + health.update(Event::SlowResponse); + continue; + } + + health.update(Event::ResponseOk); + } + Err(err) => { + health.on_error(&err); } } - // Wait a bit before trying to reconnect - tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; } }); - // This task will check the health of the endpoint and update the health score - let health_checker = Health::monitor(health.clone(), client_rx.clone(), on_client_ready.clone()); - - Self { - url, - health, - client_rx, - reconnect_tx, - on_client_ready, - background_tasks: vec![connection_task, health_checker], - connect_counter, - } + self.background_tasks.push(handler); } pub fn url(&self) -> &str { @@ -117,13 +373,6 @@ impl Endpoint { self.health.as_ref() } - pub async fn connected(&self) { - if self.client_rx.borrow().is_some() { - return; - } - self.on_client_ready.notified().await; - } - pub fn connect_counter(&self) -> u32 { self.connect_counter.load(Ordering::Relaxed) } @@ -134,28 +383,26 @@ impl Endpoint { params: Vec, timeout: Duration, ) -> Result { - match tokio::time::timeout(timeout, async { - self.connected().await; - let client = self - .client_rx - .borrow() - .clone() - .ok_or(errors::failed("client not connected"))?; - match client.request(method, params.clone()).await { - Ok(resp) => Ok(resp), - Err(err) => { - self.health.on_error(&err); - Err(err) - } - } - }) - .await - { - Ok(res) => res, - Err(_) => { - tracing::error!("request timed out method: {method} params: {params:?}"); - self.health.on_error(&jsonrpsee::core::Error::RequestTimeout); - Err(jsonrpsee::core::Error::RequestTimeout) + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let res = self + .message_tx + .send(Message::Request { + method: method.into(), + params, + response: response_tx, + timeout, + }) + .await; + + if let Err(err) = res { + tracing::error!("Unexpected error in message channel: {err}"); + } + + match response_rx.await { + Ok(resp) => resp, + Err(err) => { + tracing::error!("Unexpected error in response channel: {err}"); + Err(jsonrpsee::core::Error::Custom("Internal server error".into())) } } } @@ -167,37 +414,35 @@ impl Endpoint { unsubscribe_method: &str, timeout: Duration, ) -> Result, jsonrpsee::core::Error> { - match tokio::time::timeout(timeout, async { - self.connected().await; - let client = self - .client_rx - .borrow() - .clone() - .ok_or(errors::failed("client not connected"))?; - match client - .subscribe(subscribe_method, params.clone(), unsubscribe_method) - .await - { - Ok(resp) => Ok(resp), - Err(err) => { - self.health.on_error(&err); - Err(err) - } - } - }) - .await - { - Ok(res) => res, - Err(_) => { - tracing::error!("subscribe timed out subscribe: {subscribe_method} params: {params:?}"); - self.health.on_error(&jsonrpsee::core::Error::RequestTimeout); - Err(jsonrpsee::core::Error::RequestTimeout) + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let res = self + .message_tx + .send(Message::Subscribe { + subscribe: subscribe_method.into(), + params, + unsubscribe: unsubscribe_method.into(), + response: response_tx, + timeout, + }) + .await; + + if let Err(err) = res { + tracing::error!("Unexpected error in message channel: {err}"); + } + + match response_rx.await { + Ok(resp) => resp, + Err(err) => { + tracing::error!("Unexpected error in response channel: {err}"); + Err(jsonrpsee::core::Error::Custom("Internal server error".into())) } } } pub async fn reconnect(&self) { - // notify the client to reconnect - self.reconnect_tx.send(()).await.unwrap(); + let res = self.message_tx.send(Message::Reconnect).await; + if let Err(err) = res { + tracing::error!("Unexpected error in message channel: {err}"); + } } } diff --git a/src/extensions/client/health.rs b/src/extensions/client/health.rs index 36793c7..d7224f0 100644 --- a/src/extensions/client/health.rs +++ b/src/extensions/client/health.rs @@ -1,31 +1,29 @@ -use crate::extensions::client::HealthCheckConfig; -use jsonrpsee::{async_client::Client, core::client::ClientT}; -use std::{ - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, - }, - time::Duration, -}; +use std::sync::atomic::{AtomicU32, Ordering}; + +const MAX_SCORE: u32 = 100; +const THRESHOLD: u32 = 50; #[derive(Debug)] pub enum Event { ResponseOk, + ConnectionSuccessful, SlowResponse, RequestTimeout, - ConnectionSuccessful, ServerError, - StaleChain, + Unhealthy, + ConnectionClosed, } impl Event { pub fn update_score(&self, current: u32) -> u32 { u32::min( match self { + Event::ConnectionSuccessful => current.saturating_add(60), Event::ResponseOk => current.saturating_add(2), - Event::SlowResponse => current.saturating_sub(5), - Event::RequestTimeout | Event::ServerError | Event::StaleChain => 0, - Event::ConnectionSuccessful => MAX_SCORE / 5 * 4, // 80% of max score + Event::SlowResponse => current.saturating_sub(20), + Event::RequestTimeout => current.saturating_sub(40), + Event::ConnectionClosed => current.saturating_sub(30), + Event::ServerError | Event::Unhealthy => 0, }, MAX_SCORE, ) @@ -35,19 +33,14 @@ impl Event { #[derive(Debug, Default)] pub struct Health { url: String, - config: Option, score: AtomicU32, unhealthy: tokio::sync::Notify, } -const MAX_SCORE: u32 = 100; -const THRESHOLD: u32 = MAX_SCORE / 2; - impl Health { - pub fn new(url: String, config: Option) -> Self { + pub fn new(url: String) -> Self { Self { url, - config, score: AtomicU32::new(0), unhealthy: tokio::sync::Notify::new(), } @@ -65,13 +58,13 @@ impl Health { } self.score.store(new_score, Ordering::Relaxed); tracing::trace!( - "Endpoint {:?} score updated from: {current_score} to: {new_score}", + "{:?} score updated from: {current_score} to: {new_score} because {event:?}", self.url ); // Notify waiters if the score has dropped below the threshold if current_score >= THRESHOLD && new_score < THRESHOLD { - tracing::warn!("Endpoint {:?} became unhealthy", self.url); + tracing::warn!("{:?} became unhealthy", self.url); self.unhealthy.notify_waiters(); } } @@ -82,11 +75,11 @@ impl Health { // NOT SERVER ERROR } jsonrpsee::core::Error::RequestTimeout => { - tracing::warn!("Endpoint {:?} request timeout", self.url); + tracing::warn!("{:?} request timeout", self.url); self.update(Event::RequestTimeout); } _ => { - tracing::warn!("Endpoint {:?} responded with error: {err:?}", self.url); + tracing::warn!("{:?} responded with error: {err:?}", self.url); self.update(Event::ServerError); } }; @@ -96,65 +89,3 @@ impl Health { self.unhealthy.notified().await; } } - -impl Health { - pub fn monitor( - health: Arc, - client_rx_: tokio::sync::watch::Receiver>>, - on_client_ready: Arc, - ) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let config = match health.config { - Some(ref config) => config, - None => return, - }; - - // Wait for the client to be ready before starting the health check - on_client_ready.notified().await; - - let method_name = config.health_method.as_ref().expect("Invalid health config"); - let health_response = config.response.clone(); - let interval = Duration::from_secs(config.interval_sec); - let healthy_response_time = Duration::from_millis(config.healthy_response_time_ms); - - let client = match client_rx_.borrow().clone() { - Some(client) => client, - None => return, - }; - - loop { - // Wait for the next interval - tokio::time::sleep(interval).await; - - let request_start = std::time::Instant::now(); - match client - .request::>(method_name, vec![]) - .await - { - Ok(response) => { - let duration = request_start.elapsed(); - - // Check response - if let Some(ref health_response) = health_response { - if !health_response.validate(&response) { - health.update(Event::StaleChain); - continue; - } - } - - // Check response time - if duration > healthy_response_time { - health.update(Event::SlowResponse); - continue; - } - - health.update(Event::ResponseOk); - } - Err(err) => { - health.on_error(&err); - } - } - } - }) - } -} diff --git a/src/extensions/client/mod.rs b/src/extensions/client/mod.rs index 569d3e8..af0a93b 100644 --- a/src/extensions/client/mod.rs +++ b/src/extensions/client/mod.rs @@ -1,4 +1,5 @@ use std::{ + fmt::{Debug, Formatter}, sync::{atomic::AtomicU32, Arc}, time::Duration, }; @@ -114,23 +115,12 @@ pub struct HealthCheckConfig { pub interval_sec: u64, #[serde(default = "healthy_response_time_ms")] pub healthy_response_time_ms: u64, - pub health_method: Option, + pub health_method: String, pub response: Option, } -impl Default for HealthCheckConfig { - fn default() -> Self { - Self { - interval_sec: interval_sec(), - healthy_response_time_ms: healthy_response_time_ms(), - health_method: None, - response: None, - } - } -} - pub fn interval_sec() -> u64 { - 10 + 300 } pub fn healthy_response_time_ms() -> u64 { @@ -167,7 +157,6 @@ impl HealthResponse { } } -#[derive(Debug)] enum Message { Request { method: String, @@ -185,27 +174,57 @@ enum Message { RotateEndpoint, } +impl Debug for Message { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Message::Request { + method, + params, + response: _, + retries, + } => write!(f, "Request({method}, {params:?}, _, {retries})"), + Message::Subscribe { + subscribe, + params, + unsubscribe, + response: _, + retries, + } => write!(f, "Subscribe({subscribe}, {params:?}, {unsubscribe}, _, {retries})"), + Message::RotateEndpoint => write!(f, "RotateEndpoint"), + } + } +} + #[async_trait] impl Extension for Client { type Config = ClientConfig; async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { let health_check = config.health_check.clone(); - if config.shuffle_endpoints { + let endpoints = if config.shuffle_endpoints { let mut endpoints = config.endpoints.clone(); endpoints.shuffle(&mut thread_rng()); - Ok(Self::new(endpoints, None, None, None, health_check)?) + endpoints } else { - Ok(Self::new(config.endpoints.clone(), None, None, None, health_check)?) - } + config.endpoints.clone() + }; + + // TODO: make the params configurable + Ok(Self::new( + endpoints, + Duration::from_secs(30), + Duration::from_secs(30), + None, + health_check, + )?) } } impl Client { pub fn new( endpoints: impl IntoIterator>, - request_timeout: Option, - connection_timeout: Option, + request_timeout: Duration, + connection_timeout: Duration, retries: Option, health_config: Option, ) -> Result { @@ -252,14 +271,9 @@ impl Client { let select_healtiest = |endpoints: Vec>, current_idx: usize| async move { if endpoints.len() == 1 { let selected_endpoint = endpoints[0].clone(); - // Ensure it's connected - selected_endpoint.connected().await; return (selected_endpoint, 0); } - // wait for at least one endpoint to connect - futures::future::select_all(endpoints.iter().map(|x| x.connected().boxed())).await; - let (idx, endpoint) = endpoints .iter() .enumerate() @@ -282,13 +296,10 @@ impl Client { } }; - let handle_message = |message: Message, endpoint: Arc, rotation_notify: Arc| { + let handle_message = |message: Message, endpoint: Arc| { let tx = message_tx_bg.clone(); let request_backoff_counter = request_backoff_counter.clone(); - // total timeout for a request - let task_timeout = request_timeout.unwrap_or(Duration::from_secs(30)); - tokio::spawn(async move { match message { Message::Request { @@ -304,7 +315,7 @@ impl Client { return; } - match endpoint.request(&method, params.clone(), task_timeout).await { + match endpoint.request(&method, params.clone(), request_timeout).await { result @ Ok(_) => { request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); // make sure it's still connected @@ -320,9 +331,6 @@ impl Client { | Error::Transport(_) | Error::RestartNeeded(_) | Error::MaxSlotsExceeded => { - // Make sure endpoint is rotated - rotation_notify.notified().await; - tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; // make sure it's still connected @@ -367,7 +375,7 @@ impl Client { retries = retries.saturating_sub(1); match endpoint - .subscribe(&subscribe, params.clone(), &unsubscribe, task_timeout) + .subscribe(&subscribe, params.clone(), &unsubscribe, request_timeout) .await { result @ Ok(_) => { @@ -385,9 +393,6 @@ impl Client { | Error::Transport(_) | Error::RestartNeeded(_) | Error::MaxSlotsExceeded => { - // Make sure endpoint is rotated - rotation_notify.notified().await; - tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; // make sure it's still connected @@ -444,8 +449,8 @@ impl Client { tracing::warn!("Switch to endpoint: {new_url}", new_url=new_selected_endpoint.url()); selected_endpoint = new_selected_endpoint; current_endpoint_idx = new_current_endpoint_idx; - rotation_notify_bg.notify_waiters(); } + rotation_notify_bg.notify_waiters(); } message = message_rx.recv() => { tracing::trace!("Received message {message:?}"); @@ -455,7 +460,7 @@ impl Client { (selected_endpoint, current_endpoint_idx) = next_endpoint(current_endpoint_idx).await; rotation_notify_bg.notify_waiters(); } - Some(message) => handle_message(message, selected_endpoint.clone(), rotation_notify_bg.clone()), + Some(message) => handle_message(message, selected_endpoint.clone()), None => { tracing::debug!("Client dropped"); break; @@ -475,10 +480,6 @@ impl Client { }) } - pub fn with_endpoints(endpoints: impl IntoIterator>) -> Result { - Self::new(endpoints, None, None, None, None) - } - pub fn endpoints(&self) -> &Vec> { self.endpoints.as_ref() } diff --git a/src/extensions/client/tests.rs b/src/extensions/client/tests.rs index 4a7f126..cf229a8 100644 --- a/src/extensions/client/tests.rs +++ b/src/extensions/client/tests.rs @@ -11,7 +11,14 @@ use tokio::sync::mpsc; async fn basic_request() { let (addr, handle, mut rx, _) = dummy_server().await; - let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); + let client = Client::new( + [format!("ws://{addr}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(); let task = tokio::spawn(async move { let req = rx.recv().await.unwrap(); @@ -31,7 +38,14 @@ async fn basic_request() { async fn basic_subscription() { let (addr, handle, _, mut rx) = dummy_server().await; - let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); + let client = Client::new( + [format!("ws://{addr}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(); let task = tokio::spawn(async move { let sub = rx.recv().await.unwrap(); @@ -67,10 +81,15 @@ async fn multiple_endpoints() { format!("ws://{addr2}"), format!("ws://{addr3}"), ], + Duration::from_secs(1), + Duration::from_secs(1), None, - None, - None, - Some(Default::default()), + Some(HealthCheckConfig { + interval_sec: 1, + healthy_response_time_ms: 250, + health_method: "mock_rpc".into(), + response: None, + }), ) .unwrap(); @@ -122,7 +141,14 @@ async fn multiple_endpoints() { async fn concurrent_requests() { let (addr, handle, mut rx, _) = dummy_server().await; - let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); + let client = Client::new( + [format!("ws://{addr}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(); let task = tokio::spawn(async move { let req1 = rx.recv().await.unwrap(); @@ -158,8 +184,8 @@ async fn retry_requests_successful() { let client = Client::new( [format!("ws://{addr1}"), format!("ws://{addr2}")], - Some(Duration::from_millis(100)), - None, + Duration::from_millis(100), + Duration::from_millis(100), Some(2), None, ) @@ -196,8 +222,8 @@ async fn retry_requests_out_of_retries() { let client = Client::new( [format!("ws://{addr1}"), format!("ws://{addr2}")], - Some(Duration::from_millis(100)), - None, + Duration::from_millis(100), + Duration::from_millis(100), Some(2), None, ) @@ -260,13 +286,13 @@ async fn health_check_works() { let client = Client::new( [format!("ws://{addr1}"), format!("ws://{addr2}")], - None, - None, + Duration::from_secs(1), + Duration::from_secs(1), None, Some(HealthCheckConfig { interval_sec: 1, healthy_response_time_ms: 250, - health_method: Some("system_health".into()), + health_method: "system_health".into(), response: Some(HealthResponse::Contains(vec![( "isSyncing".to_string(), Box::new(HealthResponse::Eq(false.into())), @@ -307,8 +333,8 @@ async fn reconnect_on_disconnect() { let client = Client::new( [format!("ws://{addr1}"), format!("ws://{addr2}")], - Some(Duration::from_millis(100)), - None, + Duration::from_millis(100), + Duration::from_millis(100), Some(2), None, ) diff --git a/src/middlewares/methods/block_tag.rs b/src/middlewares/methods/block_tag.rs index ed8e657..2f28e26 100644 --- a/src/middlewares/methods/block_tag.rs +++ b/src/middlewares/methods/block_tag.rs @@ -167,7 +167,14 @@ mod tests { let (addr, _server) = builder.build().await; - let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); + let client = Client::new( + [format!("ws://{addr}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(); let api = EthApi::new(Arc::new(client), Duration::from_secs(100)); ( diff --git a/src/middlewares/methods/inject_params.rs b/src/middlewares/methods/inject_params.rs index 77feaed..a23d635 100644 --- a/src/middlewares/methods/inject_params.rs +++ b/src/middlewares/methods/inject_params.rs @@ -213,7 +213,14 @@ mod tests { let (addr, _server) = builder.build().await; - let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); + let client = Client::new( + [format!("ws://{addr}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(); let api = SubstrateApi::new(Arc::new(client), Duration::from_secs(100)); ExecutionContext { diff --git a/src/tests/merge_subscription.rs b/src/tests/merge_subscription.rs index 37a7c53..89054b0 100644 --- a/src/tests/merge_subscription.rs +++ b/src/tests/merge_subscription.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use serde_json::json; use crate::{ @@ -97,7 +99,14 @@ async fn merge_subscription_works() { let subway_server = server::build(config).await.unwrap(); let addr = subway_server.addr; - let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); + let client = Client::new( + [format!("ws://{addr}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(); let mut first_sub = client .subscribe(subscribe_mock, vec![], unsubscribe_mock) .await diff --git a/src/tests/upstream.rs b/src/tests/upstream.rs index 9c730da..ab63169 100644 --- a/src/tests/upstream.rs +++ b/src/tests/upstream.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use crate::{ config::{Config, MergeStrategy, MiddlewaresConfig, RpcDefinitions, RpcSubscription}, extensions::{ @@ -73,7 +75,14 @@ async fn upstream_error_propagate() { let subway_server = server::build(config).await.unwrap(); let addr = subway_server.addr; - let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); + let client = Client::new( + [format!("ws://{addr}")], + Duration::from_secs(1), + Duration::from_secs(1), + None, + None, + ) + .unwrap(); let result = client.subscribe(subscribe_mock, vec![], unsubscribe_mock).await; assert!(result