Skip to content

Commit

Permalink
Add optional encryption support to redis storage backend
Browse files Browse the repository at this point in the history
  • Loading branch information
khvzak committed Nov 6, 2023
1 parent 7dcff80 commit 1b4ed63
Show file tree
Hide file tree
Showing 6 changed files with 370 additions and 38 deletions.
8 changes: 6 additions & 2 deletions casper-server/src/lua/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ where
.raw_get("surrogate_keys")
.context("invalid `surrogate_keys`")?;
let ttl: f32 = item.raw_get("ttl").context("invalid `ttl`")?;
let encrypt: Option<bool> = item.raw_get("encrypt").unwrap_or_default();

// Read Response body (it's consumed and saved)
let body = lua_try!(resp.body_mut().buffer().await).unwrap_or_default();
Expand All @@ -189,6 +190,7 @@ where
body,
surrogate_keys,
ttl: Duration::from_secs_f32(ttl),
encrypt: encrypt.unwrap_or_default(),
})
.await;

Expand Down Expand Up @@ -228,6 +230,7 @@ where
let ttl: f32 = item
.raw_get("ttl")
.with_context(|_| format!("invalid `ttl` #{}", i + 1))?;
let encrypt: Option<bool> = item.raw_get("encrypt").unwrap_or_default();

// Read Response body (it's consumed and saved)
let body = resp.body_mut().buffer().await?.unwrap_or_default();
Expand All @@ -246,19 +249,20 @@ where
.map(|s| Key::copy_from_slice(s.as_bytes()))
.collect::<Vec<_>>();

items.push((i, key, resp, body, surrogate_keys, ttl));
items.push((i, key, resp, body, surrogate_keys, ttl, encrypt));
}

// Transform items elements from tuple to Item struct
let items = items
.iter()
.map(|(_, key, resp, body, surrogate_keys, ttl)| Item {
.map(|(_, key, resp, body, surrogate_keys, ttl, encrypt)| Item {
key: key.clone(),
status: resp.status(),
headers: Cow::Borrowed(resp.headers()),
body: body.clone(),
surrogate_keys: surrogate_keys.clone(),
ttl: Duration::from_secs_f32(*ttl),
encrypt: encrypt.unwrap_or_default(),
})
.collect::<Vec<_>>();

Expand Down
155 changes: 119 additions & 36 deletions casper-server/src/storage/backends/redis/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};

use anyhow::{Context, Result};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use base64::Engine as _;
use bitflags::bitflags;
Expand All @@ -28,6 +28,7 @@ use tokio::time::timeout;

use super::Config;
use crate::storage::{decode_headers, encode_headers, Item, ItemKey, Key, Storage};
use crate::utils::aes::{aes256_decrypt, aes256_encrypt, AESDecoder};
use crate::utils::zstd::{compress_with_zstd, decompress_with_zstd, ZstdDecoder};

// TODO: Define format version
Expand Down Expand Up @@ -70,12 +71,14 @@ bitflags! {
const EX_COMPRESSED = 0b00000001; // Deprecated
const HEADERS_COMPRESSED = 0b00000010; // Headers compression
const BODY_COMPRESSED = 0b00000100; // Body compression
const ENCRYPTED = 0b00001000;
}
}

const EX_COMPRESSED: Flags = Flags::EX_COMPRESSED; // Deprecated
const HEADERS_COMPRESSED: Flags = Flags::HEADERS_COMPRESSED;
const BODY_COMPRESSED: Flags = Flags::BODY_COMPRESSED;
const ENCRYPTED: Flags = Flags::ENCRYPTED;

