diff --git a/.gitignore b/.gitignore index 547b3dc1..21b8fe2b 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ # IDE stuff .idea/ +.vscode/ # Allow developers to use python pre-commit locally /.pre-commit-config.yaml diff --git a/Cargo.lock b/Cargo.lock index d33d90ed..7872515a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4859,7 +4859,7 @@ version = "0.0.0" source = "git+https://github.com/worldcoin/orb-messages?rev=787ab78581b705af0946bcfe3a0453b64af2193f#787ab78581b705af0946bcfe3a0453b64af2193f" dependencies = [ "prost 0.12.6", - "prost-build", + "prost-build 0.12.6", "thiserror 1.0.65", ] @@ -4869,7 +4869,7 @@ version = "0.0.0" source = "git+https://github.com/worldcoin/orb-messages?rev=c439077c7c1bc3a8eb6f224c32b5b4d60d094809#c439077c7c1bc3a8eb6f224c32b5b4d60d094809" dependencies = [ "prost 0.12.6", - "prost-build", + "prost-build 0.12.6", "thiserror 1.0.65", ] @@ -4884,6 +4884,38 @@ dependencies = [ "uuid 1.11.0", ] +[[package]] +name = "orb-relay-client" +version = "0.1.0" +dependencies = [ + "clap", + "eyre", + "orb-relay-messages", + "orb-security-utils", + "orb-telemetry", + "rand", + "secrecy", + "serde_json", + "sha2", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "orb-relay-messages" +version = "0.0.0" +source = "git+https://github.com/worldcoin/orb-relay-messages.git?rev=f1c73751200ea9df7f1712ec203c7882f30f60f4#f1c73751200ea9df7f1712ec203c7882f30f60f4" +dependencies = [ + "prost 0.13.3", + "prost-build 0.13.3", + "prost-types 0.13.3", + "tonic", + "tonic-build", +] + [[package]] name = "orb-rgb" version = "0.0.0" @@ -5107,7 +5139,7 @@ dependencies = [ "orb-telemetry", "polling 2.5.2", "prost 0.12.6", - "prost-build", + "prost-build 0.12.6", "reqwest 0.12.9", "semver", "serde", @@ -5732,6 +5764,27 @@ dependencies = [ "tempfile", ] +[[package]] +name = "prost-build" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" +dependencies = [ + "bytes", + "heck 0.4.1", + "itertools 0.10.5", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost 0.13.3", + "prost-types 0.13.3", + "regex", + "syn 2.0.90", + "tempfile", +] + [[package]] name = "prost-derive" version = "0.12.6" @@ -7713,8 +7766,10 @@ dependencies = [ "percent-encoding", "pin-project", "prost 0.13.3", + "rustls-pemfile 2.2.0", "socket2", "tokio", + "tokio-rustls 0.26.1", "tokio-stream", "tower", "tower-layer", @@ -7722,6 +7777,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "tonic-build" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9557ce109ea773b399c9b9e5dca39294110b74f1f342cb347a80d1fce8c26a11" +dependencies = [ + "prettyplease", + "proc-macro2", + "prost-build 0.13.3", + "prost-types 0.13.3", + "quote", + "syn 2.0.90", +] + [[package]] name = "tower" version = "0.4.13" diff --git a/Cargo.toml b/Cargo.toml index 4e20b3c1..2b51c6bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ members = [ "mcu-interface", "mcu-util", "qr-link", + "relay-client", "security-utils", "seek-camera/sys", "seek-camera/wrapper", @@ -95,6 +96,7 @@ orb-build-info.path = "build-info" orb-const-concat.path = "const-concat" orb-header-parsing.path = "header-parsing" orb-mcu-interface.path = "mcu-interface" +orb-relay-client.path = "relay-client" orb-security-utils.path = "security-utils" orb-slot-ctrl.path = "slot-ctrl" orb-telemetry.path = "telemetry" @@ -106,6 +108,11 @@ seek-camera.path = "seek-camera/wrapper" git = "https://github.com/worldcoin/orb-messages" rev = "787ab78581b705af0946bcfe3a0453b64af2193f" +[workspace.dependencies.orb-relay-messages] +git = "https://github.com/worldcoin/orb-relay-messages.git" +rev = "f1c73751200ea9df7f1712ec203c7882f30f60f4" +features = ["client"] + [workspace.dependencies.nusb] git = "https://github.com/kevinmehall/nusb" rev = "3ec3508324cdd01ca288b91ddcb2f92fd6a6f813" diff --git a/relay-client/Cargo.toml b/relay-client/Cargo.toml new file mode 100644 index 00000000..7ef38395 --- /dev/null +++ b/relay-client/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "orb-relay-client" +version = "0.1.0" +publish = false +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +clap = { version = "4", features = ["derive"] } +eyre.workspace = true +orb-relay-messages.workspace = true +orb-security-utils = { workspace = true, features = ["reqwest"] } +orb-telemetry.workspace = true +rand = "0.8" +serde_json.workspace = true +secrecy.workspace = true +sha2 = "0.10" +tokio-stream.workspace = true +tokio-util.workspace = true +tokio.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true diff --git a/relay-client/src/bin/decode-msg.rs b/relay-client/src/bin/decode-msg.rs new file mode 100644 index 00000000..5a46b305 --- /dev/null +++ b/relay-client/src/bin/decode-msg.rs @@ -0,0 +1,39 @@ +use clap::Parser; +use eyre::{eyre, Result}; +use orb_relay_client::debug_any; +use orb_relay_messages::prost_types::Any; +use serde_json::Value; + +#[derive(Parser, Debug)] +struct Args { + #[arg()] + json: String, +} + +fn main() -> Result<()> { + let args = Args::parse(); + println!("{}", decode_payload(&args.json)?); + Ok(()) +} + +fn decode_payload(json: &str) -> Result { + println!("json: {}", json); + let v: Value = serde_json::from_str(json)?; + let any = Any { + type_url: v["type_url"] + .as_str() + .ok_or_else(|| eyre!("Invalid type_url"))? + .to_string(), + value: v["value"] + .as_array() + .ok_or_else(|| eyre!("Invalid value"))? + .iter() + .map(|n| { + n.as_u64() + .ok_or_else(|| eyre!("Invalid number")) + .map(|n| n as u8) + }) + .collect::>()?, + }; + Ok(debug_any(&Some(any))) +} diff --git a/relay-client/src/bin/manual-test.rs b/relay-client/src/bin/manual-test.rs new file mode 100644 index 00000000..8cf7fe79 --- /dev/null +++ b/relay-client/src/bin/manual-test.rs @@ -0,0 +1,720 @@ +use clap::Parser; +use eyre::{Ok, Result}; +use orb_relay_client::{client::Client, debug_any, PayloadMatcher}; +use orb_relay_messages::{common, self_serve}; +use rand::{distributions::Alphanumeric, Rng}; +use std::{ + env, + sync::LazyLock, + time::{Duration, Instant, SystemTime, UNIX_EPOCH}, +}; + +static BACKEND_URL: LazyLock = LazyLock::new(|| { + let backend = + env::var("RELAY_TOOL_BACKEND").unwrap_or_else(|_| "stage".to_string()); + match backend.as_str() { + "stage" => "https://relay.stage.orb.worldcoin.org", + "prod" => "https://relay.orb.worldcoin.org", + "local" => "http://127.0.0.1:8443", + _ => panic!("Invalid backend option"), + } + .to_string() +}); +static APP_KEY: LazyLock = + LazyLock::new(|| env::var("RELAY_TOOL_APP_KEY").unwrap_or_default()); +static ORB_KEY: LazyLock = + LazyLock::new(|| env::var("RELAY_TOOL_ORB_KEY").unwrap_or_default()); + +static ORB_ID: LazyLock = + LazyLock::new(|| env::var("RELAY_TOOL_ORB_ID").unwrap_or_default()); +static SESSION_ID: LazyLock = + LazyLock::new(|| env::var("RELAY_TOOL_SESSION_ID").unwrap_or_default()); + +#[derive(Parser, Debug)] +#[clap(author, version, about, long_about = None)] +struct Args { + /// Run only the stage_consumer_app function + #[clap(short = 'c', long = "consume-only")] + consume_only: bool, + /// Run only the stage_producer_app function + #[clap(short = 'p', long = "produce-only")] + produce_only: bool, + #[clap(short = 's', long = "start-orb-signup")] + start_orb_signup: bool, + #[clap(short = 'w', long = "slow-tests")] + slow_tests: bool, +} + +#[tokio::main] +async fn main() -> Result<()> { + orb_telemetry::TelemetryConfig::new() + .with_journald("worldcoin-relay-client") + .init(); + + let args = Args::parse(); + + if args.consume_only { + stage_consumer_app().await?; + } else if args.start_orb_signup { + stage_producer_from_app_start_orb_signup().await?; + } else if args.produce_only { + stage_producer_orb().await?; + } else { + app_to_orb().await?; + orb_to_app().await?; + orb_to_app_with_state_request().await?; + orb_to_app_blocking_send().await?; + if args.slow_tests { + orb_to_app_with_clients_created_later_and_delay().await?; + } + } + + Ok(()) +} + +async fn app_to_orb() -> Result<()> { + tracing::info!("== Running App to Orb =="); + let (orb_id, session_id) = get_ids(); + + let mut app_client = Client::new_as_app( + BACKEND_URL.to_string(), + APP_KEY.to_string(), + session_id.to_string(), + orb_id.to_string(), + ); + let now = Instant::now(); + app_client.connect().await?; + tracing::info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); + + let mut orb_client = Client::new_as_orb( + BACKEND_URL.to_string(), + ORB_KEY.to_string(), + orb_id.to_string(), + session_id.to_string(), + ); + let now = Instant::now(); + orb_client.connect().await?; + tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + + let now = Instant::now(); + let time_now = time_now()?; + tracing::info!("Sending time now: {}", time_now); + app_client + .send(common::v1::AnnounceOrbId { + orb_id: time_now.clone(), + mode_type: common::v1::announce_orb_id::ModeType::SelfServe.into(), + hardware_type: common::v1::announce_orb_id::HardwareType::Diamond.into(), + }) + .await?; + tracing::info!( + "Time took to send a message from the app: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + 'ext: loop { + #[expect(clippy::never_loop)] + for msg in orb_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + if let Some(common::v1::AnnounceOrbId { orb_id, .. }) = + common::v1::AnnounceOrbId::matches(msg.payload.as_ref().unwrap()) + { + assert!( + orb_id == time_now, + "Received orb_id is not the same as sent orb_id" + ); + break 'ext; + } + unreachable!("Received unexpected message: {msg:?}"); + } + } + tracing::info!( + "Time took to receive a message: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + app_client + .send(self_serve::orb::v1::SignupEnded { + success: true, + failure_feedback: [].to_vec(), + }) + .await?; + tracing::info!( + "Time took to send a second message: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + 'ext: loop { + #[expect(clippy::never_loop)] + for msg in orb_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + if let Some(self_serve::orb::v1::SignupEnded { success, .. }) = + self_serve::orb::v1::SignupEnded::matches(msg.payload.as_ref().unwrap()) + { + assert!(success, "Received: success is not true"); + break 'ext; + } + unreachable!("Received unexpected message: {msg:?}"); + } + } + tracing::info!( + "Time took to receive a second message: {}ms", + now.elapsed().as_millis() + ); + + orb_client + .graceful_shutdown(Duration::from_millis(500), Duration::from_millis(1000)) + .await; + app_client + .graceful_shutdown(Duration::from_millis(500), Duration::from_millis(1000)) + .await; + + Ok(()) +} + +async fn orb_to_app() -> Result<()> { + tracing::info!("== Running Orb to App =="); + let (orb_id, session_id) = get_ids(); + + let mut app_client = Client::new_as_app( + BACKEND_URL.to_string(), + APP_KEY.to_string(), + session_id.to_string(), + orb_id.to_string(), + ); + let now = Instant::now(); + app_client.connect().await?; + tracing::info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); + + let mut orb_client = Client::new_as_orb( + BACKEND_URL.to_string(), + ORB_KEY.to_string(), + orb_id.to_string(), + session_id.to_string(), + ); + let now = Instant::now(); + orb_client.connect().await?; + tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + + let now = Instant::now(); + let time_now = time_now()?; + tracing::info!("Sending time now: {}", time_now); + orb_client + .send(common::v1::AnnounceOrbId { + orb_id: time_now.clone(), + mode_type: common::v1::announce_orb_id::ModeType::SelfServe.into(), + hardware_type: common::v1::announce_orb_id::HardwareType::Diamond.into(), + }) + .await?; + tracing::info!( + "Time took to send a message from the app: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + 'ext: loop { + #[expect(clippy::never_loop)] + for msg in app_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + if let Some(common::v1::AnnounceOrbId { orb_id, .. }) = + common::v1::AnnounceOrbId::matches(msg.payload.as_ref().unwrap()) + { + assert!( + orb_id == time_now, + "Received orb_id is not the same as sent orb_id" + ); + break 'ext; + } + unreachable!("Received unexpected message: {msg:?}"); + } + } + tracing::info!( + "Time took to receive a message: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + orb_client + .send(self_serve::orb::v1::SignupEnded { + success: true, + failure_feedback: Vec::new(), + }) + .await?; + tracing::info!( + "Time took to send a second message: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + 'ext: loop { + #[expect(clippy::never_loop)] + for msg in app_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + if let Some(self_serve::orb::v1::SignupEnded { success, .. }) = + self_serve::orb::v1::SignupEnded::matches(msg.payload.as_ref().unwrap()) + { + assert!(success, "Received: success is not true"); + break 'ext; + } + unreachable!("Received unexpected message: {msg:?}"); + } + } + tracing::info!( + "Time took to receive a second message: {}ms", + now.elapsed().as_millis() + ); + + orb_client + .graceful_shutdown(Duration::from_millis(500), Duration::from_millis(1000)) + .await; + app_client + .graceful_shutdown(Duration::from_millis(500), Duration::from_millis(1000)) + .await; + + Ok(()) +} + +async fn orb_to_app_with_state_request() -> Result<()> { + tracing::info!("== Running Orb to App with state request =="); + let (orb_id, session_id) = get_ids(); + + let mut app_client = Client::new_as_app( + BACKEND_URL.to_string(), + APP_KEY.to_string(), + session_id.to_string(), + orb_id.to_string(), + ); + let now = Instant::now(); + app_client.connect().await?; + tracing::info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); + + let mut orb_client = Client::new_as_orb( + BACKEND_URL.to_string(), + ORB_KEY.to_string(), + orb_id.to_string(), + session_id.to_string(), + ); + let now = Instant::now(); + orb_client.connect().await?; + tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + + let now = Instant::now(); + app_client + .send(self_serve::app::v1::RequestState {}) + .await?; + tracing::info!( + "Time took to send RequestState from the app: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + 'ext: loop { + #[expect(clippy::never_loop)] + for msg in app_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + break 'ext; + } + } + tracing::info!( + "Time took to receive a message: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + let time_now = time_now()?; + tracing::info!("Sending time now: {}", time_now); + orb_client + .send(common::v1::AnnounceOrbId { + orb_id: time_now, + mode_type: common::v1::announce_orb_id::ModeType::SelfServe.into(), + hardware_type: common::v1::announce_orb_id::HardwareType::Diamond.into(), + }) + .await?; + tracing::info!( + "Time took to send a message from the app: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + 'ext: loop { + #[expect(clippy::never_loop)] + for msg in app_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + break 'ext; + } + } + tracing::info!( + "Time took to receive a message: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + app_client + .send(self_serve::app::v1::RequestState {}) + .await?; + tracing::info!( + "Time took to send RequestState from the app: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + 'ext: loop { + #[expect(clippy::never_loop)] + for msg in app_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + break 'ext; + } + } + tracing::info!( + "Time took to receive a message: {}ms", + now.elapsed().as_millis() + ); + + orb_client + .graceful_shutdown(Duration::from_millis(500), Duration::from_millis(1000)) + .await; + app_client + .graceful_shutdown(Duration::from_millis(500), Duration::from_millis(1000)) + .await; + + Ok(()) +} + +async fn orb_to_app_blocking_send() -> Result<()> { + tracing::info!("== Running Orb to App blocking send =="); + let (orb_id, session_id) = get_ids(); + + let mut app_client = Client::new_as_app( + BACKEND_URL.to_string(), + APP_KEY.to_string(), + session_id.to_string(), + orb_id.to_string(), + ); + let now = Instant::now(); + app_client.connect().await?; + tracing::info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); + + let mut orb_client = Client::new_as_orb( + BACKEND_URL.to_string(), + ORB_KEY.to_string(), + orb_id.to_string(), + session_id.to_string(), + ); + let now = Instant::now(); + orb_client.connect().await?; + tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + + let now = Instant::now(); + let time_now = time_now()?; + tracing::info!("Sending time now: {}", time_now); + orb_client + .send_blocking( + common::v1::AnnounceOrbId { + orb_id: time_now.clone(), + mode_type: common::v1::announce_orb_id::ModeType::SelfServe.into(), + hardware_type: common::v1::announce_orb_id::HardwareType::Diamond + .into(), + }, + Duration::from_secs(5), + ) + .await?; + tracing::info!( + "Time took to send a message from the app: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + 'ext: loop { + #[expect(clippy::never_loop)] + for msg in app_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + if let Some(common::v1::AnnounceOrbId { orb_id, .. }) = + common::v1::AnnounceOrbId::matches(msg.payload.as_ref().unwrap()) + { + assert!( + orb_id == time_now, + "Received orb_id is not the same as sent orb_id" + ); + break 'ext; + } + unreachable!("Received unexpected message: {msg:?}"); + } + } + tracing::info!( + "Time took to receive a message: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + orb_client + .send_blocking( + self_serve::orb::v1::SignupEnded { + success: true, + failure_feedback: Vec::new(), + }, + Duration::from_secs(5), + ) + .await?; + tracing::info!( + "Time took to send a second message: {}ms", + now.elapsed().as_millis() + ); + + let now = Instant::now(); + 'ext: loop { + #[expect(clippy::never_loop)] + for msg in app_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + if let Some(self_serve::orb::v1::SignupEnded { success, .. }) = + self_serve::orb::v1::SignupEnded::matches(msg.payload.as_ref().unwrap()) + { + assert!(success, "Received: success is not true"); + break 'ext; + } + unreachable!("Received unexpected message: {msg:?}"); + } + } + tracing::info!( + "Time took to receive a second message: {}ms", + now.elapsed().as_millis() + ); + + orb_client + .graceful_shutdown(Duration::from_millis(500), Duration::from_millis(1000)) + .await; + app_client + .graceful_shutdown(Duration::from_millis(500), Duration::from_millis(1000)) + .await; + + Ok(()) +} + +async fn orb_to_app_with_clients_created_later_and_delay() -> Result<()> { + let (orb_id, session_id) = get_ids(); + + let mut orb_client = Client::new_as_orb( + BACKEND_URL.to_string(), + ORB_KEY.to_string(), + orb_id.to_string(), + session_id.to_string(), + ); + let now = Instant::now(); + orb_client.connect().await?; + tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + + let now = Instant::now(); + let time_now = time_now()?; + tracing::info!("Sending time now: {}", time_now); + orb_client + .send(common::v1::AnnounceOrbId { + orb_id: time_now, + mode_type: common::v1::announce_orb_id::ModeType::SelfServe.into(), + hardware_type: common::v1::announce_orb_id::HardwareType::Diamond.into(), + }) + .await?; + tracing::info!( + "Time took to send a message from the app: {}ms", + now.elapsed().as_millis() + ); + + tracing::info!("Waiting for 60 seconds..."); + tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; + + let mut app_client = Client::new_as_app( + BACKEND_URL.to_string(), + APP_KEY.to_string(), + session_id.to_string(), + orb_id.to_string(), + ); + let now = Instant::now(); + app_client.connect().await?; + tracing::info!("Time took to app_connect: {}ms", now.elapsed().as_millis()); + + let now = Instant::now(); + 'ext: loop { + #[expect(clippy::never_loop)] + for msg in app_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + break 'ext; + } + } + tracing::info!( + "Time took to receive a message: {}ms", + now.elapsed().as_millis() + ); + + orb_client + .graceful_shutdown(Duration::from_millis(500), Duration::from_millis(1000)) + .await; + app_client + .graceful_shutdown(Duration::from_millis(500), Duration::from_millis(1000)) + .await; + + Ok(()) +} + +fn get_ids() -> (String, String) { + let mut rng = rand::thread_rng(); + let orb_id: String = (&mut rng) + .sample_iter(Alphanumeric) + .take(10) + .map(char::from) + .collect(); + let session_id: String = (&mut rng) + .sample_iter(Alphanumeric) + .take(10) + .map(char::from) + .collect(); + tracing::info!("Orb ID: {orb_id}, Session ID: {session_id}"); + (orb_id, session_id) +} + +fn time_now() -> Result { + Ok(SystemTime::now() + .duration_since(UNIX_EPOCH)? + .as_nanos() + .to_string()) +} + +async fn stage_consumer_app() -> Result<()> { + let mut app_client = Client::new_as_app( + BACKEND_URL.to_string(), + APP_KEY.to_string(), + SESSION_ID.to_string(), + ORB_ID.to_string(), + ); + let now = Instant::now(); + app_client.connect().await?; + tracing::info!("Time took to connect: {}ms", now.elapsed().as_millis()); + + loop { + for msg in app_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + } + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + } +} + +async fn stage_producer_orb() -> Result<()> { + let mut orb_client = Client::new_as_orb( + BACKEND_URL.to_string(), + ORB_KEY.to_string(), + ORB_ID.to_string(), + SESSION_ID.to_string(), + ); + let now = Instant::now(); + orb_client.connect().await?; + tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + + loop { + let time_now = time_now()?; + tracing::info!("Sending time now: {}", time_now); + orb_client + .send(common::v1::AnnounceOrbId { + orb_id: time_now, + mode_type: common::v1::announce_orb_id::ModeType::SelfServe.into(), + hardware_type: common::v1::announce_orb_id::HardwareType::Diamond + .into(), + }) + .await?; + tokio::time::sleep(tokio::time::Duration::from_secs(120)).await; + } +} + +async fn stage_producer_from_app_start_orb_signup() -> Result<()> { + let mut app_client = Client::new_as_app( + BACKEND_URL.to_string(), + APP_KEY.to_string(), + SESSION_ID.to_string(), + ORB_ID.to_string(), + ); + let now = Instant::now(); + app_client.connect().await?; + tracing::info!("Time took to orb_connect: {}ms", now.elapsed().as_millis()); + + tracing::info!("Sending StartCapture now"); + app_client + .send(self_serve::app::v1::StartCapture {}) + .await?; + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + + loop { + for msg in app_client.get_buffered_messages().await { + tracing::info!( + "Received message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + msg.src, + msg.dst, + msg.seq, + debug_any(&msg.payload) + ); + } + } +} diff --git a/relay-client/src/client.rs b/relay-client/src/client.rs new file mode 100644 index 00000000..9d74a445 --- /dev/null +++ b/relay-client/src/client.rs @@ -0,0 +1,757 @@ +//! Orb-Relay client +use crate::{debug_any, IntoPayload, PayloadMatcher}; +use eyre::{Context, OptionExt, Result}; +use orb_relay_messages::{ + common, + prost_types::Any, + relay::{ + connect_request::AuthMethod, entity::EntityType, relay_connect_request, + relay_connect_response, relay_service_client::RelayServiceClient, + ConnectRequest, ConnectResponse, Entity, Heartbeat, RelayConnectRequest, + RelayConnectResponse, RelayPayload, ZkpAuthRequest, + }, + self_serve, + tonic::{ + transport::{Certificate, Channel, ClientTlsConfig}, + Streaming, + }, +}; +use orb_security_utils::reqwest::{ + AWS_ROOT_CA1_CERT, AWS_ROOT_CA2_CERT, AWS_ROOT_CA3_CERT, AWS_ROOT_CA4_CERT, + GTS_ROOT_R1_CERT, GTS_ROOT_R2_CERT, GTS_ROOT_R3_CERT, GTS_ROOT_R4_CERT, + SFS_ROOT_G2_CERT, +}; +use secrecy::{ExposeSecret, SecretString}; +use std::{ + any::type_name, + collections::{BTreeMap, VecDeque}, + sync::Arc, +}; +use tokio::{ + sync::{ + mpsc::{self, Sender}, + oneshot, Mutex, + }, + time::{self, Duration}, +}; +use tokio_stream::{wrappers::ReceiverStream, StreamExt}; +use tokio_util::sync::CancellationToken; + +#[derive(Debug, Clone)] +pub struct TokenAuth { + token: SecretString, +} + +#[derive(Debug, Clone)] +pub struct ZkpAuth { + root: SecretString, + signal: SecretString, + nullifier_hash: SecretString, + proof: SecretString, +} + +#[derive(Debug, Clone)] +pub enum Auth { + Token(TokenAuth), + ZKP(ZkpAuth), +} + +#[derive(Debug, Clone)] +enum Mode { + Orb, + App, +} + +#[derive(Debug, Clone)] +struct Config { + src_id: String, + dst_id: String, + url: String, + + auth: Auth, + + // TODO: Maybe split this into a separate struct and a trait? + mode: Mode, + + max_buffer_size: usize, + reconnect_delay: Duration, + keep_alive_interval: Duration, + keep_alive_timeout: Duration, + connect_timeout: Duration, + request_timeout: Duration, + heartbeat_interval: Duration, +} + +enum Command { + ReplayPendingMessages, + GetPendingMessages(oneshot::Sender), + Reconnect, +} + +enum OutgoingMessage { + Normal(Any), + Blocking(Any, oneshot::Sender<()>), +} + +/// Client state +pub struct Client { + message_buffer: Arc>>, + outgoing_tx: Option>, + command_tx: Option>, + shutdown_token: Option, + shutdown_completed: Option>, + config: Config, +} + +impl Client { + fn no_state(&self) -> RelayConnectRequest { + let (src_t, dst_t) = match self.config.mode { + Mode::Orb => (EntityType::Orb as i32, EntityType::App as i32), + Mode::App => (EntityType::App as i32, EntityType::Orb as i32), + }; + RelayPayload { + src: Some(Entity { + id: self.config.src_id.clone(), + entity_type: src_t, + }), + dst: Some(Entity { + id: self.config.dst_id.clone(), + entity_type: dst_t, + }), + payload: Some(common::v1::NoState::default().into_payload()), + seq: 0, + } + .into() + } + + #[must_use] + fn new( + url: String, + auth: Auth, + src_id: String, + dst_id: String, + mode: Mode, + ) -> Self { + Self { + message_buffer: Arc::new(Mutex::new(VecDeque::new())), + outgoing_tx: None, + command_tx: None, + shutdown_token: None, + shutdown_completed: None, + config: Config { + src_id, + dst_id, + url, + auth, + mode, + max_buffer_size: 100, + reconnect_delay: Duration::from_secs(1), + keep_alive_interval: Duration::from_secs(5), + keep_alive_timeout: Duration::from_secs(10), + connect_timeout: Duration::from_secs(20), + request_timeout: Duration::from_secs(20), + heartbeat_interval: Duration::from_secs(15), + }, + } + } + + /// Create a new client that sends messages from an Orb to an App + #[must_use] + pub fn new_as_orb( + url: String, + token: String, + orb_id: String, + session_id: String, + ) -> Self { + Self::new( + url, + Auth::Token(TokenAuth { + token: token.into(), + }), + orb_id, + session_id, + Mode::Orb, + ) + } + + /// Create a new client that sends messages from an App to an Orb + #[must_use] + pub fn new_as_app( + url: String, + token: String, + session_id: String, + orb_id: String, + ) -> Self { + Self::new( + url, + Auth::Token(TokenAuth { + token: token.into(), + }), + session_id, + orb_id, + Mode::App, + ) + } + + /// Create a new client that sends messages from an App to an Orb (using ZKP as auth method) + #[must_use] + pub fn new_as_app_zkp( + url: String, + root: String, + signal: String, + nullifier_hash: String, + proof: String, + session_id: String, + orb_id: String, + ) -> Self { + Self::new( + url, + Auth::ZKP(ZkpAuth { + root: root.into(), + signal: signal.into(), + nullifier_hash: nullifier_hash.into(), + proof: proof.into(), + }), + session_id, + orb_id, + Mode::App, + ) + } + + async fn check_for_msg(&self) -> Option { + for msg in self.get_buffered_messages().await { + if let Some(payload) = &msg.payload { + if let Some(specific_payload) = T::matches(payload) { + return Some(specific_payload); + } + tracing::warn!( + "While waiting for payload of type {:?}, we got: {:?}", + type_name::(), + debug_any(&msg.payload) + ); + } + } + None + } + + /// Get buffered messages + pub async fn get_buffered_messages(&self) -> VecDeque { + let mut buffer = self.message_buffer.lock().await; + std::mem::take(&mut *buffer) + } + + /// Connect to the Orb-Relay server + pub async fn connect(&mut self) -> Result<()> { + let shutdown_token = CancellationToken::new(); + self.shutdown_token = Some(shutdown_token.clone()); + + let (connection_established_tx, connection_established_rx) = oneshot::channel(); + + let message_buffer = Arc::clone(&self.message_buffer); + // TODO: Make the buffer size configurable + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(32); + self.outgoing_tx = Some(outgoing_tx); + let (command_tx, mut command_rx) = mpsc::channel(32); + self.command_tx = Some(command_tx); + let (shutdown_completed_tx, shutdown_completed_rx) = oneshot::channel(); + self.shutdown_completed = Some(shutdown_completed_rx); + + let config = self.config.clone(); + let no_state = self.no_state(); + + tracing::info!( + "Connecting with: src_id: {}, dst_id: {}", + config.src_id, + config.dst_id + ); + tokio::spawn(async move { + let mut agent = PollerAgent { + config: &config, + pending_messages: Default::default(), + last_message: no_state, + seq: 0, + }; + let mut connection_established_tx = Some(connection_established_tx); + + loop { + if let Err(e) = agent + .main_loop( + &message_buffer, + shutdown_token.clone(), + &mut outgoing_rx, + &mut command_rx, + connection_established_tx.take(), + ) + .await + { + tracing::error!("Connection error: {e}"); + } + + if shutdown_token.is_cancelled() { + tracing::info!("Connection shutdown"); + break; + } + + tracing::info!( + "Reconnecting in {}s ...", + config.reconnect_delay.as_secs() + ); + tokio::time::sleep(config.reconnect_delay).await; + } + shutdown_completed_tx.send(()).ok(); + }); + + // Wait for the connection to be established. Notice that if the first connection attempt, this will pop an + // error as expected behavior. + connection_established_rx + .await + .wrap_err("Failed to establish connection")?; + + Ok(()) + } + + /// Wait for a specific message type + pub async fn wait_for_msg( + &self, + wait: Duration, + ) -> Result { + let start_time = tokio::time::Instant::now(); + loop { + if let Some(payload) = self.check_for_msg::().await { + return Ok(payload); + } + if start_time.elapsed() >= wait { + return Err(eyre::eyre!( + "Timeout waiting for payload of type {:?}", + std::any::type_name::() + )); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + + /// Send a message to current session + pub async fn send(&mut self, msg: T) -> Result<()> { + self.send_internal(msg, None).await + } + + /// Send a message and wait until the corresponding ack is received + pub async fn send_blocking( + &mut self, + msg: T, + timeout: Duration, + ) -> Result<()> { + let (ack_tx, ack_rx) = oneshot::channel(); + self.send_internal(msg, Some(ack_tx)).await?; + match tokio::time::timeout(timeout, ack_rx).await { + Ok(Ok(())) => Ok(()), + Ok(Err(_)) => Err(eyre::eyre!("Failed to receive ack: sender dropped")), + Err(_) => Err(eyre::eyre!("Timeout waiting for ack")), + } + } + + async fn send_internal( + &mut self, + msg: T, + ack_tx: Option>, + ) -> Result<()> { + let msg = match ack_tx { + Some(ack_tx) => OutgoingMessage::Blocking(msg.into_payload(), ack_tx), + None => OutgoingMessage::Normal(msg.into_payload()), + }; + self.outgoing_tx + .as_ref() + .ok_or_eyre("client not connected")? + .send(msg) + .await + .inspect_err(|e| tracing::error!("Failed to send payload: {e}")) + .wrap_err("Failed to send payload") + } + + /// Check if there are any pending messages + pub async fn has_pending_messages(&self) -> Result { + let command_tx = self + .command_tx + .as_ref() + .ok_or_else(|| eyre::eyre!("Client not connected"))?; + let (reply_tx, reply_rx) = oneshot::channel(); + command_tx + .send(Command::GetPendingMessages(reply_tx)) + .await?; + let pending_count = reply_rx.await?; + Ok(pending_count) + } + + /// Request to replay pending messages + pub async fn replay_pending_messages(&self) -> Result<()> { + let command_tx = self + .command_tx + .as_ref() + .ok_or_else(|| eyre::eyre!("Client not connected"))?; + command_tx.send(Command::ReplayPendingMessages).await?; + Ok(()) + } + + /// Reconnect the client. On restart, pending messages will be replayed. + pub async fn reconnect(&self) -> Result<()> { + let command_tx = self + .command_tx + .as_ref() + .ok_or_else(|| eyre::eyre!("Client not connected"))?; + command_tx.send(Command::Reconnect).await?; + Ok(()) + } + + pub async fn graceful_shutdown( + &mut self, + wait_for_pending_messages: Duration, + wait_for_shutdown: Duration, + ) { + // Let's wait for all acks to be received + if self.has_pending_messages().await.map_or(false, |n| n > 0) { + tracing::info!( + "Giving {}ms for pending messages to be acked", + wait_for_pending_messages.as_millis() + ); + tokio::time::sleep(wait_for_pending_messages).await; + } + // If there are still pending messages, we retry to send them + if self.has_pending_messages().await.map_or(false, |n| n > 0) { + tracing::info!("There are still pending messages, replaying..."); + if let Ok(()) = self.replay_pending_messages().await { + tokio::time::sleep(wait_for_pending_messages).await; + } + } + + // Eventually, there not much more we can do, so we shutdown the client + self.shutdown(); + + if let Some(shutdown_completed) = self.shutdown_completed.take() { + match tokio::time::timeout(wait_for_shutdown, shutdown_completed).await { + Ok(_) => tracing::info!("Shutdown completed successfully."), + Err(_) => tracing::warn!("Timed out waiting for shutdown to complete."), + } + } + } + + /// Shutdown the client + pub fn shutdown(&mut self) { + tracing::info!("Shutting down requested"); + if let Some(token) = self.shutdown_token.take() { + token.cancel(); + } + } + + pub async fn wait_for_msg_while_spamming< + T: PayloadMatcher, + S: IntoPayload + std::clone::Clone, + >( + &mut self, + wait: Duration, + spam: S, + spam_every: Duration, + ) -> Result { + let start_time = tokio::time::Instant::now(); + let mut spam_time = tokio::time::Instant::now(); + loop { + if let Some(payload) = self.check_for_msg::().await { + return Ok(payload); + } + + if spam_time.elapsed() >= spam_every { + let _ = self.send(spam.clone()).await; + spam_time = tokio::time::Instant::now(); + } + + if start_time.elapsed() >= wait { + return Err(eyre::eyre!( + "Timeout waiting for payload of type {:?}", + std::any::type_name::() + )); + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + } +} + +impl Drop for Client { + fn drop(&mut self) { + self.shutdown(); + } +} + +struct PollerAgent<'a> { + config: &'a Config, + pending_messages: BTreeMap>)>, + last_message: RelayConnectRequest, + seq: u64, +} + +impl<'a> PollerAgent<'a> { + // TODO: We need to split auth and subscription. Maybe ideally we issue 1 connect and then a subscribe that will notify + // the server that we care about messages from a certain queue only. That will avoid multiplexing messages from + // different sources. + async fn main_loop( + &mut self, + message_buffer: &Arc>>, + shutdown_token: CancellationToken, + outgoing_rx: &mut mpsc::Receiver, + command_rx: &mut mpsc::Receiver, + connection_established_tx: Option>, + ) -> Result<()> { + let (mut response_stream, sender_tx) = match self.connect().await { + Ok(ok) => ok, + Err(e) => return Err(e), + }; + + if let Some(tx) = connection_established_tx { + let _ = tx.send(()); + } + + self.replay_pending_messages(&sender_tx).await?; + + let mut interval = time::interval(self.config.heartbeat_interval); + + loop { + tokio::select! { + () = shutdown_token.cancelled() => { + tracing::info!("Shutting down connection"); + if !self.pending_messages.is_empty() { + tracing::warn!("Pending messages {}: {:?}", self.pending_messages.len(), self.pending_messages); + } + return Ok(()); + } + message = response_stream.next() => { + match message { + Some(Ok(RelayConnectResponse { + msg: + Some(relay_connect_response::Msg::Payload(RelayPayload { + src: Some(src), + dst, + seq, + payload: Some(payload), + })), + })) => { + if self_serve::app::v1::RequestState::matches(&payload).is_some() { + sender_tx + .send(self.last_message.clone()) + .await + .wrap_err("Failed to send outgoing message")?; + } else if src.id != self.config.dst_id { + tracing::error!( + "Skipping received message from unexpected source: {:?}: {payload:?}", + src.id + ); + } else { + self.handle_message( + RelayPayload { src: Some(src), dst, seq, payload: Some(payload) }, + message_buffer, + ) + .await?; + } + } + Some(Ok(RelayConnectResponse { msg: Some(relay_connect_response::Msg::Ack(ack)) })) => { + if let Some((_, Some(ack_tx))) = self.pending_messages.remove(&ack.seq) { + if ack_tx.send(()).is_err() { + // The receiver has been dropped, possibly due to a timeout. That means we + // need to increase the timeout at send_blocking(). + tracing::warn!( + "Failed to send ack back to send_blocking(): receiver dropped" + ); + } + } + } + Some(Err(e)) => { + tracing::error!("Error receiving message from tonic stream: {e:?}"); + return Err(e.into()); + } + None => { + tracing::info!("Stream ended"); + return Ok(()); + } + _ => { + tracing::error!("Received unexpected message: {message:?}"); + } + } + } + Some(outgoing_message) = outgoing_rx.recv() => { + self.seq = self.seq.wrapping_add(1); + let (payload, maybe_ack_tx) = match outgoing_message { + OutgoingMessage::Normal(payload) => (payload, None), + OutgoingMessage::Blocking(payload, ack_tx) => (payload, Some(ack_tx)), + }; + let (src_t, dst_t) = match self.config.mode { + Mode::Orb => (EntityType::Orb as i32, EntityType::App as i32), + Mode::App => (EntityType::App as i32, EntityType::Orb as i32), + }; + let relay_message = RelayPayload { + src: Some(Entity { id: self.config.src_id.clone(), entity_type: src_t }), + dst: Some(Entity { id: self.config.dst_id.clone(), entity_type: dst_t }), + seq: self.seq, + payload: Some(payload), + }; + + tracing::debug!("Sending message: from: {:?}, to: {:?}, seq: {:?}, payload: {:?}", + relay_message.src, relay_message.dst, relay_message.seq, debug_any(&relay_message.payload)); + + self.pending_messages.insert(self.seq, (relay_message.clone().into(), maybe_ack_tx)); + self.last_message = relay_message.clone().into(); + sender_tx.send(relay_message.into()).await.wrap_err("Failed to send outgoing message")?; + } + Some(command) = command_rx.recv() => { + match command { + Command::ReplayPendingMessages => { + self.replay_pending_messages(&sender_tx).await?; + } + Command::GetPendingMessages(reply_tx) => { + let _ = reply_tx.send(self.pending_messages.len()); + } + Command::Reconnect => { + tracing::info!("Reconnecting..."); + return Ok(()); + } + } + } + _ = interval.tick() => { + self.seq = self.seq.wrapping_add(1); + sender_tx + .send(Heartbeat { seq: self.seq }.into()) + .await + .wrap_err("Failed to send heartbeat")?; + }, + } + } + } + + async fn replay_pending_messages( + &mut self, + sender_tx: &Sender, + ) -> Result<()> { + if !self.pending_messages.is_empty() { + tracing::warn!("Replaying pending messages: {:?}", self.pending_messages); + for (_key, (msg, sender)) in self.pending_messages.iter_mut() { + sender_tx + .send(msg.clone()) + .await + .wrap_err("Failed to send pending message")?; + // If there's a sender, send a signal and set it to None. We are coming from a reconnect or a manual + // retry, so we don't care about the acks. + if let Some(tx) = sender.take() { + let _ = tx.send(()); + } + } + } + Ok(()) + } + + async fn connect( + &self, + ) -> Result<(Streaming, Sender)> { + let channel = Channel::from_shared(self.config.url.clone())? + .tls_config(Self::create_tls_config())? + .keep_alive_while_idle(true) + .http2_keep_alive_interval(self.config.keep_alive_interval) + .keep_alive_timeout(self.config.keep_alive_timeout) + .connect_timeout(self.config.connect_timeout) + .timeout(self.config.request_timeout) + .connect() + .await + .wrap_err("Failed to create gRPC channel")?; + + // TODO: Make the buffer size configurable + let (sender_tx, sender_rx) = mpsc::channel(32); + + let mut client = RelayServiceClient::new(channel); + let response = client.relay_connect(ReceiverStream::new(sender_rx)); + self.send_connect_request(&sender_tx).await?; + let mut response_stream = response.await?.into_inner(); + + self.wait_for_connect_response(&mut response_stream).await?; + Ok((response_stream, sender_tx)) + } + + // TODO: See if we can move this setup into `orb-security-utils`. + fn create_tls_config() -> ClientTlsConfig { + ClientTlsConfig::new().ca_certificates(vec![ + Certificate::from_pem(AWS_ROOT_CA1_CERT), + Certificate::from_pem(AWS_ROOT_CA2_CERT), + Certificate::from_pem(AWS_ROOT_CA3_CERT), + Certificate::from_pem(AWS_ROOT_CA4_CERT), + Certificate::from_pem(SFS_ROOT_G2_CERT), + Certificate::from_pem(GTS_ROOT_R1_CERT), + Certificate::from_pem(GTS_ROOT_R2_CERT), + Certificate::from_pem(GTS_ROOT_R3_CERT), + Certificate::from_pem(GTS_ROOT_R4_CERT), + ]) + } + + async fn send_connect_request( + &self, + tx: &mpsc::Sender, + ) -> Result<()> { + tx.send(RelayConnectRequest { + msg: Some(relay_connect_request::Msg::ConnectRequest(ConnectRequest { + client_id: Some(Entity { + id: self.config.src_id.clone(), + entity_type: match self.config.mode { + Mode::Orb => EntityType::Orb as i32, + Mode::App => EntityType::App as i32, + }, + }), + auth_method: Some(match &self.config.auth { + Auth::Token(t) => { + AuthMethod::Token(t.token.expose_secret().to_string()) + } + Auth::ZKP(z) => AuthMethod::ZkpAuthRequest(ZkpAuthRequest { + root: z.root.expose_secret().to_string(), + signal: z.signal.expose_secret().to_string(), + nullifier_hash: z.nullifier_hash.expose_secret().to_string(), + proof: z.proof.expose_secret().to_string(), + }), + }), + })), + }) + .await + .wrap_err("Failed to send connect request") + } + + async fn wait_for_connect_response( + &self, + response_stream: &mut Streaming, + ) -> Result<()> { + while let Some(message) = response_stream.next().await { + let message = message?.msg.ok_or_eyre("ConnectResponse msg is missing")?; + if let relay_connect_response::Msg::ConnectResponse(ConnectResponse { + success, + error, + .. + }) = message + { + return if success { + tracing::info!("Successful connection"); + Ok(()) + } else { + Err(eyre::eyre!("Failed to establish connection: {error:?}")) + }; + } + } + Err(eyre::eyre!( + "Connection stream ended before receiving ConnectResponse" + )) + } + + async fn handle_message( + &self, + payload: RelayPayload, + message_buffer: &Arc>>, + ) -> Result<()> { + let mut buffer = message_buffer.lock().await; + if buffer.len() >= self.config.max_buffer_size { + // Remove the oldest message to maintain the buffer size + let msg: Vec = buffer.drain(0..1).collect(); + tracing::warn!("Buffer is full, removing oldest message: {msg:?}"); + } + buffer.push_back(payload); + Ok(()) + } +} diff --git a/relay-client/src/lib.rs b/relay-client/src/lib.rs new file mode 100644 index 00000000..45af3150 --- /dev/null +++ b/relay-client/src/lib.rs @@ -0,0 +1,170 @@ +//! Orb-Relay crate +use orb_relay_messages::{common, prost::Name, prost_types::Any, self_serve}; + +pub mod client; + +pub trait PayloadMatcher { + type Output; + fn matches(payload: &Any) -> Option; +} + +fn unpack_any(any: &Any) -> Option { + if any.type_url != T::type_url() { + return None; + } + T::decode(any.value.as_slice()).ok() +} + +impl PayloadMatcher for self_serve::app::v1::StartCapture { + type Output = self_serve::app::v1::StartCapture; + + fn matches(payload: &Any) -> Option { + if let Some(self_serve::app::v1::w::W::StartCapture(p)) = + unpack_any::(payload)?.w + { + return Some(p); + } + unpack_any::(payload) + } +} + +impl PayloadMatcher for self_serve::app::v1::RequestState { + type Output = self_serve::app::v1::RequestState; + + fn matches(payload: &Any) -> Option { + if let Some(self_serve::app::v1::w::W::RequestState(p)) = + unpack_any::(payload)?.w + { + return Some(p); + } + unpack_any::(payload) + } +} + +impl PayloadMatcher for common::v1::AnnounceOrbId { + type Output = common::v1::AnnounceOrbId; + + fn matches(payload: &Any) -> Option { + if let Some(common::v1::w::W::AnnounceOrbId(p)) = + unpack_any::(payload)?.w + { + return Some(p); + } + unpack_any::(payload) + } +} + +impl PayloadMatcher for self_serve::orb::v1::SignupEnded { + type Output = self_serve::orb::v1::SignupEnded; + + fn matches(payload: &Any) -> Option { + let w: self_serve::orb::v1::W = unpack_any(payload)?; + match w.w { + Some(self_serve::orb::v1::w::W::SignupEnded(p)) => Some(p), + _ => None, + } + } +} + +pub trait IntoPayload { + fn into_payload(self) -> Any; +} + +impl IntoPayload for self_serve::orb::v1::AgeVerificationRequiredFromOperator { + fn into_payload(self) -> Any { + Any::from_msg(&self_serve::orb::v1::W { + w: Some( + self_serve::orb::v1::w::W::AgeVerificationRequiredFromOperator(self), + ), + }) + .unwrap() + } +} + +impl IntoPayload for self_serve::orb::v1::CaptureStarted { + fn into_payload(self) -> Any { + Any::from_msg(&self_serve::orb::v1::W { + w: Some(self_serve::orb::v1::w::W::CaptureStarted(self)), + }) + .unwrap() + } +} + +impl IntoPayload for self_serve::orb::v1::CaptureEnded { + fn into_payload(self) -> Any { + Any::from_msg(&self_serve::orb::v1::W { + w: Some(self_serve::orb::v1::w::W::CaptureEnded(self)), + }) + .unwrap() + } +} + +impl IntoPayload for self_serve::orb::v1::CaptureTriggerTimeout { + fn into_payload(self) -> Any { + Any::from_msg(&self_serve::orb::v1::W { + w: Some(self_serve::orb::v1::w::W::CaptureTriggerTimeout(self)), + }) + .unwrap() + } +} + +impl IntoPayload for self_serve::orb::v1::SignupEnded { + fn into_payload(self) -> Any { + Any::from_msg(&self_serve::orb::v1::W { + w: Some(self_serve::orb::v1::w::W::SignupEnded(self)), + }) + .unwrap() + } +} + +impl IntoPayload for self_serve::app::v1::RequestState { + fn into_payload(self) -> Any { + Any::from_msg(&self_serve::app::v1::W { + w: Some(self_serve::app::v1::w::W::RequestState(self)), + }) + .unwrap() + } +} + +impl IntoPayload for self_serve::app::v1::StartCapture { + fn into_payload(self) -> Any { + Any::from_msg(&self_serve::app::v1::W { + w: Some(self_serve::app::v1::w::W::StartCapture(self)), + }) + .unwrap() + } +} + +impl IntoPayload for common::v1::AnnounceOrbId { + fn into_payload(self) -> Any { + Any::from_msg(&common::v1::W { + w: Some(common::v1::w::W::AnnounceOrbId(self)), + }) + .unwrap() + } +} + +impl IntoPayload for common::v1::NoState { + fn into_payload(self) -> Any { + Any::from_msg(&common::v1::W { + w: Some(common::v1::w::W::NoState(self)), + }) + .unwrap() + } +} + +/// Debug any message +pub fn debug_any(any: &Option) -> String { + let Some(any) = any else { + return "None".to_string(); + }; + if let Some(w) = unpack_any::(any) { + format!("{:?}", w) + } else if let Some(w) = unpack_any::(any) { + format!("{:?}", w) + } else if let Some(w) = unpack_any::(any) { + format!("{:?}", w) + } else { + "Error".to_string() + } +}