diff --git a/src/headers.rs b/src/headers.rs new file mode 100644 index 0000000..91071e7 --- /dev/null +++ b/src/headers.rs @@ -0,0 +1,137 @@ +use ic_utils::interfaces::http_request::HeaderField; +use lazy_regex::regex_captures; + +const MAX_LOG_CERT_NAME_SIZE: usize = 100; +const MAX_LOG_CERT_B64_SIZE: usize = 2000; + +#[derive(Debug, PartialEq)] +pub struct HeadersData { + pub certificate: Option, ()>>, + pub tree: Option, ()>>, + pub encoding: Option, +} + +pub fn extract_headers_data(headers: &[HeaderField], logger: &slog::Logger) -> HeadersData { + let mut headers_data = HeadersData { + certificate: None, + tree: None, + encoding: None, + }; + + for HeaderField(name, value) in headers { + if name.eq_ignore_ascii_case("Ic-Certificate") { + for field in value.split(',') { + if let Some((_, name, b64_value)) = regex_captures!("^(.*)=:(.*):$", field.trim()) { + slog::trace!( + logger, + ">> certificate {:.l1$}: {:.l2$}", + name, + b64_value, + l1 = MAX_LOG_CERT_NAME_SIZE, + l2 = MAX_LOG_CERT_B64_SIZE + ); + let bytes = decode_hash_tree(name, Some(b64_value.to_string()), logger); + if name == "certificate" { + headers_data.certificate = Some(match (headers_data.certificate, bytes) { + (None, bytes) => bytes, + (Some(Ok(certificate)), Ok(bytes)) => { + slog::warn!(logger, "duplicate certificate field: {:?}", bytes); + Ok(certificate) + } + (Some(Ok(certificate)), Err(_)) => { + slog::warn!( + logger, + "duplicate certificate field (failed to decode)" + ); + Ok(certificate) + } + (Some(Err(_)), bytes) => { + slog::warn!( + logger, + "duplicate certificate field (failed to decode)" + ); + bytes + } + }); + } else if name == "tree" { + headers_data.tree = Some(match (headers_data.tree, bytes) { + (None, bytes) => bytes, + (Some(Ok(tree)), Ok(bytes)) => { + slog::warn!(logger, "duplicate tree field: {:?}", bytes); + Ok(tree) + } + (Some(Ok(tree)), Err(_)) => { + slog::warn!(logger, "duplicate tree field (failed to decode)"); + Ok(tree) + } + (Some(Err(_)), bytes) => { + slog::warn!(logger, "duplicate tree field (failed to decode)"); + bytes + } + }); + } + } + } + } else if name.eq_ignore_ascii_case("Content-Encoding") { + let enc = value.trim().to_string(); + headers_data.encoding = Some(enc); + } + } + + headers_data +} + +fn decode_hash_tree( + name: &str, + value: Option, + logger: &slog::Logger, +) -> Result, ()> { + match value { + Some(tree) => base64::decode(tree).map_err(|e| { + slog::warn!(logger, "Unable to decode {} from base64: {}", name, e); + }), + _ => Err(()), + } +} + +#[cfg(test)] +mod tests { + use ic_utils::interfaces::http_request::HeaderField; + use slog::o; + + use super::{extract_headers_data, HeadersData}; + + #[test] + fn extract_headers_data_simple() { + let logger = slog::Logger::root(slog::Discard, o!()); + let headers: Vec = vec![]; + + let out = extract_headers_data(&headers, &logger); + + assert_eq!( + out, + HeadersData { + certificate: None, + tree: None, + encoding: None, + } + ); + } + + #[test] + fn extract_headers_data_content_encoding() { + let logger = slog::Logger::root(slog::Discard, o!()); + let headers: Vec = vec![HeaderField("Content-Encoding".into(), "test".into())]; + + let out = extract_headers_data(&headers, &logger); + + assert_eq!( + out, + HeadersData { + certificate: None, + tree: None, + encoding: Some(String::from("test")), + } + ); + } +} diff --git a/src/main.rs b/src/main.rs index 9871389..8dac1d2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,5 @@ use axum::{handler::Handler, routing::get, Extension, Router}; use clap::{crate_authors, crate_version, Parser}; -use flate2::read::{DeflateDecoder, GzDecoder}; use futures::{future::OptionFuture, try_join, FutureExt, StreamExt}; use http_body::{LengthLimitError, Limited}; use hyper::{ @@ -12,9 +11,7 @@ use hyper::{ use ic_agent::{ agent::http_transport::{reqwest, ReqwestHttpReplicaV2Transport}, agent_error::HttpErrorPayload, - export::Principal, - ic_types::{hash_tree::LookupResult, HashTree}, - lookup_value, Agent, AgentError, Certificate, + Agent, AgentError, }; use ic_utils::{ call::AsyncCall, @@ -24,11 +21,9 @@ use ic_utils::{ StreamingCallbackHttpResponse, StreamingStrategy, Token, }, }; -use lazy_regex::regex_captures; -use opentelemetry::{sdk::Resource, KeyValue}; +use opentelemetry::{global, sdk::Resource, KeyValue}; use opentelemetry_prometheus::PrometheusExporter; use prometheus::{Encoder, TextEncoder}; -use sha2::{Digest, Sha256}; use slog::Drain; use std::{ convert::Infallible, @@ -46,9 +41,17 @@ use std::{ mod canister_id; mod config; +mod headers; mod logging; - -use crate::config::dns_canister_config::DnsCanisterConfig; +mod metrics; +mod validate; + +use crate::{ + config::dns_canister_config::DnsCanisterConfig, + headers::{extract_headers_data, HeadersData}, + metrics::{MetricParams, WithMetrics}, + validate::{Validate, Validator}, +}; type HttpResponseAny = HttpResponse; @@ -60,12 +63,6 @@ const STREAM_CALLBACK_BUFFFER: usize = 2; // The maximum length of a body we should log as tracing. const MAX_LOG_BODY_SIZE: usize = 100; -const MAX_LOG_CERT_NAME_SIZE: usize = 100; -const MAX_LOG_CERT_B64_SIZE: usize = 2000; - -// The limit of a buffer we should decompress ~10mb. -const MAX_CHUNK_SIZE_TO_DECOMPRESS: usize = 1024; -const MAX_CHUNKS_TO_DECOMPRESS: u64 = 10_240; const KB: usize = 1024; const MB: usize = 1024 * KB; @@ -180,99 +177,11 @@ pub(crate) struct Opts { metrics_addr: Option, } -fn decode_hash_tree( - name: &str, - value: Option, - logger: &slog::Logger, -) -> Result, ()> { - match value { - Some(tree) => base64::decode(tree).map_err(|e| { - slog::warn!(logger, "Unable to decode {} from base64: {}", name, e); - }), - _ => Err(()), - } -} - -struct HeadersData { - certificate: Option, ()>>, - tree: Option, ()>>, - encoding: Option, -} - -fn extract_headers_data(headers: &[HeaderField], logger: &slog::Logger) -> HeadersData { - let mut headers_data = HeadersData { - certificate: None, - tree: None, - encoding: None, - }; - - for HeaderField(name, value) in headers { - if name.eq_ignore_ascii_case("IC-CERTIFICATE") { - for field in value.split(',') { - if let Some((_, name, b64_value)) = regex_captures!("^(.*)=:(.*):$", field.trim()) { - slog::trace!( - logger, - ">> certificate {:.l1$}: {:.l2$}", - name, - b64_value, - l1 = MAX_LOG_CERT_NAME_SIZE, - l2 = MAX_LOG_CERT_B64_SIZE - ); - let bytes = decode_hash_tree(name, Some(b64_value.to_string()), logger); - if name == "certificate" { - headers_data.certificate = Some(match (headers_data.certificate, bytes) { - (None, bytes) => bytes, - (Some(Ok(certificate)), Ok(bytes)) => { - slog::warn!(logger, "duplicate certificate field: {:?}", bytes); - Ok(certificate) - } - (Some(Ok(certificate)), Err(_)) => { - slog::warn!( - logger, - "duplicate certificate field (failed to decode)" - ); - Ok(certificate) - } - (Some(Err(_)), bytes) => { - slog::warn!( - logger, - "duplicate certificate field (failed to decode)" - ); - bytes - } - }); - } else if name == "tree" { - headers_data.tree = Some(match (headers_data.tree, bytes) { - (None, bytes) => bytes, - (Some(Ok(tree)), Ok(bytes)) => { - slog::warn!(logger, "duplicate tree field: {:?}", bytes); - Ok(tree) - } - (Some(Ok(tree)), Err(_)) => { - slog::warn!(logger, "duplicate tree field (failed to decode)"); - Ok(tree) - } - (Some(Err(_)), bytes) => { - slog::warn!(logger, "duplicate tree field (failed to decode)"); - bytes - } - }); - } - } - } - } else if name.eq_ignore_ascii_case("CONTENT-ENCODING") { - let enc = value.trim().to_string(); - headers_data.encoding = Some(enc); - } - } - - headers_data -} - async fn forward_request( request: Request, agent: Arc, resolver: &dyn canister_id::Resolver, + validator: &dyn Validate, logger: slog::Logger, ) -> Result, Box> { let canister_id = match resolver.resolve(&request) { @@ -471,7 +380,7 @@ async fn forward_request( builder.body(body)? } else { - let body_valid = validate( + let body_valid = validator.validate( &headers_data, &canister_id, &agent, @@ -523,159 +432,6 @@ async fn forward_request( Ok(response) } -fn validate( - headers_data: &HeadersData, - canister_id: &Principal, - agent: &Agent, - uri: &Uri, - response_body: &[u8], - logger: slog::Logger, -) -> Result<(), String> { - let body_sha = if let Some(body_sha) = - decode_body_to_sha256(response_body, headers_data.encoding.clone()) - { - body_sha - } else { - return Err("Body could not be decoded".into()); - }; - - let body_valid = match ( - headers_data.certificate.as_ref(), - headers_data.tree.as_ref(), - ) { - (Some(Ok(certificate)), Some(Ok(tree))) => match validate_body( - Certificates { certificate, tree }, - canister_id, - agent, - uri, - &body_sha, - logger.clone(), - ) { - Ok(true) => Ok(()), - Ok(false) => Err("Body does not pass verification".to_string()), - Err(e) => Err(format!("Certificate validation failed: {}", e)), - }, - (Some(_), _) | (_, Some(_)) => Err("Body does not pass verification".to_string()), - - // TODO: Remove this (FOLLOW-483) - // Canisters don't have to provide certified variables - // This should change in the future, grandfathering in current implementations - (None, None) => Ok(()), - }; - - if body_valid.is_err() && !cfg!(feature = "skip_body_verification") { - return body_valid; - } - - Ok(()) -} - -fn decode_body_to_sha256(body: &[u8], encoding: Option) -> Option<[u8; 32]> { - let mut sha256 = Sha256::new(); - let mut decoded = [0u8; MAX_CHUNK_SIZE_TO_DECOMPRESS]; - match encoding.as_deref() { - Some("gzip") => { - let mut decoder = GzDecoder::new(body); - for _ in 0..MAX_CHUNKS_TO_DECOMPRESS { - let bytes = decoder.read(&mut decoded).ok()?; - if bytes == 0 { - return Some(sha256.finalize().into()); - } - sha256.update(&decoded[0..bytes]); - } - if decoder.bytes().next().is_some() { - return None; - } - } - Some("deflate") => { - let mut decoder = DeflateDecoder::new(body); - for _ in 0..MAX_CHUNKS_TO_DECOMPRESS { - let bytes = decoder.read(&mut decoded).ok()?; - if bytes == 0 { - return Some(sha256.finalize().into()); - } - sha256.update(&decoded[0..bytes]); - } - if decoder.bytes().next().is_some() { - return None; - } - } - _ => sha256.update(body), - }; - Some(sha256.finalize().into()) -} - -struct Certificates<'a> { - certificate: &'a Vec, - tree: &'a Vec, -} - -fn validate_body( - certificates: Certificates, - canister_id: &Principal, - agent: &Agent, - uri: &Uri, - body_sha: &[u8; 32], - logger: slog::Logger, -) -> anyhow::Result { - let cert: Certificate = - serde_cbor::from_slice(certificates.certificate).map_err(AgentError::InvalidCborData)?; - let tree: HashTree = - serde_cbor::from_slice(certificates.tree).map_err(AgentError::InvalidCborData)?; - - if let Err(e) = agent.verify(&cert, *canister_id, false) { - slog::trace!(logger, ">> certificate failed verification: {}", e); - return Ok(false); - } - - let certified_data_path = vec![ - "canister".into(), - canister_id.into(), - "certified_data".into(), - ]; - let witness = match lookup_value(&cert, certified_data_path) { - Ok(witness) => witness, - Err(e) => { - slog::trace!( - logger, - ">> Could not find certified data for this canister in the certificate: {}", - e - ); - return Ok(false); - } - }; - let digest = tree.digest(); - - if witness != digest { - slog::trace!( - logger, - ">> witness ({}) did not match digest ({})", - hex::encode(witness), - hex::encode(digest) - ); - - return Ok(false); - } - - let path = ["http_assets".into(), uri.path().into()]; - let tree_sha = match tree.lookup_path(&path) { - LookupResult::Found(v) => v, - _ => match tree.lookup_path(&["http_assets".into(), "/index.html".into()]) { - LookupResult::Found(v) => v, - _ => { - slog::trace!( - logger, - ">> Invalid Tree in the header. Does not contain path {:?}", - path - ); - return Ok(false); - } - }, - }; - - Ok(body_sha == tree_sha) -} - fn is_hop_header(name: &str) -> bool { name.to_ascii_lowercase() == "connection" || name.to_ascii_lowercase() == "keep-alive" @@ -766,6 +522,7 @@ struct HandleRequest { client: reqwest::Client, proxy_url: Option, resolver: Arc>, + validator: Arc, logger: slog::Logger, fetch_root_key: bool, debug: bool, @@ -779,6 +536,7 @@ async fn handle_request( client, proxy_url, resolver, + validator, logger, fetch_root_key, debug, @@ -823,7 +581,14 @@ async fn handle_request( if fetch_root_key && agent.fetch_root_key().await.is_err() { unable_to_fetch_root_key() } else { - forward_request(request, agent, resolver.as_ref(), logger.clone()).await + forward_request( + request, + agent, + resolver.as_ref(), + validator.as_ref(), + logger.clone(), + ) + .await } }; @@ -1033,10 +798,12 @@ fn main() -> Result<(), Box> { &opts.ssl_root_certificate, opts.replica_resolve, ); + // Setup metrics let exporter = opentelemetry_prometheus::exporter() .with_resource(Resource::new(vec![KeyValue::new("service", "prober")])) .init(); + let meter = global::meter("icx-proxy"); let metrics_addr = opts.metrics_addr; let create_metrics_server = move || { @@ -1057,6 +824,10 @@ fn main() -> Result<(), Box> { check_params: !opts.ignore_url_canister_param, }); + let validator = Validator::new(); + let validator = WithMetrics(validator, MetricParams::new(&meter, "validator")); + let validator = Arc::new(validator); + let counter = AtomicUsize::new(0); let debug = opts.debug; let proxy_url = opts.proxy.clone(); @@ -1066,6 +837,7 @@ fn main() -> Result<(), Box> { let ip_addr = socket.remote_addr(); let ip_addr = ip_addr.ip(); let resolver = resolver.clone(); + let validator = validator.clone(); let logger = logger.clone(); // Select an agent. @@ -1082,16 +854,15 @@ fn main() -> Result<(), Box> { async move { Ok::<_, Infallible>(service_fn(move |request| { - let logger = logger.clone(); - let resolver = resolver.clone(); handle_request(HandleRequest { ip_addr, request, replica_url: replica_url.clone(), client: client.clone(), proxy_url: proxy_url.clone(), - resolver, - logger, + resolver: resolver.clone(), + validator: validator.clone(), + logger: logger.clone(), fetch_root_key, debug, }) diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 0000000..4824d12 --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,51 @@ +use opentelemetry::{ + metrics::{Counter, Meter}, + KeyValue, +}; + +use crate::validate::Validate; + +pub struct WithMetrics(pub T, pub MetricParams); + +pub struct MetricParams { + pub counter: Counter, +} + +impl MetricParams { + pub fn new(meter: &Meter, name: &str) -> Self { + Self { + counter: meter + .u64_counter(format!("{name}.total")) + .with_description(format!("Counts occurences of {name} calls")) + .init(), + } + } +} + +impl Validate for WithMetrics { + fn validate( + &self, + headers_data: &crate::headers::HeadersData, + canister_id: &candid::Principal, + agent: &ic_agent::Agent, + uri: &hyper::Uri, + response_body: &[u8], + logger: slog::Logger, + ) -> Result<(), String> { + let out = self + .0 + .validate(headers_data, canister_id, agent, uri, response_body, logger); + + let mut status = if out.is_ok() { "ok" } else { "fail" }; + if cfg!(feature = "skip_body_verification") { + status = "skip"; + } + + let labels = &[KeyValue::new("status", status)]; + + let MetricParams { counter } = &self.1; + counter.add(1, labels); + + out + } +} diff --git a/src/validate.rs b/src/validate.rs new file mode 100644 index 0000000..9be20ec --- /dev/null +++ b/src/validate.rs @@ -0,0 +1,228 @@ +use std::io::Read; + +use candid::Principal; +use flate2::read::{DeflateDecoder, GzDecoder}; +use hyper::Uri; +use ic_agent::{ + hash_tree::LookupResult, ic_types::HashTree, lookup_value, Agent, AgentError, Certificate, +}; +use sha2::{Digest, Sha256}; + +use crate::HeadersData; + +// The limit of a buffer we should decompress ~10mb. +const MAX_CHUNK_SIZE_TO_DECOMPRESS: usize = 1024; +const MAX_CHUNKS_TO_DECOMPRESS: u64 = 10_240; + +pub trait Validate: Sync + Send { + fn validate( + &self, + headers_data: &HeadersData, + canister_id: &Principal, + agent: &Agent, + uri: &Uri, + response_body: &[u8], + logger: slog::Logger, + ) -> Result<(), String>; +} + +pub struct Validator {} + +impl Validator { + pub fn new() -> Self { + Self {} + } +} + +impl Validate for Validator { + fn validate( + &self, + headers_data: &HeadersData, + canister_id: &Principal, + agent: &Agent, + uri: &Uri, + response_body: &[u8], + logger: slog::Logger, + ) -> Result<(), String> { + let body_sha = if let Some(body_sha) = + decode_body_to_sha256(response_body, headers_data.encoding.clone()) + { + body_sha + } else { + return Err("Body could not be decoded".into()); + }; + + let body_valid = match ( + headers_data.certificate.as_ref(), + headers_data.tree.as_ref(), + ) { + (Some(Ok(certificate)), Some(Ok(tree))) => match validate_body( + Certificates { certificate, tree }, + canister_id, + agent, + uri, + &body_sha, + logger.clone(), + ) { + Ok(true) => Ok(()), + Ok(false) => Err("Body does not pass verification".to_string()), + Err(e) => Err(format!("Certificate validation failed: {}", e)), + }, + (Some(_), _) | (_, Some(_)) => Err("Body does not pass verification".to_string()), + + // TODO: Remove this (FOLLOW-483) + // Canisters don't have to provide certified variables + // This should change in the future, grandfathering in current implementations + (None, None) => Ok(()), + }; + + if cfg!(feature = "skip_body_verification") { + return Ok(()); + } + + body_valid + } +} + +struct Certificates<'a> { + certificate: &'a Vec, + tree: &'a Vec, +} + +fn decode_body_to_sha256(body: &[u8], encoding: Option) -> Option<[u8; 32]> { + let mut sha256 = Sha256::new(); + let mut decoded = [0u8; MAX_CHUNK_SIZE_TO_DECOMPRESS]; + match encoding.as_deref() { + Some("gzip") => { + let mut decoder = GzDecoder::new(body); + for _ in 0..MAX_CHUNKS_TO_DECOMPRESS { + let bytes = decoder.read(&mut decoded).ok()?; + if bytes == 0 { + return Some(sha256.finalize().into()); + } + sha256.update(&decoded[0..bytes]); + } + if decoder.bytes().next().is_some() { + return None; + } + } + Some("deflate") => { + let mut decoder = DeflateDecoder::new(body); + for _ in 0..MAX_CHUNKS_TO_DECOMPRESS { + let bytes = decoder.read(&mut decoded).ok()?; + if bytes == 0 { + return Some(sha256.finalize().into()); + } + sha256.update(&decoded[0..bytes]); + } + if decoder.bytes().next().is_some() { + return None; + } + } + _ => sha256.update(body), + }; + Some(sha256.finalize().into()) +} + +fn validate_body( + certificates: Certificates, + canister_id: &Principal, + agent: &Agent, + uri: &Uri, + body_sha: &[u8; 32], + logger: slog::Logger, +) -> anyhow::Result { + let cert: Certificate = + serde_cbor::from_slice(certificates.certificate).map_err(AgentError::InvalidCborData)?; + let tree: HashTree = + serde_cbor::from_slice(certificates.tree).map_err(AgentError::InvalidCborData)?; + + if let Err(e) = agent.verify(&cert, *canister_id, false) { + slog::trace!(logger, ">> certificate failed verification: {}", e); + return Ok(false); + } + + let certified_data_path = vec![ + "canister".into(), + canister_id.into(), + "certified_data".into(), + ]; + let witness = match lookup_value(&cert, certified_data_path) { + Ok(witness) => witness, + Err(e) => { + slog::trace!( + logger, + ">> Could not find certified data for this canister in the certificate: {}", + e + ); + return Ok(false); + } + }; + let digest = tree.digest(); + + if witness != digest { + slog::trace!( + logger, + ">> witness ({}) did not match digest ({})", + hex::encode(witness), + hex::encode(digest) + ); + + return Ok(false); + } + + let path = ["http_assets".into(), uri.path().into()]; + let tree_sha = match tree.lookup_path(&path) { + LookupResult::Found(v) => v, + _ => match tree.lookup_path(&["http_assets".into(), "/index.html".into()]) { + LookupResult::Found(v) => v, + _ => { + slog::trace!( + logger, + ">> Invalid Tree in the header. Does not contain path {:?}", + path + ); + return Ok(false); + } + }, + }; + + Ok(body_sha == tree_sha) +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use candid::Principal; + use hyper::Uri; + use ic_agent::{agent::http_transport::ReqwestHttpReplicaV2Transport, Agent}; + use slog::o; + + use crate::{ + headers::HeadersData, + validate::{Validate, Validator}, + }; + + #[test] + fn validate_nop() { + let headers = HeadersData { + certificate: None, + encoding: None, + tree: None, + }; + + let canister_id = Principal::from_text("wwc2m-2qaaa-aaaac-qaaaa-cai").unwrap(); + let transport = ReqwestHttpReplicaV2Transport::create("http://www.example.com").unwrap(); + let agent = Agent::builder().with_transport(transport).build().unwrap(); + let uri = Uri::from_str("http://www.example.com").unwrap(); + let body = vec![]; + let logger = slog::Logger::root(slog::Discard, o!()); + + let validator = Validator::new(); + + let out = validator.validate(&headers, &canister_id, &agent, &uri, &body, logger); + + assert_eq!(out, Ok(())); + } +}