use crate::path::Path; use core::{ cmp::Ordering::{self, Equal}, fmt, hash::{BuildHasher, Hash}, }; use hashbrown::HashMap; use std::collections::BinaryHeap; #[derive(Copy, Clone, Debug)] pub struct PathEntry { // cost so far + heursitic cost_estimate: f32, node: S, } impl PartialEq for PathEntry { fn eq(&self, other: &PathEntry) -> bool { self.node.eq(&other.node) } } impl Eq for PathEntry {} impl Ord for PathEntry { // This method implements reverse ordering, so that the lowest cost // will be ordered first fn cmp(&self, other: &PathEntry) -> Ordering { other .cost_estimate .partial_cmp(&self.cost_estimate) .unwrap_or(Equal) } } impl PartialOrd for PathEntry { fn partial_cmp(&self, other: &PathEntry) -> Option { Some(self.cmp(other)) } // This is particularily hot in `BinaryHeap::pop`, so we provide this // implementation. // // NOTE: This probably doesn't handle edge cases like `NaNs` in a consistent // manner with `Ord`, but I don't think we need to care about that here(?) // // See note about reverse ordering above. fn le(&self, other: &PathEntry) -> bool { other.cost_estimate <= self.cost_estimate } } pub enum PathResult { None(Path), Exhausted(Path), // second field is cost Path(Path, f32), Pending, } impl PathResult { /// Returns `Some((path, cost))` if a path reaching the target was /// successfully found. pub fn into_path(self) -> Option<(Path, f32)> { match self { PathResult::Path(path, cost) => Some((path, cost)), _ => None, } } pub fn map(self, f: impl FnOnce(Path) -> Path) -> PathResult { match self { PathResult::None(p) => PathResult::None(f(p)), PathResult::Exhausted(p) => PathResult::Exhausted(f(p)), PathResult::Path(p, cost) => PathResult::Path(f(p), cost), PathResult::Pending => PathResult::Pending, } } } // If node entry exists, this was visited! #[derive(Clone, Debug)] struct NodeEntry { // if came_from == self this is the start node! came_from: S, cheapest_score: f32, } #[derive(Clone)] pub struct Astar { iter: usize, max_iters: usize, potential_nodes: BinaryHeap>, // cost, node pairs visited_nodes: HashMap, Hasher>, /// Node with the lowest heuristic value so far. /// /// (node, heuristic value) closest_node: Option<(S, f32)>, } /// NOTE: Must manually derive since Hasher doesn't implement it. impl fmt::Debug for Astar { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Astar") .field("iter", &self.iter) .field("max_iters", &self.max_iters) .field("potential_nodes", &self.potential_nodes) .field("visited_nodes", &self.visited_nodes) .field("closest_node", &self.closest_node) .finish() } } impl Astar { pub fn new(max_iters: usize, start: S, hasher: H) -> Self { Self { max_iters, iter: 0, potential_nodes: core::iter::once(PathEntry { cost_estimate: 0.0, node: start.clone(), }) .collect(), visited_nodes: { let mut s = HashMap::with_capacity_and_hasher(1, hasher); s.extend(core::iter::once((start.clone(), NodeEntry { came_from: start, cheapest_score: 0.0, }))); s }, closest_node: None, } } pub fn poll( &mut self, iters: usize, // Estimate how far we are from the target? but we are given two nodes... // (current, previous) mut heuristic: impl FnMut(&S, &S) -> f32, // get neighboring nodes mut neighbors: impl FnMut(&S) -> I, // have we reached target? mut satisfied: impl FnMut(&S) -> bool, ) -> PathResult where I: Iterator, // (node, transition cost) { let iter_limit = self.max_iters.min(self.iter + iters); while self.iter < iter_limit { if let Some(PathEntry { node, .. }) = self.potential_nodes.pop() { let (node_cheapest, came_from) = self .visited_nodes .get(&node) .map(|n| (n.cheapest_score, n.came_from.clone())) .expect(""); if satisfied(&node) { return PathResult::Path(self.reconstruct_path_to(node), node_cheapest); } else { for (neighbor, transition) in neighbors(&node) { if neighbor == came_from { continue; } let neighbor_cheapest = self .visited_nodes .get(&neighbor) .map_or(f32::MAX, |n| n.cheapest_score); // compute cost to traverse to each neighbor let cost = node_cheapest + transition; if cost < neighbor_cheapest { let previously_visited = self .visited_nodes .insert(neighbor.clone(), NodeEntry { came_from: node.clone(), cheapest_score: cost, }) .is_some(); let h = heuristic(&neighbor, &node); // note that cheapest_scores does not include the heuristic // priority queue does include heuristic let cost_estimate = cost + h; if self .closest_node .as_ref() .map(|&(_, ch)| h < ch) .unwrap_or(true) { self.closest_node = Some((node.clone(), h)); }; // TODO: I think the if here should be removed // if we hadn't already visited, add this to potential nodes, what about // its neighbors, wouldn't they need to be revisted??? if !previously_visited { self.potential_nodes.push(PathEntry { cost_estimate, node: neighbor, }); } } } } } else { return PathResult::None( self.closest_node .clone() .map(|(lc, _)| self.reconstruct_path_to(lc)) .unwrap_or_default(), ); } self.iter += 1 } if self.iter >= self.max_iters { PathResult::Exhausted( self.closest_node .clone() .map(|(lc, _)| self.reconstruct_path_to(lc)) .unwrap_or_default(), ) } else { PathResult::Pending } } fn reconstruct_path_to(&mut self, end: S) -> Path { let mut path = vec![end.clone()]; let mut cnode = &end; while let Some(node) = self .visited_nodes .get(cnode) .map(|n| &n.came_from) .filter(|n| *n != cnode) { path.push(node.clone()); cnode = node; } path.into_iter().rev().collect() } }