From dc70318e819efc0d0535a5d7bd35a0c7ab8e9106 Mon Sep 17 00:00:00 2001 From: Nathan Merrill Date: Fri, 17 Nov 2023 09:57:27 -0700 Subject: [PATCH] fix: replace panics with results & better option types (#437) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Replace panics with results & better option types * Apply suggestions from code review * Remove Reqwest from AWClient::new and return a Result instead * Run cargo fmt --------- Co-authored-by: NathanM <2955071+nathanmerrill@users.noreply.github.com> Co-authored-by: Erik Bjäreholt --- aw-client-rust/src/blocking.rs | 12 ++-- aw-client-rust/src/lib.rs | 22 ++++--- aw-client-rust/tests/test.rs | 6 +- aw-models/src/bucket.rs | 2 +- aw-sync/src/dirs.rs | 13 ++-- aw-sync/src/main.rs | 108 +++++++++++++-------------------- aw-sync/src/sync.rs | 22 ++++--- aw-sync/src/sync_wrapper.rs | 56 ++++++++--------- aw-sync/src/util.rs | 10 +-- 9 files changed, 114 insertions(+), 137 deletions(-) diff --git a/aw-client-rust/src/blocking.rs b/aw-client-rust/src/blocking.rs index f5aab1c7..3acd8eb9 100644 --- a/aw-client-rust/src/blocking.rs +++ b/aw-client-rust/src/blocking.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; use std::future::Future; use std::vec::Vec; +use std::{collections::HashMap, error::Error}; use chrono::{DateTime, Utc}; @@ -10,7 +10,7 @@ use super::AwClient as AsyncAwClient; pub struct AwClient { client: AsyncAwClient, - pub baseurl: String, + pub baseurl: reqwest::Url, pub name: String, pub hostname: String, } @@ -38,15 +38,15 @@ macro_rules! proxy_method } impl AwClient { - pub fn new(ip: &str, port: &str, name: &str) -> AwClient { - let async_client = AsyncAwClient::new(ip, port, name); + pub fn new(host: &str, port: u16, name: &str) -> Result> { + let async_client = AsyncAwClient::new(host, port, name)?; - AwClient { + Ok(AwClient { baseurl: async_client.baseurl.clone(), name: async_client.name.clone(), hostname: async_client.hostname.clone(), client: async_client, - } + }) } proxy_method!(get_bucket, Bucket, bucketname: &str); diff --git a/aw-client-rust/src/lib.rs b/aw-client-rust/src/lib.rs index 0c055f91..18fad49e 100644 --- a/aw-client-rust/src/lib.rs +++ b/aw-client-rust/src/lib.rs @@ -7,8 +7,8 @@ extern crate tokio; pub mod blocking; -use std::collections::HashMap; use std::vec::Vec; +use std::{collections::HashMap, error::Error}; use chrono::{DateTime, Utc}; use serde_json::Map; @@ -17,7 +17,7 @@ pub use aw_models::{Bucket, BucketMetadata, Event}; pub struct AwClient { client: reqwest::Client, - pub baseurl: String, + pub baseurl: reqwest::Url, pub name: String, pub hostname: String, } @@ -28,20 +28,24 @@ impl std::fmt::Debug for AwClient { } } +fn get_hostname() -> String { + return gethostname::gethostname().to_string_lossy().to_string(); +} + impl AwClient { - pub fn new(ip: &str, port: &str, name: &str) -> AwClient { - let baseurl = format!("http://{ip}:{port}"); + pub fn new(host: &str, port: u16, name: &str) -> Result> { + let baseurl = reqwest::Url::parse(&format!("http://{}:{}", host, port))?; + let hostname = get_hostname(); let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(120)) - .build() - .unwrap(); - let hostname = gethostname::gethostname().into_string().unwrap(); - AwClient { + .build()?; + + Ok(AwClient { client, baseurl, name: name.to_string(), hostname, - } + }) } pub async fn get_bucket(&self, bucketname: &str) -> Result { diff --git a/aw-client-rust/tests/test.rs b/aw-client-rust/tests/test.rs index 8fd21c1f..cce54a5f 100644 --- a/aw-client-rust/tests/test.rs +++ b/aw-client-rust/tests/test.rs @@ -60,10 +60,10 @@ mod test { #[test] fn test_full() { - let ip = "127.0.0.1"; - let port: String = PORT.to_string(); let clientname = "aw-client-rust-test"; - let client: AwClient = AwClient::new(ip, &port, clientname); + + let client: AwClient = + AwClient::new("127.0.0.1", PORT, clientname).expect("Client creation failed"); let shutdown_handler = setup_testserver(); diff --git a/aw-models/src/bucket.rs b/aw-models/src/bucket.rs index c0537d2b..e608006e 100644 --- a/aw-models/src/bucket.rs +++ b/aw-models/src/bucket.rs @@ -49,7 +49,7 @@ fn test_bucket() { id: "id".to_string(), _type: "type".to_string(), client: "client".to_string(), - hostname: "hostname".to_string(), + hostname: "hostname".into(), created: None, data: json_map! {}, metadata: BucketMetadata::default(), diff --git a/aw-sync/src/dirs.rs b/aw-sync/src/dirs.rs index 5e3faffd..878df033 100644 --- a/aw-sync/src/dirs.rs +++ b/aw-sync/src/dirs.rs @@ -1,14 +1,16 @@ use dirs::home_dir; +use std::error::Error; use std::fs; use std::path::PathBuf; // TODO: This could be refactored to share logic with aw-server/src/dirs.rs // TODO: add proper config support #[allow(dead_code)] -pub fn get_config_dir() -> Result { - let mut dir = appdirs::user_config_dir(Some("activitywatch"), None, false)?; +pub fn get_config_dir() -> Result> { + let mut dir = appdirs::user_config_dir(Some("activitywatch"), None, false) + .map_err(|_| "Unable to read user config dir")?; dir.push("aw-sync"); - fs::create_dir_all(dir.clone()).expect("Unable to create config dir"); + fs::create_dir_all(dir.clone())?; Ok(dir) } @@ -21,7 +23,8 @@ pub fn get_server_config_path(testing: bool) -> Result { })) } -pub fn get_sync_dir() -> Result { +pub fn get_sync_dir() -> Result> { // TODO: make this configurable - home_dir().map(|p| p.join("ActivityWatchSync")).ok_or(()) + let home_dir = home_dir().ok_or("Unable to read home_dir")?; + Ok(home_dir.join("ActivityWatchSync")) } diff --git a/aw-sync/src/main.rs b/aw-sync/src/main.rs index 83dacef5..bbbe9a17 100644 --- a/aw-sync/src/main.rs +++ b/aw-sync/src/main.rs @@ -14,10 +14,9 @@ extern crate serde; extern crate serde_json; use std::error::Error; -use std::path::Path; use std::path::PathBuf; -use chrono::{DateTime, Datelike, TimeZone, Utc}; +use chrono::{DateTime, Utc}; use clap::{Parser, Subcommand}; use aw_client_rust::blocking::AwClient; @@ -40,7 +39,7 @@ struct Opts { /// Port of instance to connect to. #[clap(long)] - port: Option, + port: Option, /// Convenience option for using the default testing host and port. #[clap(long)] @@ -58,8 +57,8 @@ enum Commands { /// Pulls remote buckets then pushes local buckets. Sync { /// Host(s) to pull from, comma separated. Will pull from all hosts if not specified. - #[clap(long)] - host: Option, + #[clap(long, value_parser=parse_list)] + host: Option>, }, /// Sync subcommand (advanced) @@ -73,57 +72,64 @@ enum Commands { /// If not specified, start from beginning. /// NOTE: might be unstable, as count cannot be used to verify integrity of sync. /// Format: YYYY-MM-DD - #[clap(long)] - start_date: Option, + #[clap(long, value_parser=parse_start_date)] + start_date: Option>, /// Specify buckets to sync using a comma-separated list. /// If not specified, all buckets will be synced. - #[clap(long)] - buckets: Option, + #[clap(long, value_parser=parse_list)] + buckets: Option>, /// Mode to sync in. Can be "push", "pull", or "both". /// Defaults to "both". #[clap(long, default_value = "both")] - mode: String, + mode: sync::SyncMode, /// Full path to sync directory. /// If not specified, exit. #[clap(long)] - sync_dir: String, + sync_dir: PathBuf, /// Full path to sync db file /// Useful for syncing buckets from a specific db file in the sync directory. /// Must be a valid absolute path to a file in the sync directory. #[clap(long)] - sync_db: Option, + sync_db: Option, }, /// List buckets and their sync status. List {}, } +fn parse_start_date(arg: &str) -> Result, chrono::ParseError> { + chrono::NaiveDate::parse_from_str(arg, "%Y-%m-%d") + .map(|nd| nd.and_time(chrono::NaiveTime::MIN).and_utc()) +} + +fn parse_list(arg: &str) -> Result, clap::Error> { + Ok(arg.split(',').map(|s| s.to_string()).collect()) +} + fn main() -> Result<(), Box> { let opts: Opts = Opts::parse(); let verbose = opts.verbose; info!("Started aw-sync..."); - aw_server::logging::setup_logger("aw-sync", opts.testing, verbose) - .expect("Failed to setup logging"); + aw_server::logging::setup_logger("aw-sync", opts.testing, verbose)?; let port = opts .port - .or_else(|| Some(crate::util::get_server_port(opts.testing).ok()?.to_string())) - .unwrap(); + .map(|a| Ok(a)) + .unwrap_or_else(|| util::get_server_port(opts.testing))?; - let client = AwClient::new(opts.host.as_str(), port.as_str(), "aw-sync"); + let client = AwClient::new(&opts.host, port, "aw-sync")?; - match &opts.command { + match opts.command { // Perform basic sync Commands::Sync { host } => { // Pull match host { - Some(host) => { - let hosts: Vec<&str> = host.split(',').collect(); + Some(hosts) => { for host in hosts.iter() { info!("Pulling from host: {}", host); sync_wrapper::pull(host, &client)?; @@ -137,8 +143,7 @@ fn main() -> Result<(), Box> { // Push info!("Pushing local data"); - sync_wrapper::push(&client)?; - Ok(()) + sync_wrapper::push(&client) } // Perform two-way sync Commands::SyncAdvanced { @@ -148,60 +153,31 @@ fn main() -> Result<(), Box> { sync_dir, sync_db, } => { - let sync_directory = if sync_dir.is_empty() { - error!("No sync directory specified, exiting..."); - std::process::exit(1); - } else { - Path::new(&sync_dir) - }; - info!("Using sync dir: {}", sync_directory.display()); - - if let Some(sync_db) = &sync_db { - info!("Using sync db: {}", sync_db); + if !sync_dir.is_absolute() { + Err("Sync dir must be absolute")? } - let start: Option> = start_date.as_ref().map(|date| { - println!("{}", date.clone()); - chrono::NaiveDate::parse_from_str(&date.clone(), "%Y-%m-%d") - .map(|nd| { - Utc.with_ymd_and_hms(nd.year(), nd.month(), nd.day(), 0, 0, 0) - .single() - .unwrap() - }) - .expect("Date was not on the format YYYY-MM-DD") - }); - - // Parse comma-separated list - let buckets_vec: Option> = buckets - .as_ref() - .map(|b| b.split(',').map(|s| s.to_string()).collect()); - - let sync_db: Option = sync_db.as_ref().map(|db| { - let db_path = Path::new(db); + info!("Using sync dir: {}", &sync_dir.display()); + + if let Some(db_path) = &sync_db { + info!("Using sync db: {}", &db_path.display()); + if !db_path.is_absolute() { - panic!("Sync db path must be absolute"); + Err("Sync db path must be absolute")? } - if !db_path.starts_with(sync_directory) { - panic!("Sync db path must be in sync directory"); + if !db_path.starts_with(&sync_dir) { + Err("Sync db path must be in sync directory")? } - db_path.to_path_buf() - }); + } let sync_spec = sync::SyncSpec { - path: sync_directory.to_path_buf(), + path: sync_dir, path_db: sync_db, - buckets: buckets_vec, - start, - }; - - let mode_enum = match mode.as_str() { - "push" => sync::SyncMode::Push, - "pull" => sync::SyncMode::Pull, - "both" => sync::SyncMode::Both, - _ => panic!("Invalid mode"), + buckets, + start: start_date, }; - sync::sync_run(&client, &sync_spec, mode_enum) + sync::sync_run(&client, &sync_spec, mode) } // List all buckets diff --git a/aw-sync/src/sync.rs b/aw-sync/src/sync.rs index c842f5a6..405118a9 100644 --- a/aw-sync/src/sync.rs +++ b/aw-sync/src/sync.rs @@ -9,6 +9,7 @@ extern crate chrono; extern crate reqwest; extern crate serde_json; +use std::error::Error; use std::fs; use std::path::{Path, PathBuf}; @@ -17,10 +18,11 @@ use chrono::{DateTime, Utc}; use aw_datastore::{Datastore, DatastoreError}; use aw_models::{Bucket, Event}; +use clap::ValueEnum; use crate::accessmethod::AccessMethod; -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Copy, Clone, ValueEnum)] pub enum SyncMode { Push, Pull, @@ -54,8 +56,12 @@ impl Default for SyncSpec { } /// Performs a single sync pass -pub fn sync_run(client: &AwClient, sync_spec: &SyncSpec, mode: SyncMode) -> Result<(), String> { - let info = client.get_info().map_err(|e| e.to_string())?; +pub fn sync_run( + client: &AwClient, + sync_spec: &SyncSpec, + mode: SyncMode, +) -> Result<(), Box> { + let info = client.get_info()?; // FIXME: Here it is assumed that the device_id for the local server is the one used by // aw-server-rust, which is not necessarily true (aw-server-python has seperate device_id). @@ -128,10 +134,10 @@ pub fn sync_run(client: &AwClient, sync_spec: &SyncSpec, mode: SyncMode) -> Resu } #[allow(dead_code)] -pub fn list_buckets(client: &AwClient) -> Result<(), String> { +pub fn list_buckets(client: &AwClient) -> Result<(), Box> { let sync_directory = crate::dirs::get_sync_dir().map_err(|_| "Could not get sync dir")?; let sync_directory = sync_directory.as_path(); - let info = client.get_info().map_err(|e| e.to_string())?; + let info = client.get_info()?; // FIXME: Incorrect device_id assumption? let device_id = info.device_id.as_str(); @@ -156,12 +162,12 @@ pub fn list_buckets(client: &AwClient) -> Result<(), String> { Ok(()) } -fn setup_local_remote(path: &Path, device_id: &str) -> Result { +fn setup_local_remote(path: &Path, device_id: &str) -> Result> { // FIXME: Don't run twice if already exists - fs::create_dir_all(path).unwrap(); + fs::create_dir_all(path)?; let remotedir = path.join(device_id); - fs::create_dir_all(&remotedir).unwrap(); + fs::create_dir_all(&remotedir)?; let dbfile = remotedir.join("test.db"); diff --git a/aw-sync/src/sync_wrapper.rs b/aw-sync/src/sync_wrapper.rs index 11daae50..4e786e4f 100644 --- a/aw-sync/src/sync_wrapper.rs +++ b/aw-sync/src/sync_wrapper.rs @@ -15,13 +15,13 @@ pub fn pull_all(client: &AwClient) -> Result<(), Box> { } pub fn pull(host: &str, client: &AwClient) -> Result<(), Box> { - // Check if server is running - let parts: Vec<&str> = client.baseurl.split("://").collect(); - let host_parts: Vec<&str> = parts[1].split(':').collect(); - let addr = host_parts[0]; - let port = host_parts[1].parse::().unwrap(); + let socket_addrs = client.baseurl.socket_addrs(|| None)?; + let socket_addr = socket_addrs + .get(0) + .ok_or("Unable to resolve baseurl into socket address")?; - if TcpStream::connect((addr, port)).is_err() { + // Check if server is running + if TcpStream::connect(socket_addr).is_err() { return Err(format!("Local server {} not running", &client.baseurl).into()); } @@ -44,7 +44,7 @@ pub fn pull(host: &str, client: &AwClient) -> Result<(), Box> { // filter out dbs that are smaller than 50kB (workaround for trying to sync empty database // files that are spuriously created somewhere) - let mut dbs = dbs + let dbs = dbs .into_iter() .filter(|entry| entry.metadata().map(|m| m.len() > 50_000).unwrap_or(false)) .collect::>(); @@ -55,44 +55,38 @@ pub fn pull(host: &str, client: &AwClient) -> Result<(), Box> { "More than one db found in sync folder for host, choosing largest db {:?}", dbs ); - dbs = vec![dbs - .into_iter() - .max_by_key(|entry| entry.metadata().map(|m| m.len()).unwrap_or(0)) - .unwrap()]; - } - // if no db, error - if dbs.is_empty() { - return Err(format!("No db found in sync folder {:?}", sync_dir).into()); } - for db in dbs { - let sync_spec = SyncSpec { - path: sync_dir.clone(), - path_db: Some(db.path().clone()), - buckets: Some(vec![ - format!("aw-watcher-window_{}", host), - format!("aw-watcher-afk_{}", host), - ]), - start: None, - }; - sync_run(client, &sync_spec, SyncMode::Pull)?; - } + let db = dbs + .into_iter() + .max_by_key(|entry| entry.metadata().map(|m| m.len()).unwrap_or(0)) + .ok_or_else(|| format!("No db found in sync folder {:?}", sync_dir))?; + + let sync_spec = SyncSpec { + path: sync_dir.clone(), + path_db: Some(db.path().clone()), + buckets: Some(vec![ + format!("aw-watcher-window_{}", host), + format!("aw-watcher-afk_{}", host), + ]), + start: None, + }; + sync_run(client, &sync_spec, SyncMode::Pull)?; Ok(()) } pub fn push(client: &AwClient) -> Result<(), Box> { - let hostname = crate::util::get_hostname()?; let sync_dir = crate::dirs::get_sync_dir() .map_err(|_| "Could not get sync dir")? - .join(&hostname); + .join(&client.hostname); let sync_spec = SyncSpec { path: sync_dir, path_db: None, buckets: Some(vec![ - format!("aw-watcher-window_{}", hostname), - format!("aw-watcher-afk_{}", hostname), + format!("aw-watcher-window_{}", client.hostname), + format!("aw-watcher-afk_{}", client.hostname), ]), start: None, }; diff --git a/aw-sync/src/util.rs b/aw-sync/src/util.rs index f2a38f0e..dde96555 100644 --- a/aw-sync/src/util.rs +++ b/aw-sync/src/util.rs @@ -6,13 +6,6 @@ use std::fs::File; use std::io::Read; use std::path::{Path, PathBuf}; -pub fn get_hostname() -> Result> { - let hostname = gethostname::gethostname() - .into_string() - .map_err(|_| "Failed to convert hostname to string")?; - Ok(hostname) -} - /// Returns the port of the local aw-server instance pub fn get_server_port(testing: bool) -> Result> { // TODO: get aw-server config more reliably @@ -67,7 +60,8 @@ fn contains_subdir_with_db_file(dir: &std::path::Path) -> bool { /// Only returns folders that match ./{host}/{device_id}/*.db // TODO: share logic with find_remotes and find_remotes_nonlocal pub fn get_remotes() -> Result, Box> { - let sync_root_dir = crate::dirs::get_sync_dir().map_err(|_| "Could not get sync dir")?; + let sync_root_dir = crate::dirs::get_sync_dir()?; + fs::create_dir_all(&sync_root_dir)?; let hostnames = fs::read_dir(sync_root_dir)? .filter_map(Result::ok) .filter(|entry| entry.path().is_dir() && contains_subdir_with_db_file(&entry.path()))