diff --git a/src/bin/day15.rs b/src/bin/day15.rs index b5b92fc..b03b536 100644 --- a/src/bin/day15.rs +++ b/src/bin/day15.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; +use std::collections::BinaryHeap; use std::collections::HashMap; use std::collections::HashSet; use std::error::Error; @@ -28,6 +30,30 @@ fn get_risk(x: i32, y: i32, map: &Vec>) -> Option { } } +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +struct Point { + risk: u32, + x: i32, + y: i32, +} + +impl Ord for Point { + fn cmp(&self, other: &Self) -> Ordering { + other + .risk + .cmp(&self.risk) + .then_with(|| self.x.cmp(&other.x)) + .then_with(|| self.y.cmp(&other.y)) + } +} + +// `PartialOrd` needs to be implemented as well. +impl PartialOrd for Point { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + fn do_the_dijkstra( start_x: i32, start_y: i32, @@ -36,9 +62,16 @@ fn do_the_dijkstra( map: &Vec>, ) -> u32 { let mut unexplored: HashSet<(i32, i32)> = HashSet::new(); + let mut path_risk: HashMap<(i32, i32), u32> = HashMap::new(); path_risk.insert((start_x, start_y), 0); - //let mut predecessors: HashMap<(i32, i32), (i32, i32)> = HashMap::new(); + + let mut heap: BinaryHeap = BinaryHeap::new(); + heap.push(Point { + x: start_x, + y: start_y, + risk: 0, + }); for x in 0..target_x + 1 { for y in 0..target_y + 1 { @@ -47,47 +80,48 @@ fn do_the_dijkstra( } while !unexplored.is_empty() { - let mut cur_node: Option<(i32, i32)> = None; + let point = heap.pop().unwrap(); - let known_nodes: HashSet<(i32, i32)> = path_risk.keys().cloned().collect(); - let candidates: HashSet<(i32, i32)> = - unexplored.intersection(&known_nodes).cloned().collect(); + let cur_node = (point.x, point.y); + let cur_risk = point.risk; - for candidate in candidates.iter() { - if let Some(node) = cur_node { - if path_risk[candidate] < path_risk[&node] { - cur_node = Some(*candidate); - } - } else { - cur_node = Some(*candidate); - } + if cur_risk > path_risk[&cur_node] { + continue; } - let cur_node = cur_node.unwrap(); - println!("{:?} {}", cur_node, unexplored.len()); unexplored.remove(&cur_node); if cur_node.0 == target_x && cur_node.1 == target_y { println!("Found my target!"); - break; + return cur_risk; } for (dx, dy) in [(1, 0), (0, 1), (-1, 0), (0, -1)] { let neighboor = (cur_node.0 + dx, cur_node.1 + dy); - if let Some(enter_risk) = get_risk(neighboor.1, neighboor.0, &map) { + if let Some(enter_risk) = get_risk(neighboor.0, neighboor.1, &map) { let alt_risk = path_risk[&cur_node] + enter_risk; if !path_risk.contains_key(&neighboor) { path_risk.insert(neighboor, alt_risk); + heap.push(Point { + x: neighboor.0, + y: neighboor.1, + risk: alt_risk, + }); } else if alt_risk < path_risk[&neighboor] { path_risk.insert(neighboor, alt_risk); + heap.push(Point { + x: neighboor.0, + y: neighboor.1, + risk: alt_risk, + }); } } } } - path_risk[&(target_x, target_y)] + return 0; } fn main() -> Result<(), Box> {