From 7624053a50e1683c48757ee5dacbb454921e39ef Mon Sep 17 00:00:00 2001 From: Daniel Bloom <82895745+Daniel-Bloom-dfinity@users.noreply.github.com> Date: Wed, 10 Aug 2022 10:30:27 -0700 Subject: [PATCH] refactor: from `slog`+`hyper` to `tracing`+`axum` (#45) This is a massive refactor which will enable us to collect better logs. - Moves code out of `main.rs` - agent proxy code is now in `proxy/agent.rs` - forwarding proxy code is now in `proxy/forward.rs` - temporary `src/http_transport.rs` pending on dfinity/agent-rs#373 - Adds a new log format switch with `json` support. The exact format may need some fine tuning - Completely gets rid of `reqwest` --- Cargo.lock | 411 ++++--------- Cargo.toml | 36 +- src/canister_id.rs | 72 ++- src/config/dns_canister_config.rs | 6 +- src/headers.rs | 41 +- src/http_client.rs | 287 ++++++++++ src/http_transport.rs | 312 ++++++++++ src/logging.rs | 196 ++++--- src/main.rs | 922 ++---------------------------- src/metrics.rs | 91 ++- src/proxy/agent.rs | 366 ++++++++++++ src/proxy/forward.rs | 128 +++++ src/proxy/mod.rs | 185 ++++++ src/validate.rs | 42 +- 14 files changed, 1786 insertions(+), 1309 deletions(-) create mode 100644 src/http_client.rs create mode 100644 src/http_transport.rs create mode 100644 src/proxy/agent.rs create mode 100644 src/proxy/forward.rs create mode 100644 src/proxy/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 26f349b..da2ff75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anyhow" version = "1.0.56" @@ -68,9 +77,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "axum" -version = "0.5.3" +version = "0.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f523b4e98ba6897ae90994bc18423d9877c54f9047b06a00ddc8122a957b1c70" +checksum = "6b9496f0c1d1afb7a2af4338bbe1d969cddfead41d87a9fb3aaa6d0bbc7af648" dependencies = [ "async-trait", "axum-core", @@ -99,9 +108,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.2.2" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3ddbd16eabff8b45f21b98671fddcc93daaa7ac4c84f8473693437226040de5" +checksum = "e4f44a0e6200e9d11a1cdc989e4b358f6e3d354fbf48478f345a17f4e43f8635" dependencies = [ "async-trait", "bytes", @@ -506,22 +515,13 @@ dependencies = [ [[package]] name = "encoding_rs" -version = "0.8.30" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dc8abb250ffdda33912550faa54c88ec8b998dec0b2c55ab224921ce11df" +checksum = "9852635589dc9f9ea1b6fe9f05b50ef208c85c834a562f0c6abb1c475736ec2b" dependencies = [ "cfg-if", ] -[[package]] -name = "fastrand" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3fcf0cee53519c866c09b5de1f6c56ff9d647101f81c1964fa632e148896cdf" -dependencies = [ - "instant", -] - [[package]] name = "ff" version = "0.12.0" @@ -556,21 +556,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "form_urlencoded" version = "1.0.1" @@ -691,13 +676,13 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.5" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d39cd93900197114fa1fcb7ae84ca742095eed9442088988ae74fa744e930e77" +checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" dependencies = [ "cfg-if", "libc", - "wasi 0.10.2+wasi-snapshot-preview1", + "wasi", ] [[package]] @@ -713,9 +698,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.12" +version = "0.3.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62eeb471aa3e3c9197aa4bfeabfe02982f6dc96f750486c0bb0009ac58b26d2b" +checksum = "37a82c6d637fc9515a4694bbf1cb2457b79d81ce52b3108bdeea58b07dd34a57" dependencies = [ "bytes", "fnv", @@ -726,7 +711,7 @@ dependencies = [ "indexmap", "slab", "tokio", - "tokio-util 0.6.9", + "tokio-util", "tracing", ] @@ -774,9 +759,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.6" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31f4c6746584866f0feabcc69893c5b51beef3831656a968ed7ae254cdc4fd03" +checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399" dependencies = [ "bytes", "fnv", @@ -852,19 +837,6 @@ dependencies = [ "webpki-roots", ] -[[package]] -name = "hyper-tls" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" -dependencies = [ - "bytes", - "hyper", - "native-tls", - "tokio", - "tokio-native-tls", -] - [[package]] name = "ic-agent" version = "0.20.0" @@ -940,7 +912,7 @@ dependencies = [ [[package]] name = "icx-proxy" -version = "0.9.0" +version = "0.10.0" dependencies = [ "anyhow", "axum", @@ -948,31 +920,31 @@ dependencies = [ "candid", "clap", "flate2", + "form_urlencoded", "futures", "garcon", "hex", "http-body", "hyper", "hyper-rustls", - "hyper-tls", "ic-agent", "ic-utils", + "itertools", "lazy-regex", "opentelemetry", "opentelemetry-prometheus", "prometheus", - "reqwest", "rustls", - "rustls-pemfile 1.0.0", + "rustls-pemfile", "serde", "serde_cbor", "serde_json", "sha2", - "slog", - "slog-async", - "slog-term", "tokio", - "url", + "tower", + "tower-http", + "tracing", + "tracing-subscriber", "webpki-roots", ] @@ -997,20 +969,11 @@ dependencies = [ "hashbrown", ] -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - [[package]] name = "ipnet" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e70ee094dc02fd9c13fdad4940090f22dbd6ac7c9e7094a46cf0232a50bc7c" +checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b" [[package]] name = "itertools" @@ -1207,7 +1170,7 @@ dependencies = [ "log", "miow", "ntapi", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "winapi", ] @@ -1220,24 +1183,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "native-tls" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48ba9f7719b5a0f42f338907614285fb5fd70e53858141f69898a1fb7203b24d" -dependencies = [ - "lazy_static", - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "new_debug_unreachable" version = "1.0.4" @@ -1329,39 +1274,12 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "87f3e037eac156d1775da914196f0f37741a274155e34a0b7e427c35d2a2ecb9" -[[package]] -name = "openssl" -version = "0.10.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7ae222234c30df141154f159066c5093ff73b63204dcda7121eb082fc56a95" -dependencies = [ - "bitflags", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-sys", -] - [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" -[[package]] -name = "openssl-sys" -version = "0.9.72" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e46109c383602735fa0a2e48dd2b7c892b048e1bf69e5c3b1d804b7d9c203cb" -dependencies = [ - "autocfg", - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "opentelemetry" version = "0.17.0" @@ -1403,17 +1321,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.5", -] - [[package]] name = "parking_lot" version = "0.12.0" @@ -1421,21 +1328,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58" dependencies = [ "lock_api", - "parking_lot_core 0.9.1", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall", - "smallvec", - "winapi", + "parking_lot_core", ] [[package]] @@ -1459,9 +1352,9 @@ checksum = "0744126afe1a6dd7f394cb50a716dbe086cb06e255e53d8d0185d82828358fb5" [[package]] name = "pem" -version = "1.0.2" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9a3b09a20e374558580a4914d3b7d89bd61b954a5a5e1dcbea98753addb1947" +checksum = "03c64931a1a212348ec4f3b4362585eca7159d0d09cbdf4a7f74f02173596fd4" dependencies = [ "base64", ] @@ -1548,12 +1441,6 @@ dependencies = [ "spki", ] -[[package]] -name = "pkg-config" -version = "0.3.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58893f751c9b0412871a09abd62ecd2a00298c6c83befa223ef98c52aef40cbe" - [[package]] name = "ppv-lite86" version = "0.2.16" @@ -1621,15 +1508,15 @@ dependencies = [ [[package]] name = "prometheus" -version = "0.13.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7f64969ffd5dd8f39bd57a68ac53c163a095ed9d0fb707146da1b27025a3504" +checksum = "cface98dfa6d645ea4c789839f176e4b072265d085bfcc48eaa8d137f58d3c39" dependencies = [ "cfg-if", "fnv", "lazy_static", "memchr", - "parking_lot 0.11.2", + "parking_lot", "protobuf", "thiserror", ] @@ -1724,20 +1611,11 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" -[[package]] -name = "remove_dir_all" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7" -dependencies = [ - "winapi", -] - [[package]] name = "reqwest" -version = "0.11.10" +version = "0.11.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46a1f7aa4f35e5e8b4160449f51afc758f0ce6454315a9fa7d0d113e958c41eb" +checksum = "b75aa69a3f06bbcc66ede33af2af253c6f7a86b1ca0033f60c580a27074fbf92" dependencies = [ "base64", "bytes", @@ -1749,24 +1627,22 @@ dependencies = [ "http-body", "hyper", "hyper-rustls", - "hyper-tls", "ipnet", "js-sys", "lazy_static", "log", "mime", - "native-tls", "percent-encoding", "pin-project-lite", "rustls", - "rustls-pemfile 0.3.0", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", "tokio", - "tokio-native-tls", "tokio-rustls", - "tokio-util 0.6.9", + "tokio-util", + "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", @@ -1815,34 +1691,16 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.6.1" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca9ebdfa27d3fc180e42879037b5338ab1c040c06affd00d8338598e7800943" +checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" dependencies = [ "openssl-probe", - "rustls-pemfile 0.2.1", + "rustls-pemfile", "schannel", "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eebeaeb360c87bfb72e84abdb3447159c0eaececf1bef2aecd65a8be949d1c9" -dependencies = [ - "base64", -] - -[[package]] -name = "rustls-pemfile" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ee86d63972a7c661d1536fefe8c3c8407321c3df668891286de28abcd087360" -dependencies = [ - "base64", -] - [[package]] name = "rustls-pemfile" version = "1.0.0" @@ -2006,6 +1864,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +dependencies = [ + "lazy_static", +] + [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -2049,37 +1916,6 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5" -[[package]] -name = "slog" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8347046d4ebd943127157b94d63abb990fcf729dc4e9978927fdf4ac3c998d06" - -[[package]] -name = "slog-async" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "766c59b252e62a34651412870ff55d8c4e6d04df19b43eecb2703e417b097ffe" -dependencies = [ - "crossbeam-channel", - "slog", - "take_mut", - "thread_local", -] - -[[package]] -name = "slog-term" -version = "2.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87d29185c55b7b258b4f120eab00f48557d4d9bc814f41713f449d35b0f8977c" -dependencies = [ - "atty", - "slog", - "term", - "thread_local", - "time", -] - [[package]] name = "smallvec" version = "1.8.0" @@ -2114,13 +1950,13 @@ dependencies = [ [[package]] name = "string_cache" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33994d0838dc2d152d17a62adf608a869b5e846b65b389af7f3dbc1de45c5b26" +checksum = "213494b7a2b503146286049378ce02b482200519accc31872ee8be91fa820a08" dependencies = [ - "lazy_static", "new_debug_unreachable", - "parking_lot 0.11.2", + "once_cell", + "parking_lot", "phf_shared", "precomputed-hash", ] @@ -2173,26 +2009,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" -[[package]] -name = "take_mut" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f764005d11ee5f36500a149ace24e00e3da98b0158b3e2d53a7495660d3f4d60" - -[[package]] -name = "tempfile" -version = "3.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4" -dependencies = [ - "cfg-if", - "fastrand", - "libc", - "redox_syscall", - "remove_dir_all", - "winapi", -] - [[package]] name = "term" version = "0.7.0" @@ -2293,17 +2109,18 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.17.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2af73ac49756f3f7c01172e34a23e5d0216f6c32333757c2c61feb2bbff5a5ee" +checksum = "7a8325f63a7d4774dd041e363b2409ed1c5cbbd0f867795e661df066b2b0a581" dependencies = [ + "autocfg", "bytes", "libc", "memchr", "mio", "num_cpus", "once_cell", - "parking_lot 0.12.0", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", @@ -2322,16 +2139,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tokio-native-tls" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d995660bd2b7f8c1568414c1126076c13fbb725c40112dc0120b78eb9b717b" -dependencies = [ - "native-tls", - "tokio", -] - [[package]] name = "tokio-rustls" version = "0.23.2" @@ -2345,29 +2152,16 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e99e1983e5d376cd8eb4b66604d2e99e79f5bd988c3055891dcd8c9e2604cc0" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "log", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "tokio-util" -version = "0.7.1" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0edfdeb067411dba2044da6d1cb2df793dd35add7888d73c16e3381ded401764" +checksum = "cc463cd8deddc3770d20f9852143d50bf6094e640b485cb2e189a2099085ff45" dependencies = [ "bytes", "futures-core", "futures-sink", "pin-project-lite", "tokio", + "tracing", ] [[package]] @@ -2390,7 +2184,6 @@ dependencies = [ "pin-project", "pin-project-lite", "tokio", - "tokio-util 0.7.1", "tower-layer", "tower-service", "tracing", @@ -2398,9 +2191,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.2.5" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aba3f3efabf7fb41fae8534fc20a817013dd1c12cb45441efb6c82e6556b4cd8" +checksum = "3c530c8675c1dbf98facee631536fa116b5fb6382d7dd6dc1b118d970eafe3ba" dependencies = [ "bitflags", "bytes", @@ -2413,6 +2206,7 @@ dependencies = [ "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -2436,9 +2230,21 @@ dependencies = [ "cfg-if", "log", "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.23" @@ -2446,6 +2252,45 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa31669fa42c09c34d94d8165dd2012e8ff3c66aca50f3bb226b68f216f2706c" dependencies = [ "lazy_static", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +dependencies = [ + "lazy_static", + "log", + "tracing-core", +] + +[[package]] +name = "tracing-serde" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1" +dependencies = [ + "serde", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bc28f93baff38037f64e6f43d34cfa1605f27a49c34e8a04c5e78b0babf2596" +dependencies = [ + "ansi_term", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", + "tracing-serde", ] [[package]] @@ -2518,10 +2363,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ae116fef2b7fea257ed6440d3cfcff7f190865f170cdad00bb6465bf18ecba" [[package]] -name = "vcpkg" -version = "0.2.15" +name = "valuable" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" [[package]] name = "version_check" @@ -2539,12 +2384,6 @@ dependencies = [ "try-lock", ] -[[package]] -name = "wasi" -version = "0.10.2+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 356b3e7..f13e357 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "icx-proxy" -version = "0.9.0" +version = "0.10.0" authors = ["DFINITY Stiftung "] edition = "2018" description = "CLI tool to create an HTTP proxy to the Internet Computer." @@ -18,36 +18,36 @@ path = "src/main.rs" [dependencies] anyhow = "1" -axum = "0.5.3" +axum = "0.5" base64 = "0.13" candid = { version = "0.7", features = ["mute_warnings"] } clap = { version = "3", features = ["cargo", "derive"] } -flate2 = "1.0.0" -futures = "0.3.21" +flate2 = "1" +form_urlencoded = "1" +futures = "0.3" garcon = { version = "0.2", features = ["async"] } hex = "0.4" -http-body = "0.4.5" -hyper = { version = "0.14", features = ["full"] } -hyper-rustls = { version = "0.23", features = [ "webpki-roots" ] } -hyper-tls = "0.5" -ic-agent = "0.20" +http-body = "0.4" +hyper = { version = "0.14.11", features = ["client", "http2", "http1"] } +hyper-rustls = { version = "0.23", features = [ "webpki-roots", "http2" ] } +itertools = "0.10" +ic-agent = { version = "0.20", default-features = false } ic-utils = { version = "0.20", features = ["raw"] } lazy-regex = "2" -opentelemetry = "0.17.0" -opentelemetry-prometheus = "0.10.0" -prometheus = "0.13.0" -reqwest = { version = "0.11", features = ["rustls-tls-webpki-roots"] } -rustls = "0.20" +opentelemetry = "0.17" +opentelemetry-prometheus = "0.10" +prometheus = "0.13" +rustls = { version = "0.20", features = ["dangerous_configuration"] } rustls-pemfile = "1" +tower = "0.4" +tower-http = { version = "0.3", features = ["trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["json"]} serde = "1" serde_cbor = "0.11" serde_json = "1" sha2 = "0.10" -slog = { version = "2", features = ["max_level_trace"] } -slog-async = "2" -slog-term = "2" tokio = { version = "1", features = ["full"] } -url = "2" webpki-roots = "0.22" [features] diff --git a/src/canister_id.rs b/src/canister_id.rs index 9a86c7d..b3e4cf9 100644 --- a/src/canister_id.rs +++ b/src/canister_id.rs @@ -1,8 +1,29 @@ -use hyper::{header::HOST, Request, Uri}; +use crate::http_transport::hyper::{header::HOST, http::request::Parts, Uri}; +use anyhow::Context; +use clap::Args; use ic_agent::export::Principal; +use tracing::error; use crate::config::dns_canister_config::DnsCanisterConfig; +/// The options for the canister resolver +#[derive(Args)] +pub struct Opts { + /// A map of domain names to canister IDs. + /// Format: domain.name:canister-id + #[clap(long)] + dns_alias: Vec, + + /// A list of domain name suffixes. If found, the next (to the left) subdomain + /// is used as the Principal, if it parses as a Principal. + #[clap(long, default_value = "localhost")] + dns_suffix: Vec, + + /// Whether or not to ignore `canisterId=` when locating the canister. + #[clap(long)] + ignore_url_canister_param: bool, +} + /// A resolver for `Principal`s from a `Uri`. trait UriResolver: Sync + Send { fn resolve(&self, uri: &Uri) -> Option; @@ -13,12 +34,11 @@ impl UriResolver for &T { T::resolve(self, uri) } } - struct UriParameterResolver; impl UriResolver for UriParameterResolver { fn resolve(&self, uri: &Uri) -> Option { - url::form_urlencoded::parse(uri.query()?.as_bytes()) + form_urlencoded::parse(uri.query()?.as_bytes()) .find(|(name, _)| name == "canisterId") .and_then(|(_, canister_id)| Principal::from_text(canister_id.as_ref()).ok()) } @@ -31,31 +51,31 @@ impl UriResolver for DnsCanisterConfig { } /// A resolver for `Principal`s from a `Request`. -pub trait Resolver: Sync + Send { - fn resolve(&self, request: &Request) -> Option; +pub trait Resolver: Sync + Send { + fn resolve(&self, request: &Parts) -> Option; } -impl> Resolver for &T { - fn resolve(&self, request: &Request) -> Option { +impl Resolver for &T { + fn resolve(&self, request: &Parts) -> Option { T::resolve(self, request) } } struct RequestUriResolver(pub T); -impl Resolver for RequestUriResolver { - fn resolve(&self, request: &Request) -> Option { - self.0.resolve(request.uri()) +impl Resolver for RequestUriResolver { + fn resolve(&self, request: &Parts) -> Option { + self.0.resolve(&request.uri) } } struct RequestHostResolver(pub T); -impl Resolver for RequestHostResolver { - fn resolve(&self, request: &Request) -> Option { +impl Resolver for RequestHostResolver { + fn resolve(&self, request: &Parts) -> Option { self.0.resolve( &Uri::builder() - .authority(request.headers().get(HOST)?.as_bytes()) + .authority(request.headers.get(HOST)?.as_bytes()) .build() .ok()?, ) @@ -68,8 +88,8 @@ pub struct DefaultResolver { pub check_params: bool, } -impl Resolver for DefaultResolver { - fn resolve(&self, request: &Request) -> Option { +impl Resolver for DefaultResolver { + fn resolve(&self, request: &Parts) -> Option { if let Some(v) = RequestHostResolver(&self.dns).resolve(request) { return Some(v); } @@ -85,13 +105,29 @@ impl Resolver for DefaultResolver { } } +pub fn setup(opts: Opts) -> Result { + let dns = DnsCanisterConfig::new(&opts.dns_alias, &opts.dns_suffix) + .context("Failed to configure canister resolver DNS"); + let dns = match dns { + Err(e) => { + error!("{e}"); + Err(e) + } + Ok(v) => Ok(v), + }?; + Ok(DefaultResolver { + dns, + check_params: !opts.ignore_url_canister_param, + }) +} + #[cfg(test)] mod tests { - use hyper::{header::HOST, Request}; use ic_agent::export::Principal; use super::{DefaultResolver, Resolver}; use crate::config::dns_canister_config::DnsCanisterConfig; + use crate::http_transport::hyper::{header::HOST, http::request::Parts, Request}; #[test] fn simple_resolve() { @@ -229,7 +265,7 @@ mod tests { DnsCanisterConfig::new(&aliases, &suffixes).unwrap() } - fn build_req(host: Option<&str>, uri: &str) -> Request<()> { + fn build_req(host: Option<&str>, uri: &str) -> Parts { let req = Request::builder().uri(uri); if let Some(host) = host { req.header(HOST, host) @@ -238,6 +274,8 @@ mod tests { } .body(()) .unwrap() + .into_parts() + .0 } fn principal(v: &str) -> Principal { diff --git a/src/config/dns_canister_config.rs b/src/config/dns_canister_config.rs index cb5fce5..5a08501 100644 --- a/src/config/dns_canister_config.rs +++ b/src/config/dns_canister_config.rs @@ -1,7 +1,9 @@ -use crate::config::dns_canister_rule::DnsCanisterRule; -use ic_agent::ic_types::Principal; use std::cmp::Reverse; +use ic_agent::ic_types::Principal; + +use crate::config::dns_canister_rule::DnsCanisterRule; + /// Configuration for determination of Domain Name to Principal #[derive(Clone, Debug)] pub struct DnsCanisterConfig { diff --git a/src/headers.rs b/src/headers.rs index 91071e7..c8297f9 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -1,5 +1,6 @@ use ic_utils::interfaces::http_request::HeaderField; use lazy_regex::regex_captures; +use tracing::{trace, warn}; const MAX_LOG_CERT_NAME_SIZE: usize = 100; const MAX_LOG_CERT_B64_SIZE: usize = 2000; @@ -11,7 +12,7 @@ pub struct HeadersData { pub encoding: Option, } -pub fn extract_headers_data(headers: &[HeaderField], logger: &slog::Logger) -> HeadersData { +pub fn extract_headers_data(headers: &[HeaderField]) -> HeadersData { let mut headers_data = HeadersData { certificate: None, tree: None, @@ -22,34 +23,27 @@ pub fn extract_headers_data(headers: &[HeaderField], logger: &slog::Logger) -> H if name.eq_ignore_ascii_case("Ic-Certificate") { for field in value.split(',') { if let Some((_, name, b64_value)) = regex_captures!("^(.*)=:(.*):$", field.trim()) { - slog::trace!( - logger, + trace!( ">> certificate {:.l1$}: {:.l2$}", name, b64_value, l1 = MAX_LOG_CERT_NAME_SIZE, l2 = MAX_LOG_CERT_B64_SIZE ); - let bytes = decode_hash_tree(name, Some(b64_value.to_string()), logger); + let bytes = decode_hash_tree(name, Some(b64_value.to_string())); if name == "certificate" { headers_data.certificate = Some(match (headers_data.certificate, bytes) { (None, bytes) => bytes, (Some(Ok(certificate)), Ok(bytes)) => { - slog::warn!(logger, "duplicate certificate field: {:?}", bytes); + warn!("duplicate certificate field: {:?}", bytes); Ok(certificate) } (Some(Ok(certificate)), Err(_)) => { - slog::warn!( - logger, - "duplicate certificate field (failed to decode)" - ); + warn!("duplicate certificate field (failed to decode)"); Ok(certificate) } (Some(Err(_)), bytes) => { - slog::warn!( - logger, - "duplicate certificate field (failed to decode)" - ); + warn!("duplicate certificate field (failed to decode)"); bytes } }); @@ -57,15 +51,15 @@ pub fn extract_headers_data(headers: &[HeaderField], logger: &slog::Logger) -> H headers_data.tree = Some(match (headers_data.tree, bytes) { (None, bytes) => bytes, (Some(Ok(tree)), Ok(bytes)) => { - slog::warn!(logger, "duplicate tree field: {:?}", bytes); + warn!("duplicate tree field: {:?}", bytes); Ok(tree) } (Some(Ok(tree)), Err(_)) => { - slog::warn!(logger, "duplicate tree field (failed to decode)"); + warn!("duplicate tree field (failed to decode)"); Ok(tree) } (Some(Err(_)), bytes) => { - slog::warn!(logger, "duplicate tree field (failed to decode)"); + warn!("duplicate tree field (failed to decode)"); bytes } }); @@ -81,14 +75,10 @@ pub fn extract_headers_data(headers: &[HeaderField], logger: &slog::Logger) -> H headers_data } -fn decode_hash_tree( - name: &str, - value: Option, - logger: &slog::Logger, -) -> Result, ()> { +fn decode_hash_tree(name: &str, value: Option) -> Result, ()> { match value { Some(tree) => base64::decode(tree).map_err(|e| { - slog::warn!(logger, "Unable to decode {} from base64: {}", name, e); + warn!("Unable to decode {} from base64: {}", name, e); }), _ => Err(()), } @@ -97,16 +87,14 @@ fn decode_hash_tree( #[cfg(test)] mod tests { use ic_utils::interfaces::http_request::HeaderField; - use slog::o; use super::{extract_headers_data, HeadersData}; #[test] fn extract_headers_data_simple() { - let logger = slog::Logger::root(slog::Discard, o!()); let headers: Vec = vec![]; - let out = extract_headers_data(&headers, &logger); + let out = extract_headers_data(&headers); assert_eq!( out, @@ -120,10 +108,9 @@ mod tests { #[test] fn extract_headers_data_content_encoding() { - let logger = slog::Logger::root(slog::Discard, o!()); let headers: Vec = vec![HeaderField("Content-Encoding".into(), "test".into())]; - let out = extract_headers_data(&headers, &logger); + let out = extract_headers_data(&headers); assert_eq!( out, diff --git a/src/http_client.rs b/src/http_client.rs new file mode 100644 index 0000000..251a02a --- /dev/null +++ b/src/http_client.rs @@ -0,0 +1,287 @@ +use std::{ + borrow::Cow, + collections::HashMap, + fs::File, + hash::{Hash, Hasher}, + io::{Cursor, Read}, + iter, + net::SocketAddr, + path::PathBuf, + str::FromStr, + sync::Arc, +}; + +use anyhow::Context; +use clap::Args; +use hyper_rustls::HttpsConnectorBuilder; +use itertools::Either; +use tracing::error; + +use crate::http_transport::{ + self, + hyper::{ + self, + body::Bytes, + client::{ + connect::dns::{GaiResolver, Name}, + HttpConnector, + }, + service::Service, + Client, + }, +}; + +/// DNS resolve overrides +/// `ic0.app=[::1]:9090` +struct OptResolve { + domain: String, + addr: SocketAddr, +} + +impl FromStr for OptResolve { + type Err = anyhow::Error; + fn from_str(s: &str) -> Result { + let (domain, addr) = s + .split_once('=') + .ok_or_else(|| anyhow::Error::msg("missing '='"))?; + Ok(OptResolve { + domain: domain.into(), + addr: addr.parse()?, + }) + } +} + +/// The options for the HTTP client +#[derive(Args)] +pub struct Opts { + /// The list of custom root HTTPS certificates to use to talk to the replica. This can be used + /// to connect to an IC that has a self-signed certificate, for example. Do not use this when + /// talking to the Internet Computer blockchain mainnet as it is unsecure. + #[clap(long)] + ssl_root_certificate: Vec, + + /// Allows HTTPS connection to replicas with invalid HTTPS certificates. This can be used to + /// connect to an IC that has a self-signed certificate, for example. Do not use this when + /// talking to the Internet Computer blockchain mainnet as it is *VERY* unsecure. + #[clap(long)] + danger_accept_invalid_ssl: bool, + + /// Override DNS resolution for specific replica domains to particular IP addresses. + /// Examples: ic0.app=[::1]:9090 + #[clap(long, value_name("DOMAIN=IP_PORT"))] + replica_resolve: Vec, +} + +pub type Body = hyper::Body; + +pub trait HyperBody: + http_transport::HyperBody + + From<&'static [u8]> + + From<&'static str> + + From + + From> + + From> + + From + + From + + Into +{ +} + +impl HyperBody for B where + B: http_transport::HyperBody + + From<&'static [u8]> + + From<&'static str> + + From + + From> + + From> + + From + + From + + Into +{ +} + +/// Trait representing the contraints on [`Service`] that [`HyperReplicaV2Transport`] requires. +pub trait HyperService: + http_transport::HyperService +{ + /// Values yielded in the `Body` of the `Response`. + type ResponseBody2: HyperBody; +} + +impl HyperService for S +where + B1: HyperBody, + B2: HyperBody, + S: http_transport::HyperService, +{ + type ResponseBody2 = B2; +} + +pub fn setup(opts: Opts) -> Result, anyhow::Error> { + let Opts { + danger_accept_invalid_ssl, + ssl_root_certificate, + replica_resolve, + } = opts; + let builder = rustls::ClientConfig::builder().with_safe_defaults(); + let tls_config = if !danger_accept_invalid_ssl { + use rustls::{Certificate, RootCertStore}; + + let mut root_cert_store = RootCertStore::empty(); + for cert_path in ssl_root_certificate { + let mut buf = Vec::new(); + if let Err(e) = File::open(&cert_path).and_then(|mut v| v.read_to_end(&mut buf)) { + tracing::warn!("Could not load cert `{}`: {}", cert_path.display(), e); + continue; + } + match cert_path.extension() { + Some(v) if v == "pem" => { + tracing::info!( + "adding PEM cert `{}` to root certificates", + cert_path.display() + ); + let mut pem = Cursor::new(buf); + let certs = match rustls_pemfile::certs(&mut pem) { + Ok(v) => v, + Err(e) => { + tracing::warn!( + "No valid certificate was found `{}`: {}", + cert_path.display(), + e + ); + continue; + } + }; + for c in certs { + if let Err(e) = root_cert_store.add(&rustls::Certificate(c)) { + tracing::warn!( + "Could not add part of cert `{}`: {}", + cert_path.display(), + e + ); + } + } + } + Some(v) if v == "der" => { + tracing::info!( + "adding DER cert `{}` to root certificates", + cert_path.display() + ); + if let Err(e) = root_cert_store.add(&Certificate(buf)) { + tracing::warn!("Could not add cert `{}`: {}", cert_path.display(), e); + } + } + _ => tracing::warn!( + "Could not load cert `{}`: unknown extension", + cert_path.display() + ), + } + } + + use rustls::OwnedTrustAnchor; + let trust_anchors = webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|trust_anchor| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + trust_anchor.subject, + trust_anchor.spki, + trust_anchor.name_constraints, + ) + }); + root_cert_store.add_server_trust_anchors(trust_anchors); + + builder + .with_root_certificates(root_cert_store) + .with_no_client_auth() + } else { + use rustls::{ + client::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier, ServerName}, + internal::msgs::handshake::DigitallySignedStruct, + }; + + tracing::warn!("Allowing invalid certs. THIS VERY IS INSECURE."); + struct NoVerifier; + + impl ServerCertVerifier for NoVerifier { + fn verify_server_cert( + &self, + _end_entity: &rustls::Certificate, + _intermediates: &[rustls::Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: std::time::SystemTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::Certificate, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::Certificate, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + } + builder + .with_custom_certificate_verifier(Arc::new(NoVerifier)) + .with_no_client_auth() + }; + + // Advertise support for HTTP/2 + //tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + #[derive(Debug, Eq)] + struct Uncased(Name); + impl PartialEq for Uncased { + fn eq(&self, v: &Uncased) -> bool { + self.0.as_str().eq_ignore_ascii_case(v.0.as_str()) + } + } + impl Hash for Uncased { + fn hash(&self, state: &mut H) { + self.0.as_str().len().hash(state); + for b in self.0.as_str().as_bytes() { + state.write_u8(b.to_ascii_lowercase()); + } + } + } + + let mapped = replica_resolve + .into_iter() + .map(|v| Ok((Uncased(Name::from_str(&v.domain)?), v.addr))) + .collect::, anyhow::Error>>() + .context("Invalid domain in `replica-resolve` flag"); + // TODO: inspect_err + let _ = mapped.as_ref().map_err(|e| error!("{}", e)); + let mapped = Arc::new(mapped?); + let resolver = tower::service_fn(move |name: Name| { + let mapped = mapped.clone(); + async move { + let name = Uncased(name); + if let Some(v) = mapped.get(&name) { + Ok(Either::Left(iter::once(*v))) + } else { + GaiResolver::new().call(name.0).await.map(Either::Right) + } + } + }); + let mut connector = HttpConnector::new_with_resolver(resolver); + connector.enforce_http(false); + let connector = HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_or_http() + .enable_http1() + .enable_http2() + .wrap_connector(connector); + let client: Client<_, Body> = Client::builder().build(connector); + Ok(client) +} diff --git a/src/http_transport.rs b/src/http_transport.rs new file mode 100644 index 0000000..d2f1c7d --- /dev/null +++ b/src/http_transport.rs @@ -0,0 +1,312 @@ +pub use hyper; + +use std::{ + any, error::Error, future::Future, marker::PhantomData, pin::Pin, sync::atomic::AtomicPtr, +}; + +use http_body::{LengthLimitError, Limited}; +use hyper::{ + body::{Bytes, HttpBody}, + client::HttpConnector, + header::CONTENT_TYPE, + http::uri::{Authority, PathAndQuery}, + service::Service, + Client, Method, Request, Response, Uri, +}; +use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; +use ic_agent::{ + agent::{agent_error::HttpErrorPayload, ReplicaV2Transport}, + ic_types::Principal, + AgentError, RequestId, +}; + +#[allow(dead_code)] +const IC0_DOMAIN: &str = "ic0.app"; +#[allow(dead_code)] +const IC0_SUB_DOMAIN: &str = ".ic0.app"; + +type AgentFuture<'a, V> = Pin> + Send + 'a>>; + +/// A [ReplicaV2Transport] using [hyper] to make HTTP calls to the internet computer. +#[derive(Debug)] +pub struct HyperReplicaV2Transport, B1>> { + _marker: PhantomData>, + url: Uri, + max_response_body_size: Option, + service: S, +} + +/// Trait representing the contraints on [`HttpBody`] that [`HyperReplicaV2Transport`] requires +pub trait HyperBody: + HttpBody + Send + From> + 'static +{ + /// Values yielded by the `Body`. + type BodyData: Send; + /// The error type this `Body` might generate. + type BodyError: Error + Send + Sync + 'static; +} + +impl HyperBody for B +where + B: HttpBody + Send + From> + 'static, + B::Data: Send, + B::Error: Error + Send + Sync + 'static, +{ + type BodyData = B::Data; + type BodyError = B::Error; +} + +/// Trait representing the contraints on [`Service`] that [`HyperReplicaV2Transport`] requires. +pub trait HyperService: + Send + + Sync + + Clone + + Service< + Request, + Response = Response, + Error = hyper::Error, + Future = Self::ServiceFuture, + > +{ + /// Values yielded in the `Body` of the `Response`. + type ResponseBody: HyperBody; + /// The future response value. + type ServiceFuture: Send + Future>; +} + +impl HyperService for S +where + B1: HyperBody, + B2: HyperBody, + S: Send + Sync + Clone + Service, Response = Response, Error = hyper::Error>, + S::Future: Send, +{ + type ResponseBody = B2; + type ServiceFuture = S::Future; +} + +impl HyperReplicaV2Transport { + /// Creates a replica transport from a HTTP URL. + #[allow(dead_code)] + pub fn create>(url: U) -> Result { + let connector = HttpsConnectorBuilder::new() + .with_webpki_roots() + .https_or_http() + .enable_http1() + .enable_http2() + .build(); + Self::create_with_service(url, Client::builder().build(connector)) + } +} + +impl HyperReplicaV2Transport +where + B1: HyperBody, + S: HyperService, +{ + /// Creates a replica transport from a HTTP URL and a [`HyperService`]. + pub fn create_with_service>(url: U, service: S) -> Result { + // Parse the url + let url = url.into(); + let mut parts = url.clone().into_parts(); + parts.authority = parts + .authority + .map(|v| { + let host = v.host(); + let host = match host.len().checked_sub(IC0_SUB_DOMAIN.len()) { + None => host, + Some(start) if host[start..].eq_ignore_ascii_case(IC0_SUB_DOMAIN) => IC0_DOMAIN, + Some(_) => host, + }; + let port = v.port(); + let (colon, port) = match port.as_ref() { + Some(v) => (":", v.as_str()), + None => ("", ""), + }; + Authority::from_maybe_shared(Bytes::from(format!("{host}{colon}{port}"))) + }) + .transpose() + .map_err(|_| AgentError::InvalidReplicaUrl(format!("{url}")))?; + parts.path_and_query = Some( + parts + .path_and_query + .map_or(Ok(PathAndQuery::from_static("/api/v2")), |v| { + let mut found = false; + fn replace(a: T, b: &mut T) -> T { + std::mem::replace(b, a) + } + let v = v + .path() + .trim_end_matches(|c| !replace(found || c == '/', &mut found)); + PathAndQuery::from_maybe_shared(Bytes::from(format!("{v}/api/v2"))) + }) + .map_err(|_| AgentError::InvalidReplicaUrl(format!("{url}")))?, + ); + let url = + Uri::from_parts(parts).map_err(|_| AgentError::InvalidReplicaUrl(format!("{url}")))?; + + Ok(Self { + _marker: PhantomData, + url, + service, + max_response_body_size: None, + }) + } + + /// Sets a max response body size limit + pub fn with_max_response_body_size(self, max_response_body_size: usize) -> Self { + Self { + max_response_body_size: Some(max_response_body_size), + ..self + } + } + + async fn request( + &self, + method: Method, + url: String, + body: Option>, + ) -> Result, AgentError> { + let http_request = Request::builder() + .method(method) + .uri(url) + .header(CONTENT_TYPE, "application/cbor") + .body(body.unwrap_or_default().into()) + .map_err(|err| AgentError::TransportError(Box::new(err)))?; + + fn map_error(err: E) -> AgentError { + if any::TypeId::of::() == any::TypeId::of::() { + // Store the value in an `Option` so we can `take` + // it after casting to `&mut dyn Any`. + let mut slot = Some(err); + + // Re-write the `$val` ident with the downcasted value. + let val = (&mut slot as &mut dyn any::Any) + .downcast_mut::>() + .unwrap() + .take() + .unwrap(); + + // Run the $body in scope of the replaced val. + return val; + } + AgentError::TransportError(Box::new(err)) + } + let response = self + .service + .clone() + .call(http_request) + .await + .map_err(map_error)?; + + let (parts, body) = response.into_parts(); + let body = if let Some(limit) = self.max_response_body_size { + hyper::body::to_bytes(Limited::new(body, limit)) + .await + .map_err(|err| { + if err.downcast_ref::().is_some() { + AgentError::ResponseSizeExceededLimit() + } else { + AgentError::TransportError(err) + } + })? + } else { + hyper::body::to_bytes(body) + .await + .map_err(|err| AgentError::TransportError(Box::new(err)))? + }; + + let (status, headers, body) = (parts.status, parts.headers, body.to_vec()); + if status.is_client_error() || status.is_server_error() { + Err(AgentError::HttpError(HttpErrorPayload { + status: status.into(), + content_type: headers + .get(CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + .map(|x| x.to_string()), + content: body, + })) + } else { + Ok(body) + } + } +} + +impl ReplicaV2Transport for HyperReplicaV2Transport +where + B1: HyperBody, + S: HyperService, +{ + fn call( + &self, + effective_canister_id: Principal, + envelope: Vec, + _request_id: RequestId, + ) -> AgentFuture<()> { + Box::pin(async move { + let url = format!("{}/canister/{effective_canister_id}/call", self.url); + self.request(Method::POST, url, Some(envelope)).await?; + Ok(()) + }) + } + + fn read_state( + &self, + effective_canister_id: Principal, + envelope: Vec, + ) -> AgentFuture> { + Box::pin(async move { + let url = format!("{}/canister/{effective_canister_id}/read_state", self.url); + self.request(Method::POST, url, Some(envelope)).await + }) + } + + fn query(&self, effective_canister_id: Principal, envelope: Vec) -> AgentFuture> { + Box::pin(async move { + let url = format!("{}/canister/{effective_canister_id}/query", self.url); + self.request(Method::POST, url, Some(envelope)).await + }) + } + + fn status(&self) -> AgentFuture> { + Box::pin(async move { + let url = format!("{}/status", self.url); + self.request(Method::GET, url, None).await + }) + } +} + +#[cfg(test)] +mod test { + use super::{ + hyper::{Client, Uri}, + HyperReplicaV2Transport, + }; + + #[test] + fn redirect() { + fn test(base: &str, result: &str) { + let client: Client<_> = Client::builder().build_http(); + let uri: Uri = base.parse().unwrap(); + let t = HyperReplicaV2Transport::create_with_service(uri, client).unwrap(); + assert_eq!(t.url, result, "{}", base); + } + + test("https://ic0.app", "https://ic0.app/api/v2"); + test("https://IC0.app", "https://ic0.app/api/v2"); + test("https://foo.ic0.app", "https://ic0.app/api/v2"); + test("https://foo.IC0.app", "https://ic0.app/api/v2"); + test("https://foo.Ic0.app", "https://ic0.app/api/v2"); + test("https://foo.iC0.app", "https://ic0.app/api/v2"); + test("https://foo.bar.ic0.app", "https://ic0.app/api/v2"); + test("https://ic0.app/foo/", "https://ic0.app/foo/api/v2"); + test("https://foo.ic0.app/foo/", "https://ic0.app/foo/api/v2"); + + test("https://ic1.app", "https://ic1.app/api/v2"); + test("https://foo.ic1.app", "https://foo.ic1.app/api/v2"); + test("https://ic0.app.ic1.app", "https://ic0.app.ic1.app/api/v2"); + + test("https://fooic0.app", "https://fooic0.app/api/v2"); + test("https://fooic0.app.ic0.app", "https://ic0.app/api/v2"); + } +} diff --git a/src/logging.rs b/src/logging.rs index d04e1b2..3360f00 100644 --- a/src/logging.rs +++ b/src/logging.rs @@ -1,79 +1,145 @@ -use crate::Opts; -use slog::{Drain, Level, LevelFilter, Logger}; -use std::{fs::File, path::PathBuf}; +use std::{fs::File, io::stderr, path::PathBuf}; -/// The logging mode to use. -enum LoggingMode { - /// The default mode for logging; output without any decoration, to STDERR. - Stderr, +use axum::Router; +use clap::{crate_version, ArgEnum, Args}; +use tower_http::trace::TraceLayer; +use tracing::{ + info, info_span, level_filters::LevelFilter, span::EnteredSpan, subscriber::set_global_default, + Span, +}; +use tracing_subscriber::{fmt::layer, layer::SubscriberExt, Registry}; - /// Tee logging to a file (in addition to STDERR). This mimics the verbose flag. - /// So it would be similar to `dfx ... |& tee /some/file.txt - Tee(PathBuf), +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ArgEnum, Debug)] +pub(crate) enum OptMode { + StdErr, + Tee, + File, +} - /// Output Debug logs and up to a file, regardless of verbosity, keep the STDERR output - /// the same (with verbosity). - File(PathBuf), +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ArgEnum, Debug)] +pub(crate) enum OptFormat { + Default, + Compact, + Full, + Json, } -fn create_drain(mode: LoggingMode) -> Logger { - match mode { - LoggingMode::File(out) => { - let file = File::create(out).expect("Couldn't open log file"); - let decorator = slog_term::PlainDecorator::new(file); - let drain = slog_term::FullFormat::new(decorator).build().fuse(); - Logger::root(slog_async::Async::new(drain).build().fuse(), slog::o!()) - } - // A Tee mode is basically 2 drains duplicated. - LoggingMode::Tee(out) => Logger::root( - slog::Duplicate::new( - create_drain(LoggingMode::Stderr), - create_drain(LoggingMode::File(out)), - ) - .fuse(), - slog::o!(), - ), - LoggingMode::Stderr => { - let decorator = slog_term::PlainDecorator::new(std::io::stderr()); - let drain = slog_term::CompactFormat::new(decorator).build().fuse(); - Logger::root(slog_async::Async::new(drain).build().fuse(), slog::o!()) - } - } +/// The options for logging +#[derive(Args)] +pub struct Opts { + /// Verbose level. By default, INFO will be used. Add a single `-v` to upgrade to + /// DEBUG, and another `-v` to upgrade to TRACE. + #[clap(long, short('v'), parse(from_occurrences))] + verbose: u64, + + /// Quiet level. The opposite of verbose. A single `-q` will drop the logging to + /// WARN only, then another one to ERR, and finally another one for FATAL. Another + /// `-q` will silence ALL logs. + #[clap(long, short('q'), parse(from_occurrences))] + quiet: u64, + + /// Mode to use the logging. "stderr" will output logs in STDERR, "file" will output + /// logs in a file, and "tee" will do both. + #[clap(arg_enum, long("log"), default_value_t = OptMode::StdErr)] + logmode: OptMode, + + /// Formatting to use the logging. "stderr" will output logs in STDERR, "file" will output + /// logs in a file, and "tee" will do both. + #[clap(arg_enum, long("logformat"), default_value_t = OptFormat::Default)] + logformat: OptFormat, + + /// File to output the log to, when using logmode=tee or logmode=file. + #[clap(long)] + logfile: Option, } -pub(crate) fn setup_logging(opts: &Opts) -> Logger { - // Create a logger with our argument matches. - let verbose_level = opts.verbose as i64 - opts.quiet as i64; - let logfile = opts.logfile.clone().unwrap_or_else(|| "log.txt".into()); +/// A helper to add tracing with nice spans to `Router`s +pub fn add_trace_layer(r: Router) -> Router { + r.layer(TraceLayer::new_for_http().make_span_with(Span::current())) +} - let mode = match opts.logmode.as_str() { - "tee" => LoggingMode::Tee(logfile), - "file" => LoggingMode::File(logfile), - "stderr" => LoggingMode::Stderr, - _ => unreachable!("unhandled logmode"), +pub fn setup(opts: Opts) -> EnteredSpan { + let filter = match opts.verbose as i64 - opts.quiet as i64 { + -2 => LevelFilter::ERROR, + -1 => LevelFilter::WARN, + 0 => LevelFilter::INFO, + 1 => LevelFilter::DEBUG, + x if x >= 2 => LevelFilter::TRACE, + // Silent. + _ => LevelFilter::OFF, }; - let log_level = match verbose_level { - -3 => Level::Critical, - -2 => Level::Error, - -1 => Level::Warning, - 0 => Level::Info, - 1 => Level::Debug, - 2 => Level::Trace, - x => { - if x > 0 { - Level::Trace - } else { - // Silent. - return Logger::root(slog::Discard, slog::o!()); - } + fn create_file(path: Option) -> File { + File::create(path.unwrap_or_else(|| "log.txt".into())).expect("Couldn't open log file") + } + + // The `layer_format` macro is used to uniformly customize the the format specific options for a layer + // (e.g. all json should be flattened) + macro_rules! layer_format { + (json, $writer:expr) => { + layer() + .json() + .flatten_event(true) + .with_current_span(false) + .with_writer($writer) + }; + (full, $writer:expr) => { + layer().with_writer($writer) + }; + (compact, $writer:expr) => { + layer().compact().with_writer($writer) + }; + } + // The `writer` macro is used to uniformly customize the the writer specific options for a layer + // (e.g. files don't use ANSI terminal colors) + macro_rules! writer { + (file, $format:ident) => { + layer_format!($format, create_file(opts.logfile)).with_ansi(false) + }; + (stderr, $format:ident) => { + layer_format!($format, stderr) + }; + } + // The `layer` macro is used to uniformly customize the the writer-format specific options for a layer + // (e.g. file-json includes the current span [we don't actually do this, it's just an hypothetical example]) + macro_rules! layer { + ($writer:ident, $format:ident) => { + writer!($writer, $format) + }; + } + + // The `install` macro filters to the specified level and adds all the layers to the global subscriber + macro_rules! install { + ($($layer:expr),+) => { + set_global_default(Registry::default().with(filter)$(.with($layer))+) } - }; + } - let drain = LevelFilter::new(create_drain(mode), log_level).fuse(); - let drain = slog_async::Async::new(drain).build().fuse(); + match (opts.logmode, opts.logformat) { + (OptMode::Tee, OptFormat::Default) => { + install!(layer!(stderr, compact), layer!(file, full)) + } + (OptMode::Tee, OptFormat::Compact) => { + install!(layer!(stderr, compact), layer!(file, compact)) + } + (OptMode::Tee, OptFormat::Full) => install!(layer!(stderr, full), layer!(file, full)), + (OptMode::Tee, OptFormat::Json) => install!(layer!(stderr, json), layer!(file, json)), + (OptMode::File, OptFormat::Default | OptFormat::Full) => { + install!(layer!(file, full)) + } + (OptMode::File, OptFormat::Compact) => { + install!(layer!(file, compact)) + } + (OptMode::File, OptFormat::Json) => install!(layer!(file, json)), + (OptMode::StdErr, OptFormat::Default | OptFormat::Compact) => { + install!(layer!(stderr, compact)) + } + (OptMode::StdErr, OptFormat::Full) => install!(layer!(stderr, full)), + (OptMode::StdErr, OptFormat::Json) => install!(layer!(stderr, json)), + } + .expect("Failed to setup tracing."); - let root = Logger::root(drain, slog::o!("version" => clap::crate_version!())); - slog::info!(root, "Log Level: {}", log_level); - root + let span = info_span!(target: "icx_proxy", "icx-proxy", version = crate_version!()).entered(); + info!(target: "icx_proxy", "Log Level: {filter}"); + span } diff --git a/src/main.rs b/src/main.rs index 0c69915..a705fa7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,903 +1,101 @@ -use axum::{handler::Handler, routing::get, Extension, Router}; use clap::{crate_authors, crate_version, Parser}; -use futures::{future::OptionFuture, try_join, FutureExt, StreamExt}; -use http_body::{LengthLimitError, Limited}; -use hyper::{ - body, - http::{header::CONTENT_TYPE, uri::Parts}, - service::{make_service_fn, service_fn}, - Body, Client, Request, Response, Server, StatusCode, Uri, -}; -use ic_agent::{ - agent::http_transport::{reqwest, ReqwestHttpReplicaV2Transport}, - agent_error::HttpErrorPayload, - Agent, AgentError, -}; -use ic_utils::{ - call::AsyncCall, - call::SyncCall, - interfaces::http_request::{ - HeaderField, HttpRequestCanister, HttpRequestStreamingCallbackAny, HttpResponse, - StreamingCallbackHttpResponse, StreamingStrategy, Token, - }, -}; -use opentelemetry::{global, sdk::Resource, KeyValue}; -use opentelemetry_prometheus::PrometheusExporter; -use prometheus::{Encoder, TextEncoder}; -use slog::Drain; -use std::{ - convert::Infallible, - error::Error, - fs::File, - io::{Cursor, Read}, - net::{IpAddr, SocketAddr}, - path::PathBuf, - str::FromStr, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, Mutex, - }, -}; +use futures::try_join; +use tracing::{error, Instrument}; mod canister_id; mod config; mod headers; +mod http_client; mod logging; mod metrics; +mod proxy; mod validate; +// TODO: remove pending dfinity/agent-rs#373 +mod http_transport; + use crate::{ - config::dns_canister_config::DnsCanisterConfig, - headers::{extract_headers_data, HeadersData}, metrics::{MetricParams, WithMetrics}, - validate::{Validate, Validator}, + validate::Validator, }; -type HttpResponseAny = HttpResponse; - -// Limit the total number of calls to an HTTP Request loop to 1000 for now. -const MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT: usize = 1000; - -// Limit the number of Stream Callbacks buffered -const STREAM_CALLBACK_BUFFFER: usize = 2; - -// The maximum length of a body we should log as tracing. -const MAX_LOG_BODY_SIZE: usize = 100; - -const KB: usize = 1024; -const MB: usize = 1024 * KB; - -const REQUEST_BODY_SIZE_LIMIT: usize = 10 * MB; -const RESPONSE_BODY_SIZE_LIMIT: usize = 10 * MB; - -/// https://internetcomputer.org/docs/current/references/ic-interface-spec#reject-codes -struct ReplicaErrorCodes; -impl ReplicaErrorCodes { - const DESTINATION_INVALID: u64 = 3; -} - -/// Resolve overrides for [`reqwest::ClientBuilder::resolve()`] -/// `ic0.app=[::1]:9090` -pub(crate) struct OptResolve { - domain: String, - addr: SocketAddr, -} - -impl FromStr for OptResolve { - type Err = anyhow::Error; - fn from_str(s: &str) -> Result { - let (domain, addr) = s - .split_once('=') - .ok_or_else(|| anyhow::Error::msg("missing '='"))?; - Ok(OptResolve { - domain: domain.into(), - addr: addr.parse()?, - }) - } -} - #[derive(Parser)] #[clap( version = crate_version!(), author = crate_authors!(), propagate_version = true, )] -pub(crate) struct Opts { - /// Verbose level. By default, INFO will be used. Add a single `-v` to upgrade to - /// DEBUG, and another `-v` to upgrade to TRACE. - #[clap(long, short('v'), parse(from_occurrences))] - verbose: u64, - - /// Quiet level. The opposite of verbose. A single `-q` will drop the logging to - /// WARN only, then another one to ERR, and finally another one for FATAL. Another - /// `-q` will silence ALL logs. - #[clap(long, short('q'), parse(from_occurrences))] - quiet: u64, +struct Opts { + /// The options for logging + #[clap(flatten)] + log: logging::Opts, - /// Mode to use the logging. "stderr" will output logs in STDERR, "file" will output - /// logs in a file, and "tee" will do both. - #[clap(long("log"), default_value("stderr"), possible_values(&["stderr", "tee", "file"]))] - logmode: String, + /// The options for the HTTP client + #[clap(flatten)] + http_client: http_client::Opts, - /// File to output the log to, when using logmode=tee or logmode=file. - #[clap(long)] - logfile: Option, + /// The options for metrics + #[clap(flatten)] + metrics: metrics::Opts, - /// The address to bind to. - #[clap(long, default_value = "127.0.0.1:3000")] - address: SocketAddr, + /// The options for the canister resolver + #[clap(flatten)] + canister_id: canister_id::Opts, - /// A replica to use as backend. Locally, this should be a local instance or the - /// boundary node. Multiple replicas can be passed and they'll be used round-robin. - #[clap(long, default_value = "http://localhost:8000/")] - replica: Vec, - - /// Override DNS resolution for specific replica domains to particular IP addresses. - /// Examples: ic0.app=[::1]:9090 - #[clap(long, value_name("DOMAIN=IP_PORT"))] - replica_resolve: Vec, - - /// An address to forward any requests from /_/ - #[clap(long)] - proxy: Option, - - /// Whether or not this is run in a debug context (e.g. errors returned in responses - /// should show full stack and error details). - #[clap(long)] - debug: bool, - - /// Whether or not to fetch the root key from the replica back end. Do not use this when - /// talking to the Internet Computer blockchain mainnet as it is unsecure. - #[clap(long)] - fetch_root_key: bool, - - /// The list of custom root HTTPS certificates to use to talk to the replica. This can be used - /// to connect to an IC that has a self-signed certificate, for example. Do not use this when - /// talking to the Internet Computer blockchain mainnet as it is unsecure. - #[clap(long)] - ssl_root_certificate: Vec, - - /// Allows HTTPS connection to replicas with invalid HTTPS certificates. This can be used to - /// connect to an IC that has a self-signed certificate, for example. Do not use this when - /// talking to the Internet Computer blockchain mainnet as it is *VERY* unsecure. - #[clap(long)] - danger_accept_invalid_ssl: bool, - - /// A map of domain names to canister IDs. - /// Format: domain.name:canister-id - #[clap(long)] - dns_alias: Vec, - - /// A list of domain name suffixes. If found, the next (to the left) subdomain - /// is used as the Principal, if it parses as a Principal. - #[clap(long, default_value = "localhost")] - dns_suffix: Vec, - - /// Whether or not to ignore `canisterId=` when locating the canister. - #[clap(long)] - ignore_url_canister_param: bool, - - /// Address to expose Prometheus metrics on - /// Examples: 127.0.0.1:9090, [::1]:9090 - #[clap(long)] - metrics_addr: Option, + /// The options for the proxy server + #[clap(flatten)] + proxy: proxy::Opts, } -async fn forward_request( - request: Request, - agent: Arc, - resolver: &dyn canister_id::Resolver, - validator: &dyn Validate, - logger: slog::Logger, -) -> Result, Box> { - let canister_id = match resolver.resolve(&request) { - None => { - return Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .body("Could not find a canister id to forward to.".into()) - .unwrap()) - } - Some(x) => x, - }; - - slog::trace!( - logger, - "<< {} {} {:?}", - request.method(), - request.uri(), - &request.version() - ); - - let (parts, body) = request.into_parts(); - let method = parts.method; - let uri = parts.uri.to_string(); - let headers = parts - .headers - .iter() - .filter_map(|(name, value)| { - Some(HeaderField( - name.as_str().into(), - value.to_str().ok()?.into(), - )) - }) - .inspect(|HeaderField(name, value)| { - slog::trace!(logger, "<< {}: {}", name, value); - }) - .collect::>(); +fn main() -> Result<(), anyhow::Error> { + let Opts { + log, + http_client, + metrics, + canister_id, + proxy, + .. + } = Opts::parse(); - // Limit request body size - let body = Limited::new(body, REQUEST_BODY_SIZE_LIMIT); - let entire_body = match hyper::body::to_bytes(body).await { - Ok(data) => data, - Err(err) => { - if err.downcast_ref::().is_some() { - return Ok(Response::builder() - .status(StatusCode::PAYLOAD_TOO_LARGE) - .body(Body::from("Request size exceeds limit"))?); - } - return Err(err); - } - } - .to_vec(); - - slog::trace!(logger, "<<"); - if logger.is_trace_enabled() { - let body = String::from_utf8_lossy( - &entire_body[0..usize::min(entire_body.len(), MAX_LOG_BODY_SIZE)], - ); - slog::trace!( - logger, - "<< \"{}\"{}", - &body.escape_default(), - if body.len() > MAX_LOG_BODY_SIZE { - format!("... {} bytes total", body.len()) - } else { - String::new() - } - ); - } + let _span = logging::setup(log); - let canister = HttpRequestCanister::create(agent.as_ref(), canister_id); - let query_result = canister - .http_request_custom( - method.as_str(), - uri.as_str(), - headers.iter().cloned(), - &entire_body, - ) - .call() - .await; + let client = http_client::setup(http_client)?; - fn handle_result( - result: Result<(HttpResponseAny,), AgentError>, - ) -> Result, Box>> { - // If the result is a Replica error, returns the 500 code and message. There is no information - // leak here because a user could use `dfx` to get the same reply. - match result { - Ok((http_response,)) => Ok(http_response), + let (meter, metrics) = metrics::setup(metrics); - Err(AgentError::ReplicaError { - reject_code: ReplicaErrorCodes::DESTINATION_INVALID, - reject_message, - }) => Err(Ok(Response::builder() - .status(StatusCode::NOT_FOUND) - .body(reject_message.into()) - .unwrap())), - - Err(AgentError::ReplicaError { - reject_code, - reject_message, - }) => Err(Ok(Response::builder() - .status(StatusCode::BAD_GATEWAY) - .body(format!(r#"Replica Error ({}): "{}""#, reject_code, reject_message).into()) - .unwrap())), - - Err(AgentError::HttpError(HttpErrorPayload { - status: 451, - content_type, - content, - })) => Err(Ok(content_type - .into_iter() - .fold(Response::builder(), |r, c| r.header(CONTENT_TYPE, c)) - .status(451) - .body(content.into()) - .unwrap())), - - Err(AgentError::ResponseSizeExceededLimit()) => Err(Ok(Response::builder() - .status(StatusCode::INSUFFICIENT_STORAGE) - .body("Response size exceeds limit".into()) - .unwrap())), - - Err(e) => Err(Err(e.into())), - } - } - - let http_response = match handle_result(query_result) { - Ok(http_response) => http_response, - Err(response_or_error) => return response_or_error, - }; - - let http_response = if http_response.upgrade == Some(true) { - let waiter = garcon::Delay::builder() - .throttle(std::time::Duration::from_millis(500)) - .timeout(std::time::Duration::from_secs(15)) - .build(); - let update_result = canister - .http_request_update_custom( - method.as_str(), - uri.as_str(), - headers.iter().cloned(), - &entire_body, - ) - .call_and_wait(waiter) - .await; - let http_response = match handle_result(update_result) { - Ok(http_response) => http_response, - Err(response_or_error) => return response_or_error, - }; - http_response - } else { - http_response - }; - - let mut builder = Response::builder().status(StatusCode::from_u16(http_response.status_code)?); - for HeaderField(name, value) in &http_response.headers { - builder = builder.header(name.as_ref(), value.as_ref()); - } - - let headers_data = extract_headers_data(&http_response.headers, &logger); - let body = if logger.is_trace_enabled() { - Some(http_response.body.clone()) - } else { - None - }; - let is_streaming = http_response.streaming_strategy.is_some(); - let response = if let Some(streaming_strategy) = http_response.streaming_strategy { - let body = http_response.body; - let body = futures::stream::once(async move { Ok(body) }); - let body = match streaming_strategy { - StreamingStrategy::Callback(callback) => body::Body::wrap_stream( - body.chain(futures::stream::try_unfold( - ( - logger.clone(), - agent, - callback.callback.0, - Some(callback.token), - ), - move |(logger, agent, callback, callback_token)| async move { - let callback_token = match callback_token { - Some(callback_token) => callback_token, - None => return Ok(None), - }; - - let canister = HttpRequestCanister::create(&agent, callback.principal); - match canister - .http_request_stream_callback(&callback.method, callback_token) - .call() - .await - { - Ok((StreamingCallbackHttpResponse { body, token },)) => { - Ok(Some((body, (logger, agent, callback, token)))) - } - Err(e) => { - slog::warn!(logger, "Error happened during streaming: {}", e); - Err(e) - } - } - }, - )) - .take(MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT) - .map(|x| async move { x }) - .buffered(STREAM_CALLBACK_BUFFFER), - ), - }; - - builder.body(body)? - } else { - let body_valid = validator.validate( - &headers_data, - &canister_id, - &agent, - &parts.uri, - &http_response.body, - logger.clone(), - ); - if body_valid.is_err() { - return Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(body_valid.unwrap_err().into()) - .unwrap()); - } - builder.body(http_response.body.into())? - }; - - if logger.is_trace_enabled() { - slog::trace!( - logger, - ">> {:?} {} {}", - &response.version(), - response.status().as_u16(), - response.status().to_string() - ); - - for (name, value) in response.headers() { - let value = String::from_utf8_lossy(value.as_bytes()); - slog::trace!(logger, ">> {}: {}", name, value); - } - - let body = body.unwrap_or_else(|| b"... streaming ...".to_vec()); - - slog::trace!(logger, ">>"); - slog::trace!( - logger, - ">> \"{}\"{}", - String::from_utf8_lossy(&body[..usize::min(MAX_LOG_BODY_SIZE, body.len())]) - .escape_default(), - if is_streaming { - "... streaming".to_string() - } else if body.len() > MAX_LOG_BODY_SIZE { - format!("... {} bytes total", body.len()) - } else { - String::new() - } - ); - } - - Ok(response) -} - -fn is_hop_header(name: &str) -> bool { - name.to_ascii_lowercase() == "connection" - || name.to_ascii_lowercase() == "keep-alive" - || name.to_ascii_lowercase() == "proxy-authenticate" - || name.to_ascii_lowercase() == "proxy-authorization" - || name.to_ascii_lowercase() == "te" - || name.to_ascii_lowercase() == "trailers" - || name.to_ascii_lowercase() == "transfer-encoding" - || name.to_ascii_lowercase() == "upgrade" -} - -/// Returns a clone of the headers without the [hop-by-hop headers]. -/// -/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html -fn remove_hop_headers( - headers: &hyper::header::HeaderMap, -) -> hyper::header::HeaderMap { - let mut result = hyper::HeaderMap::new(); - for (k, v) in headers.iter() { - if !is_hop_header(k.as_str()) { - result.insert(k.clone(), v.clone()); - } - } - result -} - -fn forward_uri(forward_url: &str, req: &Request) -> Result> { - let uri = Uri::from_str(forward_url)?; - let mut parts = Parts::from(uri); - parts.path_and_query = req.uri().path_and_query().cloned(); - - Ok(Uri::from_parts(parts)?) -} - -fn create_proxied_request( - client_ip: &IpAddr, - forward_url: &str, - mut request: Request, -) -> Result, Box> { - *request.headers_mut() = remove_hop_headers(request.headers()); - *request.uri_mut() = forward_uri(forward_url, &request)?; - - let x_forwarded_for_header_name = "x-forwarded-for"; - - // Add forwarding information in the headers - match request.headers_mut().entry(x_forwarded_for_header_name) { - hyper::header::Entry::Vacant(entry) => { - entry.insert(client_ip.to_string().parse()?); - } - - hyper::header::Entry::Occupied(mut entry) => { - let addr = format!("{}, {}", entry.get().to_str()?, client_ip); - entry.insert(addr.parse()?); - } - } - - Ok(request) -} - -async fn forward_api( - ip_addr: &IpAddr, - request: Request, - replica_url: &str, -) -> Result, Box> { - let proxied_request = create_proxied_request(ip_addr, replica_url, request)?; - - let client = Client::builder().build(hyper_tls::HttpsConnector::new()); - let response = client.request(proxied_request).await?; - Ok(response) -} - -fn not_found() -> Result, Box> { - Ok(Response::builder() - .status(StatusCode::NOT_FOUND) - .body("Not found".into())?) -} - -fn unable_to_fetch_root_key() -> Result, Box> { - Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body("Unable to fetch root key".into())?) -} - -struct HandleRequest { - ip_addr: IpAddr, - request: Request, - replica_url: String, - client: reqwest::Client, - proxy_url: Option, - resolver: Arc>, - validator: Arc, - logger: slog::Logger, - fetch_root_key: bool, - debug: bool, -} - -async fn handle_request( - HandleRequest { - ip_addr, - request, - replica_url, - client, - proxy_url, - resolver, - validator, - logger, - fetch_root_key, - debug, - }: HandleRequest, -) -> Result, Infallible> { - let request_uri_path = request.uri().path(); - let result = if request_uri_path.starts_with("/api/") { - slog::debug!( - logger, - "URI Request to path '{}' being forwarded to Replica", - &request.uri().path() - ); - forward_api(&ip_addr, request, &replica_url).await - } else if request_uri_path.starts_with("/_/") && !request_uri_path.starts_with("/_/raw") { - if let Some(proxy_url) = proxy_url { - slog::debug!( - logger, - "URI Request to path '{}' being forwarded to proxy", - &request.uri().path(), - ); - forward_api(&ip_addr, request, &proxy_url).await - } else { - slog::warn!( - logger, - "Unable to proxy {} because no --proxy is configured", - &request.uri().path() - ); - not_found() - } - } else { - let transport = ReqwestHttpReplicaV2Transport::create_with_client(replica_url, client) - .expect("failed to create transport") - .with_max_response_body_size(RESPONSE_BODY_SIZE_LIMIT); - - let agent = Arc::new( - ic_agent::Agent::builder() - .with_transport(transport) - .build() - .expect("Could not create agent..."), - ); - - if fetch_root_key && agent.fetch_root_key().await.is_err() { - unable_to_fetch_root_key() - } else { - forward_request( - request, - agent, - resolver.as_ref(), - validator.as_ref(), - logger.clone(), - ) - .await - } - }; - - match result { - Err(err) => { - slog::warn!(logger, "Internal Error during request:\n{:#?}", err); - - Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(if debug { - format!("Internal Error: {:?}", err).into() - } else { - "Internal Server Error".into() - }) - .unwrap()) - } - Ok(x) => Ok::<_, Infallible>(x), - } -} - -fn setup_http_client( - logger: &slog::Logger, - danger_accept_invalid_certs: bool, - root_certificates: &[PathBuf], - addr_mappings: Vec, -) -> reqwest::Client { - let builder = rustls::ClientConfig::builder().with_safe_defaults(); - let mut tls_config = if !danger_accept_invalid_certs { - use rustls::Certificate; - use rustls::RootCertStore; - - let mut root_cert_store = RootCertStore::empty(); - for cert_path in root_certificates { - let mut buf = Vec::new(); - if let Err(e) = File::open(cert_path).and_then(|mut v| v.read_to_end(&mut buf)) { - slog::warn!( - logger, - "Could not load cert `{}`: {}", - cert_path.display(), - e - ); - continue; - } - match cert_path.extension() { - Some(v) if v == "pem" => { - slog::info!( - logger, - "adding PEM cert `{}` to root certificates", - cert_path.display() - ); - let mut pem = Cursor::new(buf); - let certs = match rustls_pemfile::certs(&mut pem) { - Ok(v) => v, - Err(e) => { - slog::warn!( - logger, - "No valid certificate was found `{}`: {}", - cert_path.display(), - e - ); - continue; - } - }; - for c in certs { - if let Err(e) = root_cert_store.add(&rustls::Certificate(c)) { - slog::warn!( - logger, - "Could not add part of cert `{}`: {}", - cert_path.display(), - e - ); - } - } - } - Some(v) if v == "der" => { - slog::info!( - logger, - "adding DER cert `{}` to root certificates", - cert_path.display() - ); - if let Err(e) = root_cert_store.add(&Certificate(buf)) { - slog::warn!( - logger, - "Could not add cert `{}`: {}", - cert_path.display(), - e - ); - } - } - _ => slog::warn!( - logger, - "Could not load cert `{}`: unknown extension", - cert_path.display() - ), - } - } - - use rustls::OwnedTrustAnchor; - let trust_anchors = webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|trust_anchor| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - trust_anchor.subject, - trust_anchor.spki, - trust_anchor.name_constraints, - ) - }); - root_cert_store.add_server_trust_anchors(trust_anchors); - - builder - .with_root_certificates(root_cert_store) - .with_no_client_auth() - } else { - use rustls::client::HandshakeSignatureValid; - use rustls::client::ServerCertVerified; - use rustls::client::ServerCertVerifier; - use rustls::client::ServerName; - use rustls::internal::msgs::handshake::DigitallySignedStruct; - - slog::warn!(logger, "Allowing invalid certs. THIS VERY IS INSECURE."); - struct NoVerifier; - - impl ServerCertVerifier for NoVerifier { - fn verify_server_cert( - &self, - _end_entity: &rustls::Certificate, - _intermediates: &[rustls::Certificate], - _server_name: &ServerName, - _scts: &mut dyn Iterator, - _ocsp_response: &[u8], - _now: std::time::SystemTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &rustls::Certificate, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &rustls::Certificate, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - } - builder - .with_custom_certificate_verifier(Arc::new(NoVerifier)) - .with_no_client_auth() - }; - - // Advertise support for HTTP/2 - tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - let builder = reqwest::Client::builder().use_preconfigured_tls(tls_config); - - // Setup DNS - let builder = addr_mappings - .into_iter() - .fold(builder, |builder, OptResolve { domain, addr }| { - builder.resolve(&domain, addr) - }); - - builder.build().expect("Could not create HTTP client.") -} - -#[derive(Clone)] -struct MetricsHandlerArgs { - exporter: PrometheusExporter, -} - -async fn metrics_handler( - Extension(MetricsHandlerArgs { exporter }): Extension, - _: Request, -) -> Response { - let metric_families = exporter.registry().gather(); - - let encoder = TextEncoder::new(); - - let mut metrics_text = Vec::new(); - if encoder.encode(&metric_families, &mut metrics_text).is_err() { - return Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body("Internal Server Error".into()) - .unwrap(); - }; - - Response::builder() - .status(200) - .body(metrics_text.into()) - .unwrap() -} - -fn main() -> Result<(), Box> { - let opts: Opts = Opts::parse(); - - let logger = logging::setup_logging(&opts); - - let client = setup_http_client( - &logger, - opts.danger_accept_invalid_ssl, - &opts.ssl_root_certificate, - opts.replica_resolve, - ); - - // Setup metrics - let exporter = opentelemetry_prometheus::exporter() - .with_resource(Resource::new(vec![KeyValue::new("service", "prober")])) - .init(); - let meter = global::meter("icx-proxy"); - - let metrics_addr = opts.metrics_addr; - let create_metrics_server = move || { - OptionFuture::from(metrics_addr.map(|metrics_addr| { - let metrics_handler = metrics_handler.layer(Extension(MetricsHandlerArgs { exporter })); - let metrics_router = Router::new().route("/metrics", get(metrics_handler)); - - axum::Server::bind(&metrics_addr).serve(metrics_router.into_make_service()) - })) - }; - - // Prepare a list of agents for each backend replicas. - let replicas = Mutex::new(opts.replica.clone()); - - let dns = DnsCanisterConfig::new(&opts.dns_alias, &opts.dns_suffix)?; - let resolver = Arc::new(canister_id::DefaultResolver { - dns, - check_params: !opts.ignore_url_canister_param, - }); + let resolver = canister_id::setup(canister_id)?; let validator = Validator::new(); let validator = WithMetrics(validator, MetricParams::new(&meter, "validator")); - let validator = Arc::new(validator); - - let counter = AtomicUsize::new(0); - let debug = opts.debug; - let proxy_url = opts.proxy.clone(); - let fetch_root_key = opts.fetch_root_key; - - let service = make_service_fn(|socket: &hyper::server::conn::AddrStream| { - let ip_addr = socket.remote_addr(); - let ip_addr = ip_addr.ip(); - let resolver = resolver.clone(); - let validator = validator.clone(); - let logger = logger.clone(); - - // Select an agent. - let replica_url_array = replicas.lock().unwrap(); - let count = counter.fetch_add(1, Ordering::SeqCst); - let replica_url = replica_url_array - .get(count % replica_url_array.len()) - .unwrap_or_else(|| unreachable!()); - let replica_url = replica_url.clone(); - slog::debug!(logger, "Replica URL: {}", replica_url); - let proxy_url = proxy_url.clone(); - let client = client.clone(); - - async move { - Ok::<_, Infallible>(service_fn(move |request| { - handle_request(HandleRequest { - ip_addr, - request, - replica_url: replica_url.clone(), - client: client.clone(), - proxy_url: proxy_url.clone(), - resolver: resolver.clone(), - validator: validator.clone(), - logger: logger.clone(), - fetch_root_key, - debug, - }) - })) - } - }); - - let address = opts.address; - slog::info!(logger, "Starting server. Listening on http://{}/", address); + let proxy = proxy::setup( + proxy::SetupArgs { + resolver, + validator, + client, + }, + proxy, + )?; let rt = tokio::runtime::Builder::new_multi_thread() .worker_threads(10) .enable_all() .build()?; - rt.block_on(async { - try_join!( - create_metrics_server().map(|v| v.transpose()), // metrics - Server::bind(&address).serve(service), // icx - )?; + rt.block_on( + async move { + let v = try_join!( + metrics.run().in_current_span(), + proxy.run().in_current_span(), + ); + if let Err(v) = v { + error!("Runtime crashed: {v}"); + return Err(v); + } + Ok(()) + } + .in_current_span(), + )?; - Ok(()) - }) + Ok(()) } diff --git a/src/metrics.rs b/src/metrics.rs index 4824d12..37b7aed 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1,9 +1,29 @@ +use std::{future::Future, net::SocketAddr}; + +use axum::{handler::Handler, routing::get, Extension, Router}; +use clap::Args; +use futures::{future::OptionFuture, FutureExt}; +use ic_agent::{ic_types::Principal, Agent}; use opentelemetry::{ + global, metrics::{Counter, Meter}, + sdk::Resource, KeyValue, }; +use opentelemetry_prometheus::PrometheusExporter; +use prometheus::{Encoder, TextEncoder}; + +use crate::http_transport::hyper::{self, Body, Request, Response, StatusCode, Uri}; +use crate::{headers::HeadersData, logging::add_trace_layer, validate::Validate}; -use crate::validate::Validate; +/// The options for metrics +#[derive(Args)] +pub struct Opts { + /// Address to expose Prometheus metrics on + /// Examples: 127.0.0.1:9090, [::1]:9090 + #[clap(long)] + metrics_addr: Option, +} pub struct WithMetrics(pub T, pub MetricParams); @@ -25,16 +45,15 @@ impl MetricParams { impl Validate for WithMetrics { fn validate( &self, - headers_data: &crate::headers::HeadersData, - canister_id: &candid::Principal, - agent: &ic_agent::Agent, - uri: &hyper::Uri, + headers_data: &HeadersData, + canister_id: &Principal, + agent: &Agent, + uri: &Uri, response_body: &[u8], - logger: slog::Logger, ) -> Result<(), String> { let out = self .0 - .validate(headers_data, canister_id, agent, uri, response_body, logger); + .validate(headers_data, canister_id, agent, uri, response_body); let mut status = if out.is_ok() { "ok" } else { "fail" }; if cfg!(feature = "skip_body_verification") { @@ -49,3 +68,61 @@ impl Validate for WithMetrics { out } } + +#[derive(Clone)] +struct HandlerArgs { + exporter: PrometheusExporter, +} + +async fn metrics_handler( + Extension(HandlerArgs { exporter }): Extension, + _: Request, +) -> Response { + let metric_families = exporter.registry().gather(); + + let encoder = TextEncoder::new(); + + let mut metrics_text = Vec::new(); + if encoder.encode(&metric_families, &mut metrics_text).is_err() { + return Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body("Internal Server Error".into()) + .unwrap(); + }; + + Response::builder() + .status(200) + .body(metrics_text.into()) + .unwrap() +} +pub fn setup(opts: Opts) -> (Meter, Runner) { + let exporter = opentelemetry_prometheus::exporter() + .with_resource(Resource::new(vec![KeyValue::new("service", "prober")])) + .init(); + ( + global::meter("icx-proxy"), + Runner { + exporter, + metrics_addr: opts.metrics_addr, + }, + ) +} + +pub struct Runner { + exporter: PrometheusExporter, + metrics_addr: Option, +} + +impl Runner { + pub fn run(self) -> impl Future, hyper::Error>> { + let exporter = self.exporter; + OptionFuture::from(self.metrics_addr.map(|metrics_addr| { + let metrics_handler = metrics_handler.layer(Extension(HandlerArgs { exporter })); + let metrics_router = Router::new().route("/metrics", get(metrics_handler)); + + axum::Server::bind(&metrics_addr) + .serve(add_trace_layer(metrics_router).into_make_service()) + })) + .map(|v| v.transpose()) + } +} diff --git a/src/proxy/agent.rs b/src/proxy/agent.rs new file mode 100644 index 0000000..5b5ddd8 --- /dev/null +++ b/src/proxy/agent.rs @@ -0,0 +1,366 @@ +use std::{ + net::SocketAddr, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; + +use anyhow::bail; +use axum::{extract::ConnectInfo, Extension}; +use futures::StreamExt; +use http_body::{LengthLimitError, Limited}; +use ic_agent::{agent_error::HttpErrorPayload, Agent, AgentError}; +use ic_utils::{ + call::{AsyncCall, SyncCall}, + interfaces::http_request::{ + HeaderField, HttpRequestCanister, HttpRequestStreamingCallbackAny, HttpResponse, + StreamingCallbackHttpResponse, StreamingStrategy, Token, + }, +}; +use tracing::{enabled, instrument, trace, warn, Level}; + +use crate::http_transport::hyper::{ + body, http::header::CONTENT_TYPE, Body, Request, Response, StatusCode, Uri, +}; +use crate::{ + canister_id::Resolver as CanisterIdResolver, + headers::extract_headers_data, + proxy::{HandleError, REQUEST_BODY_SIZE_LIMIT}, + validate::Validate, +}; + +type HttpResponseAny = HttpResponse; + +// Limit the total number of calls to an HTTP Request loop to 1000 for now. +const MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT: usize = 1000; + +// Limit the number of Stream Callbacks buffered +const STREAM_CALLBACK_BUFFFER: usize = 2; + +// The maximum length of a body we should log as tracing. +const MAX_LOG_BODY_SIZE: usize = 100; + +/// https://internetcomputer.org/docs/current/references/ic-interface-spec#reject-codes +struct ReplicaErrorCodes; +impl ReplicaErrorCodes { + const DESTINATION_INVALID: u64 = 3; +} + +pub struct ArgsInner { + pub validator: Box, + pub resolver: Box, + pub counter: AtomicUsize, + pub replicas: Vec<(Agent, Uri)>, + pub debug: bool, + pub fetch_root_key: bool, +} + +pub struct Args { + args: Arc, + current: usize, +} + +impl Clone for Args { + fn clone(&self) -> Self { + let args = self.args.clone(); + Args { + current: args.counter.fetch_add(1, Ordering::Relaxed) % args.replicas.len(), + args, + } + } +} + +impl From for Args { + fn from(args: ArgsInner) -> Self { + Args { + args: Arc::new(args), + current: 0, + } + } +} +impl Args { + fn replica(&self) -> (&Agent, &Uri) { + let v = &self.args.replicas[self.current]; + (&v.0, &v.1) + } +} + +#[instrument(level = "info", skip_all, fields(addr = display(addr), replica = display(args.replica().1)))] +pub async fn handler( + Extension(args): Extension, + ConnectInfo(addr): ConnectInfo, + request: Request, +) -> Response { + let agent = args.replica().0; + let args = &args.args; + async { + if args.fetch_root_key && agent.fetch_root_key().await.is_err() { + unable_to_fetch_root_key() + } else { + process_request_inner( + request, + agent, + args.resolver.as_ref(), + args.validator.as_ref(), + ) + .await + } + } + .await + .handle_error(args.debug) +} + +fn unable_to_fetch_root_key() -> Result, anyhow::Error> { + Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body("Unable to fetch root key".into())?) +} + +async fn process_request_inner( + request: Request, + agent: &Agent, + resolver: &dyn CanisterIdResolver, + validator: &dyn Validate, +) -> Result, anyhow::Error> { + let (parts, body) = request.into_parts(); + + let canister_id = match resolver.resolve(&parts) { + None => { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Could not find a canister id to forward to.".into()) + .unwrap()) + } + Some(x) => x, + }; + + trace!("<< {} {} {:?}", parts.method, parts.uri, parts.version); + + let method = parts.method; + let uri = parts.uri.to_string(); + let headers = parts + .headers + .iter() + .filter_map(|(name, value)| { + Some(HeaderField( + name.as_str().into(), + value.to_str().ok()?.into(), + )) + }) + .inspect(|HeaderField(name, value)| { + trace!("<< {}: {}", name, value); + }) + .collect::>(); + + // Limit request body size + let body = Limited::new(body, REQUEST_BODY_SIZE_LIMIT); + let entire_body = match body::to_bytes(body).await { + Ok(data) => data, + Err(err) => { + if err.downcast_ref::().is_some() { + return Ok(Response::builder() + .status(StatusCode::PAYLOAD_TOO_LARGE) + .body(Body::from("Request size exceeds limit"))?); + } + bail!("Failed to read body: {err}"); + } + } + .to_vec(); + + trace!("<<"); + if enabled!(Level::TRACE) { + let body = String::from_utf8_lossy( + &entire_body[0..usize::min(entire_body.len(), MAX_LOG_BODY_SIZE)], + ); + trace!( + "<< \"{}\"{}", + &body.escape_default(), + if body.len() > MAX_LOG_BODY_SIZE { + format!("... {} bytes total", body.len()) + } else { + String::new() + } + ); + } + + let canister = HttpRequestCanister::create(agent, canister_id); + let query_result = canister + .http_request_custom( + method.as_str(), + uri.as_str(), + headers.iter().cloned(), + &entire_body, + ) + .call() + .await; + + fn handle_result( + result: Result<(HttpResponseAny,), AgentError>, + ) -> Result, anyhow::Error>> { + // If the result is a Replica error, returns the 500 code and message. There is no information + // leak here because a user could use `dfx` to get the same reply. + match result { + Ok((http_response,)) => Ok(http_response), + + Err(AgentError::ReplicaError { + reject_code: ReplicaErrorCodes::DESTINATION_INVALID, + reject_message, + }) => Err(Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .body(reject_message.into()) + .unwrap())), + + Err(AgentError::ReplicaError { + reject_code, + reject_message, + }) => Err(Ok(Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(format!(r#"Replica Error ({}): "{}""#, reject_code, reject_message).into()) + .unwrap())), + + Err(AgentError::HttpError(HttpErrorPayload { + status: 451, + content_type, + content, + })) => Err(Ok(content_type + .into_iter() + .fold(Response::builder(), |r, c| r.header(CONTENT_TYPE, c)) + .status(451) + .body(content.into()) + .unwrap())), + + Err(AgentError::ResponseSizeExceededLimit()) => Err(Ok(Response::builder() + .status(StatusCode::INSUFFICIENT_STORAGE) + .body("Response size exceeds limit".into()) + .unwrap())), + + Err(e) => Err(Err(e.into())), + } + } + + let http_response = match handle_result(query_result) { + Ok(http_response) => http_response, + Err(response_or_error) => return response_or_error, + }; + + let http_response = if http_response.upgrade == Some(true) { + let waiter = garcon::Delay::builder() + .throttle(Duration::from_millis(500)) + .timeout(Duration::from_secs(15)) + .build(); + let update_result = canister + .http_request_update_custom( + method.as_str(), + uri.as_str(), + headers.iter().cloned(), + &entire_body, + ) + .call_and_wait(waiter) + .await; + match handle_result(update_result) { + Ok(http_response) => http_response, + Err(response_or_error) => return response_or_error, + } + } else { + http_response + }; + + let mut builder = Response::builder().status(StatusCode::from_u16(http_response.status_code)?); + for HeaderField(name, value) in &http_response.headers { + builder = builder.header(name.as_ref(), value.as_ref()); + } + + let headers_data = extract_headers_data(&http_response.headers); + let body = if enabled!(Level::TRACE) { + Some(http_response.body.clone()) + } else { + None + }; + let is_streaming = http_response.streaming_strategy.is_some(); + let response = if let Some(streaming_strategy) = http_response.streaming_strategy { + let body = http_response.body; + let body = futures::stream::once(async move { Ok(body) }); + let body = match streaming_strategy { + StreamingStrategy::Callback(callback) => body::Body::wrap_stream( + body.chain(futures::stream::try_unfold( + (agent.clone(), callback.callback.0, Some(callback.token)), + move |(agent, callback, callback_token)| async move { + let callback_token = match callback_token { + Some(callback_token) => callback_token, + None => return Ok(None), + }; + + let canister = HttpRequestCanister::create(&agent, callback.principal); + match canister + .http_request_stream_callback(&callback.method, callback_token) + .call() + .await + { + Ok((StreamingCallbackHttpResponse { body, token },)) => { + Ok(Some((body, (agent, callback, token)))) + } + Err(e) => { + warn!("Error happened during streaming: {}", e); + Err(e) + } + } + }, + )) + .take(MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT) + .map(|x| async move { x }) + .buffered(STREAM_CALLBACK_BUFFFER), + ), + }; + + builder.body(body)? + } else { + let body_valid = validator.validate( + &headers_data, + &canister_id, + agent, + &parts.uri, + &http_response.body, + ); + if body_valid.is_err() { + return Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(body_valid.unwrap_err().into()) + .unwrap()); + } + builder.body(http_response.body.into())? + }; + + if enabled!(Level::TRACE) { + trace!( + ">> {:?} {} {}", + &response.version(), + response.status().as_u16(), + response.status().to_string() + ); + + for (name, value) in response.headers() { + let value = String::from_utf8_lossy(value.as_bytes()); + trace!(">> {}: {}", name, value); + } + + let body = body.unwrap_or_else(|| b"... streaming ...".to_vec()); + + trace!(">>"); + trace!( + ">> \"{}\"{}", + String::from_utf8_lossy(&body[..usize::min(MAX_LOG_BODY_SIZE, body.len())]) + .escape_default(), + if is_streaming { + "... streaming".to_string() + } else if body.len() > MAX_LOG_BODY_SIZE { + format!("... {} bytes total", body.len()) + } else { + String::new() + } + ); + } + + Ok(response) +} diff --git a/src/proxy/forward.rs b/src/proxy/forward.rs new file mode 100644 index 0000000..ea7590f --- /dev/null +++ b/src/proxy/forward.rs @@ -0,0 +1,128 @@ +use std::{ + net::{IpAddr, SocketAddr}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use axum::{extract::ConnectInfo, Extension}; +use tracing::{info, instrument}; + +use crate::http_transport::hyper::{ + header::{Entry, HeaderValue}, + http::uri::Parts, + HeaderMap, Request, Response, Uri, +}; +use crate::{ + http_client::{Body, HyperService}, + proxy::HandleError, +}; + +pub struct ArgsInner { + pub debug: bool, + pub counter: AtomicUsize, + pub proxy_urls: Vec, + pub client: C, +} +pub struct Args { + args: Arc>, + current: usize, +} +impl From> for Args { + fn from(args: ArgsInner) -> Self { + Args { + args: Arc::new(args), + current: 0, + } + } +} +impl Clone for Args { + fn clone(&self) -> Self { + let args = self.args.clone(); + Args { + current: args.counter.fetch_add(1, Ordering::Relaxed) % args.proxy_urls.len(), + args, + } + } +} +impl Args { + fn proxy_url(&self) -> &Uri { + &self.args.proxy_urls[self.current] + } +} + +#[instrument(level = "info", skip_all, fields(addr = display(addr)))] +pub async fn handler>( + Extension(args): Extension>>, + ConnectInfo(addr): ConnectInfo, + request: Request, +) -> Response { + let proxy_url = args.proxy_url(); + let args = &args.args; + + async { + info!("forwarding"); + let proxied_request = create_proxied_request(&addr.ip(), proxy_url.clone(), request)?; + let response = args.client.clone().call(proxied_request).await?; + Ok(response) + } + .await + .handle_error(args.debug) + .map(|b| b.into()) +} + +fn create_proxied_request( + client_ip: &IpAddr, + proxy_url: Uri, + mut request: Request, +) -> Result, anyhow::Error> { + *request.headers_mut() = remove_hop_headers(request.headers()); + *request.uri_mut() = forward_uri(proxy_url, &request)?; + + let x_forwarded_for_header_name = "x-forwarded-for"; + + // Add forwarding information in the headers + match request.headers_mut().entry(x_forwarded_for_header_name) { + Entry::Vacant(entry) => { + entry.insert(client_ip.to_string().parse()?); + } + + Entry::Occupied(mut entry) => { + let addr = format!("{}, {}", entry.get().to_str()?, client_ip); + entry.insert(addr.parse()?); + } + } + + Ok(request) +} + +fn is_hop_header(name: &str) -> bool { + name.eq_ignore_ascii_case("connection") + || name.eq_ignore_ascii_case("keep-alive") + || name.eq_ignore_ascii_case("proxy-authenticate") + || name.eq_ignore_ascii_case("proxy-authorization") + || name.eq_ignore_ascii_case("te") + || name.eq_ignore_ascii_case("trailers") + || name.eq_ignore_ascii_case("transfer-encoding") + || name.eq_ignore_ascii_case("upgrade") +} + +/// Returns a clone of the headers without the [hop-by-hop headers]. +/// +/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +fn remove_hop_headers(headers: &HeaderMap) -> HeaderMap { + let mut result = HeaderMap::new(); + for (k, v) in headers.iter() { + if !is_hop_header(k.as_str()) { + result.insert(k.clone(), v.clone()); + } + } + result +} + +fn forward_uri(proxy_url: Uri, req: &Request) -> Result { + let mut parts = Parts::from(proxy_url); + parts.path_and_query = req.uri().path_and_query().cloned(); + Ok(Uri::from_parts(parts)?) +} diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs new file mode 100644 index 0000000..317e70d --- /dev/null +++ b/src/proxy/mod.rs @@ -0,0 +1,185 @@ +use std::{ + future::Future, + net::SocketAddr, + sync::{atomic::AtomicUsize, Arc}, +}; + +use anyhow::{bail, Context}; +use axum::{handler::Handler, routing::any, Extension, Router}; +use clap::Args; +use ic_agent::Agent; +use tracing::{error, info}; + +use crate::http_transport::{ + hyper::{self, Response, StatusCode, Uri}, + HyperReplicaV2Transport, +}; +use crate::{ + canister_id::Resolver as CanisterIdResolver, + http_client::{Body, HyperService}, + logging::add_trace_layer, + validate::Validate, +}; + +const KB: usize = 1024; +const MB: usize = 1024 * KB; + +const REQUEST_BODY_SIZE_LIMIT: usize = 10 * MB; +const RESPONSE_BODY_SIZE_LIMIT: usize = 10 * MB; + +/// The options for the proxy server +#[derive(Args)] +pub struct Opts { + /// The address to bind to. + #[clap(long, default_value = "127.0.0.1:3000")] + address: SocketAddr, + + /// A replica to use as backend. Locally, this should be a local instance or the + /// boundary node. Multiple replicas can be passed and they'll be used round-robin. + #[clap(long, default_value = "http://localhost:8000/")] + replica: Vec, + + /// An address to forward any requests from /_/ + #[clap(long)] + proxy: Option, + + /// Whether or not this is run in a debug context (e.g. errors returned in responses + /// should show full stack and error details). + #[clap(long)] + debug: bool, + + /// Whether or not to fetch the root key from the replica back end. Do not use this when + /// talking to the Internet Computer blockchain mainnet as it is unsecure. + #[clap(long)] + fetch_root_key: bool, +} + +mod agent; +mod forward; + +use agent::{handler as agent_handler, Args as AgentArgs, ArgsInner as AgentArgsInner}; +use forward::{handler as forward_handler, Args as ForwardArgs, ArgsInner as ForwardArgsInner}; + +trait HandleError { + type B; + fn handle_error(self, debug: bool) -> Response; +} +impl HandleError for Result, anyhow::Error> +where + String: Into, + &'static str: Into, +{ + type B = B; + fn handle_error(self, debug: bool) -> Response { + match self { + Err(err) => { + error!("Internal Error during request:\n{}", err); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(if debug { + format!("Internal Error: {:?}", err).into() + } else { + "Internal Server Error".into() + }) + .unwrap() + } + Ok(v) => v, + } + } +} + +pub struct SetupArgs { + pub validator: V, + pub resolver: R, + pub client: C, +} + +pub fn setup>( + args: SetupArgs, + opts: Opts, +) -> Result { + let client = args.client; + + let agent_args = Extension(AgentArgs::from(AgentArgsInner { + validator: Box::new(args.validator), + resolver: Box::new(args.resolver), + counter: AtomicUsize::new(0), + replicas: opts + .replica + .iter() + .map(|replica_url| { + let transport = HyperReplicaV2Transport::create_with_service( + replica_url.clone(), + client.clone(), + ) + .context("failed to create transport")? + .with_max_response_body_size(RESPONSE_BODY_SIZE_LIMIT); + + let agent = Agent::builder() + .with_transport(transport) + .build() + .context("Could not create agent...")?; + Ok((agent, replica_url.clone())) + }) + .collect::>()?, + debug: opts.debug, + fetch_root_key: opts.fetch_root_key, + })); + + let agent_service = agent_handler.layer(agent_args).into_service(); + + let router = Router::new(); + // Setup `/_/` proxy for dfx if requested + let router = if let Some(proxy_url) = opts.proxy { + info!("Setting up `/_/` proxy to `{proxy_url}`"); + if proxy_url.scheme().is_none() { + bail!("No schema found on `proxy_url`"); + } + let forward_args = Extension(Arc::new(ForwardArgs::from(ForwardArgsInner { + client: client.clone(), + counter: AtomicUsize::new(0), + proxy_urls: vec![proxy_url], + debug: opts.debug, + }))); + let forward_to_replica = Extension(Arc::new(ForwardArgs::from(ForwardArgsInner { + client, + counter: AtomicUsize::new(0), + proxy_urls: opts.replica, + debug: opts.debug, + }))); + let forward_service = any(forward_handler::.layer(forward_args)); + let forward_to_replica_service = any(forward_handler::.layer(forward_to_replica)); + router + // Exclude `/_/raw` from the proxy + .route("/_/raw", agent_service.clone()) + .route("/_/raw/*path", agent_service.clone()) + // Proxy `/api` to the replica + .route("/api", forward_to_replica_service.clone()) + .route("/api/*path", forward_to_replica_service) + // Proxy everything else under `/_` to the `proxy_url` + .route("/_", forward_service.clone()) + .route("/_/", forward_service.clone()) + .route("/_/:not_raw", forward_service.clone()) + .route("/_/:not_raw/*path", forward_service) + } else { + router + }; + Ok(Runner { + router: add_trace_layer(router.fallback(agent_service)), + address: opts.address, + }) +} + +pub struct Runner { + router: Router, + address: SocketAddr, +} +impl Runner { + pub fn run(self) -> impl Future> { + info!("Starting server. Listening on http://{}/", self.address); + axum::Server::bind(&self.address).serve( + self.router + .into_make_service_with_connect_info::(), + ) + } +} diff --git a/src/validate.rs b/src/validate.rs index 9be20ec..1fa7e0d 100644 --- a/src/validate.rs +++ b/src/validate.rs @@ -1,14 +1,15 @@ use std::io::Read; -use candid::Principal; use flate2::read::{DeflateDecoder, GzDecoder}; -use hyper::Uri; use ic_agent::{ - hash_tree::LookupResult, ic_types::HashTree, lookup_value, Agent, AgentError, Certificate, + hash_tree::LookupResult, ic_types::HashTree, ic_types::Principal, lookup_value, Agent, + AgentError, Certificate, }; use sha2::{Digest, Sha256}; +use tracing::trace; -use crate::HeadersData; +use crate::headers::HeadersData; +use crate::http_transport::hyper::Uri; // The limit of a buffer we should decompress ~10mb. const MAX_CHUNK_SIZE_TO_DECOMPRESS: usize = 1024; @@ -22,7 +23,6 @@ pub trait Validate: Sync + Send { agent: &Agent, uri: &Uri, response_body: &[u8], - logger: slog::Logger, ) -> Result<(), String>; } @@ -42,7 +42,6 @@ impl Validate for Validator { agent: &Agent, uri: &Uri, response_body: &[u8], - logger: slog::Logger, ) -> Result<(), String> { let body_sha = if let Some(body_sha) = decode_body_to_sha256(response_body, headers_data.encoding.clone()) @@ -62,7 +61,6 @@ impl Validate for Validator { agent, uri, &body_sha, - logger.clone(), ) { Ok(true) => Ok(()), Ok(false) => Err("Body does not pass verification".to_string()), @@ -130,7 +128,6 @@ fn validate_body( agent: &Agent, uri: &Uri, body_sha: &[u8; 32], - logger: slog::Logger, ) -> anyhow::Result { let cert: Certificate = serde_cbor::from_slice(certificates.certificate).map_err(AgentError::InvalidCborData)?; @@ -138,7 +135,7 @@ fn validate_body( serde_cbor::from_slice(certificates.tree).map_err(AgentError::InvalidCborData)?; if let Err(e) = agent.verify(&cert, *canister_id, false) { - slog::trace!(logger, ">> certificate failed verification: {}", e); + trace!(">> certificate failed verification: {}", e); return Ok(false); } @@ -150,8 +147,7 @@ fn validate_body( let witness = match lookup_value(&cert, certified_data_path) { Ok(witness) => witness, Err(e) => { - slog::trace!( - logger, + trace!( ">> Could not find certified data for this canister in the certificate: {}", e ); @@ -161,8 +157,7 @@ fn validate_body( let digest = tree.digest(); if witness != digest { - slog::trace!( - logger, + trace!( ">> witness ({}) did not match digest ({})", hex::encode(witness), hex::encode(digest) @@ -177,8 +172,7 @@ fn validate_body( _ => match tree.lookup_path(&["http_assets".into(), "/index.html".into()]) { LookupResult::Found(v) => v, _ => { - slog::trace!( - logger, + trace!( ">> Invalid Tree in the header. Does not contain path {:?}", path ); @@ -192,13 +186,12 @@ fn validate_body( #[cfg(test)] mod tests { - use std::str::FromStr; - - use candid::Principal; - use hyper::Uri; - use ic_agent::{agent::http_transport::ReqwestHttpReplicaV2Transport, Agent}; - use slog::o; + use ic_agent::{ic_types::Principal, Agent}; + use crate::http_transport::{ + hyper::{Body, Uri}, + HyperReplicaV2Transport, + }; use crate::{ headers::HeadersData, validate::{Validate, Validator}, @@ -213,15 +206,14 @@ mod tests { }; let canister_id = Principal::from_text("wwc2m-2qaaa-aaaac-qaaaa-cai").unwrap(); - let transport = ReqwestHttpReplicaV2Transport::create("http://www.example.com").unwrap(); + let uri = Uri::from_static("http://www.example.com"); + let transport = HyperReplicaV2Transport::::create(uri.clone()).unwrap(); let agent = Agent::builder().with_transport(transport).build().unwrap(); - let uri = Uri::from_str("http://www.example.com").unwrap(); let body = vec![]; - let logger = slog::Logger::root(slog::Discard, o!()); let validator = Validator::new(); - let out = validator.validate(&headers, &canister_id, &agent, &uri, &body, logger); + let out = validator.validate(&headers, &canister_id, &agent, &uri, &body); assert_eq!(out, Ok(())); }