diff --git a/Cargo.lock b/Cargo.lock index 1434b7acdf4..bb1c97b392e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2670,6 +2670,8 @@ dependencies = [ "error-stack", "futures", "graph", + "harpc-codec", + "harpc-server", "hash-codec", "hash-graph-api", "hash-graph-authorization", @@ -2681,6 +2683,7 @@ dependencies = [ "hash-temporal-client", "hash-tracing", "mimalloc", + "multiaddr", "regex", "reqwest", "tarpc", diff --git a/apps/hash-graph/Cargo.toml b/apps/hash-graph/Cargo.toml index 8c0cf030291..470de8a83f1 100644 --- a/apps/hash-graph/Cargo.toml +++ b/apps/hash-graph/Cargo.toml @@ -48,6 +48,9 @@ tokio = { workspace = true } tokio-postgres = { workspace = true } tokio-util = { workspace = true, features = ["codec"] } tracing = { workspace = true } +harpc-server.workspace = true +multiaddr.workspace = true +harpc-codec = { workspace = true, features = ["json"] } [features] test-server = ["dep:hash-graph-test-server"] diff --git a/apps/hash-graph/src/main.rs b/apps/hash-graph/src/main.rs index 26689afafdc..5f188d4d8a1 100644 --- a/apps/hash-graph/src/main.rs +++ b/apps/hash-graph/src/main.rs @@ -1,4 +1,5 @@ #![forbid(unsafe_code)] +#![feature(async_closure)] #![expect( unreachable_pub, reason = "This is a binary but as we want to document this crate as well this should be a \ diff --git a/apps/hash-graph/src/subcommand/server.rs b/apps/hash-graph/src/subcommand/server.rs index c22ccbbb425..3290483f850 100644 --- a/apps/hash-graph/src/subcommand/server.rs +++ b/apps/hash-graph/src/subcommand/server.rs @@ -11,16 +11,22 @@ use error_stack::{Report, ResultExt as _}; use graph::{ ontology::domain_validator::DomainValidator, store::{ - DatabaseConnectionInfo, DatabasePoolConfig, FetchingPool, PostgresStorePool, StorePool as _, + DatabaseConnectionInfo, DatabasePoolConfig, FetchingPool, PostgresStorePool, StorePool, }, }; -use hash_graph_api::rest::{RestRouterDependencies, rest_api_router}; +use harpc_codec::json::JsonCodec; +use harpc_server::Server; +use hash_graph_api::{ + rest::{RestRouterDependencies, rest_api_router}, + rpc::Dependencies, +}; use hash_graph_authorization::{ - AuthorizationApi as _, NoAuthorization, + AuthorizationApi as _, AuthorizationApiPool, NoAuthorization, backend::{SpiceDbOpenApi, ZanzibarBackend as _}, zanzibar::ZanzibarClient, }; use hash_temporal_client::TemporalClientConfig; +use multiaddr::{Multiaddr, Protocol}; use regex::Regex; use reqwest::{Client, Url}; use tokio::{net::TcpListener, time::timeout}; @@ -59,6 +65,38 @@ impl TryFrom for SocketAddr { } } +#[derive(Debug, Clone, Parser)] +pub struct RpcAddress { + /// The host the RPC client is listening at. + #[clap(long, default_value = "127.0.0.1", env = "HASH_GRAPH_RPC_HOST")] + pub rpc_host: String, + + /// The port the RPC client is listening at. + #[clap(long, default_value_t = 25489, env = "HASH_GRAPH_RPC_PORT")] + pub rpc_port: u16, +} + +impl fmt::Display for RpcAddress { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{}:{}", self.rpc_host, self.rpc_port) + } +} + +impl TryFrom for SocketAddr { + type Error = Report; + + fn try_from(address: RpcAddress) -> Result> { + address + .to_string() + .parse::() + .attach_printable(address) + } +} + +#[expect( + clippy::struct_excessive_bools, + reason = "CLI arguments are boolean flags." +)] #[derive(Debug, Parser)] pub struct ServerArgs { #[clap(flatten)] @@ -67,10 +105,18 @@ pub struct ServerArgs { #[clap(flatten)] pub pool_config: DatabasePoolConfig, - /// The address the REST client is listening at. + /// The address the REST server is listening at. #[clap(flatten)] pub api_address: ApiAddress, + /// Enable the experimental RPC server. + #[clap(long, default_value_t = false, env = "HASH_GRAPH_RPC_ENABLED")] + pub rpc_enabled: bool, + + /// The address the RPC server is listening at. + #[clap(flatten)] + pub rpc_address: RpcAddress, + /// The address for the type fetcher RPC server is listening at. #[clap(flatten)] pub type_fetcher_address: TypeFetcherAddress, @@ -134,6 +180,54 @@ pub struct ServerArgs { pub temporal_port: u16, } +fn server_rpc( + address: RpcAddress, + dependencies: Dependencies, +) -> Result<(), Report> +where + S: StorePool + Send + Sync + 'static, + A: AuthorizationApiPool + Send + Sync + 'static, +{ + let server = Server::new(harpc_server::ServerConfig::default()).change_context(GraphError)?; + + let (router, task) = hash_graph_api::rpc::rpc_router( + Dependencies { + store: dependencies.store, + authorization_api: dependencies.authorization_api, + temporal_client: dependencies.temporal_client, + codec: JsonCodec, + }, + server.events(), + ); + + tokio::spawn(task.into_future()); + + let socket_address: SocketAddr = SocketAddr::try_from(address).change_context(GraphError)?; + let mut address = Multiaddr::empty(); + match socket_address { + SocketAddr::V4(v4) => { + address.push(Protocol::Ip4(*v4.ip())); + address.push(Protocol::Tcp(v4.port())); + } + SocketAddr::V6(v6) => { + address.push(Protocol::Ip6(*v6.ip())); + address.push(Protocol::Tcp(v6.port())); + } + } + + #[expect(clippy::significant_drop_tightening, reason = "false positive")] + tokio::spawn(async move { + let stream = server + .listen(address) + .await + .expect("server should be able to listen on address"); + + harpc_server::serve::serve(stream, router).await; + }); + + Ok(()) +} + pub async fn server(args: ServerArgs) -> Result<(), Report> { if args.healthcheck { return wait_healthcheck( @@ -186,24 +280,42 @@ pub async fn server(args: ServerArgs) -> Result<(), Report> { let mut zanzibar_client = ZanzibarClient::new(spicedb_client); zanzibar_client.seed().await.change_context(GraphError)?; - let router = rest_api_router(RestRouterDependencies { - store: Arc::new(pool), - authorization_api: Arc::new(zanzibar_client), - domain_regex: DomainValidator::new(args.allowed_url_domain), - temporal_client: if let Some(host) = args.temporal_host { - Some( - TemporalClientConfig::new( - Url::from_str(&format!("{}:{}", host, args.temporal_port)) - .change_context(GraphError)?, - ) - .change_context(GraphError)? - .await - .change_context(GraphError)?, + let temporal_client_fn = async |host: Option, port: u16| { + if let Some(host) = host { + TemporalClientConfig::new( + Url::from_str(&format!("{host}:{port}")).change_context(GraphError)?, ) + .change_context(GraphError)? + .await + .map(Some) + .change_context(GraphError) } else { - None - }, - }); + Ok(None) + } + }; + + let router = { + let dependencies = RestRouterDependencies { + store: Arc::new(pool), + authorization_api: Arc::new(zanzibar_client), + domain_regex: DomainValidator::new(args.allowed_url_domain), + temporal_client: temporal_client_fn(args.temporal_host.clone(), args.temporal_port) + .await?, + }; + + if args.rpc_enabled { + tracing::info!("Starting RPC server..."); + + server_rpc(args.rpc_address, Dependencies { + store: Arc::clone(&dependencies.store), + authorization_api: Arc::clone(&dependencies.authorization_api), + temporal_client: temporal_client_fn(args.temporal_host, args.temporal_port).await?, + codec: (), + })?; + } + + rest_api_router(dependencies) + }; tracing::info!("Listening on {}", args.api_address); axum::serve( diff --git a/libs/@local/graph/api/src/rpc/mod.rs b/libs/@local/graph/api/src/rpc/mod.rs index 2508d384869..b523cbb83cb 100644 --- a/libs/@local/graph/api/src/rpc/mod.rs +++ b/libs/@local/graph/api/src/rpc/mod.rs @@ -71,7 +71,10 @@ pub struct Dependencies { pub fn rpc_router( dependencies: Dependencies, notifications: N, -) -> (Router>, Task) +) -> ( + Router + Send>, + Task, +) where S: StorePool + Send + Sync + 'static, A: AuthorizationApiPool + Send + Sync + 'static, diff --git a/libs/@local/harpc/server/src/lib.rs b/libs/@local/harpc/server/src/lib.rs index ed068aa1d88..0a56313cb99 100644 --- a/libs/@local/harpc/server/src/lib.rs +++ b/libs/@local/harpc/server/src/lib.rs @@ -19,9 +19,10 @@ use core::{ use error_stack::{Report, ResultExt as _}; use futures::{Stream, StreamExt as _, stream::FusedStream}; +pub use harpc_net::{session::server::SessionConfig, transport::TransportConfig}; use harpc_net::{ - session::server::{EventStream, ListenStream, SessionConfig, SessionLayer, Transaction}, - transport::{TransportConfig, TransportLayer}, + session::server::{EventStream, ListenStream, SessionLayer, Transaction}, + transport::TransportLayer, }; use multiaddr::Multiaddr; use tokio_util::sync::{CancellationToken, DropGuard};