use crate::path::Path; use core::{ cmp::Ordering::{self, Equal}, f32, fmt, hash::{BuildHasher, Hash}, }; use hashbrown::{HashMap, HashSet}; use std::collections::BinaryHeap; #[derive(Copy, Clone, Debug)] pub struct PathEntry { cost: f32, node: S, } impl PartialEq for PathEntry { fn eq(&self, other: &PathEntry) -> bool { self.node.eq(&other.node) } } impl Eq for PathEntry {} impl Ord for PathEntry { // This method implements reverse ordering, so that the lowest cost // will be ordered first fn cmp(&self, other: &PathEntry) -> Ordering { other.cost.partial_cmp(&self.cost).unwrap_or(Equal) } } impl PartialOrd for PathEntry { fn partial_cmp(&self, other: &PathEntry) -> Option { Some(self.cmp(other)) } } pub enum PathResult { None(Path), Exhausted(Path), Path(Path), Pending, } impl PathResult { pub fn into_path(self) -> Option> { match self { PathResult::Path(path) => Some(path), _ => None, } } } #[derive(Clone)] pub struct Astar { iter: usize, max_iters: usize, potential_nodes: BinaryHeap>, came_from: HashMap, cheapest_scores: HashMap, final_scores: HashMap, visited: HashSet, cheapest_node: Option, cheapest_cost: Option, } /// NOTE: Must manually derive since Hasher doesn't implement it. impl fmt::Debug for Astar { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Astar") .field("iter", &self.iter) .field("max_iters", &self.max_iters) .field("potential_nodes", &self.potential_nodes) .field("came_from", &self.came_from) .field("cheapest_scores", &self.cheapest_scores) .field("final_scores", &self.final_scores) .field("visited", &self.visited) .field("cheapest_node", &self.cheapest_node) .field("cheapest_cost", &self.cheapest_cost) .finish() } } impl Astar { pub fn new(max_iters: usize, start: S, heuristic: impl FnOnce(&S) -> f32, hasher: H) -> Self { Self { max_iters, iter: 0, potential_nodes: core::iter::once(PathEntry { cost: 0.0, node: start.clone(), }) .collect(), came_from: HashMap::with_hasher(hasher.clone()), cheapest_scores: { let mut h = HashMap::with_capacity_and_hasher(1, hasher.clone()); h.extend(core::iter::once((start.clone(), 0.0))); h }, final_scores: { let mut h = HashMap::with_capacity_and_hasher(1, hasher.clone()); h.extend(core::iter::once((start.clone(), heuristic(&start)))); h }, visited: { let mut s = HashSet::with_capacity_and_hasher(1, hasher); s.extend(core::iter::once(start)); s }, cheapest_node: None, cheapest_cost: None, } } pub fn poll( &mut self, iters: usize, mut heuristic: impl FnMut(&S) -> f32, mut neighbors: impl FnMut(&S) -> I, mut transition: impl FnMut(&S, &S) -> f32, mut satisfied: impl FnMut(&S) -> bool, ) -> PathResult where I: Iterator, { let iter_limit = self.max_iters.min(self.iter + iters); while self.iter < iter_limit { if let Some(PathEntry { node, .. }) = self.potential_nodes.pop() { if satisfied(&node) { return PathResult::Path(self.reconstruct_path_to(node)); } else { for neighbor in neighbors(&node) { let node_cheapest = self.cheapest_scores.get(&node).unwrap_or(&f32::MAX); let neighbor_cheapest = self.cheapest_scores.get(&neighbor).unwrap_or(&f32::MAX); let cost = node_cheapest + transition(&node, &neighbor); if cost < *neighbor_cheapest { self.came_from.insert(neighbor.clone(), node.clone()); self.cheapest_scores.insert(neighbor.clone(), cost); let h = heuristic(&neighbor); let neighbor_cost = cost + h; self.final_scores.insert(neighbor.clone(), neighbor_cost); if self.cheapest_cost.map(|cc| h < cc).unwrap_or(true) { self.cheapest_node = Some(node.clone()); self.cheapest_cost = Some(h); }; if self.visited.insert(neighbor.clone()) { self.potential_nodes.push(PathEntry { node: neighbor, cost: neighbor_cost, }); } } } } } else { return PathResult::None( self.cheapest_node .clone() .map(|lc| self.reconstruct_path_to(lc)) .unwrap_or_default(), ); } self.iter += 1 } if self.iter >= self.max_iters { PathResult::Exhausted( self.cheapest_node .clone() .map(|lc| self.reconstruct_path_to(lc)) .unwrap_or_default(), ) } else { PathResult::Pending } } pub fn get_cheapest_cost(&self) -> Option { self.cheapest_cost } fn reconstruct_path_to(&mut self, end: S) -> Path { let mut path = vec![end.clone()]; let mut cnode = &end; while let Some(node) = self.came_from.get(cnode) { path.push(node.clone()); cnode = node; } path.into_iter().rev().collect() } }