Skip to content

Commit

Permalink
Extract functions and add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sunsided committed Nov 6, 2023
1 parent 8d8b019 commit c434f09
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ This project uses:
```shell
curl -v http://127.0.0.1:36849/
curl -v http://127.1.0.1:36849/
curl --http2-prior-knowledge --insecure -vv https://127.0.0.1:36849/
curl --http2-prior-knowledge --insecure -vv http://127.0.0.1:36849/
```

### TLS with ALPN
Expand Down
115 changes: 95 additions & 20 deletions src/hybrid.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,59 @@
use axum::http::Response;
use hyper::body::HttpBody;
use hyper::service::Service;
use hyper::{Body, HeaderMap, Request};
use hyper::{Body, HeaderMap, Request, Version};
use pin_project::pin_project;
use std::error::Error;
use std::future::Future;
use std::pin::Pin;
use std::task::Poll;

pub fn hybrid<Axum, Grpc>(make_web: Axum, grpc: Grpc) -> HybridMakeService<Axum, Grpc> {
HybridMakeService { make_web, grpc }
}

#[derive(Clone)]
pub struct HybridMakeService<Axum, Grpc> {
make_web: Axum,
pub struct HybridMakeService<Web, Grpc> {
make_web: Web,
grpc: Grpc,
}

impl<Web, Grpc> HybridMakeService<Web, Grpc> {
pub const fn new(make_web: Web, grpc: Grpc) -> Self {
HybridMakeService { make_web, grpc }
}
}

/// A Tower [`Service`] implementing the factory of a [`HybridService`].
impl<ConnInfo, MakeWeb, Grpc> Service<ConnInfo> for HybridMakeService<MakeWeb, Grpc>
where
MakeWeb: Service<ConnInfo>,
Grpc: Clone,
{
/// Responses given by the service.
type Response = HybridService<MakeWeb::Response, Grpc>;

/// Errors produced by the service.
type Error = MakeWeb::Error;

/// The future response value.
type Future = HybridMakeServiceFuture<MakeWeb::Future, Grpc>;

/// Returns `Poll::Ready(Ok(()))` when the service is able to process requests.
///
/// If the service is at capacity, then `Poll::Pending` is returned and the task
/// is notified when the service becomes ready again. This function is
/// expected to be called while on a task. Generally, this can be done with
/// a simple `futures::future::poll_fn` call.
///
/// If `Poll::Ready(Err(_))` is returned, the service is no longer able to service requests
/// and the caller should discard the service instance.
///
/// Once `poll_ready` returns `Poll::Ready(Ok(()))`, a request may be dispatched to the
/// service using `call`. Until a request is dispatched, repeated calls to
/// `poll_ready` must return either `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))`.
///
/// Note that `poll_ready` may reserve shared resources that are consumed in a subsequent
/// invocation of `call`. Thus, it is critical for implementations to not assume that `call`
/// will always be invoked and to ensure that such resources are released if the service is
/// dropped before `call` is invoked or the future returned by `call` is dropped before it
/// is polled.
fn poll_ready(&mut self, cx: &mut std::task::Context) -> Poll<Result<(), Self::Error>> {
self.make_web.poll_ready(cx)
}
Expand All @@ -38,6 +66,7 @@ where
}
}

/// The future returned by [`HybridMakeService`]. Returns a [`HybridService`].
#[pin_project]
pub struct HybridMakeServiceFuture<WebFuture, Grpc> {
#[pin]
Expand All @@ -64,6 +93,10 @@ where
}
}