struct RedisMetrics {
pub internal_cache_counter: Counter<u64>,
Expand Down Expand Up @@ -242,18 +245,37 @@ impl RedisBackend {

let status = StatusCode::from_u16(response_item.status_code)?;
let flags = response_item.flags;
let mut raw_headers = response_item.headers;

let headers = match flags.contains(HEADERS_COMPRESSED) || flags.contains(EX_COMPRESSED) {
true => decode_headers(&decompress_with_zstd(&response_item.headers)?)?,
false => decode_headers(&response_item.headers)?,
// Decrypt headers if required
let encryption_key = self.config.encryption_key.as_ref();
raw_headers = match (flags.contains(ENCRYPTED), encryption_key) {
(true, Some(key)) => {
aes256_decrypt(&raw_headers, key).context("failed to decrypt headers")?
}
(true, None) => return Err(anyhow!("response is encrypted")),
(false, _) => raw_headers,
};
// Decompress headers if required
if flags.contains(HEADERS_COMPRESSED) || flags.contains(EX_COMPRESSED) {
raw_headers =
decompress_with_zstd(&raw_headers).context("failed to decompress headers")?;
}

// Decode them
let headers = decode_headers(&raw_headers).context("failed to decode headers")?;

// If we have only one chunk, decode it in-place
if response_item.num_chunks == 1 {
let body = match flags.contains(BODY_COMPRESSED) || flags.contains(EX_COMPRESSED) {
true => decompress_with_zstd(&response_item.body)?,
false => response_item.body,
};
let mut body = response_item.body;
// Decrypt body
if flags.contains(ENCRYPTED) {
body = aes256_decrypt(&body, encryption_key.unwrap())?;
}
// Decompress body
if flags.contains(BODY_COMPRESSED) || flags.contains(EX_COMPRESSED) {
body = decompress_with_zstd(&body)?;
}

// Construct a new Response object
let mut resp = Response::with_body(status, Body::Bytes(body));
Expand All @@ -276,21 +298,42 @@ impl RedisBackend {
Some(data) => Ok(Bytes::from(data)),
None => Err(io::Error::new(
io::ErrorKind::NotFound,
format!("Cannot find chunk {}/{}", i + 2, num_chunks),
format!("cannot find chunk {}/{}", i + 2, num_chunks),
)),
}
});
let body_stream = stream::iter(vec![Ok(response_item.body)]).chain(chunks_stream);

// Decompress the body if required
// Decrypt and/or decompress the body if required
let body_size = response_item.body_length as u64;
let body = if flags.contains(BODY_COMPRESSED) || flags.contains(EX_COMPRESSED) {
let body_stream =
ZstdDecoder::new(body_stream).map_err(|err| Box::new(err) as Box<dyn StdError>);
Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream))))
} else {
let body_stream = body_stream.map_err(|err| Box::new(err) as Box<dyn StdError>);
Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream))))
let body = match (
flags.contains(ENCRYPTED),
flags.contains(BODY_COMPRESSED) || flags.contains(EX_COMPRESSED),
) {
(true, true) => {
// Decrypt and decompress
let body_stream = AESDecoder::new(body_stream, encryption_key.unwrap().clone());
let body_stream =
ZstdDecoder::new(body_stream).map_err(|err| Box::new(err) as Box<dyn StdError>);
Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream))))
}
(true, false) => {
// Decrypt only
let body_stream = AESDecoder::new(body_stream, encryption_key.unwrap().clone())
.map_err(|err| Box::new(err) as Box<dyn StdError>);
Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream))))
}
(false, true) => {
// Decompress only
let body_stream =
ZstdDecoder::new(body_stream).map_err(|err| Box::new(err) as Box<dyn StdError>);
Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream))))
}
(false, false) => {
// Do nothing
let body_stream = body_stream.map_err(|err| Box::new(err) as Box<dyn StdError>);
Body::Message(Box::new(SizedStream::new(body_size, Box::pin(body_stream))))
}
};

// Construct a new Response object
Expand Down Expand Up @@ -364,6 +407,16 @@ impl RedisBackend {
}
}

// If encryption is enabled, encrypt the body and headers and update flags
if let (true, Some(key)) = (item.encrypt, &self.config.encryption_key) {
(headers, body) = try_join(
aes256_encrypt(headers, key.clone()),
aes256_encrypt(body, key.clone()),
)
.await?;
flags.insert(ENCRYPTED);
}

// Split body to chunks and save chunks first
let max_chunk_size = self.config.max_body_chunk_size;
let mut num_chunks = 1;
Expand Down Expand Up @@ -616,34 +669,39 @@ mod tests {
}

