From fce1e9a8e267169b9325556d7148a96c67b93117 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 5 Dec 2023 14:01:06 +0000 Subject: [PATCH] Post-split work --- Cargo.toml | 82 +- edge-captive/Cargo.toml | 5 +- edge-captive/src/io.rs | 153 +++ edge-captive/src/lib.rs | 10 +- edge-captive/src/server.rs | 150 --- edge-dhcp/Cargo.toml | 19 +- edge-dhcp/src/asynch.rs | 569 --------- edge-dhcp/src/client.rs | 105 ++ edge-dhcp/src/io.rs | 20 + edge-dhcp/src/io/client.rs | 276 +++++ edge-dhcp/src/io/server.rs | 101 ++ edge-dhcp/src/lib.rs | 960 +-------------- edge-dhcp/src/server.rs | 258 ++++ edge-http/Cargo.toml | 28 +- edge-http/src/{asynch.rs => io.rs} | 627 +--------- edge-http/src/{ => io}/client.rs | 51 +- edge-http/src/{ => io}/server.rs | 11 +- edge-http/src/lib.rs | 1067 +++-------------- edge-mdns/Cargo.toml | 7 +- edge-mqtt/Cargo.toml | 10 +- edge-mqtt/src/io.rs | 181 +++ edge-mqtt/src/lib.rs | 183 +-- edge-raw/Cargo.toml | 14 + edge-raw/src/bytes.rs | 103 ++ edge-raw/src/io.rs | 192 +++ edge-raw/src/ip.rs | 214 ++++ edge-raw/src/lib.rs | 102 ++ edge-raw/src/udp.rs | 206 ++++ edge-std-nal-async/Cargo.toml | 17 + .../nal.rs => edge-std-nal-async/src/lib.rs | 118 +- edge-tcp/Cargo.toml | 19 - edge-tcp/README.md | 1 - edge-tcp/src/lib.rs | 204 ---- edge-ws/Cargo.toml | 11 +- edge-ws/src/io.rs | 210 ++++ edge-ws/src/lib.rs | 462 +------ embedded-nal-async-xtra/Cargo.toml | 11 + embedded-nal-async-xtra/src/lib.rs | 12 + embedded-nal-async-xtra/src/stack.rs | 5 + embedded-nal-async-xtra/src/stack/raw.rs | 61 + embedded-nal-async-xtra/src/stack/tcp.rs | 106 ++ src/asynch.rs | 90 -- src/lib.rs | 20 +- src/std.rs | 17 - 44 files changed, 2706 insertions(+), 4362 deletions(-) create mode 100644 edge-captive/src/io.rs delete mode 100644 edge-captive/src/server.rs delete mode 100644 edge-dhcp/src/asynch.rs create mode 100644 edge-dhcp/src/client.rs create mode 100644 edge-dhcp/src/io.rs create mode 100644 edge-dhcp/src/io/client.rs create mode 100644 edge-dhcp/src/io/server.rs create mode 100644 edge-dhcp/src/server.rs rename edge-http/src/{asynch.rs => io.rs} (54%) rename edge-http/src/{ => io}/client.rs (89%) rename edge-http/src/{ => io}/server.rs (98%) create mode 100644 edge-mqtt/src/io.rs create mode 100644 edge-raw/Cargo.toml create mode 100644 edge-raw/src/bytes.rs create mode 100644 edge-raw/src/io.rs create mode 100644 edge-raw/src/ip.rs create mode 100644 edge-raw/src/lib.rs create mode 100644 edge-raw/src/udp.rs create mode 100644 edge-std-nal-async/Cargo.toml rename src/std/nal.rs => edge-std-nal-async/src/lib.rs (90%) delete mode 100644 edge-tcp/Cargo.toml delete mode 100644 edge-tcp/README.md delete mode 100644 edge-tcp/src/lib.rs create mode 100644 edge-ws/src/io.rs create mode 100644 embedded-nal-async-xtra/Cargo.toml create mode 100644 embedded-nal-async-xtra/src/lib.rs create mode 100644 embedded-nal-async-xtra/src/stack.rs create mode 100644 embedded-nal-async-xtra/src/stack/raw.rs create mode 100644 embedded-nal-async-xtra/src/stack/tcp.rs delete mode 100644 src/asynch.rs delete mode 100644 src/std.rs diff --git a/Cargo.toml b/Cargo.toml index 62efed5..6f5d729 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "edge-net" -version = "0.5.0" +version = "0.6.0" authors = ["Ivan Markov "] edition = "2021" categories = ["embedded", "hardware-support"] @@ -12,46 +12,19 @@ readme = "README.md" rust-version = "1.71" [features] -std = [ - "alloc", - "embassy-sync/std", - "embedded-svc?/std", - "edge-http?/std", - "edge-captive?/std", - "dep:edge-tcp", - - "futures-lite/std", - "dep:async-io", - "dep:embedded-io-async", - "dep:embedded-nal-async", - "dep:libc" -] -alloc = ["edge-http?/alloc", "embedded-svc?/alloc"] -nightly = ["dep:edge-http", "embedded-svc?/nightly"] +std = ["edge-http/std", "edge-captive/std", "edge-mqtt/std", "edge-std-nal-async/std"] +nightly = ["edge-dhcp/nightly", "edge-http/nightly", "edge-mqtt/nightly", "edge-raw/nightly", "edge-ws/nightly", "edge-std-nal-async/nightly"] +# TODO: embedded-svc [dependencies] -embedded-svc = { version = "0.26", default-features = false, optional = true, features = ["embedded-io-async"] } - -futures-lite = { version = "1", default-features = false, optional = true } -async-io = { version = "2", default-features = false, optional = true } -embedded-io-async = { workspace = true, optional = true } -embedded-nal-async = { workspace = true, optional = true } -libc = { version = "0.2", default-features = false, optional = true } - -embassy-sync.workspace = true -log.workspace = true -heapless.workspace = true -no-std-net.workspace = true - -edge-captive = { workspace = true, optional = true } -edge-dhcp = { workspace = true, optional = true } -edge-http = { workspace = true, optional = true } -edge-mdns = { workspace = true, optional = true } -edge-mqtt = { workspace = true, optional = true } -edge-ws = { workspace = true, optional = true } - -# edge-tcp is an exception that only contains trait definitions -edge-tcp = { workspace = true, optional = true } +edge-captive = { workspace = true } +edge-dhcp = { workspace = true } +edge-http = { workspace = true } +edge-mdns = { workspace = true } +edge-mqtt = { workspace = true } +edge-raw = { workspace = true } +edge-ws = { workspace = true } +edge-std-nal-async = { workspace = true } [dev-dependencies] anyhow = "1" @@ -81,24 +54,29 @@ members = [ "edge-http", "edge-mdns", "edge-mqtt", - "edge-tcp", - "edge-ws" + "edge-raw", + "edge-ws", + "edge-std-nal-async", + "embedded-nal-async-xtra" ] [workspace.dependencies] -embassy-futures = "0.1" -embassy-sync = "0.3" -embedded-io = { version = "0.6", default-features = false } +embassy-futures = { version = "0.1", default-features = false } +embassy-sync = { version = "0.3", default-features = false } +embassy-time = { version = "0.1", default-features = false } embedded-io-async = { version = "0.6", default-features = false } -embedded-nal-async = "0.6" +embedded-nal-async = { version = "0.6", default-features = false } +embedded-svc = { version = "0.26", default-features = false, features = ["embedded-io-async"] } log = { version = "0.4", default-features = false } heapless = { version = "0.7", default-features = false } no-std-net = { version = "0.6", default-features = false } -edge-captive = { version = "0.1.0", path = "edge-captive" } -edge-dhcp = { version = "0.1.0", path = "edge-dhcp" } -edge-http = { version = "0.1.0", path = "edge-http" } -edge-mdns = { version = "0.1.0", path = "edge-mdns" } -edge-mqtt = { version = "0.1.0", path = "edge-mqtt" } -edge-tcp = { version = "0.1.0", path = "edge-tcp" } -edge-ws = { version = "0.1.0", path = "edge-ws" } +edge-captive = { version = "0.1.0", path = "edge-captive", default-features = false } +edge-dhcp = { version = "0.1.0", path = "edge-dhcp", default-features = false } +edge-http = { version = "0.1.0", path = "edge-http", default-features = false } +edge-mdns = { version = "0.1.0", path = "edge-mdns", default-features = false } +edge-mqtt = { version = "0.1.0", path = "edge-mqtt", default-features = false } +edge-raw = { version = "0.1.0", path = "edge-raw", default-features = false } +edge-ws = { version = "0.1.0", path = "edge-ws", default-features = false } +edge-std-nal-async = { version = "0.1.0", path = "edge-std-nal-async", default-features = false } +embedded-nal-async-xtra = { version = "0.1.0", path = "embedded-nal-async-xtra", default-features = false } diff --git a/edge-captive/Cargo.toml b/edge-captive/Cargo.toml index c613d6b..4965d1f 100644 --- a/edge-captive/Cargo.toml +++ b/edge-captive/Cargo.toml @@ -3,12 +3,9 @@ name = "edge-captive" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [features] std = ["domain/std"] [dependencies] -log.workspace = true - +log = { workspace = true } domain = { version = "0.7", default-features = false } diff --git a/edge-captive/src/io.rs b/edge-captive/src/io.rs new file mode 100644 index 0000000..a15e63d --- /dev/null +++ b/edge-captive/src/io.rs @@ -0,0 +1,153 @@ +#[cfg(feature = "std")] +pub mod server { + use std::{ + io, mem, + net::{Ipv4Addr, SocketAddrV4, UdpSocket}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + thread::{self, JoinHandle}, + time::Duration, + }; + + use log::*; + + #[derive(Clone, Debug)] + pub struct DnsConf { + pub bind_ip: Ipv4Addr, + pub bind_port: u16, + pub ip: Ipv4Addr, + pub ttl: Duration, + } + + impl DnsConf { + pub fn new(ip: Ipv4Addr) -> Self { + Self { + bind_ip: Ipv4Addr::new(0, 0, 0, 0), + bind_port: 53, + ip, + ttl: Duration::from_secs(60), + } + } + } + + #[derive(Debug)] + pub enum Status { + Stopped, + Started, + Error(io::Error), + } + + pub struct DnsServer { + conf: DnsConf, + status: Status, + running: Arc, + handle: Option>>, + } + + impl DnsServer { + pub fn new(conf: DnsConf) -> Self { + Self { + conf, + status: Status::Stopped, + running: Arc::new(AtomicBool::new(false)), + handle: None, + } + } + + pub fn get_status(&mut self) -> &Status { + self.cleanup(); + &self.status + } + + pub fn start(&mut self) -> Result<(), io::Error> { + if matches!(self.get_status(), Status::Started) { + return Ok(()); + } + let socket_address = SocketAddrV4::new(self.conf.bind_ip, self.conf.bind_port); + let running = self.running.clone(); + let ip = self.conf.ip; + let ttl = self.conf.ttl; + + self.running.store(true, Ordering::Relaxed); + self.handle = Some( + thread::Builder::new() + // default stack size is not enough + // 9000 was found via trial and error + .stack_size(9000) + .spawn(move || { + // Socket is not movable across thread bounds + // Otherwise we run into an assertion error here: https://github.com/espressif/esp-idf/blob/master/components/lwip/port/esp32/freertos/sys_arch.c#L103 + let socket = UdpSocket::bind(socket_address)?; + socket.set_read_timeout(Some(Duration::from_secs(1)))?; + let result = Self::run(&running, ip, ttl, socket); + + running.store(false, Ordering::Relaxed); + + result + }) + .unwrap(), + ); + + Ok(()) + } + + pub fn stop(&mut self) -> Result<(), io::Error> { + if matches!(self.get_status(), Status::Stopped) { + return Ok(()); + } + + self.running.store(false, Ordering::Relaxed); + self.cleanup(); + + let mut status = Status::Stopped; + mem::swap(&mut self.status, &mut status); + + match status { + Status::Error(e) => Err(e), + _ => Ok(()), + } + } + + fn cleanup(&mut self) { + if !self.running.load(Ordering::Relaxed) && self.handle.is_some() { + self.status = match mem::take(&mut self.handle).unwrap().join().unwrap() { + Ok(_) => Status::Stopped, + Err(e) => Status::Error(e), + }; + } + } + + fn run( + running: &AtomicBool, + ip: Ipv4Addr, + ttl: Duration, + socket: UdpSocket, + ) -> Result<(), io::Error> { + while running.load(Ordering::Relaxed) { + let mut request_arr = [0_u8; 512]; + debug!("Waiting for data"); + let (request_len, source_addr) = match socket.recv_from(&mut request_arr) { + Ok(value) => value, + Err(err) => match err.kind() { + std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut => continue, + _ => return Err(err), + }, + }; + + let request = &request_arr[..request_len]; + + debug!("Received {} bytes from {}", request.len(), source_addr); + let response = crate::process_dns_request(request, &ip.octets(), ttl) + .map_err(|_| io::ErrorKind::Other)?; + + socket.send_to(response.as_ref(), source_addr)?; + + debug!("Sent {} bytes to {}", response.as_ref().len(), source_addr); + } + + Ok(()) + } + } +} diff --git a/edge-captive/src/lib.rs b/edge-captive/src/lib.rs index 2127be0..9df73b6 100644 --- a/edge-captive/src/lib.rs +++ b/edge-captive/src/lib.rs @@ -14,11 +14,7 @@ use domain::{ rdata::A, }; -#[cfg(feature = "std")] -mod server; - -#[cfg(feature = "std")] -pub use server::*; +pub mod io; #[derive(Debug)] pub struct InnerError(T); @@ -54,11 +50,10 @@ impl From for DnsError { impl std::error::Error for DnsError {} pub fn process_dns_request( - request: impl AsRef<[u8]>, + request: &[u8], ip: &[u8; 4], ttl: Duration, ) -> Result, DnsError> { - let request = request.as_ref(); let response = Octets512::new(); let message = domain::base::Message::from_octets(request)?; @@ -106,5 +101,6 @@ pub fn process_dns_request( responseb.finish() }; + Ok(response) } diff --git a/edge-captive/src/server.rs b/edge-captive/src/server.rs deleted file mode 100644 index c56b274..0000000 --- a/edge-captive/src/server.rs +++ /dev/null @@ -1,150 +0,0 @@ -use std::{ - io, mem, - net::{Ipv4Addr, SocketAddrV4, UdpSocket}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - thread::{self, JoinHandle}, - time::Duration, -}; - -use log::*; - -#[derive(Clone, Debug)] -pub struct DnsConf { - pub bind_ip: Ipv4Addr, - pub bind_port: u16, - pub ip: Ipv4Addr, - pub ttl: Duration, -} - -impl DnsConf { - pub fn new(ip: Ipv4Addr) -> Self { - Self { - bind_ip: Ipv4Addr::new(0, 0, 0, 0), - bind_port: 53, - ip, - ttl: Duration::from_secs(60), - } - } -} - -#[derive(Debug)] -pub enum Status { - Stopped, - Started, - Error(io::Error), -} - -pub struct DnsServer { - conf: DnsConf, - status: Status, - running: Arc, - handle: Option>>, -} - -impl DnsServer { - pub fn new(conf: DnsConf) -> Self { - Self { - conf, - status: Status::Stopped, - running: Arc::new(AtomicBool::new(false)), - handle: None, - } - } - - pub fn get_status(&mut self) -> &Status { - self.cleanup(); - &self.status - } - - pub fn start(&mut self) -> Result<(), io::Error> { - if matches!(self.get_status(), Status::Started) { - return Ok(()); - } - let socket_address = SocketAddrV4::new(self.conf.bind_ip, self.conf.bind_port); - let running = self.running.clone(); - let ip = self.conf.ip; - let ttl = self.conf.ttl; - - self.running.store(true, Ordering::Relaxed); - self.handle = Some( - thread::Builder::new() - // default stack size is not enough - // 9000 was found via trial and error - .stack_size(9000) - .spawn(move || { - // Socket is not movable across thread bounds - // Otherwise we run into an assertion error here: https://github.com/espressif/esp-idf/blob/master/components/lwip/port/esp32/freertos/sys_arch.c#L103 - let socket = UdpSocket::bind(socket_address)?; - socket.set_read_timeout(Some(Duration::from_secs(1)))?; - let result = Self::run(&running, ip, ttl, socket); - - running.store(false, Ordering::Relaxed); - - result - }) - .unwrap(), - ); - - Ok(()) - } - - pub fn stop(&mut self) -> Result<(), io::Error> { - if matches!(self.get_status(), Status::Stopped) { - return Ok(()); - } - - self.running.store(false, Ordering::Relaxed); - self.cleanup(); - - let mut status = Status::Stopped; - mem::swap(&mut self.status, &mut status); - - match status { - Status::Error(e) => Err(e), - _ => Ok(()), - } - } - - fn cleanup(&mut self) { - if !self.running.load(Ordering::Relaxed) && self.handle.is_some() { - self.status = match mem::take(&mut self.handle).unwrap().join().unwrap() { - Ok(_) => Status::Stopped, - Err(e) => Status::Error(e), - }; - } - } - - fn run( - running: &AtomicBool, - ip: Ipv4Addr, - ttl: Duration, - socket: UdpSocket, - ) -> Result<(), io::Error> { - while running.load(Ordering::Relaxed) { - let mut request_arr = [0_u8; 512]; - debug!("Waiting for data"); - let (request_len, source_addr) = match socket.recv_from(&mut request_arr) { - Ok(value) => value, - Err(err) => match err.kind() { - std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut => continue, - _ => return Err(err), - }, - }; - - let request = &request_arr[..request_len]; - - debug!("Received {} bytes from {}", request.len(), source_addr); - let response = super::process_dns_request(request, &ip.octets(), ttl) - .map_err(|_| io::ErrorKind::Other)?; - - socket.send_to(response.as_ref(), source_addr)?; - - debug!("Sent {} bytes to {}", response.as_ref().len(), source_addr); - } - - Ok(()) - } -} diff --git a/edge-dhcp/Cargo.toml b/edge-dhcp/Cargo.toml index 5904e17..20b275f 100644 --- a/edge-dhcp/Cargo.toml +++ b/edge-dhcp/Cargo.toml @@ -3,24 +3,17 @@ name = "edge-dhcp" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [features] -default = [] -nightly = [ - "dep:embassy-futures", - "dep:edge-tcp", - "dep:embedded-nal-async", -] +default = ["nightly"] +nightly = ["embassy-futures", "embedded-nal-async"] [dependencies] -no-std-net.workspace = true +no-std-net = { workspace = true } heapless = { workspace = true } -log.workspace = true -embassy-time = { version = "0.1", default-features = false } +log = { workspace = true } rand_core = "0.6" - -edge-tcp = { workspace = true, optional = true } embassy-futures = { workspace = true, optional = true } +embassy-time = { workspace = true, default-features = false } # TODO: Make optional embedded-nal-async = { workspace = true, optional = true } num_enum = { version = "0.7", default-features = false } +edge-raw = { workspace = true, default-features = false } diff --git a/edge-dhcp/src/asynch.rs b/edge-dhcp/src/asynch.rs deleted file mode 100644 index a88a784..0000000 --- a/edge-dhcp/src/asynch.rs +++ /dev/null @@ -1,569 +0,0 @@ -use core::fmt::Debug; - -use embedded_nal_async::{SocketAddr, SocketAddrV4, UdpStack, UnconnectedUdp}; - -use crate as dhcp; - -use edge_tcp::{RawSocket, RawStack}; - -#[derive(Debug)] -pub enum Error { - Io(E), - Format(dhcp::Error), -} - -impl From for Error { - fn from(value: dhcp::Error) -> Self { - Self::Format(value) - } -} - -pub trait SocketFactory { - type Error: Debug; - - type Socket: Socket; - - fn raw_ports(&self) -> (Option, Option); - - async fn connect(&self) -> Result; -} - -impl SocketFactory for &T -where - T: SocketFactory, -{ - type Error = T::Error; - - type Socket = T::Socket; - - fn raw_ports(&self) -> (Option, Option) { - (*self).raw_ports() - } - - async fn connect(&self) -> Result { - (*self).connect().await - } -} - -impl SocketFactory for &mut T -where - T: SocketFactory, -{ - type Error = T::Error; - - type Socket = T::Socket; - - fn raw_ports(&self) -> (Option, Option) { - (**self).raw_ports() - } - - async fn connect(&self) -> Result { - (**self).connect().await - } -} - -pub trait Socket { - type Error: Debug; - - async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error>; - async fn recv(&mut self, buf: &mut [u8]) -> Result; -} - -// impl Socket for &mut T -// where -// T: Socket, -// { -// type Error = T::Error; - -// async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error> { -// (**self).send(data).await -// } - -// async fn recv(&mut self, buf: &mut [u8]) -> Result { -// (**self).recv(buf).await -// } -// } - -pub struct RawSocketFactory { - stack: R, - interface: Option, - local_port: Option, - remote_port: Option, -} - -impl RawSocketFactory -where - R: RawStack, -{ - pub const fn new( - stack: R, - interface: Option, - local_port: Option, - remote_port: Option, - ) -> Self { - if local_port.is_none() && remote_port.is_none() { - panic!("Either the local, or the remote port, or both should be specified"); - } - - Self { - stack, - interface, - local_port, - remote_port, - } - } -} - -impl SocketFactory for RawSocketFactory -where - R: RawStack, -{ - type Error = R::Error; - - type Socket = R::Socket; - - fn raw_ports(&self) -> (Option, Option) { - (self.local_port, self.remote_port) - } - - async fn connect(&self) -> Result { - self.stack.connect(self.interface).await - } -} - -impl Socket for S -where - S: RawSocket, -{ - type Error = S::Error; - - async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error> { - RawSocket::send(self, data).await - } - - async fn recv(&mut self, buf: &mut [u8]) -> Result { - RawSocket::receive_into(self, buf).await - } -} - -/// NOTE: This socket factory can only be used for the DHCP server -/// DHCP client *has* to run via raw sockets -pub struct UdpServerSocketFactory { - stack: U, - local: SocketAddrV4, -} - -impl UdpServerSocketFactory -where - U: UdpStack, -{ - pub const fn new(stack: U, local: SocketAddrV4) -> Self { - Self { stack, local } - } -} - -impl SocketFactory for UdpServerSocketFactory -where - U: UdpStack, -{ - type Error = U::Error; - - type Socket = UdpServerSocket; - - fn raw_ports(&self) -> (Option, Option) { - (None, None) - } - - async fn connect(&self) -> Result { - let (local, socket) = self.stack.bind_single(SocketAddr::V4(self.local)).await?; - - Ok(UdpServerSocket { - socket, - local, - remote: None, - }) - } -} - -pub struct UdpServerSocket { - socket: S, - local: SocketAddr, - remote: Option, -} - -impl Socket for UdpServerSocket -where - S: UnconnectedUdp, -{ - type Error = S::Error; - - async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error> { - let remote = self - .remote - .expect("Sending is possible only after receiving a datagram"); - - self.socket.send(self.local, remote, data).await - } - - async fn recv(&mut self, buf: &mut [u8]) -> Result { - let (len, local, remote) = self.socket.receive_into(buf).await?; - - self.local = local; - self.remote = Some(remote); - - Ok(len) - } -} - -pub mod client { - use core::fmt::Debug; - - use embassy_futures::select::{select, Either}; - use embassy_time::{Duration, Instant, Timer}; - - use embedded_nal_async::Ipv4Addr; - - use log::{info, warn}; - - use rand_core::RngCore; - - use self::dhcp::MessageType; - - pub use super::*; - - pub use crate::Settings; - - #[derive(Clone, Debug)] - pub struct Configuration { - pub mac: [u8; 6], - pub timeout: Duration, - } - - impl Configuration { - pub const fn new(mac: [u8; 6]) -> Self { - Self { - mac, - timeout: Duration::from_secs(10), - } - } - } - - /// A simple asynchronous DHCP client. - /// - /// The client takes a socket factory (either operating on raw sockets or UDP datagrams) and - /// then takes care of the all the negotiations with the DHCP server, as in discovering servers, - /// negotiating initial IP, and then keeping the lease of that IP up to date. - /// - /// Note that it is unlikely that a non-raw socket factory would actually even work, due to the peculiarities of the - /// DHCP protocol, where a lot of UDP packets are send (and often broadcasted) by the client before the client actually has an assigned IP. - pub struct Client { - rng: T, - mac: [u8; 6], - timeout: Duration, - settings: Option<(Settings, Instant)>, - } - - impl Client - where - T: RngCore, - { - pub fn new(rng: T, conf: &Configuration) -> Self { - info!("Creating DHCP client with configuration {conf:?}"); - - Self { - rng, - mac: conf.mac, - timeout: conf.timeout, - settings: None, - } - } - - /// Runs the DHCP client with the supplied socket factory, and takes care of - /// all aspects of negotiating an IP with the first DHCP server that replies to the discovery requests. - /// - /// From the POV of the user, this method will return only in two cases, which are exactly the cases where the user is expected to take an action: - /// - When an initial/new IP lease was negotiated; in that case, `Some(Settings)` is returned, and the user should assign the returned IP settings - /// to the network interface using platform-specific means - /// - When the IP lease was lost; in that case, `None` is returned, and the user should de-assign all IP settings from the network interface using - /// platform-specific means - /// - /// In both cases, user is expected to call `run` again, so that the IP lease is kept up to date / a new lease is re-negotiated - /// - /// Note that dropping this future is also safe in that it won't remove the current lease, so the user can renew - /// the operation of the client by just calling `run` later on. Of course, if the future is not polled, the client - /// would be unable - during that time - to check for lease timeout and the lease might not be renewed on time. - /// - /// But in any case, if the lease is expired or the DHCP server does not acknowledge the lease renewal, the client will - /// automatically restart the DHCP servers' discovery from the very beginning. - pub async fn run( - &mut self, - mut f: F, - buf: &mut [u8], - ) -> Result, Error> { - loop { - if let Some((settings, acquired)) = self.settings.as_ref() { - // Keep the lease - let now = Instant::now(); - - if now - *acquired - >= Duration::from_secs(settings.lease_time_secs.unwrap_or(7200) as u64 / 3) - { - info!("Renewing DHCP lease..."); - - if let Some(settings) = self - .request(&mut f, buf, settings.server_ip.unwrap(), settings.ip) - .await? - { - self.settings = Some((settings, Instant::now())); - } else { - // Lease was not renewed; let the user know - self.settings = None; - - return Ok(None); - } - } else { - Timer::after(Duration::from_secs(60)).await; - } - } else { - // Look for offers - let offer = self.discover(&mut f, buf).await?; - - if let Some(settings) = self - .request(&mut f, buf, offer.server_ip.unwrap(), offer.ip) - .await? - { - // IP acquired; let the user know - self.settings = Some((settings.clone(), Instant::now())); - - return Ok(Some(settings)); - } - } - } - } - - /// This method allows the user to inform the DHCP server that the currently leased IP (if any) is no longer used - /// by the client. - /// - /// Useful when the program runnuing the DHCP client is about to exit. - pub async fn release( - &mut self, - f: F, - buf: &mut [u8], - ) -> Result<(), Error> { - if let Some((settings, _)) = self.settings.as_ref().cloned() { - let mut socket = f.connect().await.map_err(Error::Io)?; - - let packet = self.client(&f).encode_release( - buf, - 0, - settings.server_ip.unwrap(), - settings.ip, - )?; - - socket.send(packet).await.map_err(Error::Io)?; - } - - self.settings = None; - - Ok(()) - } - - async fn discover( - &mut self, - f: &mut F, - buf: &mut [u8], - ) -> Result> { - info!("Discovering DHCP servers..."); - - let timeout = self.timeout; - let mut client = self.client(&f); - - let start = Instant::now(); - - loop { - let mut socket = f.connect().await.map_err(Error::Io)?; - - let (packet, xid) = - client.encode_discover(buf, (Instant::now() - start).as_secs() as _, None)?; - - socket.send(packet).await.map_err(Error::Io)?; - - let offer_start = Instant::now(); - - while Instant::now() - offer_start < timeout { - let timer = Timer::after(Duration::from_secs(3)); - - if let Either::First(result) = select(socket.recv(buf), timer).await { - let len = result.map_err(Error::Io)?; - let packet = &buf[..len]; - - if let Some(reply) = - client.decode_bootp_reply(packet, xid, Some(&[MessageType::Offer]))? - { - let settings = reply.settings().unwrap().1; - - info!( - "IP {} offered by DHCP server {}", - settings.ip, - settings.server_ip.unwrap() - ); - return Ok(settings); - } - } - } - - drop(socket); - - info!("No DHCP offers received, sleeping for a while..."); - - Timer::after(Duration::from_secs(3)).await; - } - } - - async fn request( - &mut self, - f: &mut F, - buf: &mut [u8], - server_ip: Ipv4Addr, - ip: Ipv4Addr, - ) -> Result, Error> { - let timeout = self.timeout; - let mut client = self.client(&f); - - for _ in 0..3 { - info!("Requesting IP {ip} from DHCP server {server_ip}"); - - let mut socket = f.connect().await.map_err(Error::Io)?; - - let start = Instant::now(); - - let (packet, xid) = client.encode_request( - buf, - (Instant::now() - start).as_secs() as _, - server_ip, - ip, - )?; - - socket.send(packet).await.map_err(Error::Io)?; - - let request_start = Instant::now(); - - while Instant::now() - request_start < timeout { - let timer = Timer::after(Duration::from_secs(10)); - - if let Either::First(result) = select(socket.recv(buf), timer).await { - let len = result.map_err(Error::Io)?; - let packet = &buf[..len]; - - if let Some(reply) = client.decode_bootp_reply( - packet, - xid, - Some(&[MessageType::Ack, MessageType::Nak]), - )? { - let (mt, settings) = reply.settings().unwrap(); - - let settings = if matches!(mt, MessageType::Ack) { - info!("IP {} leased successfully", settings.ip); - Some(settings) - } else { - info!("IP {} not acknowledged", settings.ip); - None - }; - - return Ok(settings); - } - } - } - - drop(socket); - } - - warn!("IP request was not replied"); - - Ok(None) - } - - fn client(&mut self, f: F) -> dhcp::client::Client<&mut T> { - dhcp::client::Client { - rng: &mut self.rng, - mac: self.mac, - rp_udp_client_port: f.raw_ports().0, - rp_udp_server_port: f.raw_ports().1, - } - } - } -} - -pub mod server { - use core::fmt::Debug; - - use embassy_time::Duration; - - use embedded_nal_async::Ipv4Addr; - - use log::info; - - pub use super::*; - - #[derive(Clone, Debug)] - pub struct Configuration { - pub ip: Ipv4Addr, - pub gateway: Option, - pub subnet: Option, - pub dns1: Option, - pub dns2: Option, - pub range_start: Ipv4Addr, - pub range_end: Ipv4Addr, - pub lease_duration_secs: u32, - } - - /// A simple asynchronous DHCP server. - /// - /// The client takes a socket factory (either operating on raw sockets or UDP datagrams) and - /// then processes all incoming BOOTP requests, by updating its internal simple database of leases, and issuing replies. - pub struct Server { - pub server: dhcp::server::Server, - } - - impl Server { - pub fn new(conf: &Configuration) -> Self { - info!("Creating DHCP server with configuration {conf:?}"); - - Self { - server: dhcp::server::Server { - ip: conf.ip, - gateways: conf.gateway.iter().cloned().collect(), - subnet: conf.subnet, - dns: conf.dns1.iter().chain(conf.dns2.iter()).cloned().collect(), - range_start: conf.range_start, - range_end: conf.range_end, - lease_duration: Duration::from_secs(conf.lease_duration_secs as _), - leases: heapless::LinearMap::new(), - }, - } - } - - /// Runs the DHCP server wth the supplied socket factory, processing incoming DHCP requests. - /// - /// Note that dropping this future is safe in that it won't remove the internal leases' database, - /// so users are free to drop the future in case they would like to take a snapshot of the leases or inspect them otherwise. - pub async fn run( - &mut self, - f: F, - buf: &mut [u8], - ) -> Result<(), Error> { - let mut socket = f.connect().await.map_err(Error::Io)?; - - loop { - let len = socket.recv(buf).await.map_err(Error::Io)?; - - if let Some(reply) = self - .server - .handle_bootp_request(f.raw_ports().0, buf, len)? - { - socket.send(reply).await.map_err(Error::Io)?; - } - } - } - } -} diff --git a/edge-dhcp/src/client.rs b/edge-dhcp/src/client.rs new file mode 100644 index 0000000..57e73b2 --- /dev/null +++ b/edge-dhcp/src/client.rs @@ -0,0 +1,105 @@ +use rand_core::RngCore; + +use super::*; + +/// A simple DHCP client. +/// The client is unaware of the IP/UDP transport layer and operates purely in terms of packets +/// represented as Rust slices. +/// +/// As such, the client can generate all BOOTP requests and parse BOOTP replies. +pub struct Client { + pub rng: T, + pub mac: [u8; 6], +} + +impl Client +where + T: RngCore, +{ + pub fn discover<'o>( + &mut self, + opt_buf: &'o mut [DhcpOption<'o>], + secs: u16, + ip: Option, + ) -> (Packet<'o>, u32) { + self.bootp_request(secs, None, Options::discover(ip, opt_buf)) + } + + pub fn request<'o>( + &mut self, + opt_buf: &'o mut [DhcpOption<'o>], + secs: u16, + ip: Ipv4Addr, + ) -> (Packet<'o>, u32) { + self.bootp_request(secs, None, Options::request(ip, opt_buf)) + } + + pub fn release<'o>( + &mut self, + opt_buf: &'o mut [DhcpOption<'o>], + secs: u16, + ip: Ipv4Addr, + ) -> Packet<'o> { + self.bootp_request(secs, Some(ip), Options::release(opt_buf)) + .0 + } + + pub fn decline<'o>( + &mut self, + opt_buf: &'o mut [DhcpOption<'o>], + secs: u16, + ip: Ipv4Addr, + ) -> Packet<'o> { + self.bootp_request(secs, Some(ip), Options::decline(opt_buf)) + .0 + } + + pub fn is_offer(&self, reply: &Packet<'_>, xid: u32) -> bool { + self.is_bootp_reply_for_us(reply, xid, Some(&[MessageType::Offer])) + } + + pub fn is_ack(&self, reply: &Packet<'_>, xid: u32) -> bool { + self.is_bootp_reply_for_us(reply, xid, Some(&[MessageType::Ack])) + } + + pub fn is_nak(&self, reply: &Packet<'_>, xid: u32) -> bool { + self.is_bootp_reply_for_us(reply, xid, Some(&[MessageType::Nak])) + } + + #[allow(clippy::too_many_arguments)] + pub fn bootp_request<'o>( + &mut self, + secs: u16, + ip: Option, + options: Options<'o>, + ) -> (Packet<'o>, u32) { + let xid = self.rng.next_u32(); + + (Packet::new_request(self.mac, xid, secs, ip, options), xid) + } + + pub fn is_bootp_reply_for_us( + &self, + reply: &Packet<'_>, + xid: u32, + expected_message_types: Option<&[MessageType]>, + ) -> bool { + if reply.reply && reply.is_for_us(&self.mac, xid) { + if let Some(expected_message_types) = expected_message_types { + let mt = reply.options.iter().find_map(|option| { + if let DhcpOption::MessageType(mt) = option { + Some(mt) + } else { + None + } + }); + + expected_message_types.iter().any(|emt| mt == Some(*emt)) + } else { + true + } + } else { + false + } + } +} diff --git a/edge-dhcp/src/io.rs b/edge-dhcp/src/io.rs new file mode 100644 index 0000000..e6761b9 --- /dev/null +++ b/edge-dhcp/src/io.rs @@ -0,0 +1,20 @@ +use core::fmt::Debug; + +use embedded_nal_async::{SocketAddr, SocketAddrV4, UdpStack, UnconnectedUdp}; + +use crate as dhcp; + +pub mod client; +pub mod server; + +#[derive(Debug)] +pub enum Error { + Io(E), + Format(dhcp::Error), +} + +impl From for Error { + fn from(value: dhcp::Error) -> Self { + Self::Format(value) + } +} diff --git a/edge-dhcp/src/io/client.rs b/edge-dhcp/src/io/client.rs new file mode 100644 index 0000000..12c581c --- /dev/null +++ b/edge-dhcp/src/io/client.rs @@ -0,0 +1,276 @@ +use core::fmt::Debug; + +use embassy_futures::select::{select, Either}; +use embassy_time::{Duration, Instant, Timer}; + +use embedded_nal_async::{ConnectedUdp, Ipv4Addr}; + +use log::{info, warn}; + +use rand_core::RngCore; + +pub use super::*; + +pub use crate::Settings; +use crate::{Options, Packet}; + +#[derive(Clone, Debug)] +pub struct Configuration { + pub socket: SocketAddrV4, + pub mac: [u8; 6], + pub timeout: Duration, +} + +impl Configuration { + pub const fn new(mac: [u8; 6]) -> Self { + Self { + socket: SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 68), + mac, + timeout: Duration::from_secs(10), + } + } +} + +/// A simple asynchronous DHCP client. +/// +/// The client takes a socket factory (either operating on raw sockets or UDP datagrams) and +/// then takes care of the all the negotiations with the DHCP server, as in discovering servers, +/// negotiating initial IP, and then keeping the lease of that IP up to date. +/// +/// Note that it is unlikely that a non-raw socket factory would actually even work, due to the peculiarities of the +/// DHCP protocol, where a lot of UDP packets are send (and often broadcasted) by the client before the client actually has an assigned IP. +pub struct Client<'a, T, F> { + stack: F, + buf: &'a mut [u8], + client: dhcp::client::Client, + socket: SocketAddrV4, + timeout: Duration, + pub settings: Option<(Settings, Instant)>, +} + +impl<'a, T, F> Client<'a, T, F> +where + T: RngCore, + F: UdpStack, +{ + pub fn new(stack: F, buf: &'a mut [u8], rng: T, conf: &Configuration) -> Self { + info!("Creating DHCP client with configuration {conf:?}"); + + Self { + stack, + buf, + client: dhcp::client::Client { rng, mac: conf.mac }, + socket: conf.socket, + timeout: conf.timeout, + settings: None, + } + } + + /// Runs the DHCP client with the supplied socket factory, and takes care of + /// all aspects of negotiating an IP with the first DHCP server that replies to the discovery requests. + /// + /// From the POV of the user, this method will return only in two cases, which are exactly the cases where the user is expected to take an action: + /// - When an initial/new IP lease was negotiated; in that case, `Some(Settings)` is returned, and the user should assign the returned IP settings + /// to the network interface using platform-specific means + /// - When the IP lease was lost; in that case, `None` is returned, and the user should de-assign all IP settings from the network interface using + /// platform-specific means + /// + /// In both cases, user is expected to call `run` again, so that the IP lease is kept up to date / a new lease is re-negotiated + /// + /// Note that dropping this future is also safe in that it won't remove the current lease, so the user can renew + /// the operation of the client by just calling `run` later on. Of course, if the future is not polled, the client + /// would be unable - during that time - to check for lease timeout and the lease might not be renewed on time. + /// + /// But in any case, if the lease is expired or the DHCP server does not acknowledge the lease renewal, the client will + /// automatically restart the DHCP servers' discovery from the very beginning. + pub async fn run(&mut self) -> Result, Error> { + loop { + if let Some((settings, acquired)) = self.settings.as_ref() { + // Keep the lease + let now = Instant::now(); + + if now - *acquired + >= Duration::from_secs(settings.lease_time_secs.unwrap_or(7200) as u64 / 3) + { + info!("Renewing DHCP lease..."); + + if let Some(settings) = self + .request(settings.server_ip.unwrap(), settings.ip) + .await? + { + self.settings = Some((settings, Instant::now())); + } else { + // Lease was not renewed; let the user know + self.settings = None; + + return Ok(None); + } + } else { + Timer::after(Duration::from_secs(60)).await; + } + } else { + // Look for offers + let offer = self.discover().await?; + + if let Some(settings) = self.request(offer.server_ip.unwrap(), offer.ip).await? { + // IP acquired; let the user know + self.settings = Some((settings.clone(), Instant::now())); + + return Ok(Some(settings)); + } + } + } + } + + /// This method allows the user to inform the DHCP server that the currently leased IP (if any) is no longer used + /// by the client. + /// + /// Useful when the program runnuing the DHCP client is about to exit. + pub async fn release(&mut self) -> Result<(), Error> { + if let Some((settings, _)) = self.settings.as_ref().cloned() { + let server_ip = settings.server_ip.unwrap(); + let (_, mut socket) = self + .stack + .connect_from( + SocketAddr::V4(self.socket), + SocketAddr::V4(SocketAddrV4::new(server_ip, self.socket.port())), + ) + .await + .map_err(Error::Io)?; + + let mut opt_buf = Options::buf(); + let request = self.client.release(&mut opt_buf, 0, settings.ip); + + socket + .send(request.encode(self.buf)?) + .await + .map_err(Error::Io)?; + } + + self.settings = None; + + Ok(()) + } + + async fn discover(&mut self) -> Result> { + info!("Discovering DHCP servers..."); + + let start = Instant::now(); + + loop { + let mut socket = self + .stack + .bind_multiple(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 68))) + .await + .map_err(Error::Io)?; + + let mut opt_buf = Options::buf(); + + let (request, xid) = + self.client + .discover(&mut opt_buf, (Instant::now() - start).as_secs() as _, None); + + socket + .send( + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 68)), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::BROADCAST, 67)), + request.encode(self.buf)?, + ) + .await + .map_err(Error::Io)?; + + let offer_start = Instant::now(); + + while Instant::now() - offer_start < self.timeout { + let timer = Timer::after(Duration::from_secs(3)); + + if let Either::First(result) = select(socket.receive_into(self.buf), timer).await { + let (len, _local, _remote) = result.map_err(Error::Io)?; + let reply = Packet::decode(&self.buf[..len])?; + + if self.client.is_offer(&reply, xid) { + let settings: Settings = (&reply).into(); + + info!( + "IP {} offered by DHCP server {}", + settings.ip, + settings.server_ip.unwrap() + ); + + return Ok(settings); + } + } + } + + drop(socket); + + info!("No DHCP offers received, sleeping for a while..."); + + Timer::after(Duration::from_secs(3)).await; + } + } + + async fn request( + &mut self, + server_ip: Ipv4Addr, + ip: Ipv4Addr, + ) -> Result, Error> { + for _ in 0..3 { + info!("Requesting IP {ip} from DHCP server {server_ip}"); + + let mut socket = self + .stack + .bind_multiple(SocketAddr::V4(SocketAddrV4::new(server_ip, 68))) + .await + .map_err(Error::Io)?; + + let start = Instant::now(); + + let mut opt_buf = Options::buf(); + + let (request, xid) = + self.client + .request(&mut opt_buf, (Instant::now() - start).as_secs() as _, ip); + + socket + .send( + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 68)), + SocketAddr::V4(SocketAddrV4::new(server_ip, 67)), + request.encode(self.buf)?, + ) + .await + .map_err(Error::Io)?; + + let request_start = Instant::now(); + + while Instant::now() - request_start < self.timeout { + let timer = Timer::after(Duration::from_secs(10)); + + if let Either::First(result) = select(socket.receive_into(self.buf), timer).await { + let (len, _local, _remote) = result.map_err(Error::Io)?; + let packet = &self.buf[..len]; + + let reply = Packet::decode(packet)?; + + if self.client.is_ack(&reply, xid) { + let settings = (&reply).into(); + + info!("IP {} leased successfully", ip); + + return Ok(Some(settings)); + } else if self.client.is_nak(&reply, xid) { + info!("IP {} not acknowledged", ip); + + return Ok(None); + } + } + } + + drop(socket); + } + + warn!("IP request was not replied"); + + Ok(None) + } +} diff --git a/edge-dhcp/src/io/server.rs b/edge-dhcp/src/io/server.rs new file mode 100644 index 0000000..89d9851 --- /dev/null +++ b/edge-dhcp/src/io/server.rs @@ -0,0 +1,101 @@ +use core::fmt::Debug; + +use embassy_time::Duration; + +use embedded_nal_async::Ipv4Addr; + +use log::info; + +use self::dhcp::{Options, Packet}; + +pub use super::*; + +#[derive(Clone, Debug)] +pub struct Configuration<'a> { + pub socket: SocketAddrV4, + pub ip: Ipv4Addr, + pub gateways: &'a [Ipv4Addr], + pub subnet: Option, + pub dns: &'a [Ipv4Addr], + pub range_start: Ipv4Addr, + pub range_end: Ipv4Addr, + pub lease_duration_secs: u32, +} + +/// A simple asynchronous DHCP server. +/// +/// The client takes a socket factory (either operating on raw sockets or UDP datagrams) and +/// then processes all incoming BOOTP requests, by updating its internal simple database of leases, and issuing replies. +pub struct Server<'a, const N: usize, F> { + stack: F, + buf: &'a mut [u8], + socket: SocketAddrV4, + server_options: dhcp::server::ServerOptions<'a>, + pub server: dhcp::server::Server, +} + +impl<'a, const N: usize, F> Server<'a, N, F> +where + F: UdpStack, +{ + pub fn new(stack: F, buf: &'a mut [u8], conf: &Configuration<'a>) -> Self { + info!("Creating DHCP server with configuration {conf:?}"); + + Self { + stack, + buf, + socket: conf.socket, + server_options: dhcp::server::ServerOptions { + ip: conf.ip, + gateways: conf.gateways, + subnet: conf.subnet, + dns: conf.dns, + lease_duration: Duration::from_secs(conf.lease_duration_secs as _), + }, + server: dhcp::server::Server { + range_start: conf.range_start, + range_end: conf.range_end, + leases: heapless::LinearMap::new(), + }, + } + } + + /// Runs the DHCP server wth the supplied socket factory, processing incoming DHCP requests. + /// + /// Note that dropping this future is safe in that it won't remove the internal leases' database, + /// so users are free to drop the future in case they would like to take a snapshot of the leases or inspect them otherwise. + pub async fn run(&mut self) -> Result<(), Error> { + let mut socket = self + .stack + .bind_multiple(SocketAddr::V4(self.socket)) + .await + .map_err(Error::Io)?; + + loop { + let (len, local, remote) = socket.receive_into(self.buf).await.map_err(Error::Io)?; + let packet = &self.buf[..len]; + + let request = Packet::decode(packet)?; + + let mut opt_buf = Options::buf(); + + if let Some(request) = + self.server + .handle_request(&mut opt_buf, &self.server_options, &request) + { + socket + .send( + local, + if request.broadcast { + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::BROADCAST, remote.port())) + } else { + remote + }, + request.encode(self.buf)?, + ) + .await + .map_err(Error::Io)?; + } + } + } +} diff --git a/edge-dhcp/src/lib.rs b/edge-dhcp/src/lib.rs index 715a92b..743f7bd 100644 --- a/edge-dhcp/src/lib.rs +++ b/edge-dhcp/src/lib.rs @@ -1,10 +1,3 @@ -#![cfg_attr(not(feature = "std"), no_std)] -#![allow(stable_features)] -#![allow(unknown_lints)] -#![feature(async_fn_in_trait)] -#![allow(async_fn_in_trait)] -#![feature(impl_trait_projections)] - /// This code is a `no_std` and no-alloc modification of https://github.com/krolaw/dhcp4r use core::str::Utf8Error; @@ -12,20 +5,33 @@ use no_std_net::Ipv4Addr; use num_enum::TryFromPrimitive; -#[cfg(feature = "nightly")] -pub mod asynch; +use edge_raw::bytes::{self, BytesIn, BytesOut}; -use self::raw_ip::{Ipv4PacketHeader, UdpPacketHeader}; +pub mod client; +pub mod server; + +#[cfg(feature = "nightly")] +pub mod io; #[derive(Debug)] pub enum Error { DataUnderflow, + BufferOverflow, + InvalidPacket, InvalidUtf8Str(Utf8Error), InvalidMessageType, MissingCookie, InvalidHlen, - BufferOverflow, - InvalidPacket, +} + +impl From for Error { + fn from(value: bytes::Error) -> Self { + match value { + bytes::Error::BufferOverflow => Self::BufferOverflow, + bytes::Error::DataUnderflow => Self::DataUnderflow, + bytes::Error::InvalidFormat => Self::InvalidPacket, + } + } } /// @@ -149,22 +155,6 @@ impl<'a> Packet<'a> { && self.reply } - pub fn settings(&self) -> Option<(MessageType, Settings)> { - if self.reply { - let mt = self.options.iter().find_map(|option| { - if let DhcpOption::MessageType(mt) = option { - Some(mt) - } else { - None - } - }); - - mt.map(|mt| (mt, self.into())) - } else { - None - } - } - /// Parses the packet from a byte slice pub fn decode(data: &'a [u8]) -> Result { let mut bytes = BytesIn::new(data); @@ -244,64 +234,6 @@ impl<'a> Packet<'a> { Ok(&buf[..len]) } - - /// Parses the packet from a byte slice that models a raw IP packet - /// Useful when working with raw sockets - pub fn decode_raw( - data: &'a [u8], - src_port: Option, - dst_port: Option, - ) -> Result, Error> { - if let Some((ip_hdr, ip_payload)) = Ipv4PacketHeader::decode_with_payload(data)? { - if ip_hdr.p == UdpPacketHeader::PROTO { - let (udp_hdr, udp_payload) = - UdpPacketHeader::decode_with_payload(ip_payload, &ip_hdr)?; - - if src_port.map(|p| p == udp_hdr.src).unwrap_or(true) - && dst_port.map(|p| p == udp_hdr.dst).unwrap_or(true) - { - return Ok(Some((ip_hdr, udp_hdr, Packet::decode(udp_payload)?))); - } - } - } - - Ok(None) - } - - /// Encodes the packet into the provided buf slice, together with a UDP and IPv$ headers - /// Useful when working with raw sockets - pub fn encode_raw<'o>( - &self, - src_ip: Option, - src_port: u16, - dst_ip: Option, - dst_port: u16, - buf: &'o mut [u8], - ) -> Result<&'o [u8], Error> { - if buf.len() < Ipv4PacketHeader::MIN_SIZE + UdpPacketHeader::SIZE { - Err(Error::BufferOverflow)?; - } - - let mut ip_hdr = Ipv4PacketHeader::new( - src_ip.unwrap_or(Ipv4Addr::UNSPECIFIED), - dst_ip.unwrap_or(Ipv4Addr::BROADCAST), - UdpPacketHeader::PROTO, - ); - - ip_hdr.encode_with_payload(buf, |buf, ip_hdr| { - let mut udp_hdr = UdpPacketHeader::new(src_port, dst_port); - - let len = udp_hdr - .encode_with_payload(buf, ip_hdr, |buf| { - let len = self.encode(buf)?.len(); - - Ok(len) - })? - .len(); - - Ok(len) - }) - } } #[derive(Clone, Debug)] @@ -712,514 +644,6 @@ where } } -struct BytesIn<'a> { - data: &'a [u8], - offset: usize, -} - -impl<'a> BytesIn<'a> { - pub const fn new(data: &'a [u8]) -> Self { - Self { data, offset: 0 } - } - - pub fn is_empty(&self) -> bool { - self.offset == self.data.len() - } - - pub fn offset(&self) -> usize { - self.offset - } - - pub fn byte(&mut self) -> Result { - self.arr::<1>().map(|arr| arr[0]) - } - - pub fn slice(&mut self, len: usize) -> Result<&'a [u8], Error> { - if len > self.data.len() - self.offset { - Err(Error::DataUnderflow) - } else { - let data = &self.data[self.offset..self.offset + len]; - self.offset += len; - - Ok(data) - } - } - - pub fn arr(&mut self) -> Result<[u8; N], Error> { - let slice = self.slice(N)?; - - let mut data = [0; N]; - data.copy_from_slice(slice); - - Ok(data) - } - - pub fn remaining(&mut self) -> &'a [u8] { - let data = self.slice(self.data.len() - self.offset).unwrap(); - - self.offset = self.data.len(); - - data - } - - pub fn remaining_byte(&mut self) -> Result { - Ok(self.remaining_arr::<1>()?[0]) - } - - pub fn remaining_arr(&mut self) -> Result<[u8; N], Error> { - if self.data.len() - self.offset > N { - Err(Error::InvalidHlen) // TODO - } else { - self.arr::() - } - } -} - -struct BytesOut<'a> { - buf: &'a mut [u8], - offset: usize, -} - -impl<'a> BytesOut<'a> { - pub fn new(buf: &'a mut [u8]) -> Self { - Self { buf, offset: 0 } - } - - pub fn len(&self) -> usize { - self.offset - } - - pub fn byte(&mut self, data: u8) -> Result<&mut Self, Error> { - self.push(&[data]) - } - - pub fn push(&mut self, data: &[u8]) -> Result<&mut Self, Error> { - if data.len() > self.buf.len() - self.offset { - Err(Error::BufferOverflow) - } else { - self.buf[self.offset..self.offset + data.len()].copy_from_slice(data); - self.offset += data.len(); - - Ok(self) - } - } -} - -pub mod client { - use log::trace; - - use rand_core::RngCore; - - use super::*; - - /// A simple DHCP client. - /// The client is unaware of the IP/UDP transport layer and operates purely in terms of packets - /// represented as Rust slices. - /// - /// As such, the client can generate all BOOTP requests and parse BOOTP replies. - /// - /// The client supports both raw IP as well as regular UDP payloads, where the raw payloads are - /// automatically prefixed/unprefixed with the IP and UDP header, which allows this client to be used with a raw sockets' transport layer. - /// - /// Note that it is unlikely that a non-raw socket transport would actually even work, due to the peculiarities of the - /// DHCP protocol, where a lot of UDP packets are send (and often broadcasted) by the client before the client actually has an assigned IP. - pub struct Client { - pub rng: T, - pub mac: [u8; 6], - pub rp_udp_client_port: Option, - pub rp_udp_server_port: Option, - } - - impl Client - where - T: RngCore, - { - pub fn encode_discover<'o>( - &mut self, - buf: &'o mut [u8], - secs: u16, - ip: Option, - ) -> Result<(&'o [u8], u32), Error> { - let mut opt_buf = Options::buf(); - - self.encode_bootp_request(buf, secs, None, None, Options::discover(ip, &mut opt_buf)) - } - - pub fn encode_request<'o>( - &mut self, - buf: &'o mut [u8], - secs: u16, - server_ip: Ipv4Addr, - our_ip: Ipv4Addr, - ) -> Result<(&'o [u8], u32), Error> { - let mut opt_buf = Options::buf(); - - self.encode_bootp_request( - buf, - secs, - Some(server_ip), - None, - Options::request(our_ip, &mut opt_buf), - ) - } - - pub fn encode_release<'o>( - &mut self, - buf: &'o mut [u8], - secs: u16, - server_ip: Ipv4Addr, - our_ip: Ipv4Addr, - ) -> Result<&'o [u8], Error> { - let mut opt_buf = Options::buf(); - - self.encode_bootp_request( - buf, - secs, - Some(server_ip), - Some(our_ip), - Options::release(&mut opt_buf), - ) - .map(|r| r.0) - } - - pub fn encode_decline<'o>( - &mut self, - buf: &'o mut [u8], - secs: u16, - server_ip: Ipv4Addr, - our_ip: Ipv4Addr, - ) -> Result<&'o [u8], Error> { - let mut opt_buf = Options::buf(); - - self.encode_bootp_request( - buf, - secs, - Some(server_ip), - Some(our_ip), - Options::decline(&mut opt_buf), - ) - .map(|r| r.0) - } - - #[allow(clippy::too_many_arguments)] - pub fn encode_bootp_request<'o>( - &mut self, - buf: &'o mut [u8], - secs: u16, - server_ip: Option, - our_ip: Option, - options: Options<'_>, - ) -> Result<(&'o [u8], u32), Error> { - let xid = self.rng.next_u32(); - - let request = Packet::new_request(self.mac, xid, secs, our_ip, options.clone()); - - let data = if self.rp_udp_server_port.is_some() || self.rp_udp_client_port.is_some() { - request.encode_raw( - our_ip, - self.rp_udp_client_port.unwrap_or(68), - server_ip, - self.rp_udp_server_port.unwrap_or(67), - buf, - )? - } else { - request.encode(buf)? - }; - - Ok((data, xid)) - } - - pub fn decode_bootp_reply<'o>( - &self, - data: &'o [u8], - xid: u32, - expected_message_types: Option<&[MessageType]>, - ) -> Result>, Error> { - let reply = if self.rp_udp_server_port.is_some() || self.rp_udp_client_port.is_some() { - Packet::decode_raw(data, self.rp_udp_server_port, self.rp_udp_client_port)? - .map(|r| r.2) - } else { - Some(Packet::decode(data)?) - }; - - trace!("DHCP packet decoded:\n{reply:?}"); - - Ok(reply.and_then(|reply| { - if reply.is_for_us(&self.mac, xid) { - if let Some(expected_message_types) = expected_message_types { - let (mt, _) = reply.settings().unwrap(); - - if expected_message_types.iter().any(|emt| mt == *emt) { - return Some(reply); - } - } else { - return Some(reply); - } - } - - None - })) - } - } -} - -pub mod server { - use core::fmt::Debug; - - use embassy_time::{Duration, Instant}; - - use log::{info, trace}; - - use super::*; - - #[derive(Clone, Debug)] - pub struct Lease { - mac: [u8; 16], - expires: Instant, - } - - /// A simple DHCP server. - /// The server is unaware of the IP/UDP transport layer and operates purely in terms of packets - /// represented as Rust slices. - /// - /// The server supports both raw IP as well as regular UDP payloads, where the raw payloads are - /// automatically prefixed/unprefixed with the IP and UDP header, which allows this server to be used with a raw sockets' transport layer. - #[derive(Clone, Debug)] - pub struct Server { - pub ip: Ipv4Addr, - pub gateways: heapless::Vec, - pub subnet: Option, - pub dns: heapless::Vec, - pub range_start: Ipv4Addr, - pub range_end: Ipv4Addr, - pub lease_duration: Duration, - pub leases: heapless::LinearMap, - } - - impl Server { - pub fn handle_bootp_request<'o>( - &mut self, - rp_udp_server_port: Option, - buf: &'o mut [u8], - incoming_len: usize, - ) -> Result, Error> { - let request = if let Some(port) = rp_udp_server_port { - Packet::decode_raw(&buf[..incoming_len], None, Some(port))? - .map(|(ip_hdr, udp_hdr, request)| (Some((ip_hdr, udp_hdr)), request)) - } else { - Some((None, Packet::decode(&buf[..incoming_len])?)) - }; - - if let Some((raw_hdrs, request)) = request { - trace!("Got packet {request:?}"); - - if !request.reply { - let mt = request.options.iter().find_map(|option| { - if let DhcpOption::MessageType(mt) = option { - Some(mt) - } else { - None - } - }); - - if let Some(mt) = mt { - let server_identifier = request.options.iter().find_map(|option| { - if let DhcpOption::ServerIdentifier(ip) = option { - Some(ip) - } else { - None - } - }); - - if server_identifier == Some(self.ip) - || server_identifier.is_none() && matches!(mt, MessageType::Discover) - { - info!("Packet is for us, will process, message type {mt:?}"); - - let mut opt_buf = Options::buf(); - - let reply = match mt { - MessageType::Discover => { - let requested_ip = request.options.iter().find_map(|option| { - if let DhcpOption::RequestedIpAddress(ip) = option { - Some(ip) - } else { - None - } - }); - - let ip = requested_ip - .and_then(|ip| { - self.is_available(&request.chaddr, ip).then_some(ip) - }) - .or_else(|| self.current_lease(&request.chaddr)) - .or_else(|| self.available()); - - ip.map(|ip| { - self.reply_to( - &request, - MessageType::Offer, - Some(ip), - &mut opt_buf, - ) - }) - } - MessageType::Request => { - let ip = request - .options - .iter() - .find_map(|option| { - if let DhcpOption::RequestedIpAddress(ip) = option { - Some(ip) - } else { - None - } - }) - .unwrap_or(request.ciaddr); - - Some( - if self.is_available(&request.chaddr, ip) - && self.add_lease( - ip, - request.chaddr, - Instant::now() + self.lease_duration, - ) - { - self.reply_to( - &request, - MessageType::Ack, - Some(ip), - &mut opt_buf, - ) - } else { - self.reply_to( - &request, - MessageType::Nak, - None, - &mut opt_buf, - ) - }, - ) - } - MessageType::Decline | MessageType::Release => { - self.remove_lease(&request.chaddr); - - None - } - _ => None, - }; - - if let Some(reply) = reply { - let packet = if let Some((ip_hdr, udp_hdr)) = raw_hdrs { - reply.encode_raw( - Some(self.ip), - udp_hdr.dst, - Some(ip_hdr.src), - udp_hdr.src, - buf, - )? - } else { - reply.encode(buf)? - }; - - return Ok(Some(packet)); - } - } - } - } - } - - Ok(None) - } - - fn reply_to<'a>( - &'a self, - request: &Packet<'_>, - mt: MessageType, - ip: Option, - buf: &'a mut [DhcpOption<'a>], - ) -> Packet<'a> { - let reply = request.new_reply( - ip, - request.options.reply( - mt, - self.ip, - self.lease_duration.as_secs() as _, - &self.gateways, - self.subnet, - &self.dns, - buf, - ), - ); - - info!("Reply: {reply:?}"); - - reply - } - - fn is_available(&self, mac: &[u8; 16], addr: Ipv4Addr) -> bool { - let pos: u32 = addr.into(); - - let start: u32 = self.range_start.into(); - let end: u32 = self.range_end.into(); - - pos >= start - && pos <= end - && match self.leases.get(&addr) { - Some(lease) => lease.mac == *mac || Instant::now() > lease.expires, - None => true, - } - } - - fn available(&mut self) -> Option { - let start: u32 = self.range_start.into(); - let end: u32 = self.range_end.into(); - - for pos in start..end + 1 { - let addr = pos.into(); - - if !self.leases.contains_key(&addr) { - return Some(addr); - } - } - - if let Some(addr) = self - .leases - .iter() - .find_map(|(addr, lease)| (Instant::now() > lease.expires).then_some(*addr)) - { - self.leases.remove(&addr); - - Some(addr) - } else { - None - } - } - - fn current_lease(&self, mac: &[u8; 16]) -> Option { - self.leases - .iter() - .find_map(|(addr, lease)| (lease.mac == *mac).then_some(*addr)) - } - - fn add_lease(&mut self, addr: Ipv4Addr, mac: [u8; 16], expires: Instant) -> bool { - self.remove_lease(&mac); - - self.leases.insert(addr, Lease { mac, expires }).is_ok() - } - - fn remove_lease(&mut self, mac: &[u8; 16]) -> bool { - if let Some(addr) = self.current_lease(mac) { - self.leases.remove(&addr); - - true - } else { - false - } - } - } -} - // DHCP Options const SUBNET_MASK: u8 = 1; const ROUTER: u8 = 3; @@ -1233,351 +657,3 @@ const DHCP_MESSAGE_TYPE: u8 = 53; const SERVER_IDENTIFIER: u8 = 54; const PARAMETER_REQUEST_LIST: u8 = 55; const MESSAGE: u8 = 56; - -// IP and UDP headers as well as utility functions for (de)serializing those, as well as computing their checksums -// -// Useful in the context of DHCP, as it operates in terms of raw sockets (particuarly the client) so (dis)assembling -// IP & UDP packets "by hand" is necessary. -pub mod raw_ip { - use log::trace; - - use no_std_net::Ipv4Addr; - - use super::{BytesIn, BytesOut, Error}; - - #[derive(Clone, Debug)] - pub struct Ipv4PacketHeader { - pub version: u8, // Version - pub hlen: u8, // Header length - pub tos: u8, // Type of service - pub len: u16, // Total length - pub id: u16, // Identification - pub off: u16, // Fragment offset field - pub ttl: u8, // Time to live - pub p: u8, // Protocol - pub sum: u16, // Checksum - pub src: Ipv4Addr, // Source address - pub dst: Ipv4Addr, // Dest address - } - - impl Ipv4PacketHeader { - pub const MIN_SIZE: usize = 20; - pub const CHECKSUM_WORD: usize = 5; - - pub const IP_DF: u16 = 0x4000; // Don't fragment flag - pub const IP_MF: u16 = 0x2000; // More fragments flag - - pub fn new(src: Ipv4Addr, dst: Ipv4Addr, proto: u8) -> Self { - Self { - version: 4, - hlen: Self::MIN_SIZE as _, - tos: 0, - len: Self::MIN_SIZE as _, - id: 0, - off: 0, - ttl: 64, - p: proto, - sum: 0, - src, - dst, - } - } - - /// Parses the packet from a byte slice - pub fn decode(data: &[u8]) -> Result { - let mut bytes = BytesIn::new(data); - - let vhl = bytes.byte()?; - - Ok(Self { - version: vhl >> 4, - hlen: (vhl & 0x0f) * 4, - tos: bytes.byte()?, - len: u16::from_be_bytes(bytes.arr()?), - id: u16::from_be_bytes(bytes.arr()?), - off: u16::from_be_bytes(bytes.arr()?), - ttl: bytes.byte()?, - p: bytes.byte()?, - sum: u16::from_be_bytes(bytes.arr()?), - src: u32::from_be_bytes(bytes.arr()?).into(), - dst: u32::from_be_bytes(bytes.arr()?).into(), - }) - } - - /// Encodes the packet into the provided buf slice - pub fn encode<'o>(&self, buf: &'o mut [u8]) -> Result<&'o [u8], Error> { - let mut bytes = BytesOut::new(buf); - - bytes - .byte( - (self.version << 4) | (self.hlen / 4 + (if self.hlen % 4 > 0 { 1 } else { 0 })), - )? - .byte(self.tos)? - .push(&u16::to_be_bytes(self.len))? - .push(&u16::to_be_bytes(self.id))? - .push(&u16::to_be_bytes(self.off))? - .byte(self.ttl)? - .byte(self.p)? - .push(&u16::to_be_bytes(self.sum))? - .push(&u32::to_be_bytes(self.src.into()))? - .push(&u32::to_be_bytes(self.dst.into()))?; - - let len = bytes.len(); - - Ok(&buf[..len]) - } - - pub fn encode_with_payload<'o, F>( - &mut self, - buf: &'o mut [u8], - encoder: F, - ) -> Result<&'o [u8], Error> - where - F: FnOnce(&mut [u8], &Self) -> Result, - { - let hdr_len = self.hlen as usize; - if hdr_len < Self::MIN_SIZE || buf.len() < hdr_len { - Err(Error::BufferOverflow)?; - } - - let (hdr_buf, payload_buf) = buf.split_at_mut(hdr_len); - - let payload_len = encoder(payload_buf, self)?; - - let len = hdr_len + payload_len; - self.len = len as _; - - let min_hdr_len = self.encode(hdr_buf)?.len(); - assert_eq!(min_hdr_len, Self::MIN_SIZE); - - hdr_buf[Self::MIN_SIZE..hdr_len].fill(0); - - let checksum = Self::checksum(hdr_buf); - self.sum = checksum; - - Self::inject_checksum(hdr_buf, checksum); - - Ok(&buf[..len]) - } - - pub fn decode_with_payload(packet: &[u8]) -> Result, Error> { - let hdr = Self::decode(packet)?; - if hdr.version == 4 { - // IPv4 - let len = hdr.len as usize; - if packet.len() < len { - Err(Error::DataUnderflow)?; - } - - let checksum = Self::checksum(&packet[..len]); - - trace!("IP header decoded, total_size={}, src={}, dst={}, hlen={}, size={}, checksum={}, ours={}", packet.len(), hdr.src, hdr.dst, hdr.hlen, hdr.len, hdr.sum, checksum); - - if checksum != hdr.sum { - Err(Error::InvalidPacket)?; - } - - let packet = &packet[..len]; - let hdr_len = hdr.hlen as usize; - if packet.len() < hdr_len { - Err(Error::DataUnderflow)?; - } - - Ok(Some((hdr, &packet[hdr_len..]))) - } else { - Ok(None) - } - } - - pub fn inject_checksum(packet: &mut [u8], checksum: u16) { - let checksum = checksum.to_be_bytes(); - - let offset = Self::CHECKSUM_WORD << 1; - packet[offset] = checksum[0]; - packet[offset + 1] = checksum[1]; - } - - pub fn checksum(packet: &[u8]) -> u16 { - let hlen = (packet[0] & 0x0f) as usize * 4; - - let sum = checksum_accumulate(&packet[..hlen], Self::CHECKSUM_WORD); - - checksum_finish(sum) - } - } - - #[derive(Clone, Debug)] - pub struct UdpPacketHeader { - pub src: u16, // Source port - pub dst: u16, // Destination port - pub len: u16, // UDP length - pub sum: u16, // UDP checksum - } - - impl UdpPacketHeader { - pub const PROTO: u8 = 17; - - pub const SIZE: usize = 8; - pub const CHECKSUM_WORD: usize = 3; - - pub fn new(src: u16, dst: u16) -> Self { - Self { - src, - dst, - len: 0, - sum: 0, - } - } - - /// Parses the packet header from a byte slice - pub fn decode(data: &[u8]) -> Result { - let mut bytes = BytesIn::new(data); - - Ok(Self { - src: u16::from_be_bytes(bytes.arr()?), - dst: u16::from_be_bytes(bytes.arr()?), - len: u16::from_be_bytes(bytes.arr()?), - sum: u16::from_be_bytes(bytes.arr()?), - }) - } - - /// Encodes the packet header into the provided buf slice - pub fn encode<'o>(&self, buf: &'o mut [u8]) -> Result<&'o [u8], Error> { - let mut bytes = BytesOut::new(buf); - - bytes - .push(&u16::to_be_bytes(self.src))? - .push(&u16::to_be_bytes(self.dst))? - .push(&u16::to_be_bytes(self.len))? - .push(&u16::to_be_bytes(self.sum))?; - - let len = bytes.len(); - - Ok(&buf[..len]) - } - - pub fn encode_with_payload<'o, F>( - &mut self, - buf: &'o mut [u8], - ip_hdr: &Ipv4PacketHeader, - encoder: F, - ) -> Result<&'o [u8], Error> - where - F: FnOnce(&mut [u8]) -> Result, - { - if buf.len() < Self::SIZE { - Err(Error::BufferOverflow)?; - } - - let (hdr_buf, payload_buf) = buf.split_at_mut(Self::SIZE); - - let payload_len = encoder(payload_buf)?; - - let len = Self::SIZE + payload_len; - self.len = len as _; - - let hdr_len = self.encode(hdr_buf)?.len(); - assert_eq!(Self::SIZE, hdr_len); - - let packet = &mut buf[..len]; - - let checksum = Self::checksum(packet, ip_hdr); - self.sum = checksum; - - Self::inject_checksum(packet, checksum); - - Ok(packet) - } - - pub fn decode_with_payload<'o>( - packet: &'o [u8], - ip_hdr: &Ipv4PacketHeader, - ) -> Result<(Self, &'o [u8]), Error> { - let hdr = Self::decode(packet)?; - - let len = hdr.len as usize; - if packet.len() < len { - Err(Error::DataUnderflow)?; - } - - let checksum = Self::checksum(&packet[..len], ip_hdr); - - trace!( - "UDP header decoded, src={}, dst={}, size={}, checksum={}, ours={}", - hdr.src, - hdr.dst, - hdr.len, - hdr.sum, - checksum - ); - - if checksum != hdr.sum { - Err(Error::InvalidPacket)?; - } - - let packet = &packet[..len]; - - let payload_data = &packet[Self::SIZE..]; - - Ok((hdr, payload_data)) - } - - pub fn inject_checksum(packet: &mut [u8], checksum: u16) { - let checksum = checksum.to_be_bytes(); - - let offset = Self::CHECKSUM_WORD << 1; - packet[offset] = checksum[0]; - packet[offset + 1] = checksum[1]; - } - - pub fn checksum(packet: &[u8], ip_hdr: &Ipv4PacketHeader) -> u16 { - let mut buf = [0; 12]; - - // Pseudo IP-header for UDP checksum calculation - let len = BytesOut::new(&mut buf) - .push(&u32::to_be_bytes(ip_hdr.src.into())) - .unwrap() - .push(&u32::to_be_bytes(ip_hdr.dst.into())) - .unwrap() - .byte(0) - .unwrap() - .byte(ip_hdr.p) - .unwrap() - .push(&u16::to_be_bytes(packet.len() as u16)) - .unwrap() - .len(); - - let sum = checksum_accumulate(&buf[..len], usize::MAX) - + checksum_accumulate(packet, Self::CHECKSUM_WORD); - - checksum_finish(sum) - } - } - - pub fn checksum_accumulate(bytes: &[u8], checksum_word: usize) -> u32 { - let mut bytes = BytesIn::new(bytes); - - let mut sum: u32 = 0; - while !bytes.is_empty() { - let skip = (bytes.offset() >> 1) == checksum_word; - let arr = bytes - .arr() - .ok() - .unwrap_or_else(|| [bytes.byte().unwrap(), 0]); - - let word = if skip { 0 } else { u16::from_be_bytes(arr) }; - - sum += word as u32; - } - - sum - } - - pub fn checksum_finish(mut sum: u32) -> u16 { - while sum >> 16 != 0 { - sum = (sum >> 16) + (sum & 0xffff); - } - - !sum as u16 - } -} diff --git a/edge-dhcp/src/server.rs b/edge-dhcp/src/server.rs new file mode 100644 index 0000000..93ff968 --- /dev/null +++ b/edge-dhcp/src/server.rs @@ -0,0 +1,258 @@ +use core::fmt::Debug; + +use embassy_time::{Duration, Instant}; + +use log::info; + +use super::*; + +#[derive(Clone, Debug)] +pub struct Lease { + mac: [u8; 16], + expires: Instant, +} + +#[derive(Clone, Debug)] +pub enum Action<'a> { + Discover(Option, &'a [u8; 16]), + Request(Ipv4Addr, &'a [u8; 16]), + Release(Ipv4Addr, &'a [u8; 16]), + Decline(Ipv4Addr, &'a [u8; 16]), +} + +pub struct ServerOptions<'a> { + pub ip: Ipv4Addr, + pub gateways: &'a [Ipv4Addr], + pub subnet: Option, + pub dns: &'a [Ipv4Addr], + pub lease_duration: Duration, +} + +impl<'a> ServerOptions<'a> { + pub fn process<'o>(&self, request: &'o Packet<'o>) -> Option> { + if request.reply { + return None; + } + + let mt = request.options.iter().find_map(|option| { + if let DhcpOption::MessageType(mt) = option { + Some(mt) + } else { + None + } + }); + + if let Some(mt) = mt { + let server_identifier = request.options.iter().find_map(|option| { + if let DhcpOption::ServerIdentifier(ip) = option { + Some(ip) + } else { + None + } + }); + + if server_identifier == Some(self.ip) + || server_identifier.is_none() && matches!(mt, MessageType::Discover) + { + info!("Packet is for us, will process, message type {mt:?}"); + + let request = match mt { + MessageType::Discover => { + let requested_ip = request.options.iter().find_map(|option| { + if let DhcpOption::RequestedIpAddress(ip) = option { + Some(ip) + } else { + None + } + }); + + Some(Action::Discover(requested_ip, &request.chaddr)) + } + MessageType::Request => { + let ip = request + .options + .iter() + .find_map(|option| { + if let DhcpOption::RequestedIpAddress(ip) = option { + Some(ip) + } else { + None + } + }) + .unwrap_or(request.ciaddr); + + Some(Action::Request(ip, &request.chaddr)) + } + MessageType::Release => Some(Action::Release(request.yiaddr, &request.chaddr)), + MessageType::Decline => Some(Action::Decline(request.yiaddr, &request.chaddr)), + _ => None, + }; + + return request; + } + } + + None + } + + pub fn offer( + &self, + request: &Packet, + ip: Ipv4Addr, + opt_buf: &'a mut [DhcpOption<'a>], + ) -> Packet<'a> { + self.reply(request, MessageType::Offer, Some(ip), opt_buf) + } + + pub fn ack_nack( + &self, + request: &Packet, + ip: Option, + opt_buf: &'a mut [DhcpOption<'a>], + ) -> Packet<'a> { + self.reply( + request, + if ip.is_some() { + MessageType::Ack + } else { + MessageType::Nak + }, + ip, + opt_buf, + ) + } + + fn reply( + &self, + request: &Packet, + mt: MessageType, + ip: Option, + buf: &'a mut [DhcpOption<'a>], + ) -> Packet<'a> { + let reply = request.new_reply( + ip, + request.options.reply( + mt, + self.ip, + self.lease_duration.as_secs() as _, + self.gateways, + self.subnet, + self.dns, + buf, + ), + ); + + info!("Reply: {reply:?}"); + + reply + } +} + +/// A simple DHCP server. +/// The server is unaware of the IP/UDP transport layer and operates purely in terms of packets +/// represented as Rust slices. +#[derive(Clone, Debug)] +pub struct Server { + pub range_start: Ipv4Addr, + pub range_end: Ipv4Addr, + pub leases: heapless::LinearMap, +} + +impl Server { + pub fn handle_request<'o>( + &mut self, + opt_buf: &'o mut [DhcpOption<'o>], + server_options: &'o ServerOptions, + request: &Packet, + ) -> Option> { + server_options + .process(request) + .and_then(|action| match action { + Action::Discover(requested_ip, mac) => { + let ip = requested_ip + .and_then(|ip| self.is_available(mac, ip).then_some(ip)) + .or_else(|| self.current_lease(mac)) + .or_else(|| self.available()); + + ip.map(|ip| server_options.offer(request, ip, opt_buf)) + } + Action::Request(ip, mac) => { + let ip = (self.is_available(mac, ip) + && self.add_lease( + ip, + request.chaddr, + Instant::now() + server_options.lease_duration, + )) + .then_some(ip); + + Some(server_options.ack_nack(request, ip, opt_buf)) + } + Action::Release(_ip, mac) | Action::Decline(_ip, mac) => { + self.remove_lease(mac); + + None + } + }) + } + + fn is_available(&self, mac: &[u8; 16], addr: Ipv4Addr) -> bool { + let pos: u32 = addr.into(); + + let start: u32 = self.range_start.into(); + let end: u32 = self.range_end.into(); + + pos >= start + && pos <= end + && match self.leases.get(&addr) { + Some(lease) => lease.mac == *mac || Instant::now() > lease.expires, + None => true, + } + } + + fn available(&mut self) -> Option { + let start: u32 = self.range_start.into(); + let end: u32 = self.range_end.into(); + + for pos in start..end + 1 { + let addr = pos.into(); + + if !self.leases.contains_key(&addr) { + return Some(addr); + } + } + + if let Some(addr) = self + .leases + .iter() + .find_map(|(addr, lease)| (Instant::now() > lease.expires).then_some(*addr)) + { + self.leases.remove(&addr); + + Some(addr) + } else { + None + } + } + + fn current_lease(&self, mac: &[u8; 16]) -> Option { + self.leases + .iter() + .find_map(|(addr, lease)| (lease.mac == *mac).then_some(*addr)) + } + + fn add_lease(&mut self, addr: Ipv4Addr, mac: [u8; 16], expires: Instant) -> bool { + self.remove_lease(&mac); + + self.leases.insert(addr, Lease { mac, expires }).is_ok() + } + + fn remove_lease(&mut self, mac: &[u8; 16]) -> bool { + if let Some(addr) = self.current_lease(mac) { + self.leases.remove(&addr); + + true + } else { + false + } + } +} diff --git a/edge-http/Cargo.toml b/edge-http/Cargo.toml index 617aa3d..266aae0 100644 --- a/edge-http/Cargo.toml +++ b/edge-http/Cargo.toml @@ -3,22 +3,20 @@ name = "edge-http" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [features] -std = ["httparse/std", "embedded-io/std", "embassy-sync/std"] -alloc = ["embedded-io/alloc", "embedded-io-async/alloc"] +std = [] +nightly = ["embedded-io-async", "embedded-nal-async", "embedded-nal-async-xtra/nightly", "embassy-sync", "embassy-futures", "embedded-svc?/nightly"] [dependencies] -embassy-sync = { workspace = true, features = ["nightly"] } -embedded-io.workspace = true -embedded-io-async.workspace = true -embedded-nal-async.workspace = true -heapless.workspace = true -log.workspace = true -no-std-net.workspace = true - -embassy-futures.workspace = true +embedded-io-async = { workspace = true, optional = true } +embedded-nal-async = { workspace = true, optional = true } +embedded-nal-async-xtra = { workspace = true, optional = true } +embedded-svc = { workspace = true, optional = true, default-features = false } +heapless = { workspace = true } +log = { workspace = true } +no-std-net = { workspace = true } +embassy-sync = { workspace = true, optional = true } +embassy-futures = { workspace = true, optional = true } httparse = { version = "1.7", default-features = false } - -edge-tcp = { version = "0.1.0", path = "../edge-tcp" } +base64 = { version = "0.13", default-features = false } +sha1_smol = { version = "1", default-features = false } diff --git a/edge-http/src/asynch.rs b/edge-http/src/io.rs similarity index 54% rename from edge-http/src/asynch.rs rename to edge-http/src/io.rs index 6611fb3..aee488d 100644 --- a/edge-http/src/asynch.rs +++ b/edge-http/src/io.rs @@ -2,18 +2,13 @@ use core::cmp::min; use core::fmt::{Display, Write as _}; use core::str; -use embedded_io::ErrorType; -use embedded_io_async::{Read, Write}; +use embedded_io_async::{ErrorType, Read, Write}; -use httparse::{Header, Status, EMPTY_HEADER}; +use httparse::Status; use log::trace; -#[allow(unused_imports)] -#[cfg(feature = "embedded-svc")] -pub use embedded_svc_compat::*; - -use super::ws::http::UpgradeError; +use crate::{BodyType, Headers, Method, RequestHeaders, ResponseHeaders}; pub mod client; pub mod server; @@ -46,14 +41,14 @@ impl From for Error { } } -impl embedded_io::Error for Error +impl embedded_io_async::Error for Error where - E: embedded_io::Error, + E: embedded_io_async::Error, { - fn kind(&self) -> embedded_io::ErrorKind { + fn kind(&self) -> embedded_io_async::ErrorKind { match self { Self::Io(e) => e.kind(), - _ => embedded_io::ErrorKind::Other, + _ => embedded_io_async::ErrorKind::Other, } } } @@ -80,162 +75,6 @@ where #[cfg(feature = "std")] impl std::error::Error for Error where E: std::error::Error {} -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "std", derive(Hash))] -pub enum Method { - Delete, - Get, - Head, - Post, - Put, - Connect, - Options, - Trace, - Copy, - Lock, - MkCol, - Move, - Propfind, - Proppatch, - Search, - Unlock, - Bind, - Rebind, - Unbind, - Acl, - Report, - MkActivity, - Checkout, - Merge, - MSearch, - Notify, - Subscribe, - Unsubscribe, - Patch, - Purge, - MkCalendar, - Link, - Unlink, -} - -impl Method { - pub fn new(method: &str) -> Option { - if method.eq_ignore_ascii_case("Delete") { - Some(Self::Delete) - } else if method.eq_ignore_ascii_case("Get") { - Some(Self::Get) - } else if method.eq_ignore_ascii_case("Head") { - Some(Self::Head) - } else if method.eq_ignore_ascii_case("Post") { - Some(Self::Post) - } else if method.eq_ignore_ascii_case("Put") { - Some(Self::Put) - } else if method.eq_ignore_ascii_case("Connect") { - Some(Self::Connect) - } else if method.eq_ignore_ascii_case("Options") { - Some(Self::Options) - } else if method.eq_ignore_ascii_case("Trace") { - Some(Self::Trace) - } else if method.eq_ignore_ascii_case("Copy") { - Some(Self::Copy) - } else if method.eq_ignore_ascii_case("Lock") { - Some(Self::Lock) - } else if method.eq_ignore_ascii_case("MkCol") { - Some(Self::MkCol) - } else if method.eq_ignore_ascii_case("Move") { - Some(Self::Move) - } else if method.eq_ignore_ascii_case("Propfind") { - Some(Self::Propfind) - } else if method.eq_ignore_ascii_case("Proppatch") { - Some(Self::Proppatch) - } else if method.eq_ignore_ascii_case("Search") { - Some(Self::Search) - } else if method.eq_ignore_ascii_case("Unlock") { - Some(Self::Unlock) - } else if method.eq_ignore_ascii_case("Bind") { - Some(Self::Bind) - } else if method.eq_ignore_ascii_case("Rebind") { - Some(Self::Rebind) - } else if method.eq_ignore_ascii_case("Unbind") { - Some(Self::Unbind) - } else if method.eq_ignore_ascii_case("Acl") { - Some(Self::Acl) - } else if method.eq_ignore_ascii_case("Report") { - Some(Self::Report) - } else if method.eq_ignore_ascii_case("MkActivity") { - Some(Self::MkActivity) - } else if method.eq_ignore_ascii_case("Checkout") { - Some(Self::Checkout) - } else if method.eq_ignore_ascii_case("Merge") { - Some(Self::Merge) - } else if method.eq_ignore_ascii_case("MSearch") { - Some(Self::MSearch) - } else if method.eq_ignore_ascii_case("Notify") { - Some(Self::Notify) - } else if method.eq_ignore_ascii_case("Subscribe") { - Some(Self::Subscribe) - } else if method.eq_ignore_ascii_case("Unsubscribe") { - Some(Self::Unsubscribe) - } else if method.eq_ignore_ascii_case("Patch") { - Some(Self::Patch) - } else if method.eq_ignore_ascii_case("Purge") { - Some(Self::Purge) - } else if method.eq_ignore_ascii_case("MkCalendar") { - Some(Self::MkCalendar) - } else if method.eq_ignore_ascii_case("Link") { - Some(Self::Link) - } else if method.eq_ignore_ascii_case("Unlink") { - Some(Self::Unlink) - } else { - None - } - } - - fn as_str(&self) -> &'static str { - match self { - Self::Delete => "DELETE", - Self::Get => "GET", - Self::Head => "HEAD", - Self::Post => "POST", - Self::Put => "PUT", - Self::Connect => "CONNECT", - Self::Options => "OPTIONS", - Self::Trace => "TRACE", - Self::Copy => "COPY", - Self::Lock => "LOCK", - Self::MkCol => "MKCOL", - Self::Move => "MOVE", - Self::Propfind => "PROPFIND", - Self::Proppatch => "PROPPATCH", - Self::Search => "SEARCH", - Self::Unlock => "UNLOCK", - Self::Bind => "BIND", - Self::Rebind => "REBIND", - Self::Unbind => "UNBIND", - Self::Acl => "ACL", - Self::Report => "REPORT", - Self::MkActivity => "MKACTIVITY", - Self::Checkout => "CHECKOUT", - Self::Merge => "MERGE", - Self::MSearch => "MSEARCH", - Self::Notify => "NOTIFY", - Self::Subscribe => "SUBSCRIBE", - Self::Unsubscribe => "UNSUBSCRIBE", - Self::Patch => "PATCH", - Self::Purge => "PURGE", - Self::MkCalendar => "MKCALENDAR", - Self::Link => "LINK", - Self::Unlink => "UNLINK", - } - } -} - -impl Display for Method { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}", self.as_str()) - } -} - pub async fn send_request( method: Option, path: Option<&str>, @@ -311,215 +150,7 @@ where output.write_all(b"\r\n").await.map_err(Error::Io) } -#[derive(Debug)] -pub struct Headers<'b, const N: usize = 64>([httparse::Header<'b>; N]); - impl<'b, const N: usize> Headers<'b, N> { - pub const fn new() -> Self { - Self([httparse::EMPTY_HEADER; N]) - } - - pub fn content_len(&self) -> Option { - self.get("Content-Length") - .map(|content_len_str| content_len_str.parse::().unwrap()) - } - - pub fn content_type(&self) -> Option<&str> { - self.get("Content-Type") - } - - pub fn content_encoding(&self) -> Option<&str> { - self.get("Content-Encoding") - } - - pub fn transfer_encoding(&self) -> Option<&str> { - self.get("Transfer-Encoding") - } - - pub fn host(&self) -> Option<&str> { - self.get("Host") - } - - pub fn connection(&self) -> Option<&str> { - self.get("Connection") - } - - pub fn cache_control(&self) -> Option<&str> { - self.get("Cache-Control") - } - - pub fn upgrade(&self) -> Option<&str> { - self.get("Upgrade") - } - - pub fn is_ws_upgrade_request(&self) -> bool { - crate::asynch::ws::http::is_upgrade_request(self.iter()) - } - - pub fn iter(&self) -> impl Iterator { - self.iter_raw() - .map(|(name, value)| (name, unsafe { str::from_utf8_unchecked(value) })) - } - - pub fn iter_raw(&self) -> impl Iterator { - self.0 - .iter() - .filter(|header| !header.name.is_empty()) - .map(|header| (header.name, header.value)) - } - - pub fn get(&self, name: &str) -> Option<&str> { - self.iter() - .find(|(hname, _)| name.eq_ignore_ascii_case(hname)) - .map(|(_, value)| value) - } - - pub fn get_raw(&self, name: &str) -> Option<&[u8]> { - self.iter_raw() - .find(|(hname, _)| name.eq_ignore_ascii_case(hname)) - .map(|(_, value)| value) - } - - pub fn set(&mut self, name: &'b str, value: &'b str) -> &mut Self { - self.set_raw(name, value.as_bytes()) - } - - pub fn set_raw(&mut self, name: &'b str, value: &'b [u8]) -> &mut Self { - if !name.is_empty() { - for header in &mut self.0 { - if header.name.is_empty() || header.name.eq_ignore_ascii_case(name) { - *header = Header { name, value }; - return self; - } - } - } - - panic!("No space left"); - } - - pub fn remove(&mut self, name: &str) -> &mut Self { - let index = self - .0 - .iter() - .enumerate() - .find(|(_, header)| header.name.eq_ignore_ascii_case(name)); - - if let Some((mut index, _)) = index { - while index < self.0.len() - 1 { - self.0[index] = self.0[index + 1]; - - index += 1; - } - - self.0[index] = EMPTY_HEADER; - } - - self - } - - pub fn set_content_len( - &mut self, - content_len: u64, - buf: &'b mut heapless::String<20>, - ) -> &mut Self { - *buf = heapless::String::<20>::from(content_len); - - self.set("Content-Length", buf.as_str()) - } - - pub fn set_content_type(&mut self, content_type: &'b str) -> &mut Self { - self.set("Content-Type", content_type) - } - - pub fn set_content_encoding(&mut self, content_encoding: &'b str) -> &mut Self { - self.set("Content-Encoding", content_encoding) - } - - pub fn set_transfer_encoding(&mut self, transfer_encoding: &'b str) -> &mut Self { - self.set("Transfer-Encoding", transfer_encoding) - } - - pub fn set_transfer_encoding_chunked(&mut self) -> &mut Self { - self.set_transfer_encoding("Chunked") - } - - pub fn set_host(&mut self, host: &'b str) -> &mut Self { - self.set("Host", host) - } - - pub fn set_connection(&mut self, connection: &'b str) -> &mut Self { - self.set("Connection", connection) - } - - pub fn set_connection_close(&mut self) -> &mut Self { - self.set_connection("Close") - } - - pub fn set_connection_keep_alive(&mut self) -> &mut Self { - self.set_connection("Keep-Alive") - } - - pub fn set_connection_upgrade(&mut self) -> &mut Self { - self.set_connection("Upgrade") - } - - pub fn set_cache_control(&mut self, cache: &'b str) -> &mut Self { - self.set("Cache-Control", cache) - } - - pub fn set_cache_control_no_cache(&mut self) -> &mut Self { - self.set_cache_control("No-Cache") - } - - pub fn set_upgrade(&mut self, upgrade: &'b str) -> &mut Self { - self.set("Upgrade", upgrade) - } - - pub fn set_upgrade_websocket(&mut self) -> &mut Self { - self.set_upgrade("websocket") - } - - pub fn set_ws_upgrade_request_headers( - &mut self, - host: Option<&'b str>, - origin: Option<&'b str>, - version: Option<&'b str>, - nonce: &[u8; crate::asynch::ws::http::NONCE_LEN], - nonce_base64_buf: &'b mut [u8; crate::asynch::ws::http::MAX_BASE64_KEY_LEN], - ) -> &mut Self { - for (name, value) in crate::asynch::ws::http::upgrade_request_headers( - host, - origin, - version, - nonce, - nonce_base64_buf, - ) { - self.set(name, value); - } - - self - } - - pub fn set_ws_upgrade_response_headers<'a, H>( - &mut self, - request_headers: H, - version: Option<&'a str>, - sec_key_response_base64_buf: &'b mut [u8; crate::asynch::ws::http::MAX_BASE64_KEY_RESPONSE_LEN], - ) -> Result<&mut Self, UpgradeError> - where - H: IntoIterator, - { - for (name, value) in crate::asynch::ws::http::upgrade_response_headers( - request_headers, - version, - sec_key_response_base64_buf, - )? { - self.set(name, value); - } - - Ok(self) - } - pub async fn send(&self, output: W) -> Result> where W: Write, @@ -528,51 +159,6 @@ impl<'b, const N: usize> Headers<'b, N> { } } -impl<'b, const N: usize> Default for Headers<'b, N> { - fn default() -> Self { - Self::new() - } -} - -#[derive(Copy, Clone, Eq, PartialEq, Debug)] -pub enum BodyType { - Chunked, - ContentLen(u64), - Close, - Unknown, -} - -impl BodyType { - pub fn from_header(name: &str, value: &str) -> Self { - if "Transfer-Encoding".eq_ignore_ascii_case(name) { - if value.eq_ignore_ascii_case("Chunked") { - return Self::Chunked; - } - } else if "Content-Length".eq_ignore_ascii_case(name) { - return Self::ContentLen(value.parse::().unwrap()); // TODO - } else if "Connection".eq_ignore_ascii_case(name) && value.eq_ignore_ascii_case("Close") { - return Self::Close; - } - - Self::Unknown - } - - pub fn from_headers<'a, H>(headers: H) -> Self - where - H: IntoIterator, - { - for (name, value) in headers { - let body = Self::from_header(name, value); - - if body != Self::Unknown { - return body; - } - } - - Self::Unknown - } -} - pub enum Body<'b, R> { Close(PartiallyRead<'b, R>), ContentLen(ContentLenRead>), @@ -1175,22 +761,7 @@ where } } -#[derive(Default, Debug)] -pub struct RequestHeaders<'b, const N: usize> { - pub method: Option, - pub path: Option<&'b str>, - pub headers: Headers<'b, N>, -} - impl<'b, const N: usize> RequestHeaders<'b, N> { - pub const fn new() -> Self { - Self { - method: None, - path: None, - headers: Headers::::new(), - } - } - pub async fn receive( &mut self, buf: &'b mut [u8], @@ -1241,44 +812,7 @@ impl<'b, const N: usize> RequestHeaders<'b, N> { } } -impl<'b, const N: usize> Display for RequestHeaders<'b, N> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - // if let Some(version) = self.version { - // writeln!(f, "Version {}", version)?; - // } - - if let Some(method) = self.method { - writeln!(f, "{} {}", method, self.path.unwrap_or(""))?; - } - - for (name, value) in self.headers.iter() { - if name.is_empty() { - break; - } - - writeln!(f, "{name}: {value}")?; - } - - Ok(()) - } -} - -#[derive(Default, Debug)] -pub struct ResponseHeaders<'b, const N: usize> { - pub code: Option, - pub reason: Option<&'b str>, - pub headers: Headers<'b, N>, -} - impl<'b, const N: usize> ResponseHeaders<'b, N> { - pub const fn new() -> Self { - Self { - code: None, - reason: None, - headers: Headers::::new(), - } - } - pub async fn receive( &mut self, buf: &'b mut [u8], @@ -1323,28 +857,6 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> { } } -impl<'b, const N: usize> Display for ResponseHeaders<'b, N> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - // if let Some(version) = self.version { - // writeln!(f, "Version {}", version)?; - // } - - if let Some(code) = self.code { - writeln!(f, "{} {}", code, self.reason.unwrap_or(""))?; - } - - for (name, value) in self.headers.iter() { - if name.is_empty() { - break; - } - - writeln!(f, "{name}: {value}")?; - } - - Ok(()) - } -} - async fn read_reply_buf( mut input: R, buf: &mut [u8], @@ -1432,128 +944,3 @@ where Ok(()) } - -#[cfg(feature = "embedded-svc")] -mod embedded_svc_compat { - use core::str; - - use embedded_svc::http::client::asynch::Method; - - impl From for super::Method { - fn from(method: Method) -> Self { - match method { - Method::Delete => super::Method::Delete, - Method::Get => super::Method::Get, - Method::Head => super::Method::Head, - Method::Post => super::Method::Post, - Method::Put => super::Method::Put, - Method::Connect => super::Method::Connect, - Method::Options => super::Method::Options, - Method::Trace => super::Method::Trace, - Method::Copy => super::Method::Copy, - Method::Lock => super::Method::Lock, - Method::MkCol => super::Method::MkCol, - Method::Move => super::Method::Move, - Method::Propfind => super::Method::Propfind, - Method::Proppatch => super::Method::Proppatch, - Method::Search => super::Method::Search, - Method::Unlock => super::Method::Unlock, - Method::Bind => super::Method::Bind, - Method::Rebind => super::Method::Rebind, - Method::Unbind => super::Method::Unbind, - Method::Acl => super::Method::Acl, - Method::Report => super::Method::Report, - Method::MkActivity => super::Method::MkActivity, - Method::Checkout => super::Method::Checkout, - Method::Merge => super::Method::Merge, - Method::MSearch => super::Method::MSearch, - Method::Notify => super::Method::Notify, - Method::Subscribe => super::Method::Subscribe, - Method::Unsubscribe => super::Method::Unsubscribe, - Method::Patch => super::Method::Patch, - Method::Purge => super::Method::Purge, - Method::MkCalendar => super::Method::MkCalendar, - Method::Link => super::Method::Link, - Method::Unlink => super::Method::Unlink, - } - } - } - - impl From for Method { - fn from(method: super::Method) -> Self { - match method { - super::Method::Delete => Method::Delete, - super::Method::Get => Method::Get, - super::Method::Head => Method::Head, - super::Method::Post => Method::Post, - super::Method::Put => Method::Put, - super::Method::Connect => Method::Connect, - super::Method::Options => Method::Options, - super::Method::Trace => Method::Trace, - super::Method::Copy => Method::Copy, - super::Method::Lock => Method::Lock, - super::Method::MkCol => Method::MkCol, - super::Method::Move => Method::Move, - super::Method::Propfind => Method::Propfind, - super::Method::Proppatch => Method::Proppatch, - super::Method::Search => Method::Search, - super::Method::Unlock => Method::Unlock, - super::Method::Bind => Method::Bind, - super::Method::Rebind => Method::Rebind, - super::Method::Unbind => Method::Unbind, - super::Method::Acl => Method::Acl, - super::Method::Report => Method::Report, - super::Method::MkActivity => Method::MkActivity, - super::Method::Checkout => Method::Checkout, - super::Method::Merge => Method::Merge, - super::Method::MSearch => Method::MSearch, - super::Method::Notify => Method::Notify, - super::Method::Subscribe => Method::Subscribe, - super::Method::Unsubscribe => Method::Unsubscribe, - super::Method::Patch => Method::Patch, - super::Method::Purge => Method::Purge, - super::Method::MkCalendar => Method::MkCalendar, - super::Method::Link => Method::Link, - super::Method::Unlink => Method::Unlink, - } - } - } - - impl<'b, const N: usize> embedded_svc::http::Query for super::RequestHeaders<'b, N> { - fn uri(&self) -> &'_ str { - self.path.unwrap_or("") - } - - fn method(&self) -> Method { - self.method.unwrap_or(super::Method::Get).into() - } - } - - impl<'b, const N: usize> embedded_svc::http::Headers for super::RequestHeaders<'b, N> { - fn header(&self, name: &str) -> Option<&'_ str> { - self.headers.get(name) - } - } - - impl<'b, const N: usize> embedded_svc::http::Status for super::ResponseHeaders<'b, N> { - fn status(&self) -> u16 { - self.code.unwrap_or(200) - } - - fn status_message(&self) -> Option<&'_ str> { - self.reason - } - } - - impl<'b, const N: usize> embedded_svc::http::Headers for super::ResponseHeaders<'b, N> { - fn header(&self, name: &str) -> Option<&'_ str> { - self.headers.get(name) - } - } - - impl<'b, const N: usize> embedded_svc::http::Headers for super::Headers<'b, N> { - fn header(&self, name: &str) -> Option<&'_ str> { - self.get(name) - } - } -} diff --git a/edge-http/src/client.rs b/edge-http/src/io/client.rs similarity index 89% rename from edge-http/src/client.rs rename to edge-http/src/io/client.rs index 85dd606..a37df3c 100644 --- a/edge-http/src/client.rs +++ b/edge-http/src/io/client.rs @@ -1,13 +1,14 @@ use core::{mem, str}; -use embedded_io::ErrorType; -use embedded_io_async::{Read, Write}; -use no_std_net::SocketAddr; +use embedded_io_async::{ErrorType, Read, Write}; -use crate::{ +use embedded_nal_async::{SocketAddr, TcpConnect}; + +use crate::ws::{upgrade_request_headers, MAX_BASE64_KEY_LEN, NONCE_LEN}; + +use super::{ send_headers, send_headers_end, send_request, Body, BodyType, Error, ResponseHeaders, SendBody, }; -use embedded_nal_async::TcpConnect; #[allow(unused_imports)] #[cfg(feature = "embedded-svc")] @@ -68,6 +69,46 @@ where matches!(self, Self::Response(_)) } + pub async fn initiate_ws_upgrade_request<'a>( + &'a mut self, + host: Option<&'a str>, + origin: Option<&'a str>, + uri: &'a str, + version: Option<&'a str>, + nonce: &[u8; NONCE_LEN], + ) -> Result<(), Error> + where + T: TcpConnect, + { + let mut nonce_base64_buf = [0_u8; MAX_BASE64_KEY_LEN]; + + let headers = upgrade_request_headers(host, origin, version, nonce, &mut nonce_base64_buf); + + self.initiate_request(Method::Get, uri, &headers).await + } + + pub fn is_ws_upgrade_accepted(&self, _nonce: &[u8; NONCE_LEN]) -> Result> + where + T: TcpConnect, + { + let headers = self.headers()?; + + let succeeded = matches!(headers.code, Some(101)) + && headers + .headers + .connection() + .map(|v| v.eq_ignore_ascii_case("Upgrade")) + .unwrap_or(false) + && headers + .headers + .upgrade() + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false) + && headers.headers.get("Sec-WebSocket-Accept").is_some(); + + Ok(succeeded) + } + #[allow(clippy::type_complexity)] pub fn split(&mut self) -> (&ResponseHeaders<'b, N>, &mut Body<'b, T::Connection<'b>>) { let response = self.response_mut().expect("Not in response mode"); diff --git a/edge-http/src/server.rs b/edge-http/src/io/server.rs similarity index 98% rename from edge-http/src/server.rs rename to edge-http/src/io/server.rs index 8dc4196..1872e28 100644 --- a/edge-http/src/server.rs +++ b/edge-http/src/io/server.rs @@ -2,12 +2,11 @@ use core::fmt::{self, Debug, Display, Write as _}; use core::future::Future; use core::mem; -use embedded_io::ErrorType; -use embedded_io_async::{Read, Write}; +use embedded_io_async::{ErrorType, Read, Write}; use log::{info, warn}; -use crate::{ +use super::{ send_headers, send_headers_end, send_status, Body, BodyType, Error, Method, RequestHeaders, SendBody, }; @@ -336,7 +335,7 @@ pub struct Server { impl Server where - A: edge_tcp::TcpAccept, + A: embedded_nal_async_xtra::TcpAccept, H: for<'b, 't> Handler<'b, N, &'b mut A::Connection<'t>>, { pub const fn new(acceptor: A, handler: H) -> Self { @@ -422,8 +421,8 @@ mod embedded_svc_compat { use embedded_svc::http::server::asynch::{Connection, Headers, Query}; use embedded_svc::utils::http::server::registration::{ChainHandler, ChainRoot}; - use crate::asynch::http::Method; - use crate::asynch::http::{Body, RequestHeaders}; + use crate::io::Body; + use crate::{Method, RequestHeaders}; use super::*; diff --git a/edge-http/src/lib.rs b/edge-http/src/lib.rs index 33db836..b28191e 100644 --- a/edge-http/src/lib.rs +++ b/edge-http/src/lib.rs @@ -1,93 +1,18 @@ #![cfg_attr(not(feature = "std"), no_std)] #![allow(stable_features)] #![allow(unknown_lints)] -#![feature(async_fn_in_trait)] -#![allow(async_fn_in_trait)] -#![feature(impl_trait_projections)] -#![feature(impl_trait_in_assoc_type)] +#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", allow(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", feature(impl_trait_projections))] +#![cfg_attr(feature = "nightly", feature(impl_trait_in_assoc_type))] -use core::cmp::min; -use core::fmt::{Display, Write as _}; +use core::fmt::Display; use core::str; -use embedded_io::ErrorType; -use embedded_io_async::{Read, Write}; - -use httparse::{Header, Status, EMPTY_HEADER}; - -use log::trace; - -#[allow(unused_imports)] -#[cfg(feature = "embedded-svc")] -pub use embedded_svc_compat::*; +use httparse::{Header, EMPTY_HEADER}; #[cfg(feature = "nightly")] -pub mod asynch; - -pub mod client; -pub mod server; - -/// An error in parsing the headers or the body. -#[derive(Debug)] -pub enum Error { - InvalidHeaders, - InvalidBody, - TooManyHeaders, - TooLongHeaders, - TooLongBody, - IncompleteHeaders, - IncompleteBody, - InvalidState, - Io(E), -} - -impl From for Error { - fn from(e: httparse::Error) -> Self { - match e { - httparse::Error::HeaderName => Self::InvalidHeaders, - httparse::Error::HeaderValue => Self::InvalidHeaders, - httparse::Error::NewLine => Self::InvalidHeaders, - httparse::Error::Status => Self::InvalidHeaders, - httparse::Error::Token => Self::InvalidHeaders, - httparse::Error::TooManyHeaders => Self::TooManyHeaders, - httparse::Error::Version => Self::InvalidHeaders, - } - } -} - -impl embedded_io::Error for Error -where - E: embedded_io::Error, -{ - fn kind(&self) -> embedded_io::ErrorKind { - match self { - Self::Io(e) => e.kind(), - _ => embedded_io::ErrorKind::Other, - } - } -} - -impl Display for Error -where - E: Display, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - Self::InvalidHeaders => write!(f, "Invalid HTTP headers or status line"), - Self::InvalidBody => write!(f, "Invalid HTTP body"), - Self::TooManyHeaders => write!(f, "Too many HTTP headers"), - Self::TooLongHeaders => write!(f, "HTTP headers section is too long"), - Self::TooLongBody => write!(f, "HTTP body is too long"), - Self::IncompleteHeaders => write!(f, "HTTP headers section is incomplete"), - Self::IncompleteBody => write!(f, "HTTP body is incomplete"), - Self::InvalidState => write!(f, "Connection is not in requested state"), - Self::Io(e) => write!(f, "{e}"), - } - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error where E: std::error::Error {} +pub mod io; #[derive(Copy, Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "std", derive(Hash))] @@ -245,81 +170,6 @@ impl Display for Method { } } -pub async fn send_request( - method: Option, - path: Option<&str>, - output: W, -) -> Result<(), Error> -where - W: Write, -{ - send_status_line(true, method.map(|method| method.as_str()), path, output).await -} - -pub async fn send_status( - status: Option, - reason: Option<&str>, - output: W, -) -> Result<(), Error> -where - W: Write, -{ - let status_str = status.map(heapless::String::<5>::from); - - send_status_line( - false, - status_str.as_ref().map(|status| status.as_str()), - reason, - output, - ) - .await -} - -pub async fn send_headers<'a, H, W>(headers: H, output: W) -> Result> -where - W: Write, - H: IntoIterator, -{ - send_raw_headers( - headers - .into_iter() - .map(|(name, value)| (*name, value.as_bytes())), - output, - ) - .await -} - -pub async fn send_raw_headers<'a, H, W>( - headers: H, - mut output: W, -) -> Result> -where - W: Write, - H: IntoIterator, -{ - let mut body = BodyType::Unknown; - - for (name, value) in headers.into_iter() { - if body == BodyType::Unknown { - body = BodyType::from_header(name, unsafe { str::from_utf8_unchecked(value) }); - } - - output.write_all(name.as_bytes()).await.map_err(Error::Io)?; - output.write_all(b": ").await.map_err(Error::Io)?; - output.write_all(value).await.map_err(Error::Io)?; - output.write_all(b"\r\n").await.map_err(Error::Io)?; - } - - Ok(body) -} - -pub async fn send_headers_end(mut output: W) -> Result<(), Error> -where - W: Write, -{ - output.write_all(b"\r\n").await.map_err(Error::Io) -} - #[derive(Debug)] pub struct Headers<'b, const N: usize = 64>([httparse::Header<'b>; N]); @@ -357,6 +207,10 @@ impl<'b, const N: usize> Headers<'b, N> { self.get("Cache-Control") } + pub fn is_ws_upgrade_request(&self) -> bool { + crate::ws::is_upgrade_request(self.iter()) + } + pub fn upgrade(&self) -> Option<&str> { self.get("Upgrade") } @@ -484,11 +338,39 @@ impl<'b, const N: usize> Headers<'b, N> { self.set_upgrade("websocket") } - pub async fn send(&self, output: W) -> Result> + pub fn set_ws_upgrade_request_headers( + &mut self, + host: Option<&'b str>, + origin: Option<&'b str>, + version: Option<&'b str>, + nonce: &[u8; ws::NONCE_LEN], + nonce_base64_buf: &'b mut [u8; ws::MAX_BASE64_KEY_LEN], + ) -> &mut Self { + for (name, value) in + ws::upgrade_request_headers(host, origin, version, nonce, nonce_base64_buf) + { + self.set(name, value); + } + + self + } + + pub fn set_ws_upgrade_response_headers<'a, H>( + &mut self, + request_headers: H, + version: Option<&'a str>, + sec_key_response_base64_buf: &'b mut [u8; ws::MAX_BASE64_KEY_RESPONSE_LEN], + ) -> Result<&mut Self, ws::UpgradeError> where - W: Write, + H: IntoIterator, { - send_raw_headers(self.iter_raw(), output).await + for (name, value) in + ws::upgrade_response_headers(request_headers, version, sec_key_response_base64_buf)? + { + self.set(name, value); + } + + Ok(self) } } @@ -537,608 +419,6 @@ impl BodyType { } } -pub enum Body<'b, R> { - Close(PartiallyRead<'b, R>), - ContentLen(ContentLenRead>), - Chunked(ChunkedRead<'b, PartiallyRead<'b, R>>), -} - -impl<'b, R> Body<'b, R> -where - R: Read, -{ - pub fn new(body_type: BodyType, buf: &'b mut [u8], read_len: usize, input: R) -> Self { - match body_type { - BodyType::Chunked => Body::Chunked(ChunkedRead::new( - PartiallyRead::new(&[], input), - buf, - read_len, - )), - BodyType::ContentLen(content_len) => Body::ContentLen(ContentLenRead::new( - content_len, - PartiallyRead::new(&buf[..read_len], input), - )), - BodyType::Close => Body::Close(PartiallyRead::new(&buf[..read_len], input)), - BodyType::Unknown => Body::ContentLen(ContentLenRead::new( - 0, - PartiallyRead::new(&buf[..read_len], input), - )), - } - } - - pub fn is_complete(&self) -> bool { - match self { - Self::Close(_) => true, - Self::ContentLen(r) => r.is_complete(), - Self::Chunked(r) => r.is_complete(), - } - } - - pub fn as_raw_reader(&mut self) -> &mut R { - match self { - Self::Close(r) => &mut r.input, - Self::ContentLen(r) => &mut r.input.input, - Self::Chunked(r) => &mut r.input.input, - } - } - - pub fn release(self) -> R { - match self { - Self::Close(r) => r.release(), - Self::ContentLen(r) => r.release().release(), - Self::Chunked(r) => r.release().release(), - } - } -} - -impl<'b, R> ErrorType for Body<'b, R> -where - R: ErrorType, -{ - type Error = Error; -} - -impl<'b, R> Read for Body<'b, R> -where - R: Read, -{ - async fn read(&mut self, buf: &mut [u8]) -> Result { - match self { - Self::Close(read) => Ok(read.read(buf).await.map_err(Error::Io)?), - Self::ContentLen(read) => Ok(read.read(buf).await?), - Self::Chunked(read) => Ok(read.read(buf).await?), - } - } -} - -pub struct PartiallyRead<'b, R> { - buf: &'b [u8], - read_len: usize, - input: R, -} - -impl<'b, R> PartiallyRead<'b, R> { - pub const fn new(buf: &'b [u8], input: R) -> Self { - Self { - buf, - read_len: 0, - input, - } - } - - pub fn buf_len(&self) -> usize { - self.buf.len() - } - - pub fn as_raw_reader(&mut self) -> &mut R { - &mut self.input - } - - pub fn release(self) -> R { - self.input - } -} - -impl<'b, R> ErrorType for PartiallyRead<'b, R> -where - R: ErrorType, -{ - type Error = R::Error; -} - -impl<'b, R> Read for PartiallyRead<'b, R> -where - R: Read, -{ - async fn read(&mut self, buf: &mut [u8]) -> Result { - if self.buf.len() > self.read_len { - let len = min(buf.len(), self.buf.len() - self.read_len); - buf[..len].copy_from_slice(&self.buf[self.read_len..self.read_len + len]); - - self.read_len += len; - - Ok(len) - } else { - Ok(self.input.read(buf).await?) - } - } -} - -pub struct ContentLenRead { - content_len: u64, - read_len: u64, - input: R, -} - -impl ContentLenRead { - pub const fn new(content_len: u64, input: R) -> Self { - Self { - content_len, - read_len: 0, - input, - } - } - - pub fn is_complete(&self) -> bool { - self.content_len == self.read_len - } - - pub fn release(self) -> R { - self.input - } -} - -impl ErrorType for ContentLenRead -where - R: ErrorType, -{ - type Error = Error; -} - -impl Read for ContentLenRead -where - R: Read, -{ - async fn read(&mut self, buf: &mut [u8]) -> Result { - let len = min(buf.len() as _, self.content_len - self.read_len); - if len > 0 { - let read = self - .input - .read(&mut buf[..len as _]) - .await - .map_err(Error::Io)?; - self.read_len += read as u64; - - Ok(read) - } else { - Ok(0) - } - } -} - -pub struct ChunkedRead<'b, R> { - buf: &'b mut [u8], - buf_offset: usize, - buf_len: usize, - input: R, - remain: u64, - complete: bool, -} - -impl<'b, R> ChunkedRead<'b, R> -where - R: Read, -{ - pub fn new(input: R, buf: &'b mut [u8], buf_len: usize) -> Self { - Self { - buf, - buf_offset: 0, - buf_len, - input, - remain: 0, - complete: false, - } - } - - pub fn is_complete(&self) -> bool { - self.complete - } - - pub fn release(self) -> R { - self.input - } - - // The elegant pull parser taken from here: - // https://github.com/kchmck/uhttp_chunked_bytes.rs/blob/master/src/lib.rs - // Changes: - // - Converted to async - // - Iterators removed - // - Simpler error handling - // - Consumption of trailer - async fn next(&mut self) -> Result, Error> { - if self.complete { - return Ok(None); - } - - if self.remain == 0 { - if let Some(size) = self.parse_size().await? { - // If chunk size is zero (final chunk), the stream is finished [RFC7230§4.1]. - if size == 0 { - self.consume_trailer().await?; - self.complete = true; - return Ok(None); - } - - self.remain = size; - } else { - self.complete = true; - return Ok(None); - } - } - - let next = self.input_fetch().await?; - self.remain -= 1; - - // If current chunk is finished, verify it ends with CRLF [RFC7230§4.1]. - if self.remain == 0 { - self.consume_multi(b"\r\n").await?; - } - - Ok(Some(next)) - } - - // Parse the number of bytes in the next chunk. - async fn parse_size(&mut self) -> Result, Error> { - let mut digits = [0_u8; 16]; - - let slice = match self.parse_digits(&mut digits[..]).await? { - // This is safe because the following call to `from_str_radix` does - // its own verification on the bytes. - Some(s) => unsafe { str::from_utf8_unchecked(s) }, - None => return Ok(None), - }; - - let size = u64::from_str_radix(slice, 16).map_err(|_| Error::InvalidBody)?; - - Ok(Some(size)) - } - - // Extract the hex digits for the current chunk size. - async fn parse_digits<'a>( - &'a mut self, - digits: &'a mut [u8], - ) -> Result, Error> { - // Number of hex digits that have been extracted. - let mut len = 0; - - loop { - let b = match self.input_next().await? { - Some(b) => b, - None => { - return if len == 0 { - // If EOF at the beginning of a new chunk, the stream is finished. - Ok(None) - } else { - Err(Error::IncompleteBody) - }; - } - }; - - match b { - b'\r' => { - self.consume(b'\n').await?; - break; - } - b';' => { - self.consume_ext().await?; - break; - } - _ => { - match digits.get_mut(len) { - Some(d) => *d = b, - None => return Err(Error::InvalidBody), - } - - len += 1; - } - } - } - - Ok(Some(&digits[..len])) - } - - // Consume and discard current chunk extension. - // This doesn't check whether the characters up to CRLF actually have correct syntax. - async fn consume_ext(&mut self) -> Result<(), Error> { - self.consume_header().await?; - - Ok(()) - } - - // Consume and discard the optional trailer following the last chunk. - async fn consume_trailer(&mut self) -> Result<(), Error> { - while self.consume_header().await? {} - - Ok(()) - } - - // Consume and discard each header in the optional trailer following the last chunk. - async fn consume_header(&mut self) -> Result> { - let mut first = self.input_fetch().await?; - let mut len = 1; - - loop { - let second = self.input_fetch().await?; - len += 1; - - if first == b'\r' && second == b'\n' { - return Ok(len > 2); - } - - first = second; - } - } - - // Verify the next bytes in the stream match the expectation. - async fn consume_multi(&mut self, bytes: &[u8]) -> Result<(), Error> { - for byte in bytes { - self.consume(*byte).await?; - } - - Ok(()) - } - - // Verify the next byte in the stream is matching the expectation. - async fn consume(&mut self, byte: u8) -> Result<(), Error> { - if self.input_fetch().await? == byte { - Ok(()) - } else { - Err(Error::InvalidBody) - } - } - - async fn input_fetch(&mut self) -> Result> { - self.input_next().await?.ok_or(Error::IncompleteBody) - } - - async fn input_next(&mut self) -> Result, Error> { - if self.buf_offset == self.buf_len { - self.buf_len = self.input.read(self.buf).await.map_err(Error::Io)?; - self.buf_offset = 0; - } - - if self.buf_len > 0 { - let byte = self.buf[self.buf_offset]; - self.buf_offset += 1; - - Ok(Some(byte)) - } else { - Ok(None) - } - } -} - -impl<'b, R> ErrorType for ChunkedRead<'b, R> -where - R: ErrorType, -{ - type Error = Error; -} - -impl<'b, R> Read for ChunkedRead<'b, R> -where - R: Read, -{ - async fn read(&mut self, buf: &mut [u8]) -> Result { - for (index, byte_pos) in buf.iter_mut().enumerate() { - if let Some(byte) = self.next().await? { - *byte_pos = byte; - } else { - return Ok(index); - } - } - - Ok(buf.len()) - } -} - -pub enum SendBody { - Close(W), - ContentLen(ContentLenWrite), - Chunked(ChunkedWrite), -} - -impl SendBody -where - W: Write, -{ - pub fn new(body_type: BodyType, output: W) -> SendBody { - match body_type { - BodyType::Chunked => SendBody::Chunked(ChunkedWrite::new(output)), - BodyType::ContentLen(content_len) => { - SendBody::ContentLen(ContentLenWrite::new(content_len, output)) - } - BodyType::Close => SendBody::Close(output), - BodyType::Unknown => SendBody::ContentLen(ContentLenWrite::new(0, output)), - } - } - - pub fn is_complete(&self) -> bool { - match self { - Self::ContentLen(w) => w.is_complete(), - _ => true, - } - } - - pub async fn finish(&mut self) -> Result<(), Error> - where - W: Write, - { - match self { - Self::Close(_) => (), - Self::ContentLen(_) => (), - Self::Chunked(w) => w.finish().await?, - } - - self.flush().await?; - - Ok(()) - } - - pub fn as_raw_writer(&mut self) -> &mut W { - match self { - Self::Close(w) => w, - Self::ContentLen(w) => &mut w.output, - Self::Chunked(w) => &mut w.output, - } - } - - pub fn release(self) -> W { - match self { - Self::Close(w) => w, - Self::ContentLen(w) => w.release(), - Self::Chunked(w) => w.release(), - } - } -} - -impl ErrorType for SendBody -where - W: ErrorType, -{ - type Error = Error; -} - -impl Write for SendBody -where - W: Write, -{ - async fn write(&mut self, buf: &[u8]) -> Result { - match self { - Self::Close(w) => Ok(w.write(buf).await.map_err(Error::Io)?), - Self::ContentLen(w) => Ok(w.write(buf).await?), - Self::Chunked(w) => Ok(w.write(buf).await?), - } - } - - async fn flush(&mut self) -> Result<(), Self::Error> { - match self { - Self::Close(w) => Ok(w.flush().await.map_err(Error::Io)?), - Self::ContentLen(w) => Ok(w.flush().await?), - Self::Chunked(w) => Ok(w.flush().await?), - } - } -} - -pub struct ContentLenWrite { - content_len: u64, - write_len: u64, - output: W, -} - -impl ContentLenWrite { - pub const fn new(content_len: u64, output: W) -> Self { - Self { - content_len, - write_len: 0, - output, - } - } - - pub fn is_complete(&self) -> bool { - self.content_len == self.write_len - } - - pub fn release(self) -> W { - self.output - } -} - -impl ErrorType for ContentLenWrite -where - W: ErrorType, -{ - type Error = Error; -} - -impl Write for ContentLenWrite -where - W: Write, -{ - async fn write(&mut self, buf: &[u8]) -> Result { - if self.content_len >= self.write_len + buf.len() as u64 { - let write = self.output.write(buf).await.map_err(Error::Io)?; - self.write_len += write as u64; - - Ok(write) - } else { - Err(Error::TooLongBody) - } - } - - async fn flush(&mut self) -> Result<(), Self::Error> { - self.output.flush().await.map_err(Error::Io) - } -} - -pub struct ChunkedWrite { - output: W, -} - -impl ChunkedWrite { - pub const fn new(output: W) -> Self { - Self { output } - } - - pub async fn finish(&mut self) -> Result<(), Error> - where - W: Write, - { - self.output.write_all(b"\r\n").await.map_err(Error::Io) - } - - pub fn release(self) -> W { - self.output - } -} - -impl ErrorType for ChunkedWrite -where - W: ErrorType, -{ - type Error = Error; -} - -impl Write for ChunkedWrite -where - W: Write, -{ - async fn write(&mut self, buf: &[u8]) -> Result { - if !buf.is_empty() { - let mut len_str = heapless::String::<10>::new(); - write!(&mut len_str, "{:X}\r\n", buf.len()).unwrap(); - self.output - .write_all(len_str.as_bytes()) - .await - .map_err(Error::Io)?; - - self.output.write_all(buf).await.map_err(Error::Io)?; - self.output - .write_all("\r\n".as_bytes()) - .await - .map_err(Error::Io)?; - - Ok(buf.len()) - } else { - Ok(0) - } - } - - async fn flush(&mut self) -> Result<(), Self::Error> { - self.output.flush().await.map_err(Error::Io) - } -} - #[derive(Default, Debug)] pub struct RequestHeaders<'b, const N: usize> { pub method: Option, @@ -1154,55 +434,6 @@ impl<'b, const N: usize> RequestHeaders<'b, N> { headers: Headers::::new(), } } - - pub async fn receive( - &mut self, - buf: &'b mut [u8], - mut input: R, - ) -> Result<(&'b mut [u8], usize), Error> - where - R: Read, - { - let (read_len, headers_len) = match read_reply_buf::(&mut input, buf, true).await { - Ok(read_len) => read_len, - Err(e) => return Err(e), - }; - - let mut parser = httparse::Request::new(&mut self.headers.0); - - let (headers_buf, body_buf) = buf.split_at_mut(headers_len); - - let status = match parser.parse(headers_buf) { - Ok(status) => status, - Err(e) => return Err(e.into()), - }; - - if let Status::Complete(headers_len2) = status { - if headers_len != headers_len2 { - unreachable!("Should not happen. HTTP header parsing is indeterminate.") - } - - self.method = parser.method.and_then(Method::new); - self.path = parser.path; - - trace!("Received:\n{}", self); - - Ok((body_buf, read_len - headers_len)) - } else { - unreachable!("Secondary parse of already loaded buffer failed.") - } - } - - pub async fn send(&self, mut output: W) -> Result> - where - W: Write, - { - send_request(self.method, self.path, &mut output).await?; - let body_type = self.headers.send(&mut output).await?; - send_headers_end(output).await?; - - Ok(body_type) - } } impl<'b, const N: usize> Display for RequestHeaders<'b, N> { @@ -1242,49 +473,6 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> { headers: Headers::::new(), } } - - pub async fn receive( - &mut self, - buf: &'b mut [u8], - mut input: R, - ) -> Result<(&'b mut [u8], usize), Error> - where - R: Read, - { - let (read_len, headers_len) = read_reply_buf::(&mut input, buf, false).await?; - - let mut parser = httparse::Response::new(&mut self.headers.0); - - let (headers_buf, body_buf) = buf.split_at_mut(headers_len); - - let status = parser.parse(headers_buf).map_err(Error::from)?; - - if let Status::Complete(headers_len2) = status { - if headers_len != headers_len2 { - unreachable!("Should not happen. HTTP header parsing is indeterminate.") - } - - self.code = parser.code; - self.reason = parser.reason; - - trace!("Received:\n{}", self); - - Ok((body_buf, read_len - headers_len)) - } else { - unreachable!("Secondary parse of already loaded buffer failed.") - } - } - - pub async fn send(&self, mut output: W) -> Result> - where - W: Write, - { - send_status(self.code, self.reason, &mut output).await?; - let body_type = self.headers.send(&mut output).await?; - send_headers_end(output).await?; - - Ok(body_type) - } } impl<'b, const N: usize> Display for ResponseHeaders<'b, N> { @@ -1309,92 +497,129 @@ impl<'b, const N: usize> Display for ResponseHeaders<'b, N> { } } -async fn read_reply_buf( - mut input: R, - buf: &mut [u8], - request: bool, -) -> Result<(usize, usize), Error> -where - R: Read, -{ - let mut offset = 0; - let mut size = 0; +pub mod ws { + pub const NONCE_LEN: usize = 16; + pub const MAX_BASE64_KEY_LEN: usize = 28; + pub const MAX_BASE64_KEY_RESPONSE_LEN: usize = 33; - while buf.len() > size { - let read = input.read(&mut buf[offset..]).await.map_err(Error::Io)?; + pub const UPGRADE_REQUEST_HEADERS_LEN: usize = 7; + pub const UPGRADE_RESPONSE_HEADERS_LEN: usize = 3; - offset += read; - size += read; + pub fn upgrade_request_headers<'a>( + host: Option<&'a str>, + origin: Option<&'a str>, + version: Option<&'a str>, + nonce: &[u8; NONCE_LEN], + nonce_base64_buf: &'a mut [u8; MAX_BASE64_KEY_LEN], + ) -> [(&'a str, &'a str); UPGRADE_REQUEST_HEADERS_LEN] { + let nonce_base64_len = + base64::encode_config_slice(nonce, base64::URL_SAFE, nonce_base64_buf); - let mut headers = [httparse::EMPTY_HEADER; N]; + let host = host.map(|host| ("Host", host)).unwrap_or(("", "")); + let origin = origin.map(|origin| ("Origin", origin)).unwrap_or(("", "")); - let status = if request { - httparse::Request::new(&mut headers).parse(&buf[..size])? - } else { - httparse::Response::new(&mut headers).parse(&buf[..size])? - }; + [ + host, + origin, + ("Content-Length", "0"), + ("Connection", "Upgrade"), + ("Upgrade", "websocket"), + ("Sec-WebSocket-Version", version.unwrap_or("13")), + ("Sec-WebSocket-Key", unsafe { + core::str::from_utf8_unchecked(&nonce_base64_buf[..nonce_base64_len]) + }), + ] + } - if let httparse::Status::Complete(headers_len) = status { - return Ok((size, headers_len)); + pub fn is_upgrade_request<'a, H>(request_headers: H) -> bool + where + H: IntoIterator, + { + let mut connection = false; + let mut upgrade = false; + + for (name, value) in request_headers { + if name.eq_ignore_ascii_case("Connection") { + connection = value.eq_ignore_ascii_case("Upgrade"); + } else if name.eq_ignore_ascii_case("Upgrade") { + upgrade = value.eq_ignore_ascii_case("websocket"); + } } - } - Err(Error::TooManyHeaders) -} + connection && upgrade + } -async fn send_status_line( - request: bool, - token: Option<&str>, - extra: Option<&str>, - mut output: W, -) -> Result<(), Error> -where - W: Write, -{ - let mut written = false; - - if !request { - output.write_all(b"HTTP/1.1").await.map_err(Error::Io)?; - written = true; + #[derive(Debug, Copy, Clone, Eq, PartialEq)] + pub enum UpgradeError { + NoVersion, + NoSecKey, + UnsupportedVersion, + SecKeyTooLong, } - if let Some(token) = token { - if written { - output.write_all(b" ").await.map_err(Error::Io)?; - } + pub fn upgrade_response_headers<'a, 'b, H>( + request_headers: H, + version: Option<&'a str>, + sec_key_response_base64_buf: &'b mut [u8; MAX_BASE64_KEY_RESPONSE_LEN], + ) -> Result<[(&'b str, &'b str); UPGRADE_RESPONSE_HEADERS_LEN], UpgradeError> + where + H: IntoIterator, + { + let mut version_ok = false; + let mut sec_key = None; - output - .write_all(token.as_bytes()) - .await - .map_err(Error::Io)?; + for (name, value) in request_headers { + if name.eq_ignore_ascii_case("Sec-WebSocket-Version") { + if !value.eq_ignore_ascii_case(version.unwrap_or("13")) { + return Err(UpgradeError::NoVersion); + } - written = true; - } + version_ok = true; + } else if name.eq_ignore_ascii_case("Sec-WebSocket-Key") { + const WS_MAGIC_GUUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - if let Some(extra) = extra { - if written { - output.write_all(b" ").await.map_err(Error::Io)?; - } + let mut buf = [0_u8; MAX_BASE64_KEY_LEN + WS_MAGIC_GUUID.len()]; - output - .write_all(extra.as_bytes()) - .await - .map_err(Error::Io)?; + let value_len = value.as_bytes().len(); - written = true; - } + if value_len > MAX_BASE64_KEY_LEN { + return Err(UpgradeError::SecKeyTooLong); + } - if request { - if written { - output.write_all(b" ").await.map_err(Error::Io)?; - } + buf[..value_len].copy_from_slice(value.as_bytes()); + buf[value_len..value_len + WS_MAGIC_GUUID.as_bytes().len()] + .copy_from_slice(WS_MAGIC_GUUID.as_bytes()); - output.write_all(b"HTTP/1.1").await.map_err(Error::Io)?; - } + let mut sha1 = sha1_smol::Sha1::new(); + + sha1.update(&buf[..value_len + WS_MAGIC_GUUID.as_bytes().len()]); + + let sec_key_len = base64::encode_config_slice( + sha1.digest().bytes(), + base64::STANDARD_NO_PAD, + sec_key_response_base64_buf, + ); - output.write_all(b"\r\n").await.map_err(Error::Io)?; + sec_key = Some(sec_key_len); + } + } - Ok(()) + if version_ok { + if let Some(sec_key_len) = sec_key { + Ok([ + ("Connection", "Upgrade"), + ("Upgrade", "websocket"), + ("Sec-WebSocket-Accept", unsafe { + core::str::from_utf8_unchecked(&sec_key_response_base64_buf[..sec_key_len]) + }), + ]) + } else { + Err(UpgradeError::NoSecKey) + } + } else { + Err(UpgradeError::NoVersion) + } + } } #[cfg(feature = "embedded-svc")] diff --git a/edge-mdns/Cargo.toml b/edge-mdns/Cargo.toml index d1e1833..267b624 100644 --- a/edge-mdns/Cargo.toml +++ b/edge-mdns/Cargo.toml @@ -3,10 +3,7 @@ name = "edge-mdns" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] -log.workspace = true -heapless.workspace = true - +log = { workspace = true } +heapless = { workspace = true } domain = { version = "0.7", default-features = false } diff --git a/edge-mqtt/Cargo.toml b/edge-mqtt/Cargo.toml index db6c145..820e294 100644 --- a/edge-mqtt/Cargo.toml +++ b/edge-mqtt/Cargo.toml @@ -3,10 +3,12 @@ name = "edge-mqtt" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [features] -std = [] +default = ["std"] +std = ["embedded-svc?/std", "rumqttc"] +nightly = ["embedded-svc?/nightly"] [dependencies] -rumqttc = { version = "0.19" } +rumqttc = { version = "0.19", optional = true } +log = { workspace = true } +embedded-svc = { workspace = true, optional = true, default-features = false } diff --git a/edge-mqtt/src/io.rs b/edge-mqtt/src/io.rs new file mode 100644 index 0000000..3991bc7 --- /dev/null +++ b/edge-mqtt/src/io.rs @@ -0,0 +1,181 @@ +pub use rumqttc::*; + +#[cfg(all(feature = "nightly", feature = "embedded-svc"))] +pub use embedded_svc_compat::*; + +#[cfg(all(feature = "nightly", feature = "embedded-svc"))] +mod embedded_svc_compat { + use core::fmt::{Debug, Display}; + use core::marker::PhantomData; + + use embedded_svc::mqtt::client::asynch::{ + Client, Connection, Details, ErrorType, Event, Message, MessageId, Publish, QoS, + }; + use embedded_svc::mqtt::client::MessageImpl; + + use log::trace; + + use rumqttc::{AsyncClient, ClientError, ConnectionError, EventLoop, PubAck, SubAck, UnsubAck}; + + #[derive(Debug)] + pub enum MqttError { + ClientError(ClientError), + ConnectionError(ConnectionError), + } + + impl Display for MqttError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { + match self { + MqttError::ClientError(error) => write!(f, "ClientError: {error}"), + MqttError::ConnectionError(error) => write!(f, "ConnectionError: {error}"), + } + } + } + + #[cfg(feature = "std")] + impl std::error::Error for MqttError {} + + impl From for MqttError { + fn from(value: ClientError) -> Self { + Self::ClientError(value) + } + } + + impl From for MqttError { + fn from(value: ConnectionError) -> Self { + Self::ConnectionError(value) + } + } + + pub struct MqttClient(AsyncClient); + + impl MqttClient { + pub const fn new(client: AsyncClient) -> Self { + Self(client) + } + } + + impl ErrorType for MqttClient { + type Error = MqttError; + } + + impl Client for MqttClient { + async fn subscribe(&mut self, topic: &str, qos: QoS) -> Result { + self.0.subscribe(topic, to_qos(qos)).await?; + + Ok(0) + } + + async fn unsubscribe(&mut self, topic: &str) -> Result { + self.0.unsubscribe(topic).await?; + + Ok(0) + } + } + + impl Publish for MqttClient { + async fn publish( + &mut self, + topic: &str, + qos: embedded_svc::mqtt::client::QoS, + retain: bool, + payload: &[u8], + ) -> Result { + self.0.publish(topic, to_qos(qos), retain, payload).await?; + + Ok(0) + } + } + + pub struct MessageRef<'a>(&'a rumqttc::Publish); + + impl<'a> MessageRef<'a> { + pub fn into_message_impl(&self) -> Option { + Some(MessageImpl::new(self)) + } + } + + impl<'a> Message for MessageRef<'a> { + fn id(&self) -> MessageId { + self.0.pkid as _ + } + + fn topic(&self) -> Option<&'_ str> { + Some(&self.0.topic) + } + + fn data(&self) -> &'_ [u8] { + &self.0.payload + } + + fn details(&self) -> &Details { + &Details::Complete + } + } + + pub struct MqttConnection(EventLoop, F, PhantomData M>); + + impl MqttConnection { + pub const fn new(event_loop: EventLoop, message_converter: F) -> Self { + Self(event_loop, message_converter, PhantomData) + } + } + + impl ErrorType for MqttConnection { + type Error = MqttError; + } + + impl Connection for MqttConnection + where + F: FnMut(&MessageRef) -> Option + Send, + M: Send, + { + type Message<'a> = M where Self: 'a; + + async fn next(&mut self) -> Option>, Self::Error>> { + loop { + let event = self.0.poll().await; + trace!("Got event: {:?}", event); + + match event { + Ok(event) => { + let event = match event { + rumqttc::Event::Incoming(incoming) => match incoming { + rumqttc::Packet::Connect(_) => Some(Event::BeforeConnect), + rumqttc::Packet::ConnAck(_) => Some(Event::Connected(true)), + rumqttc::Packet::Disconnect => Some(Event::Disconnected), + rumqttc::Packet::PubAck(PubAck { pkid, .. }) => { + Some(Event::Published(pkid as _)) + } + rumqttc::Packet::SubAck(SubAck { pkid, .. }) => { + Some(Event::Subscribed(pkid as _)) + } + rumqttc::Packet::UnsubAck(UnsubAck { pkid, .. }) => { + Some(Event::Unsubscribed(pkid as _)) + } + rumqttc::Packet::Publish(publish) => { + (self.1)(&MessageRef(&publish)).map(Event::Received) + } + _ => None, + }, + rumqttc::Event::Outgoing(_) => None, + }; + + if let Some(event) = event { + return Some(Ok(event)); + } + } + Err(err) => return Some(Err(MqttError::ConnectionError(err))), + } + } + } + } + + fn to_qos(qos: QoS) -> rumqttc::QoS { + match qos { + QoS::AtMostOnce => rumqttc::QoS::AtMostOnce, + QoS::AtLeastOnce => rumqttc::QoS::AtLeastOnce, + QoS::ExactlyOnce => rumqttc::QoS::ExactlyOnce, + } + } +} diff --git a/edge-mqtt/src/lib.rs b/edge-mqtt/src/lib.rs index f8c6068..b101c08 100644 --- a/edge-mqtt/src/lib.rs +++ b/edge-mqtt/src/lib.rs @@ -1,183 +1,4 @@ #![cfg_attr(not(feature = "std"), no_std)] -pub use rumqttc::*; - -#[cfg(feature = "embedded-svc")] -pub use embedded_svc_compat::*; - -#[cfg(feature = "embedded-svc")] -mod embedded_svc_compat { - use core::fmt::{Debug, Display}; - use core::marker::PhantomData; - - use embedded_svc::mqtt::client::asynch::{ - Client, Connection, Details, ErrorType, Event, Message, MessageId, Publish, QoS, - }; - use embedded_svc::mqtt::client::MessageImpl; - - use log::trace; - - use rumqttc::{AsyncClient, ClientError, ConnectionError, EventLoop, PubAck, SubAck, UnsubAck}; - - #[derive(Debug)] - pub enum MqttError { - ClientError(ClientError), - ConnectionError(ConnectionError), - } - - impl Display for MqttError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { - match self { - MqttError::ClientError(error) => write!(f, "ClientError: {error}"), - MqttError::ConnectionError(error) => write!(f, "ConnectionError: {error}"), - } - } - } - - #[cfg(feature = "std")] - impl std::error::Error for MqttError {} - - impl From for MqttError { - fn from(value: ClientError) -> Self { - Self::ClientError(value) - } - } - - impl From for MqttError { - fn from(value: ConnectionError) -> Self { - Self::ConnectionError(value) - } - } - - pub struct MqttClient(AsyncClient); - - impl MqttClient { - pub const fn new(client: AsyncClient) -> Self { - Self(client) - } - } - - impl ErrorType for MqttClient { - type Error = MqttError; - } - - impl Client for MqttClient { - async fn subscribe(&mut self, topic: &str, qos: QoS) -> Result { - self.0.subscribe(topic, to_qos(qos)).await?; - - Ok(0) - } - - async fn unsubscribe(&mut self, topic: &str) -> Result { - self.0.unsubscribe(topic).await?; - - Ok(0) - } - } - - impl Publish for MqttClient { - async fn publish( - &mut self, - topic: &str, - qos: embedded_svc::mqtt::client::QoS, - retain: bool, - payload: &[u8], - ) -> Result { - self.0.publish(topic, to_qos(qos), retain, payload).await?; - - Ok(0) - } - } - - pub struct MessageRef<'a>(&'a rumqttc::Publish); - - impl<'a> MessageRef<'a> { - pub fn into_message_impl(&self) -> Option { - Some(MessageImpl::new(self)) - } - } - - impl<'a> Message for MessageRef<'a> { - fn id(&self) -> MessageId { - self.0.pkid as _ - } - - fn topic(&self) -> Option<&'_ str> { - Some(&self.0.topic) - } - - fn data(&self) -> &'_ [u8] { - &self.0.payload - } - - fn details(&self) -> &Details { - &Details::Complete - } - } - - pub struct MqttConnection(EventLoop, F, PhantomData M>); - - impl MqttConnection { - pub const fn new(event_loop: EventLoop, message_converter: F) -> Self { - Self(event_loop, message_converter, PhantomData) - } - } - - impl ErrorType for MqttConnection { - type Error = MqttError; - } - - impl Connection for MqttConnection - where - F: FnMut(&MessageRef) -> Option + Send, - M: Send, - { - type Message<'a> = M where Self: 'a; - - async fn next(&mut self) -> Option>, Self::Error>> { - loop { - let event = self.0.poll().await; - trace!("Got event: {:?}", event); - - match event { - Ok(event) => { - let event = match event { - rumqttc::Event::Incoming(incoming) => match incoming { - rumqttc::Packet::Connect(_) => Some(Event::BeforeConnect), - rumqttc::Packet::ConnAck(_) => Some(Event::Connected(true)), - rumqttc::Packet::Disconnect => Some(Event::Disconnected), - rumqttc::Packet::PubAck(PubAck { pkid, .. }) => { - Some(Event::Published(pkid as _)) - } - rumqttc::Packet::SubAck(SubAck { pkid, .. }) => { - Some(Event::Subscribed(pkid as _)) - } - rumqttc::Packet::UnsubAck(UnsubAck { pkid, .. }) => { - Some(Event::Unsubscribed(pkid as _)) - } - rumqttc::Packet::Publish(publish) => { - (self.1)(&MessageRef(&publish)).map(Event::Received) - } - _ => None, - }, - rumqttc::Event::Outgoing(_) => None, - }; - - if let Some(event) = event { - return Some(Ok(event)); - } - } - Err(err) => return Some(Err(MqttError::ConnectionError(err))), - } - } - } - } - - fn to_qos(qos: QoS) -> rumqttc::QoS { - match qos { - QoS::AtMostOnce => rumqttc::QoS::AtMostOnce, - QoS::AtLeastOnce => rumqttc::QoS::AtLeastOnce, - QoS::ExactlyOnce => rumqttc::QoS::ExactlyOnce, - } - } -} +#[cfg(feature = "std")] +pub mod io; diff --git a/edge-raw/Cargo.toml b/edge-raw/Cargo.toml new file mode 100644 index 0000000..2b182f9 --- /dev/null +++ b/edge-raw/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "edge-raw" +version = "0.1.0" +edition = "2021" + +[features] +nightly = ["embedded-io-async", "embedded-nal-async", "embedded-nal-async-xtra"] + +[dependencies] +log = { workspace = true } +no-std-net = { workspace = true } +embedded-io-async = { workspace = true, default-features = false, optional = true } +embedded-nal-async = { workspace = true, default-features = false, optional = true } +embedded-nal-async-xtra = { workspace = true, default-features = false, optional = true } diff --git a/edge-raw/src/bytes.rs b/edge-raw/src/bytes.rs new file mode 100644 index 0000000..f7abbf3 --- /dev/null +++ b/edge-raw/src/bytes.rs @@ -0,0 +1,103 @@ +#[derive(Debug)] +pub enum Error { + BufferOverflow, + DataUnderflow, + InvalidFormat, +} + +pub struct BytesIn<'a> { + data: &'a [u8], + offset: usize, +} + +impl<'a> BytesIn<'a> { + pub const fn new(data: &'a [u8]) -> Self { + Self { data, offset: 0 } + } + + pub fn is_empty(&self) -> bool { + self.offset == self.data.len() + } + + pub fn offset(&self) -> usize { + self.offset + } + + pub fn byte(&mut self) -> Result { + self.arr::<1>().map(|arr| arr[0]) + } + + pub fn slice(&mut self, len: usize) -> Result<&'a [u8], Error> { + if len > self.data.len() - self.offset { + Err(Error::DataUnderflow) + } else { + let data = &self.data[self.offset..self.offset + len]; + self.offset += len; + + Ok(data) + } + } + + pub fn arr(&mut self) -> Result<[u8; N], Error> { + let slice = self.slice(N)?; + + let mut data = [0; N]; + data.copy_from_slice(slice); + + Ok(data) + } + + pub fn remaining(&mut self) -> &'a [u8] { + let data = self.slice(self.data.len() - self.offset).unwrap(); + + self.offset = self.data.len(); + + data + } + + pub fn remaining_byte(&mut self) -> Result { + Ok(self.remaining_arr::<1>()?[0]) + } + + pub fn remaining_arr(&mut self) -> Result<[u8; N], Error> { + if self.data.len() - self.offset > N { + Err(Error::InvalidFormat) + } else { + self.arr::() + } + } +} + +pub struct BytesOut<'a> { + buf: &'a mut [u8], + offset: usize, +} + +impl<'a> BytesOut<'a> { + pub fn new(buf: &'a mut [u8]) -> Self { + Self { buf, offset: 0 } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn len(&self) -> usize { + self.offset + } + + pub fn byte(&mut self, data: u8) -> Result<&mut Self, Error> { + self.push(&[data]) + } + + pub fn push(&mut self, data: &[u8]) -> Result<&mut Self, Error> { + if data.len() > self.buf.len() - self.offset { + Err(Error::BufferOverflow) + } else { + self.buf[self.offset..self.offset + data.len()].copy_from_slice(data); + self.offset += data.len(); + + Ok(self) + } + } +} diff --git a/edge-raw/src/io.rs b/edge-raw/src/io.rs new file mode 100644 index 0000000..40b7397 --- /dev/null +++ b/edge-raw/src/io.rs @@ -0,0 +1,192 @@ +use core::fmt::Debug; + +use embedded_io_async::ErrorKind; + +use embedded_nal_async::{ConnectedUdp, SocketAddr, SocketAddrV4, UdpStack, UnconnectedUdp}; + +use embedded_nal_async_xtra::{RawSocket, RawStack}; + +use crate as raw; + +#[derive(Debug)] +pub enum Error { + Io(E), + UnsupportedProtocol, + RawError(raw::Error), +} + +impl From for Error { + fn from(value: raw::Error) -> Self { + Self::RawError(value) + } +} + +impl embedded_io_async::Error for Error +where + E: embedded_io_async::Error, +{ + fn kind(&self) -> ErrorKind { + match self { + Self::Io(err) => err.kind(), + Self::UnsupportedProtocol => ErrorKind::InvalidInput, + Self::RawError(_) => ErrorKind::InvalidData, + } + } +} + +pub struct ConnectedUdp2RawSocket(T, SocketAddrV4, SocketAddrV4); + +impl ConnectedUdp for ConnectedUdp2RawSocket +where + T: RawSocket, +{ + type Error = Error; + + async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error> { + send( + &mut self.0, + SocketAddr::V4(self.1), + SocketAddr::V4(self.2), + data, + ) + .await + } + + async fn receive_into(&mut self, buffer: &mut [u8]) -> Result { + let (len, _, _) = receive_into(&mut self.0, Some(self.1), Some(self.2), buffer).await?; + + Ok(len) + } +} + +pub struct UnconnectedUdp2RawSocket(T, Option); + +impl UnconnectedUdp for UnconnectedUdp2RawSocket +where + T: RawSocket, +{ + type Error = Error; + + async fn send( + &mut self, + local: SocketAddr, + remote: SocketAddr, + data: &[u8], + ) -> Result<(), Self::Error> { + send(&mut self.0, local, remote, data).await + } + + async fn receive_into( + &mut self, + buffer: &mut [u8], + ) -> Result<(usize, SocketAddr, SocketAddr), Self::Error> { + receive_into(&mut self.0, None, self.1, buffer).await + } +} + +pub struct Udp2RawStack(T, T::Interface) +where + T: RawStack; + +impl UdpStack for Udp2RawStack +where + T: RawStack, +{ + type Error = Error; + + type Connected = ConnectedUdp2RawSocket; + + type UniquelyBound = UnconnectedUdp2RawSocket; + + type MultiplyBound = UnconnectedUdp2RawSocket; + + async fn connect_from( + &self, + local: SocketAddr, + remote: SocketAddr, + ) -> Result<(SocketAddr, Self::Connected), Self::Error> { + let (SocketAddr::V4(localv4), SocketAddr::V4(remotev4)) = (local, remote) else { + Err(Error::UnsupportedProtocol)? + }; + + let socket = self.0.bind(&self.1).await.map_err(Self::Error::Io)?; + + Ok((local, ConnectedUdp2RawSocket(socket, localv4, remotev4))) + } + + async fn bind_single( + &self, + local: SocketAddr, + ) -> Result<(SocketAddr, Self::UniquelyBound), Self::Error> { + let SocketAddr::V4(localv4) = local else { + Err(Error::UnsupportedProtocol)? + }; + + let socket = self.0.bind(&self.1).await.map_err(Self::Error::Io)?; + + Ok((local, UnconnectedUdp2RawSocket(socket, Some(localv4)))) + } + + async fn bind_multiple(&self, local: SocketAddr) -> Result { + let SocketAddr::V4(local) = local else { + Err(Error::UnsupportedProtocol)? + }; + + let socket = self.0.bind(&self.1).await.map_err(Self::Error::Io)?; + + Ok(UnconnectedUdp2RawSocket(socket, Some(local))) + } +} + +async fn send( + mut socket: T, + local: SocketAddr, + remote: SocketAddr, + data: &[u8], +) -> Result<(), Error> { + let (SocketAddr::V4(local), SocketAddr::V4(remote)) = (local, remote) else { + Err(Error::UnsupportedProtocol)? + }; + + let mut buf = [0; 1500]; + + let data = raw::ip_udp_encode(&mut buf, local, remote, |buf| { + if data.len() <= buf.len() { + buf[..data.len()].copy_from_slice(data); + + Ok(data.len()) + } else { + Err(raw::Error::BufferOverflow) + } + })?; + + socket.send(data).await.map_err(Error::Io) +} + +async fn receive_into( + mut socket: T, + filter_src: Option, + filter_dst: Option, + buffer: &mut [u8], +) -> Result<(usize, SocketAddr, SocketAddr), Error> { + let mut buf = [0; 1500]; + + let (local, remote, len) = loop { + let len = socket.receive_into(&mut buf).await.map_err(Error::Io)?; + + match raw::ip_udp_decode(&buf[..len], filter_src, filter_dst) { + Ok(Some((local, remote, data))) => break (local, remote, data.len()), + Ok(None) => continue, + Err(raw::Error::InvalidFormat) | Err(raw::Error::InvalidChecksum) => continue, + Err(other) => Err(other)?, + } + }; + + if len <= buffer.len() { + buffer[..len].copy_from_slice(&buf[..len]); + + Ok((len, SocketAddr::V4(local), SocketAddr::V4(remote))) + } else { + Err(raw::Error::BufferOverflow.into()) + } +} diff --git a/edge-raw/src/ip.rs b/edge-raw/src/ip.rs new file mode 100644 index 0000000..b1c2852 --- /dev/null +++ b/edge-raw/src/ip.rs @@ -0,0 +1,214 @@ +use log::trace; + +use no_std_net::Ipv4Addr; + +use super::bytes::{BytesIn, BytesOut}; + +use super::{checksum_accumulate, checksum_finish, Error}; + +#[allow(clippy::type_complexity)] +pub fn decode( + packet: &[u8], + filter_src: Ipv4Addr, + filter_dst: Ipv4Addr, + filter_proto: Option, +) -> Result, Error> { + let data = Ipv4PacketHeader::decode_with_payload(packet, filter_src, filter_dst, filter_proto)? + .map(|(hdr, payload)| (hdr.src, hdr.dst, hdr.p, payload)); + + Ok(data) +} + +pub fn encode( + buf: &mut [u8], + src: Ipv4Addr, + dst: Ipv4Addr, + proto: u8, + encoder: F, +) -> Result<&[u8], Error> +where + F: FnOnce(&mut [u8]) -> Result, +{ + let mut hdr = Ipv4PacketHeader::new(src, dst, proto); + + hdr.encode_with_payload(buf, encoder) +} + +#[derive(Clone, Debug)] +pub struct Ipv4PacketHeader { + pub version: u8, // Version + pub hlen: u8, // Header length + pub tos: u8, // Type of service + pub len: u16, // Total length + pub id: u16, // Identification + pub off: u16, // Fragment offset field + pub ttl: u8, // Time to live + pub p: u8, // Protocol + pub sum: u16, // Checksum + pub src: Ipv4Addr, // Source address + pub dst: Ipv4Addr, // Dest address +} + +impl Ipv4PacketHeader { + pub const MIN_SIZE: usize = 20; + pub const CHECKSUM_WORD: usize = 5; + + pub const IP_DF: u16 = 0x4000; // Don't fragment flag + pub const IP_MF: u16 = 0x2000; // More fragments flag + + pub fn new(src: Ipv4Addr, dst: Ipv4Addr, proto: u8) -> Self { + Self { + version: 4, + hlen: Self::MIN_SIZE as _, + tos: 0, + len: Self::MIN_SIZE as _, + id: 0, + off: 0, + ttl: 64, + p: proto, + sum: 0, + src, + dst, + } + } + + /// Parses the packet from a byte slice + pub fn decode(data: &[u8]) -> Result { + let mut bytes = BytesIn::new(data); + + let vhl = bytes.byte()?; + + Ok(Self { + version: vhl >> 4, + hlen: (vhl & 0x0f) * 4, + tos: bytes.byte()?, + len: u16::from_be_bytes(bytes.arr()?), + id: u16::from_be_bytes(bytes.arr()?), + off: u16::from_be_bytes(bytes.arr()?), + ttl: bytes.byte()?, + p: bytes.byte()?, + sum: u16::from_be_bytes(bytes.arr()?), + src: u32::from_be_bytes(bytes.arr()?).into(), + dst: u32::from_be_bytes(bytes.arr()?).into(), + }) + } + + /// Encodes the packet into the provided buf slice + pub fn encode<'o>(&self, buf: &'o mut [u8]) -> Result<&'o [u8], Error> { + let mut bytes = BytesOut::new(buf); + + bytes + .byte((self.version << 4) | (self.hlen / 4 + (if self.hlen % 4 > 0 { 1 } else { 0 })))? + .byte(self.tos)? + .push(&u16::to_be_bytes(self.len))? + .push(&u16::to_be_bytes(self.id))? + .push(&u16::to_be_bytes(self.off))? + .byte(self.ttl)? + .byte(self.p)? + .push(&u16::to_be_bytes(self.sum))? + .push(&u32::to_be_bytes(self.src.into()))? + .push(&u32::to_be_bytes(self.dst.into()))?; + + let len = bytes.len(); + + Ok(&buf[..len]) + } + + pub fn encode_with_payload<'o, F>( + &mut self, + buf: &'o mut [u8], + encoder: F, + ) -> Result<&'o [u8], Error> + where + F: FnOnce(&mut [u8]) -> Result, + { + let hdr_len = self.hlen as usize; + if hdr_len < Self::MIN_SIZE || buf.len() < hdr_len { + Err(Error::BufferOverflow)?; + } + + let (hdr_buf, payload_buf) = buf.split_at_mut(hdr_len); + + let payload_len = encoder(payload_buf)?; + + let len = hdr_len + payload_len; + self.len = len as _; + + let min_hdr_len = self.encode(hdr_buf)?.len(); + assert_eq!(min_hdr_len, Self::MIN_SIZE); + + hdr_buf[Self::MIN_SIZE..hdr_len].fill(0); + + let checksum = Self::checksum(hdr_buf); + self.sum = checksum; + + Self::inject_checksum(hdr_buf, checksum); + + Ok(&buf[..len]) + } + + pub fn decode_with_payload( + packet: &[u8], + filter_src: Ipv4Addr, + filter_dst: Ipv4Addr, + filter_proto: Option, + ) -> Result, Error> { + let hdr = Self::decode(packet)?; + if hdr.version == 4 { + // IPv4 + + if !filter_src.is_unspecified() && !hdr.src.is_broadcast() && filter_src != hdr.src { + return Ok(None); + } + + if !filter_dst.is_unspecified() && !hdr.dst.is_broadcast() && filter_dst != hdr.dst { + return Ok(None); + } + + if let Some(filter_proto) = filter_proto { + if filter_proto != hdr.p { + return Ok(None); + } + } + + let len = hdr.len as usize; + if packet.len() < len { + Err(Error::DataUnderflow)?; + } + + let checksum = Self::checksum(&packet[..len]); + + trace!("IP header decoded, total_size={}, src={}, dst={}, hlen={}, size={}, checksum={}, ours={}", packet.len(), hdr.src, hdr.dst, hdr.hlen, hdr.len, hdr.sum, checksum); + + if checksum != hdr.sum { + Err(Error::InvalidChecksum)?; + } + + let packet = &packet[..len]; + let hdr_len = hdr.hlen as usize; + if packet.len() < hdr_len { + Err(Error::DataUnderflow)?; + } + + Ok(Some((hdr, &packet[hdr_len..]))) + } else { + Err(Error::InvalidFormat) + } + } + + pub fn inject_checksum(packet: &mut [u8], checksum: u16) { + let checksum = checksum.to_be_bytes(); + + let offset = Self::CHECKSUM_WORD << 1; + packet[offset] = checksum[0]; + packet[offset + 1] = checksum[1]; + } + + pub fn checksum(packet: &[u8]) -> u16 { + let hlen = (packet[0] & 0x0f) as usize * 4; + + let sum = checksum_accumulate(&packet[..hlen], Self::CHECKSUM_WORD); + + checksum_finish(sum) + } +} diff --git a/edge-raw/src/lib.rs b/edge-raw/src/lib.rs new file mode 100644 index 0000000..20ff09e --- /dev/null +++ b/edge-raw/src/lib.rs @@ -0,0 +1,102 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![allow(stable_features)] +#![allow(unknown_lints)] +#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", allow(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", feature(impl_trait_projections))] + +use no_std_net::{Ipv4Addr, SocketAddrV4}; + +use self::udp::UdpPacketHeader; + +#[cfg(feature = "nightly")] +pub mod io; + +pub mod bytes; +pub mod ip; +pub mod udp; + +use bytes::BytesIn; + +#[derive(Debug)] +pub enum Error { + DataUnderflow, + BufferOverflow, + InvalidFormat, + InvalidChecksum, +} + +impl From for Error { + fn from(value: bytes::Error) -> Self { + match value { + bytes::Error::BufferOverflow => Self::BufferOverflow, + bytes::Error::DataUnderflow => Self::DataUnderflow, + bytes::Error::InvalidFormat => Self::InvalidFormat, + } + } +} + +#[allow(clippy::type_complexity)] +pub fn ip_udp_decode( + packet: &[u8], + filter_src: Option, + filter_dst: Option, +) -> Result, Error> { + if let Some((src, dst, _proto, udp_packet)) = ip::decode( + packet, + filter_src.map(|a| *a.ip()).unwrap_or(Ipv4Addr::UNSPECIFIED), + filter_dst.map(|a| *a.ip()).unwrap_or(Ipv4Addr::UNSPECIFIED), + Some(UdpPacketHeader::PROTO), + )? { + udp::decode( + src, + dst, + udp_packet, + filter_src.map(|a| a.port()), + filter_dst.map(|a| a.port()), + ) + } else { + Ok(None) + } +} + +pub fn ip_udp_encode( + buf: &mut [u8], + src: SocketAddrV4, + dst: SocketAddrV4, + encoder: F, +) -> Result<&[u8], Error> +where + F: FnOnce(&mut [u8]) -> Result, +{ + ip::encode(buf, *src.ip(), *dst.ip(), UdpPacketHeader::PROTO, |buf| { + Ok(udp::encode(buf, src, dst, encoder)?.len()) + }) +} + +pub fn checksum_accumulate(bytes: &[u8], checksum_word: usize) -> u32 { + let mut bytes = BytesIn::new(bytes); + + let mut sum: u32 = 0; + while !bytes.is_empty() { + let skip = (bytes.offset() >> 1) == checksum_word; + let arr = bytes + .arr() + .ok() + .unwrap_or_else(|| [bytes.byte().unwrap(), 0]); + + let word = if skip { 0 } else { u16::from_be_bytes(arr) }; + + sum += word as u32; + } + + sum +} + +pub fn checksum_finish(mut sum: u32) -> u16 { + while sum >> 16 != 0 { + sum = (sum >> 16) + (sum & 0xffff); + } + + !sum as u16 +} diff --git a/edge-raw/src/udp.rs b/edge-raw/src/udp.rs new file mode 100644 index 0000000..cbf7e08 --- /dev/null +++ b/edge-raw/src/udp.rs @@ -0,0 +1,206 @@ +use log::trace; + +use no_std_net::{Ipv4Addr, SocketAddrV4}; + +use super::bytes::{BytesIn, BytesOut}; + +use super::{checksum_accumulate, checksum_finish, Error}; + +#[allow(clippy::type_complexity)] +pub fn decode( + src: Ipv4Addr, + dst: Ipv4Addr, + packet: &[u8], + filter_src: Option, + filter_dst: Option, +) -> Result, Error> { + let data = UdpPacketHeader::decode_with_payload(packet, src, dst, filter_src, filter_dst)?.map( + |(hdr, payload)| { + ( + SocketAddrV4::new(src, hdr.src), + SocketAddrV4::new(dst, hdr.dst), + payload, + ) + }, + ); + + Ok(data) +} + +pub fn encode( + buf: &mut [u8], + src: SocketAddrV4, + dst: SocketAddrV4, + payload: F, +) -> Result<&[u8], Error> +where + F: FnOnce(&mut [u8]) -> Result, +{ + let mut hdr = UdpPacketHeader::new(src.port(), dst.port()); + + hdr.encode_with_payload(buf, *src.ip(), *dst.ip(), |buf| payload(buf)) +} + +#[derive(Clone, Debug)] +pub struct UdpPacketHeader { + pub src: u16, // Source port + pub dst: u16, // Destination port + pub len: u16, // UDP length + pub sum: u16, // UDP checksum +} + +impl UdpPacketHeader { + pub const PROTO: u8 = 17; + + pub const SIZE: usize = 8; + pub const CHECKSUM_WORD: usize = 3; + + pub fn new(src: u16, dst: u16) -> Self { + Self { + src, + dst, + len: 0, + sum: 0, + } + } + + /// Parses the packet header from a byte slice + pub fn decode(data: &[u8]) -> Result { + let mut bytes = BytesIn::new(data); + + Ok(Self { + src: u16::from_be_bytes(bytes.arr()?), + dst: u16::from_be_bytes(bytes.arr()?), + len: u16::from_be_bytes(bytes.arr()?), + sum: u16::from_be_bytes(bytes.arr()?), + }) + } + + /// Encodes the packet header into the provided buf slice + pub fn encode<'o>(&self, buf: &'o mut [u8]) -> Result<&'o [u8], Error> { + let mut bytes = BytesOut::new(buf); + + bytes + .push(&u16::to_be_bytes(self.src))? + .push(&u16::to_be_bytes(self.dst))? + .push(&u16::to_be_bytes(self.len))? + .push(&u16::to_be_bytes(self.sum))?; + + let len = bytes.len(); + + Ok(&buf[..len]) + } + + pub fn encode_with_payload<'o, F>( + &mut self, + buf: &'o mut [u8], + src: Ipv4Addr, + dst: Ipv4Addr, + encoder: F, + ) -> Result<&'o [u8], Error> + where + F: FnOnce(&mut [u8]) -> Result, + { + if buf.len() < Self::SIZE { + Err(Error::BufferOverflow)?; + } + + let (hdr_buf, payload_buf) = buf.split_at_mut(Self::SIZE); + + let payload_len = encoder(payload_buf)?; + + let len = Self::SIZE + payload_len; + self.len = len as _; + + let hdr_len = self.encode(hdr_buf)?.len(); + assert_eq!(Self::SIZE, hdr_len); + + let packet = &mut buf[..len]; + + let checksum = Self::checksum(packet, src, dst); + self.sum = checksum; + + Self::inject_checksum(packet, checksum); + + Ok(packet) + } + + pub fn decode_with_payload( + packet: &[u8], + src: Ipv4Addr, + dst: Ipv4Addr, + filter_src: Option, + filter_dst: Option, + ) -> Result, Error> { + let hdr = Self::decode(packet)?; + + if let Some(filter_src) = filter_src { + if filter_src != hdr.src { + return Ok(None); + } + } + + if let Some(filter_dst) = filter_dst { + if filter_dst != hdr.dst { + return Ok(None); + } + } + + let len = hdr.len as usize; + if packet.len() < len { + Err(Error::DataUnderflow)?; + } + + let checksum = Self::checksum(&packet[..len], src, dst); + + trace!( + "UDP header decoded, src={}, dst={}, size={}, checksum={}, ours={}", + hdr.src, + hdr.dst, + hdr.len, + hdr.sum, + checksum + ); + + if checksum != hdr.sum { + Err(Error::InvalidChecksum)?; + } + + let packet = &packet[..len]; + + let payload_data = &packet[Self::SIZE..]; + + Ok(Some((hdr, payload_data))) + } + + pub fn inject_checksum(packet: &mut [u8], checksum: u16) { + let checksum = checksum.to_be_bytes(); + + let offset = Self::CHECKSUM_WORD << 1; + packet[offset] = checksum[0]; + packet[offset + 1] = checksum[1]; + } + + pub fn checksum(packet: &[u8], src: Ipv4Addr, dst: Ipv4Addr) -> u16 { + let mut buf = [0; 12]; + + // Pseudo IP-header for UDP checksum calculation + let len = BytesOut::new(&mut buf) + .push(&u32::to_be_bytes(src.into())) + .unwrap() + .push(&u32::to_be_bytes(dst.into())) + .unwrap() + .byte(0) + .unwrap() + .byte(UdpPacketHeader::PROTO) + .unwrap() + .push(&u16::to_be_bytes(packet.len() as u16)) + .unwrap() + .len(); + + let sum = checksum_accumulate(&buf[..len], usize::MAX) + + checksum_accumulate(packet, Self::CHECKSUM_WORD); + + checksum_finish(sum) + } +} diff --git a/edge-std-nal-async/Cargo.toml b/edge-std-nal-async/Cargo.toml new file mode 100644 index 0000000..b9873ff --- /dev/null +++ b/edge-std-nal-async/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "edge-std-nal-async" +version = "0.1.0" +edition = "2021" + +[features] +std = ["async-io", "futures-lite", "libc"] +nightly = ["embedded-nal-async-xtra/nightly", "embedded-io-async", "embedded-nal-async", "embedded-nal-async-xtra"] + +[dependencies] +embedded-io-async = { workspace = true, features = ["std"], optional = true } +embedded-nal-async = { workspace = true, optional = true } +embedded-nal-async-xtra = { workspace = true, optional = true } +async-io = { version = "2", optional = true } +futures-lite = { version = "1", optional = true } +libc = { version = "0.2", optional = true } +heapless = { workspace = true } diff --git a/src/std/nal.rs b/edge-std-nal-async/src/lib.rs similarity index 90% rename from src/std/nal.rs rename to edge-std-nal-async/src/lib.rs index 4c969b8..6f78603 100644 --- a/src/std/nal.rs +++ b/edge-std-nal-async/src/lib.rs @@ -1,3 +1,10 @@ +#![allow(stable_features)] +#![allow(unknown_lints)] +#![allow(async_fn_in_trait)] +#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", feature(impl_trait_projections))] +#![cfg(all(feature = "nightly", feature = "std"))] + use std::io; use std::net::{self, TcpStream, ToSocketAddrs, UdpSocket}; use std::os::fd::{AsFd, AsRawFd}; @@ -6,13 +13,13 @@ use async_io::Async; use futures_lite::io::{AsyncReadExt, AsyncWriteExt}; use embedded_io_async::{ErrorType, Read, Write}; -use no_std_net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use embedded_nal_async::{ - AddrType, ConnectedUdp, Dns, IpAddr, TcpConnect, UdpStack, UnconnectedUdp, + AddrType, ConnectedUdp, Dns, IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6, + TcpConnect, UdpStack, UnconnectedUdp, }; -use edge_tcp::{RawSocket, RawStack, TcpAccept, TcpListen, TcpSplittableConnection}; +use embedded_nal_async_xtra::{RawSocket, RawStack, TcpAccept, TcpListen, TcpSplittableConnection}; pub struct StdTcpConnect(()); @@ -173,33 +180,7 @@ impl UdpStack for StdUdpStack { } async fn bind_multiple(&self, _local: SocketAddr) -> Result { - unimplemented!() - } -} - -fn cvt(res: T) -> io::Result -where - T: Into + Copy, -{ - let ires: i64 = res.into(); - - if ires == -1 { - Err(io::Error::last_os_error()) - } else { - Ok(res) - } -} - -fn cvti(res: T) -> io::Result -where - T: Into + Copy, -{ - let ires: isize = res.into(); - - if ires == -1 { - Err(io::Error::last_os_error()) - } else { - Ok(res) + unimplemented!() // TODO } } @@ -261,7 +242,9 @@ impl RawStack for StdRawStack { type Socket = StdRawSocket; - async fn connect(&self, interface: Option) -> Result { + type Interface = u32; + + async fn bind(&self, interface: &Self::Interface) -> Result { let socket = unsafe { libc::socket( libc::PF_PACKET, @@ -275,7 +258,7 @@ impl RawStack for StdRawStack { let sockaddr = libc::sockaddr_ll { sll_family: libc::AF_PACKET as _, sll_protocol: (libc::ETH_P_IP as u16).to_be() as _, - sll_ifindex: interface.unwrap_or(0) as _, + sll_ifindex: *interface as _, sll_hatype: 0, sll_pkttype: 0, sll_halen: 0, @@ -301,10 +284,7 @@ impl RawStack for StdRawStack { unsafe { std::net::UdpSocket::from_raw_fd(socket) } }; - Ok(StdRawSocket( - Async::new(socket)?, - interface.unwrap_or(0) as _, - )) + Ok(StdRawSocket(Async::new(socket)?, *interface as _)) // warn!("Before connect"); // let (addr, socket) = self.connect_from(local, remote).await?; @@ -341,7 +321,7 @@ impl ConnectedUdp for StdUdpSocket { loop { offset += self.0.send(&data[offset..]).await?; - if offset == 0 { + if offset == data.len() { break; } } @@ -370,7 +350,7 @@ impl UnconnectedUdp for StdUdpSocket { loop { offset += self.0.send_to(data, to_std_addr(remote)).await?; - if offset == 0 { + if offset == data.len() { break; } } @@ -392,18 +372,15 @@ impl UnconnectedUdp for StdUdpSocket { } } -pub struct StdDns(U); +pub struct StdDns(()); -impl StdDns { - pub const fn new(unblocker: U) -> Self { - Self(unblocker) +impl StdDns { + pub const fn new() -> Self { + Self(()) } } -impl Dns for StdDns -where - U: crate::asynch::Unblocker, -{ +impl Dns for StdDns { type Error = io::Error; async fn get_host_by_name( @@ -413,28 +390,7 @@ where ) -> Result { let host = host.to_string(); - self.0 - .unblock(move || dns_lookup_host(&host, addr_type)) - .await - } - - async fn get_host_by_address( - &self, - _addr: IpAddr, - ) -> Result, Self::Error> { - Err(io::ErrorKind::Unsupported.into()) - } -} - -impl Dns for StdDns<()> { - type Error = io::Error; - - async fn get_host_by_name( - &self, - host: &str, - addr_type: AddrType, - ) -> Result { - dns_lookup_host(host, addr_type) + dns_lookup_host(&host, addr_type) } async fn get_host_by_address( @@ -496,3 +452,29 @@ pub fn to_std_ipv4_addr(addr: Ipv4Addr) -> std::net::Ipv4Addr { pub fn to_nal_ipv4_addr(addr: std::net::Ipv4Addr) -> Ipv4Addr { addr.octets().into() } + +fn cvt(res: T) -> io::Result +where + T: Into + Copy, +{ + let ires: i64 = res.into(); + + if ires == -1 { + Err(io::Error::last_os_error()) + } else { + Ok(res) + } +} + +fn cvti(res: T) -> io::Result +where + T: Into + Copy, +{ + let ires: isize = res.into(); + + if ires == -1 { + Err(io::Error::last_os_error()) + } else { + Ok(res) + } +} diff --git a/edge-tcp/Cargo.toml b/edge-tcp/Cargo.toml deleted file mode 100644 index a76ddf8..0000000 --- a/edge-tcp/Cargo.toml +++ /dev/null @@ -1,19 +0,0 @@ -[package] -name = "edge-tcp" -version = "0.1.0" -authors = ["Ivan Markov "] -edition = "2021" -categories = ["embedded", "hardware-support"] -keywords = ["embedded", "svc", "network"] -description = "TCP traits for edge-net" -repository = "https://github.com/ivmarkov/edge-net" -license = "MIT OR Apache-2.0" -readme = "README.md" -rust-version = "1.71" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -embedded-io.workspace = true -embedded-io-async.workspace = true -no-std-net.workspace = true \ No newline at end of file diff --git a/edge-tcp/README.md b/edge-tcp/README.md deleted file mode 100644 index 7cef1ed..0000000 --- a/edge-tcp/README.md +++ /dev/null @@ -1 +0,0 @@ -# TCP traits for `edge-net` diff --git a/edge-tcp/src/lib.rs b/edge-tcp/src/lib.rs deleted file mode 100644 index 3f9538d..0000000 --- a/edge-tcp/src/lib.rs +++ /dev/null @@ -1,204 +0,0 @@ -#![no_std] -#![allow(stable_features)] -#![allow(unknown_lints)] -#![feature(async_fn_in_trait)] -#![allow(async_fn_in_trait)] -#![feature(impl_trait_projections)] -#![feature(impl_trait_in_assoc_type)] - -use core::fmt::Debug; -use no_std_net::SocketAddr; - -pub trait TcpSplittableConnection { - type Error: embedded_io::Error; - - type Read<'a>: embedded_io_async::Read - where - Self: 'a; - type Write<'a>: embedded_io_async::Write - where - Self: 'a; - - async fn split(&mut self) -> Result<(Self::Read<'_>, Self::Write<'_>), Self::Error>; -} - -impl<'t, T> TcpSplittableConnection for &'t mut T -where - T: TcpSplittableConnection + 't, -{ - type Error = T::Error; - - type Read<'a> = T::Read<'a> where Self: 'a; - - type Write<'a> = T::Write<'a> where Self: 'a; - - async fn split(&mut self) -> Result<(Self::Read<'_>, Self::Write<'_>), Self::Error> { - (**self).split().await - } -} - -pub trait TcpListen { - type Error: embedded_io::Error; - - type Acceptor<'m>: TcpAccept - where - Self: 'm; - - async fn listen(&self, remote: SocketAddr) -> Result, Self::Error>; -} - -impl TcpListen for &T -where - T: TcpListen, -{ - type Error = T::Error; - - type Acceptor<'m> = T::Acceptor<'m> - where Self: 'm; - - async fn listen(&self, remote: SocketAddr) -> Result, Self::Error> { - (*self).listen(remote).await - } -} - -impl TcpListen for &mut T -where - T: TcpListen, -{ - type Error = T::Error; - - type Acceptor<'m> = T::Acceptor<'m> - where Self: 'm; - - async fn listen(&self, remote: SocketAddr) -> Result, Self::Error> { - (**self).listen(remote).await - } -} - -pub trait TcpAccept { - type Error: embedded_io::Error; - - type Connection<'m>: embedded_io_async::Read - + embedded_io_async::Write - where - Self: 'm; - - async fn accept(&self) -> Result, Self::Error>; -} - -impl TcpAccept for &T -where - T: TcpAccept, -{ - type Error = T::Error; - - type Connection<'m> = T::Connection<'m> - where Self: 'm; - - async fn accept(&self) -> Result, Self::Error> { - (**self).accept().await - } -} - -impl TcpAccept for &mut T -where - T: TcpAccept, -{ - type Error = T::Error; - - type Connection<'m> = T::Connection<'m> - where Self: 'm; - - async fn accept(&self) -> Result, Self::Error> { - (**self).accept().await - } -} - -// TODO: Ideally should go to `embedded-nal-async` -pub trait RawSocket { - type Error: Debug + embedded_io_async::Error; - - async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error>; - async fn receive_into(&mut self, buffer: &mut [u8]) -> Result; -} - -// TODO: Ideally should go to `embedded-nal-async` -impl RawSocket for &mut T -where - T: RawSocket, -{ - type Error = T::Error; - - async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error> { - (**self).send(data).await - } - - async fn receive_into(&mut self, buffer: &mut [u8]) -> Result { - (**self).receive_into(buffer).await - } -} - -// TODO: Ideally should go to `embedded-nal-async` -pub trait RawStack { - type Error: Debug; - - type Socket: RawSocket; - - async fn connect(&self, interface: Option) -> Result; -} - -// TODO: Ideally should go to `embedded-nal-async` -impl RawStack for &T -where - T: RawStack, -{ - type Error = T::Error; - - type Socket = T::Socket; - - async fn connect(&self, interface: Option) -> Result { - (*self).connect(interface).await - } -} - -impl RawStack for &mut T -where - T: RawStack, -{ - type Error = T::Error; - - type Socket = T::Socket; - - async fn connect(&self, interface: Option) -> Result { - (**self).connect(interface).await - } -} - -// pub struct IO(pub T); - -// impl ErrorType for IO -// where -// T: RawSocket, -// { -// type Error = T::Error; -// } - -// impl Read for IO -// where -// T: RawSocket, -// { -// async fn read(&mut self, buf: &mut [u8]) -> Result { -// self.0.receive_into(buf).await -// } -// } - -// impl Write for IO -// where -// T: RawSocket, -// { -// async fn write(&mut self, buf: &[u8]) -> Result { -// self.0.send(buf).await?; - -// Ok(buf.len()) -// } -// } diff --git a/edge-ws/Cargo.toml b/edge-ws/Cargo.toml index dde441c..295e204 100644 --- a/edge-ws/Cargo.toml +++ b/edge-ws/Cargo.toml @@ -3,12 +3,9 @@ name = "edge-ws" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +nightly = ["embedded-io-async", "embedded-svc?/nightly"] [dependencies] -edge-http.workspace = true -embedded-io-async.workspace = true -embedded-nal-async.workspace = true - -base64 = { version = "0.13", default-features = false } -sha1_smol = { version = "1", default-features = false } \ No newline at end of file +embedded-io-async = { workspace = true, optional = true } +embedded-svc = { workspace = true, optional = true, default-features = false } diff --git a/edge-ws/src/io.rs b/edge-ws/src/io.rs new file mode 100644 index 0000000..858b6d3 --- /dev/null +++ b/edge-ws/src/io.rs @@ -0,0 +1,210 @@ +use core::cmp::min; + +use embedded_io_async::{self, Read, ReadExactError, Write}; + +use super::*; + +#[cfg(feature = "embedded-svc")] +pub use embedded_svc_compat::*; + +impl From> for Error { + fn from(e: ReadExactError) -> Self { + match e { + ReadExactError::UnexpectedEof => Error::Invalid, + ReadExactError::Other(e) => Error::Io(e), + } + } +} + +impl FrameHeader { + pub async fn recv(mut read: R) -> Result> + where + R: Read, + { + let mut header_buf = [0; FrameHeader::MAX_LEN]; + let mut read_offset = 0; + let mut read_end = FrameHeader::MIN_LEN; + + loop { + read.read_exact(&mut header_buf[read_offset..read_end]) + .await + .map_err(Error::from)?; + + match FrameHeader::deserialize(&header_buf[..read_end]) { + Ok((header, _)) => return Ok(header), + Err(Error::Incomplete(more)) => { + read_offset = read_end; + read_end += more; + } + Err(e) => return Err(e.recast()), + } + } + } + + pub async fn send(&self, mut write: W) -> Result<(), Error> + where + W: Write, + { + let mut header_buf = [0; FrameHeader::MAX_LEN]; + let header_len = self.serialize(&mut header_buf).unwrap(); + + write + .write_all(&header_buf[..header_len]) + .await + .map_err(Error::Io) + } + + pub async fn recv_payload<'a, R>( + &'a self, + mut read: R, + payload_buf: &'a mut [u8], + ) -> Result<(), Error> + where + R: Read, + { + if (payload_buf.len() as u64) < self.payload_len { + Err(Error::BufferOverflow) + } else if self.payload_len == 0 { + Ok(()) + } else { + let payload = &mut payload_buf[..self.payload_len as _]; + + read.read_exact(payload).await.map_err(Error::from)?; + + self.mask(payload, 0); + + Ok(()) + } + } + + pub async fn send_payload<'a, W>( + &'a self, + mut write: W, + payload: &'a [u8], + ) -> Result<(), Error> + where + W: Write, + { + let payload_buf_len = payload.len() as u64; + + if payload_buf_len != self.payload_len { + Err(Error::InvalidLen) + } else if payload.is_empty() { + Ok(()) + } else if self.mask_key.is_none() { + write.write_all(payload).await.map_err(Error::Io) + } else { + let mut buf = [0_u8; 64]; + + let mut offset = 0; + + while offset < payload.len() { + let len = min(buf.len(), payload.len() - offset); + + buf[..len].copy_from_slice(&payload[offset..offset + len]); + + self.mask(&mut buf, offset); + + write.write_all(&buf).await.map_err(Error::Io)?; + + offset += len; + } + + Ok(()) + } + } +} + +pub async fn recv( + mut read: R, + frame_data_buf: &mut [u8], +) -> Result<(FrameType, usize), Error> +where + R: Read, +{ + let header = FrameHeader::recv(&mut read).await?; + header.recv_payload(read, frame_data_buf).await?; + + Ok((header.frame_type, header.payload_len as _)) +} + +pub async fn send( + mut write: W, + frame_type: FrameType, + mask_key: Option, + frame_data_buf: &[u8], +) -> Result<(), Error> +where + W: Write, +{ + let header = FrameHeader { + frame_type, + payload_len: frame_data_buf.len() as _, + mask_key, + }; + + header.send(&mut write).await?; + header.send_payload(write, frame_data_buf).await +} + +#[cfg(feature = "embedded-svc")] +mod embedded_svc_compat { + use core::convert::TryInto; + + use embedded_io_async::{Read, Write}; + use embedded_svc::io::ErrorType as IoErrorType; + use embedded_svc::ws::asynch::Sender; + use embedded_svc::ws::ErrorType; + use embedded_svc::ws::{asynch::Receiver, FrameType}; + + use super::Error; + + pub struct WsConnection(T, M); + + impl WsConnection { + pub const fn new(connection: T, mask_gen: M) -> Self { + Self(connection, mask_gen) + } + } + + impl ErrorType for WsConnection + where + T: IoErrorType, + { + type Error = Error; + } + + impl Receiver for WsConnection + where + T: Read, + { + async fn recv( + &mut self, + frame_data_buf: &mut [u8], + ) -> Result<(FrameType, usize), Self::Error> { + super::recv(&mut self.0, frame_data_buf) + .await + .map(|(frame_type, payload_len)| (frame_type.into(), payload_len)) + } + } + + impl Sender for WsConnection + where + T: Write, + M: Fn() -> Option, + { + async fn send( + &mut self, + frame_type: FrameType, + frame_data: &[u8], + ) -> Result<(), Self::Error> { + super::send( + &mut self.0, + frame_type.try_into().unwrap(), + (self.1)(), + frame_data, + ) + .await + } + } +} diff --git a/edge-ws/src/lib.rs b/edge-ws/src/lib.rs index dbeb6d9..45d0ba9 100644 --- a/edge-ws/src/lib.rs +++ b/edge-ws/src/lib.rs @@ -1,12 +1,20 @@ #![cfg_attr(not(feature = "std"), no_std)] - -use core::cmp::min; - -use embedded_io_async::{Read, ReadExactError, Write}; +#![allow(stable_features)] +#![allow(unknown_lints)] +#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", allow(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", feature(impl_trait_projections))] pub type Fragmented = bool; pub type Final = bool; +#[allow(unused)] +#[cfg(feature = "embedded-svc")] +pub use embedded_svc_compat::*; + +#[cfg(feature = "nightly")] +pub mod io; + #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum FrameType { Text(Fragmented), @@ -56,15 +64,6 @@ impl Error<()> { } } -impl From> for Error { - fn from(e: ReadExactError) -> Self { - match e { - ReadExactError::UnexpectedEof => Error::Invalid, - ReadExactError::Other(e) => Error::Io(e), - } - } -} - #[derive(Clone, Debug)] pub struct FrameHeader { pub frame_type: FrameType, @@ -249,397 +248,13 @@ impl FrameHeader { } } } - - pub async fn recv(mut read: R) -> Result> - where - R: Read, - { - let mut header_buf = [0; FrameHeader::MAX_LEN]; - let mut read_offset = 0; - let mut read_end = FrameHeader::MIN_LEN; - - loop { - read.read_exact(&mut header_buf[read_offset..read_end]) - .await - .map_err(Error::from)?; - - match FrameHeader::deserialize(&header_buf[..read_end]) { - Ok((header, _)) => return Ok(header), - Err(Error::Incomplete(more)) => { - read_offset = read_end; - read_end += more; - } - Err(e) => return Err(e.recast()), - } - } - } - - pub async fn send(&self, mut write: W) -> Result<(), Error> - where - W: Write, - { - let mut header_buf = [0; FrameHeader::MAX_LEN]; - let header_len = self.serialize(&mut header_buf).unwrap(); - - write - .write_all(&header_buf[..header_len]) - .await - .map_err(Error::Io) - } - - pub async fn recv_payload<'a, R>( - &'a self, - mut read: R, - payload_buf: &'a mut [u8], - ) -> Result<(), Error> - where - R: Read, - { - if (payload_buf.len() as u64) < self.payload_len { - Err(Error::BufferOverflow) - } else if self.payload_len == 0 { - Ok(()) - } else { - let payload = &mut payload_buf[..self.payload_len as _]; - - read.read_exact(payload).await.map_err(Error::from)?; - - self.mask(payload, 0); - - Ok(()) - } - } - - pub async fn send_payload<'a, W>( - &'a self, - mut write: W, - payload: &'a [u8], - ) -> Result<(), Error> - where - W: Write, - { - let payload_buf_len = payload.len() as u64; - - if payload_buf_len != self.payload_len { - Err(Error::InvalidLen) - } else if payload.is_empty() { - Ok(()) - } else if self.mask_key.is_none() { - write.write_all(payload).await.map_err(Error::Io) - } else { - let mut buf = [0_u8; 64]; - - let mut offset = 0; - - while offset < payload.len() { - let len = min(buf.len(), payload.len() - offset); - - buf[..len].copy_from_slice(&payload[offset..offset + len]); - - self.mask(&mut buf, offset); - - write.write_all(&buf).await.map_err(Error::Io)?; - - offset += len; - } - - Ok(()) - } - } -} - -pub async fn recv( - mut read: R, - frame_data_buf: &mut [u8], -) -> Result<(FrameType, usize), Error> -where - R: Read, -{ - let header = FrameHeader::recv(&mut read).await?; - header.recv_payload(read, frame_data_buf).await?; - - Ok((header.frame_type, header.payload_len as _)) } -pub async fn send( - mut write: W, - frame_type: FrameType, - mask_key: Option, - frame_data_buf: &[u8], -) -> Result<(), Error> -where - W: Write, -{ - let header = FrameHeader { - frame_type, - payload_len: frame_data_buf.len() as _, - mask_key, - }; - - header.send(&mut write).await?; - header.send_payload(write, frame_data_buf).await -} - -pub mod http { - use edge_http::Headers; - - pub const NONCE_LEN: usize = 16; - pub const MAX_BASE64_KEY_LEN: usize = 28; - pub const MAX_BASE64_KEY_RESPONSE_LEN: usize = 33; - - pub const UPGRADE_REQUEST_HEADERS_LEN: usize = 7; - pub const UPGRADE_RESPONSE_HEADERS_LEN: usize = 3; - - pub fn upgrade_request_headers<'a>( - host: Option<&'a str>, - origin: Option<&'a str>, - version: Option<&'a str>, - nonce: &[u8; NONCE_LEN], - nonce_base64_buf: &'a mut [u8; MAX_BASE64_KEY_LEN], - ) -> [(&'a str, &'a str); UPGRADE_REQUEST_HEADERS_LEN] { - let nonce_base64_len = - base64::encode_config_slice(nonce, base64::URL_SAFE, nonce_base64_buf); - - let host = host.map(|host| ("Host", host)).unwrap_or(("", "")); - let origin = origin.map(|origin| ("Origin", origin)).unwrap_or(("", "")); - - [ - host, - origin, - ("Content-Length", "0"), - ("Connection", "Upgrade"), - ("Upgrade", "websocket"), - ("Sec-WebSocket-Version", version.unwrap_or("13")), - ("Sec-WebSocket-Key", unsafe { - core::str::from_utf8_unchecked(&nonce_base64_buf[..nonce_base64_len]) - }), - ] - } - - pub fn is_upgrade_request<'a, H>(request_headers: H) -> bool - where - H: IntoIterator, - { - let mut connection = false; - let mut upgrade = false; - - for (name, value) in request_headers { - if name.eq_ignore_ascii_case("Connection") { - connection = value.eq_ignore_ascii_case("Upgrade"); - } else if name.eq_ignore_ascii_case("Upgrade") { - upgrade = value.eq_ignore_ascii_case("websocket"); - } - } - - connection && upgrade - } - - #[derive(Debug, Copy, Clone, Eq, PartialEq)] - pub enum UpgradeError { - NoVersion, - NoSecKey, - UnsupportedVersion, - SecKeyTooLong, - } - - pub fn upgrade_response_headers<'a, 'b, H>( - request_headers: H, - version: Option<&'a str>, - sec_key_response_base64_buf: &'b mut [u8; MAX_BASE64_KEY_RESPONSE_LEN], - ) -> Result<[(&'b str, &'b str); UPGRADE_RESPONSE_HEADERS_LEN], UpgradeError> - where - H: IntoIterator, - { - let mut version_ok = false; - let mut sec_key = None; - - for (name, value) in request_headers { - if name.eq_ignore_ascii_case("Sec-WebSocket-Version") { - if !value.eq_ignore_ascii_case(version.unwrap_or("13")) { - return Err(UpgradeError::NoVersion); - } - - version_ok = true; - } else if name.eq_ignore_ascii_case("Sec-WebSocket-Key") { - const WS_MAGIC_GUUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - - let mut buf = [0_u8; MAX_BASE64_KEY_LEN + WS_MAGIC_GUUID.len()]; - - let value_len = value.as_bytes().len(); - - if value_len > MAX_BASE64_KEY_LEN { - return Err(UpgradeError::SecKeyTooLong); - } - - buf[..value_len].copy_from_slice(value.as_bytes()); - buf[value_len..value_len + WS_MAGIC_GUUID.as_bytes().len()] - .copy_from_slice(WS_MAGIC_GUUID.as_bytes()); - - let mut sha1 = sha1_smol::Sha1::new(); - - sha1.update(&buf[..value_len + WS_MAGIC_GUUID.as_bytes().len()]); - - let sec_key_len = base64::encode_config_slice( - sha1.digest().bytes(), - base64::STANDARD_NO_PAD, - sec_key_response_base64_buf, - ); - - sec_key = Some(sec_key_len); - } - } - - if version_ok { - if let Some(sec_key_len) = sec_key { - Ok([ - ("Connection", "Upgrade"), - ("Upgrade", "websocket"), - ("Sec-WebSocket-Accept", unsafe { - core::str::from_utf8_unchecked(&sec_key_response_base64_buf[..sec_key_len]) - }), - ]) - } else { - Err(UpgradeError::NoSecKey) - } - } else { - Err(UpgradeError::NoVersion) - } - } - - pub mod client { - use embedded_nal_async::TcpConnect; - - use edge_http::{client::ClientConnection, Error, Method}; - - use super::{upgrade_request_headers, MAX_BASE64_KEY_LEN, NONCE_LEN}; - - pub async fn initiate_ws_upgrade_request<'a, 'b, const N: usize, T>( - connection: &'a mut ClientConnection<'b, N, T>, - host: Option<&'a str>, - origin: Option<&'a str>, - uri: &'a str, - version: Option<&'a str>, - nonce: &[u8; NONCE_LEN], - ) -> Result<(), Error> - where - T: TcpConnect, - { - let mut nonce_base64_buf = [0_u8; MAX_BASE64_KEY_LEN]; - - let headers = - upgrade_request_headers(host, origin, version, nonce, &mut nonce_base64_buf); - - connection - .initiate_request(Method::Get, uri, &headers) - .await - } - - pub fn is_ws_upgrade_accepted<'a, const N: usize, T>( - connection: &'a ClientConnection<'_, N, T>, - _nonce: &'a [u8; NONCE_LEN], - ) -> Result> - where - T: TcpConnect, - { - let headers = connection.headers()?; - - let succeeded = matches!(headers.code, Some(101)) - && headers - .headers - .connection() - .map(|v| v.eq_ignore_ascii_case("Upgrade")) - .unwrap_or(false) - && headers - .headers - .upgrade() - .map(|v| v.eq_ignore_ascii_case("websocket")) - .unwrap_or(false) - && headers.headers.get("Sec-WebSocket-Accept").is_some(); - - Ok(succeeded) - } - } - - pub trait HeaderExt<'b> { - fn is_ws_upgrade_request(&self) -> bool; - - fn set_ws_upgrade_request_headers( - &mut self, - host: Option<&'b str>, - origin: Option<&'b str>, - version: Option<&'b str>, - nonce: &[u8; crate::http::NONCE_LEN], - nonce_base64_buf: &'b mut [u8; crate::http::MAX_BASE64_KEY_LEN], - ) -> &mut Self; - - fn set_ws_upgrade_response_headers<'a, H>( - &mut self, - request_headers: H, - version: Option<&'a str>, - sec_key_response_base64_buf: &'b mut [u8; crate::http::MAX_BASE64_KEY_RESPONSE_LEN], - ) -> Result<&mut Self, UpgradeError> - where - H: IntoIterator; - } - - impl<'b, const N: usize> HeaderExt<'b> for Headers<'b, N> { - fn is_ws_upgrade_request(&self) -> bool { - crate::http::is_upgrade_request(self.iter()) - } - - fn set_ws_upgrade_request_headers( - &mut self, - host: Option<&'b str>, - origin: Option<&'b str>, - version: Option<&'b str>, - nonce: &[u8; crate::http::NONCE_LEN], - nonce_base64_buf: &'b mut [u8; crate::http::MAX_BASE64_KEY_LEN], - ) -> &mut Self { - for (name, value) in - crate::http::upgrade_request_headers(host, origin, version, nonce, nonce_base64_buf) - { - self.set(name, value); - } - - self - } - - fn set_ws_upgrade_response_headers<'a, H>( - &mut self, - request_headers: H, - version: Option<&'a str>, - sec_key_response_base64_buf: &'b mut [u8; crate::http::MAX_BASE64_KEY_RESPONSE_LEN], - ) -> Result<&mut Self, UpgradeError> - where - H: IntoIterator, - { - for (name, value) in crate::http::upgrade_response_headers( - request_headers, - version, - sec_key_response_base64_buf, - )? { - self.set(name, value); - } - - Ok(self) - } - } -} - -#[cfg(feature = "embedded-svc")] -pub use embedded_svc_compat::*; - #[cfg(feature = "embedded-svc")] mod embedded_svc_compat { - use core::convert::{TryFrom, TryInto}; + use core::convert::TryFrom; - use embedded_io_async::{Read, Write}; - use embedded_svc::io::ErrorType as IoErrorType; - use embedded_svc::ws::asynch::Sender; - use embedded_svc::ws::ErrorType; - use embedded_svc::ws::{asynch::Receiver, FrameType}; - - use super::Error; + use embedded_svc::ws::FrameType; impl From for FrameType { fn from(frame_type: super::FrameType) -> Self { @@ -671,53 +286,4 @@ mod embedded_svc_compat { Ok(f) } } - - pub struct WsConnection(T, M); - - impl WsConnection { - pub const fn new(connection: T, mask_gen: M) -> Self { - Self(connection, mask_gen) - } - } - - impl ErrorType for WsConnection - where - T: IoErrorType, - { - type Error = Error; - } - - impl Receiver for WsConnection - where - T: Read, - { - async fn recv( - &mut self, - frame_data_buf: &mut [u8], - ) -> Result<(FrameType, usize), Self::Error> { - super::recv(&mut self.0, frame_data_buf) - .await - .map(|(frame_type, payload_len)| (frame_type.into(), payload_len)) - } - } - - impl Sender for WsConnection - where - T: Write, - M: Fn() -> Option, - { - async fn send( - &mut self, - frame_type: FrameType, - frame_data: &[u8], - ) -> Result<(), Self::Error> { - super::send( - &mut self.0, - frame_type.try_into().unwrap(), - (self.1)(), - frame_data, - ) - .await - } - } } diff --git a/embedded-nal-async-xtra/Cargo.toml b/embedded-nal-async-xtra/Cargo.toml new file mode 100644 index 0000000..3eb6e69 --- /dev/null +++ b/embedded-nal-async-xtra/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "embedded-nal-async-xtra" +version = "0.1.0" +edition = "2021" + +[features] +nightly = ["embedded-io-async", "embedded-nal-async"] + +[dependencies] +embedded-io-async = { workspace = true, optional = true } +embedded-nal-async = { workspace = true, optional = true } diff --git a/embedded-nal-async-xtra/src/lib.rs b/embedded-nal-async-xtra/src/lib.rs new file mode 100644 index 0000000..3673f59 --- /dev/null +++ b/embedded-nal-async-xtra/src/lib.rs @@ -0,0 +1,12 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![allow(stable_features)] +#![allow(unknown_lints)] +#![allow(async_fn_in_trait)] +#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", feature(impl_trait_projections))] + +#[cfg(feature = "nightly")] +pub use stack::*; + +#[cfg(feature = "nightly")] +mod stack; diff --git a/embedded-nal-async-xtra/src/stack.rs b/embedded-nal-async-xtra/src/stack.rs new file mode 100644 index 0000000..b4d2104 --- /dev/null +++ b/embedded-nal-async-xtra/src/stack.rs @@ -0,0 +1,5 @@ +pub use raw::*; +pub use tcp::*; + +mod raw; +mod tcp; diff --git a/embedded-nal-async-xtra/src/stack/raw.rs b/embedded-nal-async-xtra/src/stack/raw.rs new file mode 100644 index 0000000..31d78e8 --- /dev/null +++ b/embedded-nal-async-xtra/src/stack/raw.rs @@ -0,0 +1,61 @@ +pub trait RawSocket { + type Error: embedded_io_async::Error; + + async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error>; + async fn receive_into(&mut self, buffer: &mut [u8]) -> Result; +} + +impl RawSocket for &mut T +where + T: RawSocket, +{ + type Error = T::Error; + + async fn send(&mut self, data: &[u8]) -> Result<(), Self::Error> { + (**self).send(data).await + } + + async fn receive_into(&mut self, buffer: &mut [u8]) -> Result { + (**self).receive_into(buffer).await + } +} + +pub trait RawStack { + type Error: embedded_io_async::Error; + + type Socket: RawSocket; + + type Interface; + + async fn bind(&self, interface: &Self::Interface) -> Result; +} + +impl RawStack for &T +where + T: RawStack, +{ + type Error = T::Error; + + type Socket = T::Socket; + + type Interface = T::Interface; + + async fn bind(&self, interface: &Self::Interface) -> Result { + (*self).bind(interface).await + } +} + +impl RawStack for &mut T +where + T: RawStack, +{ + type Error = T::Error; + + type Socket = T::Socket; + + type Interface = T::Interface; + + async fn bind(&self, interface: &Self::Interface) -> Result { + (**self).bind(interface).await + } +} diff --git a/embedded-nal-async-xtra/src/stack/tcp.rs b/embedded-nal-async-xtra/src/stack/tcp.rs new file mode 100644 index 0000000..ccff6d6 --- /dev/null +++ b/embedded-nal-async-xtra/src/stack/tcp.rs @@ -0,0 +1,106 @@ +use embedded_nal_async::SocketAddr; + +pub trait TcpSplittableConnection { + type Error: embedded_io_async::Error; + + type Read<'a>: embedded_io_async::Read + where + Self: 'a; + type Write<'a>: embedded_io_async::Write + where + Self: 'a; + + async fn split(&mut self) -> Result<(Self::Read<'_>, Self::Write<'_>), Self::Error>; +} + +impl<'t, T> TcpSplittableConnection for &'t mut T +where + T: TcpSplittableConnection + 't, +{ + type Error = T::Error; + + type Read<'a> = T::Read<'a> where Self: 'a; + + type Write<'a> = T::Write<'a> where Self: 'a; + + async fn split(&mut self) -> Result<(Self::Read<'_>, Self::Write<'_>), Self::Error> { + (**self).split().await + } +} + +pub trait TcpListen { + type Error: embedded_io_async::Error; + + type Acceptor<'m>: TcpAccept + where + Self: 'm; + + async fn listen(&self, remote: SocketAddr) -> Result, Self::Error>; +} + +impl TcpListen for &T +where + T: TcpListen, +{ + type Error = T::Error; + + type Acceptor<'m> = T::Acceptor<'m> + where Self: 'm; + + async fn listen(&self, remote: SocketAddr) -> Result, Self::Error> { + (*self).listen(remote).await + } +} + +impl TcpListen for &mut T +where + T: TcpListen, +{ + type Error = T::Error; + + type Acceptor<'m> = T::Acceptor<'m> + where Self: 'm; + + async fn listen(&self, remote: SocketAddr) -> Result, Self::Error> { + (**self).listen(remote).await + } +} + +pub trait TcpAccept { + type Error: embedded_io_async::Error; + + type Connection<'m>: embedded_io_async::Read + + embedded_io_async::Write + where + Self: 'm; + + async fn accept(&self) -> Result, Self::Error>; +} + +impl TcpAccept for &T +where + T: TcpAccept, +{ + type Error = T::Error; + + type Connection<'m> = T::Connection<'m> + where Self: 'm; + + async fn accept(&self) -> Result, Self::Error> { + (**self).accept().await + } +} + +impl TcpAccept for &mut T +where + T: TcpAccept, +{ + type Error = T::Error; + + type Connection<'m> = T::Connection<'m> + where Self: 'm; + + async fn accept(&self) -> Result, Self::Error> { + (**self).accept().await + } +} diff --git a/src/asynch.rs b/src/asynch.rs deleted file mode 100644 index b08d473..0000000 --- a/src/asynch.rs +++ /dev/null @@ -1,90 +0,0 @@ -#[cfg(feature = "embedded-svc")] -pub use embedded_svc_compat::*; - -use core::future::Future; - -pub trait Unblocker { - type UnblockFuture<'a, F, T>: Future + Send - where - Self: 'a, - F: Send + 'a, - T: Send + 'a; - - fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> - where - F: FnOnce() -> T + Send + 'a, - T: Send + 'a; -} - -impl Unblocker for &U -where - U: Unblocker, -{ - type UnblockFuture<'a, F, T> - = U::UnblockFuture<'a, F, T> where Self: 'a, F: Send + 'a, T: Send + 'a; - - fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> - where - F: FnOnce() -> T + Send + 'a, - T: Send + 'a, - { - (*self).unblock(f) - } -} - -impl Unblocker for &mut U -where - U: Unblocker, -{ - type UnblockFuture<'a, F, T> - = U::UnblockFuture<'a, F, T> where Self: 'a, F: Send + 'a, T: Send + 'a; - - fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> - where - F: FnOnce() -> T + Send + 'a, - T: Send + 'a, - { - (**self).unblock(f) - } -} - -#[cfg(feature = "embedded-svc")] -mod embedded_svc_compat { - use core::future::Future; - - use super::Unblocker; - - pub struct UnblockerCompat(U); - - impl Unblocker for UnblockerCompat - where - U: embedded_svc::utils::asyncify::Unblocker, - { - type UnblockFuture<'a, F, T> = impl Future + Send - where Self: 'a, F: Send + 'a, T: Send + 'a; - - fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> - where - F: FnOnce() -> T + Send + 'a, - T: Send + 'a, - { - self.0.unblock(f) - } - } - - impl embedded_svc::utils::asyncify::Unblocker for UnblockerCompat - where - U: Unblocker, - { - type UnblockFuture<'a, F, T> = impl Future + Send - where Self: 'a, F: Send + 'a, T: Send + 'a; - - fn unblock<'a, F, T>(&'a self, f: F) -> Self::UnblockFuture<'a, F, T> - where - F: FnOnce() -> T + Send + 'a, - T: Send + 'a, - { - self.0.unblock(f) - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 49ef4de..75b21a7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,23 +1,15 @@ #![cfg_attr(not(feature = "std"), no_std)] #![allow(stable_features)] -#![cfg_attr(feature = "nightly", feature(impl_trait_in_assoc_type))] // Used in Unblocker +#![allow(unknown_lints)] +#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", allow(async_fn_in_trait))] +#![cfg_attr(feature = "nightly", feature(impl_trait_projections))] -// Re-export enabled sub-crates -#[cfg(feature = "edge-captive")] pub use edge_captive as captive; -#[cfg(feature = "edge-dhcp")] pub use edge_dhcp as dhcp; -#[cfg(feature = "edge-http")] pub use edge_http as http; -#[cfg(feature = "edge-mdns")] pub use edge_mdns as mdns; -#[cfg(feature = "edge-mqtt")] pub use edge_mqtt as mqtt; -#[cfg(feature = "edge-ws")] +pub use edge_raw as raw; +pub use edge_std_nal_async as std_nal; pub use edge_ws as ws; - -#[cfg(feature = "nightly")] -pub mod asynch; - -#[cfg(feature = "std")] -pub mod std; diff --git a/src/std.rs b/src/std.rs deleted file mode 100644 index 8061cf0..0000000 --- a/src/std.rs +++ /dev/null @@ -1,17 +0,0 @@ -#[cfg(feature = "nightly")] -pub mod nal; - -use embassy_sync::blocking_mutex::raw::RawMutex; - -pub struct StdRawMutex(std::sync::Mutex<()>); - -unsafe impl RawMutex for StdRawMutex { - #[allow(clippy::declare_interior_mutable_const)] - const INIT: Self = Self(std::sync::Mutex::new(())); - - fn lock(&self, f: impl FnOnce() -> R) -> R { - let _guard = self.0.lock().unwrap(); - - f() - } -}