From 1b4ed63615ac592e9fee47019f478d550c7e50bd Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Mon, 6 Nov 2023 11:48:05 +0000 Subject: [PATCH] Add optional encryption support to redis storage backend --- casper-server/src/lua/storage.rs | 8 +- .../src/storage/backends/redis/client.rs | 155 +++++++++--- .../src/storage/backends/redis/config.rs | 5 + casper-server/src/storage/mod.rs | 3 + casper-server/src/utils/aes.rs | 236 ++++++++++++++++++ casper-server/src/utils/mod.rs | 1 + 6 files changed, 370 insertions(+), 38 deletions(-) create mode 100644 casper-server/src/utils/aes.rs diff --git a/casper-server/src/lua/storage.rs b/casper-server/src/lua/storage.rs index 258dcbaa..7d05adc2 100644 --- a/casper-server/src/lua/storage.rs +++ b/casper-server/src/lua/storage.rs @@ -166,6 +166,7 @@ where .raw_get("surrogate_keys") .context("invalid `surrogate_keys`")?; let ttl: f32 = item.raw_get("ttl").context("invalid `ttl`")?; + let encrypt: Option = item.raw_get("encrypt").unwrap_or_default(); // Read Response body (it's consumed and saved) let body = lua_try!(resp.body_mut().buffer().await).unwrap_or_default(); @@ -189,6 +190,7 @@ where body, surrogate_keys, ttl: Duration::from_secs_f32(ttl), + encrypt: encrypt.unwrap_or_default(), }) .await; @@ -228,6 +230,7 @@ where let ttl: f32 = item .raw_get("ttl") .with_context(|_| format!("invalid `ttl` #{}", i + 1))?; + let encrypt: Option = item.raw_get("encrypt").unwrap_or_default(); // Read Response body (it's consumed and saved) let body = resp.body_mut().buffer().await?.unwrap_or_default(); @@ -246,19 +249,20 @@ where .map(|s| Key::copy_from_slice(s.as_bytes())) .collect::>(); - items.push((i, key, resp, body, surrogate_keys, ttl)); + items.push((i, key, resp, body, surrogate_keys, ttl, encrypt)); } // Transform items elements from tuple to Item struct let items = items .iter() - .map(|(_, key, resp, body, surrogate_keys, ttl)| Item { + .map(|(_, key, resp, body, surrogate_keys, ttl, encrypt)| Item { key: key.clone(), status: resp.status(), headers: Cow::Borrowed(resp.headers()), body: body.clone(), surrogate_keys: surrogate_keys.clone(), ttl: Duration::from_secs_f32(*ttl), + encrypt: encrypt.unwrap_or_default(), }) .collect::>(); diff --git a/casper-server/src/storage/backends/redis/client.rs b/casper-server/src/storage/backends/redis/client.rs index f169cb82..bae5f0ce 100644 --- a/casper-server/src/storage/backends/redis/client.rs +++ b/casper-server/src/storage/backends/redis/client.rs @@ -5,7 +5,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use base64::Engine as _; use bitflags::bitflags; @@ -28,6 +28,7 @@ use tokio::time::timeout; use super::Config; use crate::storage::{decode_headers, encode_headers, Item, ItemKey, Key, Storage}; +use crate::utils::aes::{aes256_decrypt, aes256_encrypt, AESDecoder}; use crate::utils::zstd::{compress_with_zstd, decompress_with_zstd, ZstdDecoder}; // TODO: Define format version @@ -70,12 +71,14 @@ bitflags! { const EX_COMPRESSED = 0b00000001; // Deprecated const HEADERS_COMPRESSED = 0b00000010; // Headers compression const BODY_COMPRESSED = 0b00000100; // Body compression + const ENCRYPTED = 0b00001000; } } const EX_COMPRESSED: Flags = Flags::EX_COMPRESSED; // Deprecated const HEADERS_COMPRESSED: Flags = Flags::HEADERS_COMPRESSED; const BODY_COMPRESSED: Flags = Flags::BODY_COMPRESSED; +const ENCRYPTED: Flags = Flags::ENCRYPTED; struct RedisMetrics { pub internal_cache_counter: Counter, @@ -242,18 +245,37 @@ impl RedisBackend { let status = StatusCode::from_u16(response_item.status_code)?; let flags = response_item.flags; + let mut raw_headers = response_item.headers; - let headers = match flags.contains(HEADERS_COMPRESSED) || flags.contains(EX_COMPRESSED) { - true => decode_headers(&decompress_with_zstd(&response_item.headers)?)?, - false => decode_headers(&response_item.headers)?, + // Decrypt headers if required + let encryption_key = self.config.encryption_key.as_ref(); + raw_headers = match (flags.contains(ENCRYPTED), encryption_key) { + (true, Some(key)) => { + aes256_decrypt(&raw_headers, key).context("failed to decrypt headers")? + } + (true, None) => return Err(anyhow!("response is encrypted")), + (false, _) => raw_headers, }; + // Decompress headers if required + if flags.contains(HEADERS_COMPRESSED) || flags.contains(EX_COMPRESSED) { + raw_headers = + decompress_with_zstd(&raw_headers).context("failed to decompress headers")?; + } + + // Decode them + let headers = decode_headers(&raw_headers).context("failed to decode headers")?; // If we have only one chunk, decode it in-place if response_item.num_chunks == 1 { - let body = match flags.contains(BODY_COMPRESSED) || flags.contains(EX_COMPRESSED) { - true => decompress_with_zstd(&response_item.body)?, - false => response_item.body, - }; + let mut body = response_item.body; + // Decrypt body + if flags.contains(ENCRYPTED) { + body = aes256_decrypt(&body, encryption_key.unwrap())?; + } + // Decompress body + if flags.contains(BODY_COMPRESSED) || flags.contains(EX_COMPRESSED) { + body = decompress_with_zstd(&body)?; + } // Construct a new Response object let mut resp = Response::with_body(status, Body::Bytes(body)); @@ -276,21 +298,42 @@ impl RedisBackend { Some(data) => Ok(Bytes::from(data)), None => Err(io::Error::new( io::ErrorKind::NotFound, - format!("Cannot find chunk {}/{}", i + 2, num_chunks), + format!("cannot find chunk {}/{}", i + 2, num_chunks), )), } }); let body_stream = stream::iter(vec![Ok(response_item.body)]).chain(chunks_stream); - // Decompress the body if required + // Decrypt and/or decompress the body if required let body_size = response_item.body_length as u64; - let body = if flags.contains(BODY_COMPRESSED) || flags.contains(EX_COMPRESSED) { - let body_stream = - ZstdDecoder::new(body_stream).map_err(|err| Box::new(err) as Box); - Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream)))) - } else { - let body_stream = body_stream.map_err(|err| Box::new(err) as Box); - Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream)))) + let body = match ( + flags.contains(ENCRYPTED), + flags.contains(BODY_COMPRESSED) || flags.contains(EX_COMPRESSED), + ) { + (true, true) => { + // Decrypt and decompress + let body_stream = AESDecoder::new(body_stream, encryption_key.unwrap().clone()); + let body_stream = + ZstdDecoder::new(body_stream).map_err(|err| Box::new(err) as Box); + Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream)))) + } + (true, false) => { + // Decrypt only + let body_stream = AESDecoder::new(body_stream, encryption_key.unwrap().clone()) + .map_err(|err| Box::new(err) as Box); + Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream)))) + } + (false, true) => { + // Decompress only + let body_stream = + ZstdDecoder::new(body_stream).map_err(|err| Box::new(err) as Box); + Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream)))) + } + (false, false) => { + // Do nothing + let body_stream = body_stream.map_err(|err| Box::new(err) as Box); + Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream)))) + } }; // Construct a new Response object @@ -364,6 +407,16 @@ impl RedisBackend { } } + // If encryption is enabled, encrypt the body and headers and update flags + if let (true, Some(key)) = (item.encrypt, &self.config.encryption_key) { + (headers, body) = try_join( + aes256_encrypt(headers, key.clone()), + aes256_encrypt(body, key.clone()), + ) + .await?; + flags.insert(ENCRYPTED); + } + // Split body to chunks and save chunks first let max_chunk_size = self.config.max_body_chunk_size; let mut num_chunks = 1; @@ -616,19 +669,20 @@ mod tests { } #[ntex::test] - async fn test_chunked_compressed_body() { - // Same as the above test, but with compression enabled + async fn test_compression() { let mut config = Config::default(); - config.max_body_chunk_size = 2; // Set max chunk size to 2 bytes - config.compression_level = Some(0); + config.compression_level = Some(22); let backend = RedisBackend::new(config, None).unwrap(); backend.connect().await.unwrap(); let key = make_uniq_key(); // Cache response - let resp = make_response("hello, world"); - + let mut resp = make_response("hello, world"); // body is too small to be compressed + resp.headers_mut().insert( + HeaderName::from_static("hello-world-header"), + HeaderValue::from_static("Hello world header data"), + ); backend .store_response(Item::new(key.clone(), resp, Duration::from_secs(3))) .await @@ -636,14 +690,18 @@ mod tests { // Fetch it back let mut resp = backend.get_response(key.clone()).await.unwrap().unwrap(); + assert_eq!( + resp.headers().get("Hello-World-Header").unwrap(), + "Hello world header data" + ); let body = buffer_body(resp.take_body()).await.unwrap().to_vec(); assert_eq!(String::from_utf8(body).unwrap(), "hello, world"); } #[ntex::test] - async fn test_compressed_headers() { + async fn test_encryption() { let mut config = Config::default(); - config.compression_level = Some(22); + config.encryption_key = Some(Bytes::from_static(&[16; 32])); let backend = RedisBackend::new(config, None).unwrap(); backend.connect().await.unwrap(); @@ -652,21 +710,46 @@ mod tests { // Cache response let mut resp = make_response("hello, world"); resp.headers_mut().insert( - HeaderName::from_static("hello-world-header"), - HeaderValue::from_static("Hello world header data"), + HeaderName::from_static("x-header"), + HeaderValue::from_static("value"), ); - - backend - .store_response(Item::new(key.clone(), resp, Duration::from_secs(3))) - .await - .unwrap(); + let mut item = Item::new(key.clone(), resp, Duration::from_secs(3)); + item.encrypt = true; + backend.store_response(item).await.unwrap(); // Fetch it back - let resp = backend.get_response(key.clone()).await.unwrap().unwrap(); - assert_eq!( - resp.headers().get("Hello-World-Header").unwrap(), - "Hello world header data" + let mut resp = backend.get_response(key.clone()).await.unwrap().unwrap(); + assert_eq!(resp.headers().get("X-Header").unwrap(), "value".as_bytes()); + let body = buffer_body(resp.take_body()).await.unwrap().to_vec(); + assert_eq!(String::from_utf8(body).unwrap(), "hello, world"); + } + + #[ntex::test] + async fn test_chunked_compression_encryption() { + let mut config = Config::default(); + config.max_body_chunk_size = 2; // Set max chunk size to 2 bytes + config.compression_level = Some(0); // Use zstd default compression level + config.encryption_key = Some(Bytes::from_static(&[16; 32])); + let backend = RedisBackend::new(config, None).unwrap(); + backend.connect().await.unwrap(); + + let key = make_uniq_key(); + + // Cache response + let mut resp = make_response("hello, world!".repeat(10)); + resp.headers_mut().insert( + HeaderName::from_static("x-header"), + HeaderValue::from_static("value"), ); + let mut item = Item::new(key.clone(), resp, Duration::from_secs(3)); + item.encrypt = true; + backend.store_response(item).await.unwrap(); + + // Fetch it back + let mut resp = backend.get_response(key.clone()).await.unwrap().unwrap(); + assert_eq!(resp.headers().get("X-Header").unwrap(), "value".as_bytes()); + let body = buffer_body(resp.take_body()).await.unwrap().to_vec(); + assert_eq!(String::from_utf8(body).unwrap(), "hello, world!".repeat(10)); } #[ntex::test] diff --git a/casper-server/src/storage/backends/redis/config.rs b/casper-server/src/storage/backends/redis/config.rs index 4d8052eb..53bd5348 100644 --- a/casper-server/src/storage/backends/redis/config.rs +++ b/casper-server/src/storage/backends/redis/config.rs @@ -2,6 +2,7 @@ use std::time::Duration; use anyhow::{anyhow, bail, Context, Result}; use fred::types::{ConnectionConfig, RedisConfig, TcpConfig, TlsConnector}; +use ntex::util::Bytes; use serde::Deserialize; /// Redis backend configuration @@ -35,6 +36,9 @@ pub struct Config { pub internal_cache_size: usize, #[serde(default = "Config::default_internal_cache_ttl")] pub internal_cache_ttl: f64, + + // Optional encryption key + pub encryption_key: Option, } #[derive(Clone, Debug, Deserialize)] @@ -129,6 +133,7 @@ impl Default for Config { lazy: false, internal_cache_size: Config::default_internal_cache_size(), internal_cache_ttl: Config::default_internal_cache_ttl(), + encryption_key: None, } } } diff --git a/casper-server/src/storage/mod.rs b/casper-server/src/storage/mod.rs index 21276c33..b6105f6b 100644 --- a/casper-server/src/storage/mod.rs +++ b/casper-server/src/storage/mod.rs @@ -23,6 +23,7 @@ pub struct Item<'a> { pub body: Bytes, pub surrogate_keys: Vec, pub ttl: Duration, + pub encrypt: bool, } impl Item<'static> { @@ -36,6 +37,7 @@ impl Item<'static> { body: body.as_ref().unwrap().clone(), surrogate_keys: Vec::new(), ttl, + encrypt: false, } } @@ -54,6 +56,7 @@ impl Item<'static> { body: body.as_ref().unwrap().clone(), surrogate_keys: surrogate_keys.into_iter().map(|sk| sk.into()).collect(), ttl, + encrypt: false, } } } diff --git a/casper-server/src/utils/aes.rs b/casper-server/src/utils/aes.rs new file mode 100644 index 00000000..8960eb73 --- /dev/null +++ b/casper-server/src/utils/aes.rs @@ -0,0 +1,236 @@ +use std::borrow::Cow; +use std::io::{Error as IoError, ErrorKind as IoErrorKind}; +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::{ready, stream::Stream}; +use ntex::util::{Bytes, BytesMut}; +use openssl::symm::{decrypt_aead, encrypt_aead, Cipher, Crypter, Mode}; +use pin_project_lite::pin_project; +use rand::{thread_rng, RngCore}; + +/// Size of the tag used by AES256-GCM. +const TAG_SIZE: usize = 16; + +/// Encrypts data with AES256-GCM using the key. +/// +/// It does not block when performing the encryption. +pub async fn aes256_encrypt(data: B, key: Bytes) -> Result +where + B: AsRef<[u8]> + Send + 'static, +{ + tokio::task::spawn_blocking(move || { + let data = data.as_ref(); + let cipher = Cipher::aes_256_gcm(); + + let mut iv = vec![0; cipher.iv_len().unwrap()]; + thread_rng().fill_bytes(&mut iv); + let mut tag = [0; TAG_SIZE]; + + let key = normalize_key(&key, cipher.key_len()); + encrypt_aead(cipher, &key, Some(&iv), &[], data, &mut tag) + .map(|mut data| { + // Prepend iv and tag to the beginning of the encrypted data: + // + data.reverse(); + tag.reverse(); + data.extend_from_slice(&tag); + iv.reverse(); + data.extend_from_slice(&iv); + data.reverse(); + Bytes::from(data) + }) + .map_err(|_| IoError::new(IoErrorKind::Other, "failed to encrypt data")) + }) + .await? +} + +/// Decrypts data with AES256-GCM using the key. +/// +/// The data must be encrypted with `aes256_encrypt` and contain the iv and tag. +pub fn aes256_decrypt(data: &[u8], key: &[u8]) -> Result { + let cipher = Cipher::aes_256_gcm(); + + let (iv_tag, data) = data.split_at(cipher.iv_len().unwrap() + TAG_SIZE); + let (iv, tag) = iv_tag.split_at(cipher.iv_len().unwrap()); + + let key = normalize_key(key, cipher.key_len()); + decrypt_aead(cipher, &key, Some(iv), &[], data, tag) + .map(Into::into) + .map_err(|_| IoError::new(IoErrorKind::Other, "failed to decrypt data")) +} + +/// Normalizes the key to the required length. +fn normalize_key(key: &[u8], required_len: usize) -> Cow<[u8]> { + match key.len() { + len if len > required_len => Cow::Borrowed(&key[..required_len]), + len if len < required_len => { + let mut new_key = vec![0; required_len]; + new_key[..len].copy_from_slice(key); + Cow::Owned(new_key) + } + _ => Cow::Borrowed(key), + } +} + +pin_project! { + pub struct AESDecoder { + #[pin] + stream: S, + key: Bytes, + cipher: Option, + decrypter: Option, + iv_len: usize, + block_size: usize, + state: State, + input: Bytes, + buffer: BytesMut, + } +} + +enum State { + Init, + Reading, + Decoding, + Flushing, + Done, +} + +impl AESDecoder { + pub fn new(stream: S, key: Bytes) -> Self { + let cipher = Cipher::aes_256_gcm(); + let iv_len = cipher.iv_len().unwrap(); + let block_size = cipher.block_size(); + + Self { + stream, + key, + cipher: Some(cipher), + decrypter: None, + iv_len, + block_size, + state: State::Init, + input: Bytes::new(), + buffer: BytesMut::new(), + } + } +} + +impl Stream for AESDecoder +where + S: Stream>, +{ + type Item = ::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + loop { + match *this.state { + State::Init => { + // Fetch iv and tag to init decrypter + if let Some(chunk) = ready!(this.stream.as_mut().poll_next(cx)) { + this.buffer.extend_from_slice(&chunk?); + if this.buffer.len() < TAG_SIZE + *this.iv_len { + // Not enough data, continue + continue; + } + + // Init decrypter + let iv_tag = this.buffer.split_to(*this.iv_len + TAG_SIZE); + let (iv, tag) = iv_tag.split_at(*this.iv_len); + let cipher = this.cipher.take().unwrap(); + let key = normalize_key(this.key, cipher.key_len()); + let mut decrypter = Crypter::new(cipher, Mode::Decrypt, &key, Some(iv)) + .map_err(|_| { + IoError::new(IoErrorKind::Other, "failed to create decrypter") + })?; + decrypter.set_tag(tag).unwrap(); + *this.decrypter = Some(decrypter); + *this.input = mem::take(this.buffer).freeze(); // Consume the rest of the buffer + *this.state = State::Decoding; + } else { + *this.state = State::Done; + } + } + + State::Reading => { + if let Some(chunk) = ready!(this.stream.as_mut().poll_next(cx)) { + *this.input = chunk?; + *this.state = State::Decoding; + } else { + *this.state = State::Flushing; + } + } + + State::Decoding => { + if this.input.is_empty() { + *this.state = State::Reading; + continue; + } + + let decrypter = this.decrypter.as_mut().unwrap(); + this.buffer.resize(this.input.len() + *this.block_size, 0); + let count = decrypter.update(this.input, this.buffer).map_err(|_| { + IoError::new(IoErrorKind::InvalidData, "failed to decrypt chunk") + })?; + *this.state = State::Reading; + if count > 0 { + break Poll::Ready(Some(Ok(Bytes::copy_from_slice(&this.buffer[..count])))); + } + } + + State::Flushing => { + let decrypter = this.decrypter.as_mut().unwrap(); + this.buffer.resize(*this.block_size, 0); + let count = decrypter.finalize(this.buffer).map_err(|_| { + IoError::new(IoErrorKind::InvalidData, "failed to finalize decryption") + })?; + *this.state = State::Done; + if count > 0 { + break Poll::Ready(Some(Ok(this.buffer.split_to(count).freeze()))); + } + } + + State::Done => { + break Poll::Ready(None); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::stream::{self, TryStreamExt}; + + #[ntex::test] + async fn test_encrypt_decrypt() { + let key = Bytes::from_static(b"some key"); + + let data = Bytes::from_static(b"hello"); + let encrypted = aes256_encrypt(data.clone(), key.clone()).await.unwrap(); + + let decrypted = aes256_decrypt(&encrypted, &key).unwrap(); + assert_eq!(decrypted, data); + } + + #[ntex::test] + async fn test_decrypt_stream() { + let key = Bytes::from_static(b"some key"); + + let data = b"hello world, this is a long string that will be encrypted and decrypted"; + let encrypted = aes256_encrypt(data, key.clone()).await.unwrap(); + + let stream = stream::iter( + encrypted + .chunks(8) + .map(|chunk| Ok(Bytes::copy_from_slice(chunk))), + ); + let decoder = AESDecoder::new(stream, key.clone()); + let decoded_chunks = decoder.try_collect::>().await.unwrap(); + assert_eq!(decoded_chunks.len(), 10); + assert_eq!(decoded_chunks.concat(), data); + } +} diff --git a/casper-server/src/utils/mod.rs b/casper-server/src/utils/mod.rs index eaec0214..8dd69d3d 100644 --- a/casper-server/src/utils/mod.rs +++ b/casper-server/src/utils/mod.rs @@ -29,6 +29,7 @@ pub fn random_string(len: usize, charset: Option<&str>) -> String { } } +pub mod aes; pub mod zstd; #[cfg(test)]