#[ntex::test]
async fn test_chunked_compressed_body() {
// Same as the above test, but with compression enabled
async fn test_compression() {
let mut config = Config::default();
config.max_body_chunk_size = 2; // Set max chunk size to 2 bytes
config.compression_level = Some(0);
config.compression_level = Some(22);
let backend = RedisBackend::new(config, None).unwrap();
backend.connect().await.unwrap();

let key = make_uniq_key();

// Cache response
let resp = make_response("hello, world");

let mut resp = make_response("hello, world"); // body is too small to be compressed
resp.headers_mut().insert(
HeaderName::from_static("hello-world-header"),
HeaderValue::from_static("Hello world header data"),
);
backend
.store_response(Item::new(key.clone(), resp, Duration::from_secs(3)))
.await
.unwrap();

// Fetch it back
let mut resp = backend.get_response(key.clone()).await.unwrap().unwrap();
assert_eq!(
resp.headers().get("Hello-World-Header").unwrap(),
"Hello world header data"
);
let body = buffer_body(resp.take_body()).await.unwrap().to_vec();
assert_eq!(String::from_utf8(body).unwrap(), "hello, world");
}

#[ntex::test]
async fn test_compressed_headers() {
async fn test_encryption() {
let mut config = Config::default();
config.compression_level = Some(22);
config.encryption_key = Some(Bytes::from_static(&[16; 32]));
let backend = RedisBackend::new(config, None).unwrap();
backend.connect().await.unwrap();

Expand All @@ -652,21 +710,46 @@ mod tests {
// Cache response
let mut resp = make_response("hello, world");
resp.headers_mut().insert(
HeaderName::from_static("hello-world-header"),
HeaderValue::from_static("Hello world header data"),
HeaderName::from_static("x-header"),
HeaderValue::from_static("value"),
);

backend
.store_response(Item::new(key.clone(), resp, Duration::from_secs(3)))
.await
.unwrap();
let mut item = Item::new(key.clone(), resp, Duration::from_secs(3));
item.encrypt = true;
backend.store_response(item).await.unwrap();

// Fetch it back
let resp = backend.get_response(key.clone()).await.unwrap().unwrap();
assert_eq!(
resp.headers().get("Hello-World-Header").unwrap(),
"Hello world header data"
let mut resp = backend.get_response(key.clone()).await.unwrap().unwrap();
assert_eq!(resp.headers().get("X-Header").unwrap(), "value".as_bytes());
let body = buffer_body(resp.take_body()).await.unwrap().to_vec();
assert_eq!(String::from_utf8(body).unwrap(), "hello, world");
}

#[ntex::test]
async fn test_chunked_compression_encryption() {
let mut config = Config::default();
config.max_body_chunk_size = 2; // Set max chunk size to 2 bytes
config.compression_level = Some(0); // Use zstd default compression level
config.encryption_key = Some(Bytes::from_static(&[16; 32]));
let backend = RedisBackend::new(config, None).unwrap();
backend.connect().await.unwrap();

let key = make_uniq_key();

// Cache response
let mut resp = make_response("hello, world!".repeat(10));
resp.headers_mut().insert(
HeaderName::from_static("x-header"),
HeaderValue::from_static("value"),
);
let mut item = Item::new(key.clone(), resp, Duration::from_secs(3));
item.encrypt = true;
backend.store_response(item).await.unwrap();

// Fetch it back
let mut resp = backend.get_response(key.clone()).await.unwrap().unwrap();
assert_eq!(resp.headers().get("X-Header").unwrap(), "value".as_bytes());
let body = buffer_body(resp.take_body()).await.unwrap().to_vec();
assert_eq!(String::from_utf8(body).unwrap(), "hello, world!".repeat(10));
}

#[ntex::test]
Expand Down
5 changes: 5 additions & 0 deletions casper-server/src/storage/backends/redis/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::time::Duration;

use anyhow::{anyhow, bail, Context, Result};
use fred::types::{ConnectionConfig, RedisConfig, TcpConfig, TlsConnector};
use ntex::util::Bytes;
use serde::Deserialize;

/// Redis backend configuration
Expand Down Expand Up @@ -35,6 +36,9 @@ pub struct Config {
pub internal_cache_size: usize,
#[serde(default = "Config::default_internal_cache_ttl")]
pub internal_cache_ttl: f64,

// Optional encryption key
pub encryption_key: Option<Bytes>,
}

#[derive(Clone, Debug, Deserialize)]
Expand Down Expand Up @@ -129,6 +133,7 @@ impl Default for Config {
lazy: false,
internal_cache_size: Config::default_internal_cache_size(),
internal_cache_ttl: Config::default_internal_cache_ttl(),
encryption_key: None,
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions casper-server/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct Item<'a> {
pub body: Bytes,
pub surrogate_keys: Vec<Key>,
pub ttl: Duration,
pub encrypt: bool,
}

impl Item<'static> {
Expand All @@ -36,6 +37,7 @@ impl Item<'static> {
body: body.as_ref().unwrap().clone(),
surrogate_keys: Vec::new(),
ttl,
encrypt: false,
}
}

Expand All @@ -54,6 +56,7 @@ impl Item<'static> {
body: body.as_ref().unwrap().clone(),
surrogate_keys: surrogate_keys.into_iter().map(|sk| sk.into()).collect(),
ttl,
encrypt: false,
}
}
}
Expand Down
Loading

0 comments on commit 1b4ed63

Please sign in to comment.