Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unix Domain Socket support #2

Merged
merged 2 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ version = "0.1.0"
edition = "2021"
build = "build.rs"

[features]
default = ["unix-domain-sockets"]
unix-domain-sockets = ["hyper/socket2"]

[dependencies]
axum = { version = "0.6.20", features = ["http2"] }
ctrlc = "3.4.1"
Expand Down
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Multiple Hyper servers are spawned on different endpoints to showcase the use of
and ports while reusing the same server components. A Hyper service is used to switch the incoming traffic based on the
`content-type` header and if `application/grpc` is detected, traffic is forwarded to the Tonic server; all other
cases forward to Axum. This allows for transparent use of HTTP/1.1 and HTTP/2 (prior knowledge), as well
as ALPN on the TLS-enabled ports.
as ALPN on the TLS-enabled ports. On Unixoids, Unix Domain Sockets are supported as well.

This project uses:

Expand All @@ -31,6 +31,12 @@ curl -v http://127.1.0.1:36849/
curl --http2-prior-knowledge --insecure -vv http://127.0.0.1:36849/
```

### Unix Domain Sockets

```shell
curl -v --unix-socket /tmp/cohosting.sock http://localhost:36849/
```

### TLS with ALPN

```shell
Expand Down Expand Up @@ -62,6 +68,15 @@ grpcurl --plaintext --use-reflection -d '{ "message": "World" }' 127.0.0.1:36849
grpcurl --insecure --use-reflection -d '{ "message": "World" }' 127.0.0.1:36850 example.YourService/YourMethod
```

### Unix Domain Sockets

For UDS to work with gRPC, the `:authority` header needs to be sent. In `grpcurl`, the `--authority=xyz` flag
is used for that:

```shell
grpcurl --unix --plaintext --use-reflection --authority localhost -d '{ "message": "World" }' /tmp/cohosting.sock example.YourService/YourMethod
```

## Recommended reads

The _Combining Axum, Hyper, Tonic, and Tower for hybrid web/gRPC apps_ series:
Expand Down
2 changes: 1 addition & 1 deletion src/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ fn is_grpc_request<Body>(req: &Request<Body>) -> bool {
}

// The content-type header needs to start with `application/grpc`
const EXPECTED: &'static [u8] = b"application/grpc";
const EXPECTED: &[u8] = b"application/grpc";
if let Some(content_type) = req
.headers()
.get(hyper::header::CONTENT_TYPE)
Expand Down
57 changes: 51 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
mod certs;
mod hybrid;

#[cfg(all(unix, feature = "unix-domain-sockets"))]
mod sockets;

use crate::hybrid::HybridMakeService;
#[cfg(all(unix, feature = "unix-domain-sockets"))]
use crate::sockets::UnixDomainSocket;
use axum::routing::IntoMakeService;
use axum::Router;
use futures_util::StreamExt;
Expand Down Expand Up @@ -62,27 +67,35 @@ async fn main() -> ExitCode {
// Combine web and gRPC into a hybrid service.
let service = HybridMakeService::new(axum_make_svc, grpc_service);

// Combine the server futures.
let mut futures = JoinSet::new();

// Bind first hyper HTTP server.
let socket_addr =
SocketAddr::from_str("127.0.0.1:36849").expect("failed to parse socket address");
let server_a = create_hyper_server(service.clone(), socket_addr, &shutdown_tx);
futures.spawn(server_a);

// Bind second hyper HTTP server.
let socket_addr =
SocketAddr::from_str("127.1.0.1:36849").expect("failed to parse socket address");
let server_b = create_hyper_server(service.clone(), socket_addr, &shutdown_tx);
futures.spawn(server_b);

// Bind third hyper HTTP server (using TLS).
let socket_addr =
SocketAddr::from_str("127.0.0.1:36850").expect("failed to parse socket address");
let server_c = create_hyper_server_tls(service, socket_addr, &shutdown_tx);

// Combine the server futures.
let mut futures = JoinSet::new();
futures.spawn(server_a);
futures.spawn(server_b);
let server_c = create_hyper_server_tls(service.clone(), socket_addr, &shutdown_tx);
futures.spawn(server_c);

// Bind fourth server to Unix Domain Socket.
#[cfg(all(unix, feature = "unix-domain-sockets"))]
{
let socket_addr = std::path::PathBuf::from("/tmp/cohosting.sock");
let server_d = create_hyper_server_uds(service, socket_addr, &shutdown_tx);
futures.spawn(server_d);
}

// Wait for all servers to stop.
info!("Starting servers");
while let Some(join_result) = futures.join_next().await {
Expand Down Expand Up @@ -136,6 +149,7 @@ fn create_hyper_server(
socket_addr: SocketAddr,
shutdown_tx: &Sender<()>,
) -> impl Future<Output = Result<(), hyper::Error>> + Send {
info!("Binding server to {}", socket_addr);
Server::try_bind(&socket_addr)
.map_err(|e| {
error!(
Expand All @@ -158,11 +172,42 @@ fn create_hyper_server(
})
}

#[cfg(all(unix, feature = "unix-domain-sockets"))]
fn create_hyper_server_uds(
service: HybridMakeService<IntoMakeService<Router>, Routes>,
socket_path: std::path::PathBuf,
shutdown_tx: &Sender<()>,
) -> impl Future<Output = Result<(), hyper::Error>> + Send {
info!("Binding server to {}", socket_path.display());

let incoming = match UnixDomainSocket::new(&socket_path) {
Ok(listener) => listener,
Err(e) => {
error!(
"Failed to bind to socket {addr}: {error}",
addr = socket_path.display(),
error = e
);
panic!("{}", e);
}
};
Server::builder(incoming)
.serve(service)
.with_graceful_shutdown({
let mut shutdown_rx = shutdown_tx.subscribe();
async move {
shutdown_rx.recv().await.ok();
info!("Graceful shutdown initiated on Hyper server")
}
})
}

fn create_hyper_server_tls(
service: HybridMakeService<IntoMakeService<Router>, Routes>,
socket_addr: SocketAddr,
shutdown_tx: &Sender<()>,
) -> impl Future<Output = Result<(), hyper::Error>> + Send {
info!("Binding server to {}", socket_addr);
let listener = AddrIncoming::bind(&socket_addr)
.map_err(|e| {
error!(
Expand Down
48 changes: 48 additions & 0 deletions src/sockets.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use hyper::server::accept::Accept;
use log::debug;
use std::fs;
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::net::UnixListener;

pub struct UnixDomainSocket {
inner: UnixListener,
}

impl UnixDomainSocket {
pub fn new(path: &Path) -> std::io::Result<Self> {
let listener = UnixListener::bind(path)?;
Ok(Self { inner: listener })
}
}

impl Drop for UnixDomainSocket {
fn drop(&mut self) {
let addr = self
.inner
.local_addr()
.expect("failed to get local address from listener");
let path = addr
.as_pathname()
.expect("failed to get path name from local address");
debug!("Removing socket file {path}", path = path.display());
fs::remove_file(path).ok();
}
}

impl Accept for UnixDomainSocket {
type Conn = tokio::net::UnixStream;
type Error = std::io::Error;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
match self.inner.poll_accept(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok((socket, _addr))) => Poll::Ready(Some(Ok(socket))),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
}
}
}
Loading