From 7b5747ee7ddb2abf8b6ed00914c7fe24c10ebb87 Mon Sep 17 00:00:00 2001 From: Jiyuan Zheng Date: Mon, 22 Apr 2024 14:30:36 +0800 Subject: [PATCH] Improve env handling via preprocessing templated config files (#162) --- Cargo.lock | 5 ++- Cargo.toml | 1 + README.md | 9 +++-- src/config/mod.rs | 99 +++++++++++++++++++++++++++-------------------- src/main.rs | 7 +--- 5 files changed, 69 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3c9bf09..57d940f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2597,9 +2597,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.3" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", @@ -3090,6 +3090,7 @@ dependencies = [ "opentelemetry_sdk", "pprof", "rand 0.8.5", + "regex", "serde", "serde_json", "serde_yaml", diff --git a/Cargo.toml b/Cargo.toml index 12f717f..ca0e77f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ opentelemetry-jaeger = { version = "0.20.0", features = ["rt-tokio"] } opentelemetry_sdk = { version = "0.21.1", features = ["rt-tokio", "trace"] } rand = "0.8.5" +regex = "1.10.4" serde = "1.0.152" serde_json = "1.0.92" serde_yaml = "0.9.17" diff --git a/README.md b/README.md index bbeaffc..aefb7cf 100644 --- a/README.md +++ b/README.md @@ -20,12 +20,12 @@ Run with `RUSTFLAGS="--cfg tokio_unstable"` to enable [tokio-console](https://gi - `RUST_LOG` - Log level. Default: `info`. -- `PORT` - - Override port configuration in config file. - `LOG_FORMAT` - Log format. Default: `full`. - Options: `full`, `pretty`, `json`, `compact` +In addition, you can refer env variables in `config.yml` by using `${SOME_ENV}` + ## Features Subway is build with middleware pattern. @@ -62,7 +62,7 @@ Subway is build with middleware pattern. - TODO: Limit batch size, request size and response size. - TODO: Metrics - Getting insights of the RPC calls and server performance. - + ## Benchmarks To run all benchmarks: @@ -82,14 +82,17 @@ cargo bench --bench bench ws_round_trip This middleware will intercept all method request/responses and compare the result directly with healthy endpoint responses. This is useful for debugging to make sure the returned values are as expected. You can enable validate middleware on your config file. + ```yml middlewares: methods: - validate ``` + NOTE: Keep in mind that if you place `validate` middleware before `inject_params` you may get false positive errors because the request will not be the same. Ignored methods can be defined in extension config: + ```yml extensions: validator: diff --git a/src/config/mod.rs b/src/config/mod.rs index 9c4d9ad..41fd6fa 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,3 +1,6 @@ +use anyhow::{bail, Context}; +use regex::{Captures, Regex}; +use std::env; use std::fs; use clap::Parser; @@ -147,43 +150,18 @@ impl From for Config { } // read config file specified in command line -pub fn read_config() -> Result { +pub fn read_config() -> Result { let cmd = Command::parse(); - let config = fs::File::open(cmd.config).map_err(|e| format!("Unable to open config file: {e}"))?; - let config: ParseConfig = - serde_yaml::from_reader(&config).map_err(|e| format!("Unable to parse config file: {e}"))?; - let mut config: Config = config.into(); - - if let Ok(endpoints) = std::env::var("ENDPOINTS") { - tracing::debug!("Override endpoints with env.ENDPOINTS"); - let endpoints = endpoints - .split(',') - .map(|x| x.trim().to_string()) - .collect::>(); - - config - .extensions - .client - .as_mut() - .expect("Client extension not configured") - .endpoints = endpoints; - } + let templated_config_str = + fs::read_to_string(&cmd.config).with_context(|| format!("Unable to read config file: {}", cmd.config))?; - if let Ok(env_port) = std::env::var("PORT") { - tracing::debug!("Override port with env.PORT"); - let port = env_port.parse::(); - if let Ok(port) = port { - config - .extensions - .server - .as_mut() - .expect("Server extension not configured") - .port = port; - } else { - return Err(format!("Invalid port: {}", env_port)); - } - } + let config_str = render_template(&templated_config_str) + .with_context(|| format!("Unable to preprocess config file: {}", cmd.config))?; + + let config: ParseConfig = + serde_yaml::from_str(&config_str).with_context(|| format!("Unable to parse config file: {}", cmd.config))?; + let config: Config = config.into(); // TODO: shouldn't need to do this here. Creating a server should validates everything validate_config(&config)?; @@ -191,19 +169,42 @@ pub fn read_config() -> Result { Ok(config) } -fn validate_config(config: &Config) -> Result<(), String> { +fn render_template(templated_config_str: &str) -> Result { + // match pattern: ${SOME_VAR} + let re = Regex::new(r"\$\{([^\}]+)\}").unwrap(); + + let mut config_str = String::with_capacity(templated_config_str.len()); + let mut last_match = 0; + // replace pattern: with env variables + let replacement = |caps: &Captures| -> Result { env::var(&caps[1]) }; + + // replace every matches with early return + // when encountering error + for caps in re.captures_iter(templated_config_str) { + let m = caps.get(0).expect("Matched pattern should have at least one capture"); + config_str.push_str(&templated_config_str[last_match..m.start()]); + config_str.push_str( + &replacement(&caps).with_context(|| format!("Unable to replace environment variable {}", &caps[1]))?, + ); + last_match = m.end(); + } + config_str.push_str(&templated_config_str[last_match..]); + Ok(config_str) +} + +fn validate_config(config: &Config) -> Result<(), anyhow::Error> { // TODO: validate logic should be in each individual extensions // validate endpoints for endpoint in &config.extensions.client.as_ref().unwrap().endpoints { if endpoint.parse::().is_err() { - return Err(format!("Invalid endpoint {}", endpoint)); + bail!("Invalid endpoint {}", endpoint); } } // ensure each method has only one param with inject=true for method in &config.rpcs.methods { if method.params.iter().filter(|x| x.inject).count() > 1 { - return Err(format!("Method {} has more than one inject param", method.method)); + bail!("Method {} has more than one inject param", method.method); } } @@ -214,13 +215,29 @@ fn validate_config(config: &Config) -> Result<(), String> { if param.optional { has_optional = true; } else if has_optional { - return Err(format!( - "Method {} has required param after optional param", - method.method - )); + bail!("Method {} has required param after optional param", method.method); } } } Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn render_template_basically_works() { + env::set_var("KEY", "value"); + env::set_var("ANOTHER_KEY", "another_value"); + let templated_config_str = "${KEY} ${ANOTHER_KEY}"; + let config_str = render_template(templated_config_str).unwrap(); + assert_eq!(config_str, "value another_value"); + + env::remove_var("KEY"); + let config_str = render_template(templated_config_str); + assert!(config_str.is_err()); + env::remove_var("ANOTHER_KEY"); + } +} diff --git a/src/main.rs b/src/main.rs index aab4d5b..26da8b4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,7 @@ #[tokio::main] async fn main() -> anyhow::Result<()> { // read config from file - let config = match subway::config::read_config() { - Ok(config) => config, - Err(e) => { - return Err(anyhow::anyhow!(e)); - } - }; + let config = subway::config::read_config()?; subway::logger::enable_logger(); tracing::trace!("{:#?}", config);