diff --git a/tap_aggregator/Cargo.toml b/tap_aggregator/Cargo.toml index 3e3ae63..aca71e5 100644 --- a/tap_aggregator/Cargo.toml +++ b/tap_aggregator/Cargo.toml @@ -35,7 +35,14 @@ axum = { version = "0.7.5", features = [ futures-util = "0.3.28" lazy_static = "1.4.0" ruint = "1.10.1" -tower = { version = "0.4", features = ["util"] } +tower = { version = "0.4", features = ["util", "steer"] } +tonic = { version = "0.12.3", features = ["transport", "zstd"] } +prost = "0.13.3" +hyper = { version = "1", features = ["full"] } + +[build-dependencies] +tonic-build = "0.12.3" + [dev-dependencies] jsonrpsee = { workspace = true, features = ["http-client", "jsonrpsee-core"] } diff --git a/tap_aggregator/build.rs b/tap_aggregator/build.rs new file mode 100644 index 0000000..9e7a986 --- /dev/null +++ b/tap_aggregator/build.rs @@ -0,0 +1,11 @@ +// Copyright 2023-, Semiotic AI, Inc. +// SPDX-License-Identifier: Apache-2.0 + +fn main() -> Result<(), Box> { + println!("Running build.rs..."); + let out_dir = std::env::var("OUT_DIR").expect("OUT_DIR not set by Cargo"); + println!("OUT_DIR: {}", out_dir); // This should print the output directory + + tonic_build::compile_protos("./proto/tap_aggregator.proto")?; + Ok(()) +} diff --git a/tap_aggregator/proto/tap_aggregator.proto b/tap_aggregator/proto/tap_aggregator.proto index bdc3c85..4267936 100644 --- a/tap_aggregator/proto/tap_aggregator.proto +++ b/tap_aggregator/proto/tap_aggregator.proto @@ -1,3 +1,6 @@ +// Copyright 2023-, Semiotic AI, Inc. +// SPDX-License-Identifier: Apache-2.0 + syntax = "proto3"; package tap_aggregator.v1; @@ -42,4 +45,4 @@ message Uint128 { uint64 high = 1; // Lowest 64 bits of a 128 bit number. uint64 low = 2; -} \ No newline at end of file +} diff --git a/tap_aggregator/src/tap_aggregator.rs b/tap_aggregator/src/grpc.rs similarity index 68% rename from tap_aggregator/src/tap_aggregator.rs rename to tap_aggregator/src/grpc.rs index db3ae77..bcc7e27 100644 --- a/tap_aggregator/src/tap_aggregator.rs +++ b/tap_aggregator/src/grpc.rs @@ -1,3 +1,6 @@ +// Copyright 2023-, Semiotic AI, Inc. +// SPDX-License-Identifier: Apache-2.0 + use anyhow::anyhow; use tap_core::signed_message::EIP712SignedMessage; @@ -28,6 +31,26 @@ impl TryFrom for tap_core::receipt::SignedReceipt { } } +impl From for Receipt { + fn from(value: tap_core::receipt::Receipt) -> Self { + Self { + allocation_id: value.allocation_id.as_slice().to_vec(), + timestamp_ns: value.timestamp_ns, + nonce: value.nonce, + value: Some(value.value.into()), + } + } +} + +impl From for SignedReceipt { + fn from(value: tap_core::receipt::SignedReceipt) -> Self { + Self { + message: Some(value.message.into()), + signature: value.signature.as_bytes().to_vec(), + } + } +} + impl TryFrom for EIP712SignedMessage { type Error = anyhow::Error; fn try_from(voucher: SignedRav) -> Result { @@ -87,3 +110,26 @@ impl From for Uint128 { Self { high, low } } } + +impl RavRequest { + pub fn new( + receipts: Vec, + previous_rav: Option, + ) -> Self { + Self { + receipts: receipts.into_iter().map(Into::into).collect(), + previous_rav: previous_rav.map(Into::into), + } + } +} + +impl RavResponse { + pub fn signed_rav(mut self) -> anyhow::Result { + let signed_rav: tap_core::rav::SignedRAV = self + .rav + .take() + .ok_or(anyhow!("Couldn't find rav"))? + .try_into()?; + Ok(signed_rav) + } +} diff --git a/tap_aggregator/src/lib.rs b/tap_aggregator/src/lib.rs index e929a1e..6746f3a 100644 --- a/tap_aggregator/src/lib.rs +++ b/tap_aggregator/src/lib.rs @@ -4,6 +4,7 @@ pub mod aggregator; pub mod api_versioning; pub mod error_codes; +pub mod grpc; pub mod jsonrpsee_helpers; pub mod metrics; pub mod server; diff --git a/tap_aggregator/src/main.rs b/tap_aggregator/src/main.rs index 6550c33..556caba 100644 --- a/tap_aggregator/src/main.rs +++ b/tap_aggregator/src/main.rs @@ -3,22 +3,15 @@ #![doc = include_str!("../README.md")] -use std::borrow::Cow; -use std::collections::HashSet; -use std::str::FromStr; - -use alloy::dyn_abi::Eip712Domain; -use alloy::primitives::Address; -use alloy::primitives::FixedBytes; -use alloy::signers::local::PrivateKeySigner; +use std::{collections::HashSet, str::FromStr}; + +use alloy::{dyn_abi::Eip712Domain, primitives::Address, signers::local::PrivateKeySigner}; use anyhow::Result; use clap::Parser; -use ruint::aliases::U256; -use tokio::signal::unix::{signal, SignalKind}; - use log::{debug, info}; -use tap_aggregator::metrics; -use tap_aggregator::server; +use tap_core::tap_eip712_domain; + +use tap_aggregator::{metrics, server}; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] @@ -126,22 +119,10 @@ async fn main() -> Result<()> { .await?; info!("Server started. Listening on port {}.", args.port); - // Have tokio wait for SIGTERM or SIGINT. - let mut signal_sigint = signal(SignalKind::interrupt())?; - let mut signal_sigterm = signal(SignalKind::terminate())?; - tokio::select! { - _ = signal_sigint.recv() => debug!("Received SIGINT."), - _ = signal_sigterm.recv() => debug!("Received SIGTERM."), - } + let _ = handle.await; // If we're here, we've received a signal to exit. info!("Shutting down..."); - - // Stop the server and wait for it to finish gracefully. - handle.stop()?; - handle.stopped().await; - - debug!("Goodbye!"); Ok(()) } @@ -149,14 +130,11 @@ fn create_eip712_domain(args: &Args) -> Result { // Transfrom the args into the types expected by Eip712Domain::new(). // Transform optional strings into optional Cow. - let name = args.domain_name.clone().map(Cow::Owned); - let version = args.domain_version.clone().map(Cow::Owned); - // Transform optional strings into optional U256. if args.domain_chain_id.is_some() { debug!("Parsing domain chain ID..."); } - let chain_id: Option = args + let chain_id: Option = args .domain_chain_id .as_ref() .map(|s| s.parse()) @@ -165,17 +143,13 @@ fn create_eip712_domain(args: &Args) -> Result { if args.domain_salt.is_some() { debug!("Parsing domain salt..."); } - let salt: Option> = args.domain_salt.as_ref().map(|s| s.parse()).transpose()?; // Transform optional strings into optional Address. let verifying_contract: Option
= args.domain_verifying_contract; // Create the EIP-712 domain separator. - Ok(Eip712Domain::new( - name, - version, - chain_id, - verifying_contract, - salt, + Ok(tap_eip712_domain( + chain_id.unwrap_or(1), + verifying_contract.unwrap_or_default(), )) } diff --git a/tap_aggregator/src/server.rs b/tap_aggregator/src/server.rs index 0ff03d3..8e82fd5 100644 --- a/tap_aggregator/src/server.rs +++ b/tap_aggregator/src/server.rs @@ -3,23 +3,38 @@ use std::{collections::HashSet, str::FromStr}; -use alloy::dyn_abi::Eip712Domain; -use alloy::primitives::Address; -use alloy::signers::local::PrivateKeySigner; +use alloy::{dyn_abi::Eip712Domain, primitives::Address, signers::local::PrivateKeySigner}; use anyhow::Result; -use jsonrpsee::{proc_macros::rpc, server::ServerBuilder, server::ServerHandle}; +use axum::{error_handling::HandleError, routing::post_service, BoxError, Router}; +use hyper::StatusCode; +use jsonrpsee::{ + proc_macros::rpc, + server::{ServerBuilder, ServerHandle, TowerService}, +}; use lazy_static::lazy_static; +use log::info; use prometheus::{register_counter, register_int_counter, Counter, IntCounter}; - -use crate::aggregator::check_and_aggregate_receipts; -use crate::api_versioning::{ - tap_rpc_api_versions_info, TapRpcApiVersion, TapRpcApiVersionsInfo, - TAP_RPC_API_VERSIONS_DEPRECATED, -}; -use crate::error_codes::{JsonRpcErrorCode, JsonRpcWarningCode}; -use crate::jsonrpsee_helpers::{JsonRpcError, JsonRpcResponse, JsonRpcResult, JsonRpcWarning}; use tap_core::{ - rav::ReceiptAggregateVoucher, receipt::Receipt, signed_message::EIP712SignedMessage, + rav::ReceiptAggregateVoucher, + receipt::{Receipt, SignedReceipt}, + signed_message::EIP712SignedMessage, +}; +use tokio::{net::TcpListener, signal, task::JoinHandle}; +use tonic::{codec::CompressionEncoding, service::Routes, Request, Response, Status}; +use tower::{layer::util::Identity, make::Shared}; + +use crate::{ + aggregator::check_and_aggregate_receipts, + api_versioning::{ + tap_rpc_api_versions_info, TapRpcApiVersion, TapRpcApiVersionsInfo, + TAP_RPC_API_VERSIONS_DEPRECATED, + }, + error_codes::{JsonRpcErrorCode, JsonRpcWarningCode}, + grpc::{ + tap_aggregator_server::{TapAggregator, TapAggregatorServer}, + RavRequest, RavResponse, + }, + jsonrpsee_helpers::{JsonRpcError, JsonRpcResponse, JsonRpcResult, JsonRpcWarning}, }; // Register the metrics into the global metrics registry. @@ -29,37 +44,27 @@ lazy_static! { "Number of successful receipt aggregation requests." ) .unwrap(); -} -lazy_static! { static ref AGGREGATION_FAILURE_COUNTER: IntCounter = register_int_counter!( "aggregation_failure_count", "Number of failed receipt aggregation requests (for any reason)." ) .unwrap(); -} -lazy_static! { static ref DEPRECATION_WARNING_COUNT: IntCounter = register_int_counter!( "deprecation_warning_count", "Number of deprecation warnings sent to clients." ) .unwrap(); -} -lazy_static! { static ref VERSION_ERROR_COUNT: IntCounter = register_int_counter!( "version_error_count", "Number of API version errors sent to clients." ) .unwrap(); -} -lazy_static! { static ref TOTAL_AGGREGATED_RECEIPTS: IntCounter = register_int_counter!( "total_aggregated_receipts", "Total number of receipts successfully aggregated." ) .unwrap(); -} // Using float for the GRT value because it can somewhat easily exceed the maximum value of int64. -lazy_static! { static ref TOTAL_GRT_AGGREGATED: Counter = register_counter!( "total_aggregated_grt", "Total successfully aggregated GRT value (wei)." @@ -90,6 +95,7 @@ pub trait Rpc { ) -> JsonRpcResult>; } +#[derive(Clone)] struct RpcImpl { wallet: PrivateKeySigner, accepted_addresses: HashSet
, @@ -171,6 +177,54 @@ fn aggregate_receipts_( } } +#[tonic::async_trait] +impl TapAggregator for RpcImpl { + async fn aggregate_receipts( + &self, + request: Request, + ) -> Result, Status> { + let rav_request = request.into_inner(); + let receipts: Vec = rav_request + .receipts + .into_iter() + .map(TryFrom::try_from) + .collect::>() + .map_err(|_| Status::invalid_argument("Error while getting list of signed_receipts"))?; + + let previous_rav = rav_request + .previous_rav + .map(TryFrom::try_from) + .transpose() + .map_err(|_| Status::invalid_argument("Error while getting previous rav"))?; + + let receipts_grt: u128 = receipts.iter().map(|r| r.message.value).sum(); + let receipts_count: u64 = receipts.len() as u64; + + match check_and_aggregate_receipts( + &self.domain_separator, + receipts.as_slice(), + previous_rav, + &self.wallet, + &self.accepted_addresses, + ) { + Ok(res) => { + TOTAL_GRT_AGGREGATED.inc_by(receipts_grt as f64); + TOTAL_AGGREGATED_RECEIPTS.inc_by(receipts_count); + AGGREGATION_SUCCESS_COUNTER.inc(); + + let response = RavResponse { + rav: Some(res.into()), + }; + Ok(Response::new(response)) + } + Err(e) => { + AGGREGATION_FAILURE_COUNTER.inc(); + Err(Status::failed_precondition(e.to_string())) + } + } + } +} + impl RpcServer for RpcImpl { fn api_versions(&self) -> JsonRpcResult { Ok(JsonRpcResponse::ok(tap_rpc_api_versions_info())) @@ -216,27 +270,121 @@ pub async fn run_server( max_request_body_size: u32, max_response_body_size: u32, max_concurrent_connections: u32, -) -> Result<(ServerHandle, std::net::SocketAddr)> { +) -> Result<(JoinHandle<()>, std::net::SocketAddr)> { // Setting up the JSON RPC server - println!("Starting server..."); - let server = ServerBuilder::new() - .max_request_body_size(max_request_body_size) - .max_response_body_size(max_response_body_size) - .max_connections(max_concurrent_connections) - .http_only() - .build(format!("0.0.0.0:{}", port)) - .await?; - let addr = server.local_addr()?; - println!("Listening on: {}", addr); let rpc_impl = RpcImpl { wallet, accepted_addresses, domain_separator, }; - let handle = server.start(rpc_impl.into_rpc()); + let (json_rpc_service, _) = create_json_rpc_service( + rpc_impl.clone(), + max_request_body_size, + max_response_body_size, + max_concurrent_connections, + )?; + + async fn handle_anyhow_error(err: BoxError) -> (StatusCode, String) { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Something went wrong: {err}"), + ) + } + let json_rpc_router = Router::new().route_service( + "/", + HandleError::new(post_service(json_rpc_service), handle_anyhow_error), + ); + + let grpc_service = create_grpc_service(rpc_impl)?; + + let service = tower::steer::Steer::new( + [json_rpc_router, grpc_service.into_axum_router()], + |req: &hyper::Request<_>, _services: &[_]| { + if req + .headers() + .get(hyper::header::CONTENT_TYPE) + .map(|content_type| content_type.as_bytes()) + .filter(|content_type| content_type.starts_with(b"application/grpc")) + .is_some() + { + // route to the gRPC service (second service element) when the + // header is set + 1 + } else { + // otherwise route to the REST service + 0 + } + }, + ); + + // Create a `TcpListener` using tokio. + let listener = TcpListener::bind(&format!("0.0.0.0:{}", port)) + .await + .expect("Failed to bind to tap-aggregator port"); + + let addr = listener.local_addr()?; + let handle = tokio::spawn(async move { + if let Err(e) = axum::serve(listener, Shared::new(service)) + .with_graceful_shutdown(shutdown_handler()) + .await + { + log::error!("Tap Aggregator error: {e}"); + } + }); + Ok((handle, addr)) } +/// Graceful shutdown handler +async fn shutdown_handler() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("Failed to install Ctrl+C handler"); + }; + + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("Failed to install signal handler") + .recv() + .await; + }; + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + + info!("Signal received, starting graceful shutdown"); +} + +fn create_grpc_service(rpc_impl: RpcImpl) -> Result { + let grpc_service = Routes::new( + TapAggregatorServer::new(rpc_impl).accept_compressed(CompressionEncoding::Zstd), + ) + .prepare(); + + Ok(grpc_service) +} + +fn create_json_rpc_service( + rpc_impl: RpcImpl, + max_request_body_size: u32, + max_response_body_size: u32, + max_concurrent_connections: u32, +) -> Result<(TowerService, ServerHandle)> { + let service_builder = ServerBuilder::new() + .max_request_body_size(max_request_body_size) + .max_response_body_size(max_response_body_size) + .max_connections(max_concurrent_connections) + .http_only() + .to_service_builder(); + use jsonrpsee::server::stop_channel; + let (stop_handle, server_handle) = stop_channel(); + let handle = service_builder.build(rpc_impl.into_rpc(), stop_handle); + Ok((handle, server_handle)) +} + #[cfg(test)] #[allow(clippy::too_many_arguments)] mod tests { @@ -330,8 +478,7 @@ mod tests { .await .unwrap(); - handle.stop().unwrap(); - handle.stopped().await; + handle.abort(); } #[rstest] @@ -411,8 +558,7 @@ mod tests { assert!(remote_rav.recover_signer(&domain_separator).unwrap() == keys_main.address); - handle.stop().unwrap(); - handle.stopped().await; + handle.abort(); } #[rstest] @@ -501,8 +647,7 @@ mod tests { assert!(rav.recover_signer(&domain_separator).unwrap() == keys_main.address); - handle.stop().unwrap(); - handle.stopped().await; + handle.abort(); } #[rstest] @@ -576,8 +721,7 @@ mod tests { _ => panic!("Expected data in error"), } - handle.stop().unwrap(); - handle.stopped().await; + handle.abort(); } /// Test that the server returns an error when the request size exceeds the limit. @@ -674,7 +818,6 @@ mod tests { // Make sure the error is a HTTP 413 Content Too Large assert!(res.unwrap_err().to_string().contains("413")); - handle.stop().unwrap(); - handle.stopped().await; + handle.abort(); } } diff --git a/tap_aggregator/tests/aggregate_test.rs b/tap_aggregator/tests/aggregate_test.rs new file mode 100644 index 0000000..e752197 --- /dev/null +++ b/tap_aggregator/tests/aggregate_test.rs @@ -0,0 +1,88 @@ +// Copyright 2023-, Semiotic AI, Inc. +// SPDX-License-Identifier: Apache-2.0 + +use std::{collections::HashSet, str::FromStr}; + +use alloy::{primitives::Address, signers::local::PrivateKeySigner}; + +use jsonrpsee::{core::client::ClientT, http_client::HttpClientBuilder, rpc_params}; +use tap_aggregator::{ + grpc::{tap_aggregator_client::TapAggregatorClient, RavRequest}, + jsonrpsee_helpers::JsonRpcResponse, + server, +}; +use tap_core::{ + rav::ReceiptAggregateVoucher, receipt::Receipt, signed_message::EIP712SignedMessage, + tap_eip712_domain, +}; +use tonic::codec::CompressionEncoding; + +#[tokio::test] +async fn aggregation_test() { + let domain_separator = tap_eip712_domain(1, Address::ZERO); + + let wallet = PrivateKeySigner::random(); + + let max_request_body_size = 1024 * 100; + let max_response_body_size = 1024 * 100; + let max_concurrent_connections = 1; + + let accepted_addresses = HashSet::from([wallet.address()]); + + let (join_handle, local_addr) = server::run_server( + 0, + wallet.clone(), + accepted_addresses, + domain_separator.clone(), + max_request_body_size, + max_response_body_size, + max_concurrent_connections, + ) + .await + .unwrap(); + + let endpoint = format!("http://127.0.0.1:{}", local_addr.port()); + + let mut client = TapAggregatorClient::connect(endpoint.clone()) + .await + .unwrap() + .send_compressed(CompressionEncoding::Zstd); + + let allocation_id = Address::from_str("0xabababababababababababababababababababab").unwrap(); + + // Create receipts + let mut receipts = Vec::new(); + for value in 50..60 { + receipts.push( + EIP712SignedMessage::new( + &domain_separator, + Receipt::new(allocation_id, value).unwrap(), + &wallet, + ) + .unwrap(), + ); + } + + let rav_request = RavRequest::new(receipts.clone(), None); + let res = client.aggregate_receipts(rav_request).await.unwrap(); + let signed_rav: tap_core::rav::SignedRAV = res.into_inner().signed_rav().unwrap(); + + let sender_aggregator = HttpClientBuilder::default().build(&endpoint).unwrap(); + + let previous_rav: Option = None; + + let response: JsonRpcResponse> = sender_aggregator + .request( + "aggregate_receipts", + rpc_params!( + "0.0", // TODO: Set the version in a smarter place. + receipts, + previous_rav + ), + ) + .await + .unwrap(); + let response = response.data; + assert_eq!(signed_rav, response); + join_handle.abort(); +} diff --git a/tap_integration_tests/tests/showcase.rs b/tap_integration_tests/tests/showcase.rs index 19210b3..6c89f05 100644 --- a/tap_integration_tests/tests/showcase.rs +++ b/tap_integration_tests/tests/showcase.rs @@ -35,6 +35,7 @@ use tap_core::{ signed_message::{EIP712SignedMessage, MessageId}, tap_eip712_domain, }; +use tokio::task::JoinHandle; use crate::indexer_mock; @@ -345,7 +346,7 @@ async fn single_indexer_test_server( indexer_1_context: ContextFixture, available_escrow: u128, receipt_threshold_1: u64, -) -> Result<(ServerHandle, SocketAddr, ServerHandle, SocketAddr)> { +) -> Result<(ServerHandle, SocketAddr, JoinHandle<()>, SocketAddr)> { let sender_id = keys_sender.address(); let (sender_aggregator_handle, sender_aggregator_addr) = start_sender_aggregator( keys_sender, @@ -390,7 +391,7 @@ async fn two_indexers_test_servers( SocketAddr, ServerHandle, SocketAddr, - ServerHandle, + JoinHandle<()>, SocketAddr, )> { let sender_id = keys_sender.address(); @@ -454,7 +455,7 @@ async fn single_indexer_wrong_sender_test_server( indexer_1_context: ContextFixture, available_escrow: u128, receipt_threshold_1: u64, -) -> Result<(ServerHandle, SocketAddr, ServerHandle, SocketAddr)> { +) -> Result<(ServerHandle, SocketAddr, JoinHandle<()>, SocketAddr)> { let sender_id = wrong_keys_sender.address(); let (sender_aggregator_handle, sender_aggregator_addr) = start_sender_aggregator( wrong_keys_sender, @@ -491,7 +492,7 @@ async fn single_indexer_wrong_sender_test_server( #[tokio::test] async fn test_manager_one_indexer( #[future] single_indexer_test_server: Result< - (ServerHandle, SocketAddr, ServerHandle, SocketAddr), + (ServerHandle, SocketAddr, JoinHandle<()>, SocketAddr), Error, >, requests_1: Vec>, @@ -522,7 +523,7 @@ async fn test_manager_two_indexers( SocketAddr, ServerHandle, SocketAddr, - ServerHandle, + JoinHandle<()>, SocketAddr, ), Error, @@ -559,7 +560,7 @@ async fn test_manager_two_indexers( #[tokio::test] async fn test_manager_wrong_aggregator_keys( #[future] single_indexer_wrong_sender_test_server: Result< - (ServerHandle, SocketAddr, ServerHandle, SocketAddr), + (ServerHandle, SocketAddr, JoinHandle<()>, SocketAddr), Error, >, requests_1: Vec>, @@ -601,7 +602,7 @@ async fn test_manager_wrong_aggregator_keys( #[tokio::test] async fn test_manager_wrong_requestor_keys( #[future] single_indexer_test_server: Result< - (ServerHandle, SocketAddr, ServerHandle, SocketAddr), + (ServerHandle, SocketAddr, JoinHandle<()>, SocketAddr), Error, >, wrong_requests: Vec>, @@ -631,7 +632,7 @@ async fn test_tap_manager_rav_timestamp_cuttoff( SocketAddr, ServerHandle, SocketAddr, - ServerHandle, + JoinHandle<()>, SocketAddr, ), Error, @@ -765,7 +766,7 @@ async fn test_tap_aggregator_rav_timestamp_cuttoff( } assert!(expected_value == second_rav_response.data.message.valueAggregate); - sender_handle.stop()?; + sender_handle.abort(); Ok(()) } @@ -833,7 +834,7 @@ async fn start_sender_aggregator( http_request_size_limit: u32, http_response_size_limit: u32, http_max_concurrent_connections: u32, -) -> Result<(ServerHandle, SocketAddr)> { +) -> Result<(JoinHandle<()>, SocketAddr)> { let http_port = { let listener = TcpListener::bind("127.0.0.1:0")?; listener.local_addr()?.port()