Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adds find_path_one_to_many for directed graphs #20

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 125 additions & 22 deletions src/directed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,28 @@ use crate::{
},
};
use fxhash::FxHashSet;
use std::{collections::VecDeque, ops::Not, rc::Rc};
use std::{
collections::{HashMap, VecDeque},
ops::Not,
path,
rc::Rc,
};

// Helper function for constructing the path
fn construct_path(parents: &[(Sym, Sym)], start_id: Sym, goal_id: Sym, path: &mut Vec<Sym>) {
let mut current_id = goal_id;
path.push(current_id);
while current_id != start_id {
if let Some(parent_pair) = parents.iter().find(|(node, _)| *node == current_id) {
current_id = parent_pair.1;
path.push(current_id);
} else {
break; // This should not happen if the path exists
}
}

path.reverse(); // Reverse to get the path from start to goal
}

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DirectedGraph {
Expand Down Expand Up @@ -184,32 +205,81 @@ impl DirectedGraph {
.collect()
}

pub fn find_path(
pub fn find_path_one_to_many<'a>(
&self,
from: impl AsRef<str>,
to: impl AsRef<str>,
) -> GraphInteractionResult<NodeVec> {
// Helper function for constructing the path
fn construct_path(
parents: &[(Sym, Sym)],
start_id: Sym,
goal_id: Sym,
path: &mut Vec<Sym>,
) {
let mut current_id = goal_id;
path.push(current_id);
while current_id != start_id {
if let Some(parent_pair) = parents.iter().find(|(node, _)| *node == current_id) {
current_id = parent_pair.1;
path.push(current_id);
} else {
break; // This should not happen if the path exists
}
from: impl AsRef<str> + 'a,
to: impl IntoIterator<Item = impl AsRef<str> + 'a>,
) -> GraphInteractionResult<Vec<NodeVec>> {
let mut paths = Vec::<Vec<Sym>>::new();
let mut path_cache = HashMap::<(Sym, Sym), Vec<Sym>>::new();
let from = self.get_internal(from)?;

let queue = unsafe { self.u32x1_queue_0() };
let visited = unsafe { self.u32x1_set_0() };
let path_buf = unsafe { self.u32x1_vec_0() };
let path_buf2 = unsafe { self.u32x1_vec_1() };
let path = unsafe { self.u32x2_vec_0() }; // To track the path back to the start node

'to: for to in to {
path_buf.clear();
path_buf2.clear();

let to = self.get_internal(to)?;

if from == to {
path_cache.insert((from, to), vec![from]);
paths.push(vec![from]);
continue 'to;
}

path.reverse(); // Reverse to get the path from start to goal
queue.clear();
visited.clear();

// Initialize
queue.push_back(to);
visited.insert(to);

while let Some(current) = queue.pop_front() {
if let Some(pre_calculated_path) = path_cache.get(&(from, current)) {
construct_path(path, current, to, path_buf);
pre_calculated_path.iter().for_each(|&s| path_buf2.push(s));
for node in path_buf.drain(..).skip(1) {
path_buf2.push(node);
}
path_cache.insert((from, to), path_buf2.clone());
paths.push(path_buf2.clone());
continue 'to;
}
if let LazySet::Initialized(parents) = self.parent_map.get(current) {
for &parent in parents.iter() {
if visited.insert(parent) {
path.push((current, parent));
if parent == from {
// Construct the path and place it in `path_buf`
construct_path(path, from, to, path_buf);
path_cache.insert((from, to), path_buf.clone());
paths.push(path_buf.clone());
continue 'to;
}
queue.push_back(parent);
}
}
}
}
paths.push(vec![]);
}

Ok(paths
.into_iter()
.map(|path| self.resolve_mul_slice(&path))
.collect())
}

pub fn find_path(
&self,
from: impl AsRef<str>,
to: impl AsRef<str>,
) -> GraphInteractionResult<NodeVec> {
let from = self.get_internal(from)?;
let to = self.get_internal(to)?;

Expand Down Expand Up @@ -628,6 +698,39 @@ mod tests {
assert_eq!(dg.children(["A"]).unwrap(), ["H", "B"]);
}

#[test]
fn dg_find_path_one_to_many() {
let mut builder = DirectedGraphBuilder::new();
// We put more than 8 children to
// test if SIMD actually workd
builder.add_path(["A", "B", "C", "D"]);
let dg = builder.clone().build_directed();
assert_eq!(
dg.find_path_one_to_many("A", vec!["A", "B", "C", "D"])
.unwrap(),
[
vec!["A"],
vec!["A", "B"],
vec!["A", "B", "C"],
vec!["A", "B", "C", "D"],
]
);

builder.add_path(["A", "H", "D"]);
let dg = builder.clone().build_directed();
assert_eq!(
dg.find_path_one_to_many("A", vec!["A", "B", "C", "D"])
.unwrap(),
[
vec!["A"],
vec!["A", "B"],
vec!["A", "B", "C"],
vec!["A", "H", "D"],
]
);
assert_eq!(dg.children(["A"]).unwrap(), ["H", "B"]);
}

#[test]
fn dg_find_least_common_parents() {
let mut builder = DirectedGraphBuilder::new();
Expand Down
Loading