From b397ad8aa389bcc17aef17593302adb7af765c21 Mon Sep 17 00:00:00 2001 From: andres Date: Thu, 17 Oct 2024 14:21:21 -0500 Subject: [PATCH] feat: Adds find_path_one_to_many for directed graphs --- src/directed/mod.rs | 147 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 125 insertions(+), 22 deletions(-) diff --git a/src/directed/mod.rs b/src/directed/mod.rs index a33680a..ee47ceb 100644 --- a/src/directed/mod.rs +++ b/src/directed/mod.rs @@ -15,7 +15,28 @@ use crate::{ }, }; use fxhash::FxHashSet; -use std::{collections::VecDeque, ops::Not, rc::Rc}; +use std::{ + collections::{HashMap, VecDeque}, + ops::Not, + path, + rc::Rc, +}; + +// Helper function for constructing the path +fn construct_path(parents: &[(Sym, Sym)], start_id: Sym, goal_id: Sym, path: &mut Vec) { + let mut current_id = goal_id; + path.push(current_id); + while current_id != start_id { + if let Some(parent_pair) = parents.iter().find(|(node, _)| *node == current_id) { + current_id = parent_pair.1; + path.push(current_id); + } else { + break; // This should not happen if the path exists + } + } + + path.reverse(); // Reverse to get the path from start to goal +} #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct DirectedGraph { @@ -184,32 +205,81 @@ impl DirectedGraph { .collect() } - pub fn find_path( + pub fn find_path_one_to_many<'a>( &self, - from: impl AsRef, - to: impl AsRef, - ) -> GraphInteractionResult { - // Helper function for constructing the path - fn construct_path( - parents: &[(Sym, Sym)], - start_id: Sym, - goal_id: Sym, - path: &mut Vec, - ) { - let mut current_id = goal_id; - path.push(current_id); - while current_id != start_id { - if let Some(parent_pair) = parents.iter().find(|(node, _)| *node == current_id) { - current_id = parent_pair.1; - path.push(current_id); - } else { - break; // This should not happen if the path exists - } + from: impl AsRef + 'a, + to: impl IntoIterator + 'a>, + ) -> GraphInteractionResult> { + let mut paths = Vec::>::new(); + let mut path_cache = HashMap::<(Sym, Sym), Vec>::new(); + let from = self.get_internal(from)?; + + let queue = unsafe { self.u32x1_queue_0() }; + let visited = unsafe { self.u32x1_set_0() }; + let path_buf = unsafe { self.u32x1_vec_0() }; + let path_buf2 = unsafe { self.u32x1_vec_1() }; + let path = unsafe { self.u32x2_vec_0() }; // To track the path back to the start node + + 'to: for to in to { + path_buf.clear(); + path_buf2.clear(); + + let to = self.get_internal(to)?; + + if from == to { + path_cache.insert((from, to), vec![from]); + paths.push(vec![from]); + continue 'to; } - path.reverse(); // Reverse to get the path from start to goal + queue.clear(); + visited.clear(); + + // Initialize + queue.push_back(to); + visited.insert(to); + + while let Some(current) = queue.pop_front() { + if let Some(pre_calculated_path) = path_cache.get(&(from, current)) { + construct_path(path, current, to, path_buf); + pre_calculated_path.iter().for_each(|&s| path_buf2.push(s)); + for node in path_buf.drain(..).skip(1) { + path_buf2.push(node); + } + path_cache.insert((from, to), path_buf2.clone()); + paths.push(path_buf2.clone()); + continue 'to; + } + if let LazySet::Initialized(parents) = self.parent_map.get(current) { + for &parent in parents.iter() { + if visited.insert(parent) { + path.push((current, parent)); + if parent == from { + // Construct the path and place it in `path_buf` + construct_path(path, from, to, path_buf); + path_cache.insert((from, to), path_buf.clone()); + paths.push(path_buf.clone()); + continue 'to; + } + queue.push_back(parent); + } + } + } + } + paths.push(vec![]); } + Ok(paths + .into_iter() + .map(|path| self.resolve_mul_slice(&path)) + .collect()) + } + + pub fn find_path( + &self, + from: impl AsRef, + to: impl AsRef, + ) -> GraphInteractionResult { let from = self.get_internal(from)?; let to = self.get_internal(to)?; @@ -628,6 +698,39 @@ mod tests { assert_eq!(dg.children(["A"]).unwrap(), ["H", "B"]); } + #[test] + fn dg_find_path_one_to_many() { + let mut builder = DirectedGraphBuilder::new(); + // We put more than 8 children to + // test if SIMD actually workd + builder.add_path(["A", "B", "C", "D"]); + let dg = builder.clone().build_directed(); + assert_eq!( + dg.find_path_one_to_many("A", vec!["A", "B", "C", "D"]) + .unwrap(), + [ + vec!["A"], + vec!["A", "B"], + vec!["A", "B", "C"], + vec!["A", "B", "C", "D"], + ] + ); + + builder.add_path(["A", "H", "D"]); + let dg = builder.clone().build_directed(); + assert_eq!( + dg.find_path_one_to_many("A", vec!["A", "B", "C", "D"]) + .unwrap(), + [ + vec!["A"], + vec!["A", "B"], + vec!["A", "B", "C"], + vec!["A", "H", "D"], + ] + ); + assert_eq!(dg.children(["A"]).unwrap(), ["H", "B"]); + } + #[test] fn dg_find_least_common_parents() { let mut builder = DirectedGraphBuilder::new();