From 76b4159d452e0c9295603a3db99ebd4b282d085e Mon Sep 17 00:00:00 2001 From: Daniel Bloom <82895745+Daniel-Bloom-dfinity@users.noreply.github.com> Date: Fri, 17 Jun 2022 09:56:06 -0700 Subject: [PATCH] fix: use stream api instead of channel (#40) Spawning a task has some slightly different async properties compared to the stream api, and in this case, we should prefer the stream api. This allows us to buffer a preset amount and ties the cancellation of the downstream request more closely to the upstream request. --- Cargo.lock | 4 +-- src/main.rs | 75 ++++++++++++++++++++++++----------------------------- 2 files changed, 36 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9c954cd..28f2643 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -826,9 +826,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.17" +version = "0.14.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "043f0e083e9901b6cc658a77d1eb86f4fc650bbb977a4337dd63192826aa85dd" +checksum = "b26ae0a80afebe130861d90abf98e3814a4f28a4c6ffeb5ab8ebb2be311e0ef2" dependencies = [ "bytes", "futures-channel", diff --git a/src/main.rs b/src/main.rs index 62393be..66fb828 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,10 @@ use axum::{handler::Handler, routing::get, Extension, Router}; use clap::{crate_authors, crate_version, Parser}; use flate2::read::{DeflateDecoder, GzDecoder}; -use futures::{future::OptionFuture, try_join, FutureExt}; +use futures::{future::OptionFuture, try_join, FutureExt, StreamExt}; use http_body::{LengthLimitError, Limited}; use hyper::{ body, - body::Bytes, http::{header::CONTENT_TYPE, uri::Parts}, service::{make_service_fn, service_fn}, Body, Client, Request, Response, Server, StatusCode, Uri, @@ -54,7 +53,10 @@ use crate::config::dns_canister_config::DnsCanisterConfig; type HttpResponseAny = HttpResponse; // Limit the total number of calls to an HTTP Request loop to 1000 for now. -const MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT: i32 = 1000; +const MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT: usize = 1000; + +// Limit the number of Stream Callbacks buffered +const STREAM_CALLBACK_BUFFFER: usize = 2; // The maximum length of a body we should log as tracing. const MAX_LOG_BODY_SIZE: usize = 100; @@ -419,53 +421,44 @@ async fn forward_request( }; let is_streaming = http_response.streaming_strategy.is_some(); let response = if let Some(streaming_strategy) = http_response.streaming_strategy { - let (mut sender, body) = body::Body::channel(); - let agent = agent.as_ref().clone(); - sender.send_data(Bytes::from(http_response.body)).await?; - - match streaming_strategy { - StreamingStrategy::Callback(callback) => { - let streaming_canister_id = callback.callback.0.principal; - let method_name = callback.callback.0.method; - let mut callback_token = callback.token; - let logger = logger.clone(); - tokio::spawn(async move { - let canister = HttpRequestCanister::create(&agent, streaming_canister_id); - // We have not yet called http_request_stream_callback. - let mut count = 0; - loop { - count += 1; - if count > MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT { - sender.abort(); - break; - } - + let body = http_response.body; + let body = futures::stream::once(async move { Ok(body) }); + let body = match streaming_strategy { + StreamingStrategy::Callback(callback) => body::Body::wrap_stream( + body.chain(futures::stream::try_unfold( + ( + logger.clone(), + agent, + callback.callback.0, + Some(callback.token), + ), + move |(logger, agent, callback, callback_token)| async move { + let callback_token = match callback_token { + Some(callback_token) => callback_token, + None => return Ok(None), + }; + + let canister = HttpRequestCanister::create(&agent, callback.principal); match canister - .http_request_stream_callback(&method_name, callback_token) + .http_request_stream_callback(&callback.method, callback_token) .call() .await { Ok((StreamingCallbackHttpResponse { body, token },)) => { - if sender.send_data(Bytes::from(body)).await.is_err() { - sender.abort(); - break; - } - if let Some(next_token) = token { - callback_token = next_token; - } else { - break; - } + Ok(Some((body, (logger, agent, callback, token)))) } Err(e) => { - slog::debug!(logger, "Error happened during streaming: {}", e); - sender.abort(); - break; + slog::warn!(logger, "Error happened during streaming: {}", e); + Err(e) } } - } - }); - } - } + }, + )) + .take(MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT) + .map(|x| async move { x }) + .buffered(STREAM_CALLBACK_BUFFFER), + ), + }; builder.body(body)? } else {