/// Handles a request and returns a [`HybridFuture`] returning a [`HybridBody`].
///
/// This service switches between gRPC and regular HTTP traffic and forwards the request
/// to either the `Web` (e.g. Axum) or `Grpc` (e.g. Tonic) [`Service`].
pub struct HybridService<Web, Grpc> {
web: Web,
grpc: Grpc,
Expand All @@ -73,13 +106,37 @@ impl<Web, Grpc, WebBody, GrpcBody> Service<Request<Body>> for HybridService<Web,
where
Web: Service<Request<Body>, Response = Response<WebBody>>,
Grpc: Service<Request<Body>, Response = Response<GrpcBody>>,
Web::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
Grpc::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
Web::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
Grpc::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
{
/// Responses given by the service.
type Response = Response<HybridBody<WebBody, GrpcBody>>;
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;

/// Errors produced by the service.
type Error = Box<dyn Error + Send + Sync + 'static>;

/// The future response value.
type Future = HybridFuture<Web::Future, Grpc::Future>;

/// Returns `Poll::Ready(Ok(()))` when the service is able to process requests.
///
/// If the service is at capacity, then `Poll::Pending` is returned and the task
/// is notified when the service becomes ready again. This function is
/// expected to be called while on a task. Generally, this can be done with
/// a simple `futures::future::poll_fn` call.
///
/// If `Poll::Ready(Err(_))` is returned, the service is no longer able to service requests
/// and the caller should discard the service instance.
///
/// Once `poll_ready` returns `Poll::Ready(Ok(()))`, a request may be dispatched to the
/// service using `call`. Until a request is dispatched, repeated calls to
/// `poll_ready` must return either `Poll::Ready(Ok(()))` or `Poll::Ready(Err(_))`.
///
/// Note that `poll_ready` may reserve shared resources that are consumed in a subsequent
/// invocation of `call`. Thus, it is critical for implementations to not assume that `call`
/// will always be invoked and to ensure that such resources are released if the service is
/// dropped before `call` is invoked or the future returned by `call` is dropped before it
/// is polled.
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.web.poll_ready(cx) {
Poll::Ready(Ok(())) => match self.grpc.poll_ready(cx) {
Expand All @@ -93,14 +150,33 @@ where
}

fn call(&mut self, req: Request<Body>) -> Self::Future {
if req.headers().get("content-type").map(|x| x.as_bytes()) == Some(b"application/grpc") {
if is_grpc_request(&req) {
HybridFuture::Grpc(self.grpc.call(req))
} else {
HybridFuture::Web(self.web.call(req))
}
}
}

fn is_grpc_request<Body>(req: &Request<Body>) -> bool {
// gRPC uses HTTP/2 frames.
if req.version() != Version::HTTP_2 {
return false;
}

// The content-type header needs to start with `application/grpc`
const EXPECTED: &'static [u8] = b"application/grpc";
if let Some(content_type) = req
.headers()
.get(hyper::header::CONTENT_TYPE)
.map(|x| x.as_bytes())
{
content_type.len() >= EXPECTED.len() && content_type[..EXPECTED.len()] == *EXPECTED
} else {
false
}
}

#[pin_project(project = HybridBodyProj)]
pub enum HybridBody<WebBody, GrpcBody> {
Web(#[pin] WebBody),
Expand All @@ -111,11 +187,11 @@ impl<WebBody, GrpcBody> HttpBody for HybridBody<WebBody, GrpcBody>
where
WebBody: HttpBody + Send + Unpin,
GrpcBody: HttpBody<Data = WebBody::Data> + Send + Unpin,
WebBody::Error: std::error::Error + Send + Sync + 'static,
GrpcBody::Error: std::error::Error + Send + Sync + 'static,
WebBody::Error: Error + Send + Sync + 'static,
GrpcBody::Error: Error + Send + Sync + 'static,
{
type Data = WebBody::Data;
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
type Error = Box<dyn Error + Send + Sync + 'static>;

fn poll_data(
self: Pin<&mut Self>,
Expand Down Expand Up @@ -145,6 +221,7 @@ where
}
}

/// Future returned by [`HybridService`].
#[pin_project(project = HybridFutureProj)]
pub enum HybridFuture<WebFuture, GrpcFuture> {
Web(#[pin] WebFuture),
Expand All @@ -156,13 +233,11 @@ impl<WebFuture, GrpcFuture, WebBody, GrpcBody, WebError, GrpcError> Future
where
WebFuture: Future<Output = Result<Response<WebBody>, WebError>>,
GrpcFuture: Future<Output = Result<Response<GrpcBody>, GrpcError>>,
WebError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
GrpcError: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
WebError: Into<Box<dyn Error + Send + Sync + 'static>>,
GrpcError: Into<Box<dyn Error + Send + Sync + 'static>>,
{
type Output = Result<
Response<HybridBody<WebBody, GrpcBody>>,
Box<dyn std::error::Error + Send + Sync + 'static>,
>;
type Output =
Result<Response<HybridBody<WebBody, GrpcBody>>, Box<dyn Error + Send + Sync + 'static>>;

fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
match self.project() {
Expand Down
28 changes: 18 additions & 10 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod certs;
mod hybrid;

use crate::hybrid::{hybrid, HybridMakeService};
use crate::hybrid::HybridMakeService;
use axum::routing::IntoMakeService;
use axum::Router;
use futures_util::StreamExt;
Expand Down Expand Up @@ -60,7 +60,7 @@ async fn main() -> ExitCode {
let axum_make_svc = build_web_service();

// Combine web and gRPC into a hybrid service.
let service = hybrid(axum_make_svc, grpc_service);
let service = HybridMakeService::new(axum_make_svc, grpc_service);

// Bind first hyper HTTP server.
let socket_addr =
Expand Down Expand Up @@ -176,16 +176,10 @@ fn create_hyper_server_tls(
})
.expect("failed to bind Hyper server"); // TODO: Actually return error

// Create a TLS listener and filter out all invalid connections.
let incoming = TlsListener::new(certs::tls_acceptor(), listener)
.connections()
.filter(|conn| {
if let Err(err) = conn {
error!("Error: {:?}", err);
std::future::ready(false)
} else {
std::future::ready(true)
}
});
.filter(handle_tls_connect_error);

Server::builder(hyper::server::accept::from_stream(incoming))
.serve(service)
Expand Down Expand Up @@ -239,3 +233,17 @@ where
LoggingStyle::Json => formatter.json().init(),
}
}

fn handle_tls_connect_error<AddrStream>(
result: &Result<
tokio_rustls::server::TlsStream<AddrStream>,
tls_listener::Error<std::io::Error, std::io::Error, SocketAddr>,
>,
) -> std::future::Ready<bool> {
if let Err(err) = result {
error!("Error: {:?}", err);
std::future::ready(false)
} else {
std::future::ready(true)
}
}

0 comments on commit c434f09

Please sign in to comment.