diff --git a/src/crawler.rs b/src/crawler.rs index 20b63f4..3679d22 100644 --- a/src/crawler.rs +++ b/src/crawler.rs @@ -12,7 +12,7 @@ use url::Url; use crate::{ config::CrawlerConfig, - serde_types::{Instance, LoadedData, Service}, + serde_types::{HttpCodeRanges, Instance, LoadedData, Service}, }; fn default_headers() -> reqwest::header::HeaderMap { @@ -172,27 +172,7 @@ impl Crawler { Ok(response) => { let end = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); let status_code = response.status().as_u16(); - let mut status_valid = false; - match status_code { - 200 => status_valid = true, - 300..=399 => { - if service.allow_3xx { - status_valid = true; - } - } - 400..=499 => { - if service.allow_4xx { - status_valid = true; - } - } - 500..=599 => { - if service.allow_5xx { - status_valid = true; - } - } - _ => {} - } - if status_valid { + if service.allowed_http_codes.is_allowed(status_code) { if let Some(search_string) = &service.search_string { let body = response.text().await?; if !body.contains(search_string) { diff --git a/src/serde_types.rs b/src/serde_types.rs index dfd2700..24bbf0a 100644 --- a/src/serde_types.rs +++ b/src/serde_types.rs @@ -1,6 +1,9 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt, vec}; -use serde::{Deserialize, Serialize}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use url::Url; #[derive(Deserialize, Serialize, Debug, Clone)] @@ -30,6 +33,126 @@ pub struct RegexSearch { pub type Regexes = HashMap>; +pub trait HttpCodeRanges { + fn is_allowed(&self, code: u16) -> bool; +} + +#[derive(Debug, Clone)] +pub struct AllowedHttpCodes { + pub codes: Vec, + pub inclusive_ranges: Vec<(u16, u16)>, + pub exclusive_ranges: Vec<(u16, u16)>, +} + +impl HttpCodeRanges for AllowedHttpCodes { + fn is_allowed(&self, code: u16) -> bool { + if self.codes.contains(&code) { + return true; + } + + for &(start, end) in &self.inclusive_ranges { + if code >= start && code <= end { + return true; + } + } + + for &(start, end) in &self.exclusive_ranges { + if code >= start && code < end { + return true; + } + } + + false + } +} + +impl<'de> Deserialize<'de> for AllowedHttpCodes { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct AllowedHttpCodesVisitor; + + impl<'de> Visitor<'de> for AllowedHttpCodesVisitor { + type Value = AllowedHttpCodes; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string representing allowed HTTP codes and ranges") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + let mut codes = Vec::new(); + let mut inclusive_ranges = Vec::new(); + let mut exclusive_ranges = Vec::new(); + + for part in value.split(',') { + if part.contains("..=") { + let mut split = part.split("..="); + let start = split.next().unwrap().trim(); + let end = split.next().unwrap().trim(); + let start = start.parse::().map_err(de::Error::custom)?; + let end = end.parse::().map_err(de::Error::custom)?; + inclusive_ranges.push((start, end)); + } else if part.contains("..") { + let mut split = part.split(".."); + let start = split.next().unwrap().trim(); + let end = split.next().unwrap().trim(); + let start = start.parse::().map_err(de::Error::custom)?; + let end = end.parse::().map_err(de::Error::custom)?; + exclusive_ranges.push((start, end)); + } else { + let code = part.trim().parse::().map_err(de::Error::custom)?; + codes.push(code); + } + } + + Ok(AllowedHttpCodes { + codes, + inclusive_ranges, + exclusive_ranges, + }) + } + } + + deserializer.deserialize_str(AllowedHttpCodesVisitor) + } +} + +impl Serialize for AllowedHttpCodes { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut parts = Vec::new(); + + for &code in &self.codes { + parts.push(code.to_string()); + } + + for &(start, end) in &self.inclusive_ranges { + parts.push(format!("{}..={}", start, end)); + } + + for &(start, end) in &self.exclusive_ranges { + parts.push(format!("{}..{}", start, end)); + } + + let result = parts.join(","); + serializer.serialize_str(&result) + } +} + +fn default_allowed_http_codes() -> AllowedHttpCodes { + AllowedHttpCodes { + codes: vec![200], + inclusive_ranges: Vec::new(), + exclusive_ranges: Vec::new(), + } +} + #[derive(Deserialize, Serialize, Debug, Clone)] pub struct Service { #[serde(rename = "type")] @@ -39,12 +162,8 @@ pub struct Service { pub fallback: Url, #[serde(default = "default_follow_redirects")] pub follow_redirects: bool, - #[serde(default)] - pub allow_3xx: bool, - #[serde(default)] - pub allow_4xx: bool, - #[serde(default)] - pub allow_5xx: bool, + #[serde(default = "default_allowed_http_codes")] + pub allowed_http_codes: AllowedHttpCodes, #[serde(default)] pub search_string: Option, #[serde(default)]