diff --git a/src/conflict.rs b/src/conflict.rs index 3d121b6..fe4b093 100644 --- a/src/conflict.rs +++ b/src/conflict.rs @@ -13,10 +13,7 @@ use petgraph::{ use crate::solver::variable_map::VariableOrigin; use crate::{ - internal::{ - arena::ArenaId, - id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VersionSetId}, - }, + internal::id::{ClauseId, SolvableId, SolvableOrRootId, StringId, VersionSetId}, runtime::AsyncRuntime, solver::{clause::Clause, Solver}, DependencyProvider, Interner, Requirement, @@ -48,6 +45,8 @@ impl Conflict { &self, solver: &Solver, ) -> ConflictGraph { + let state = &solver.state; + let mut graph = DiGraph::::default(); let mut nodes: HashMap = HashMap::default(); let mut excluded_nodes: HashMap = HashMap::default(); @@ -56,14 +55,14 @@ impl Conflict { let unresolved_node = graph.add_node(ConflictNode::UnresolvedDependency); let mut last_node_by_name = HashMap::default(); - for clause_id in &self.clauses { - let clause = &solver.clauses.kinds[clause_id.to_usize()]; + for &clause_id in &self.clauses { + let clause = &state.clauses.kinds[clause_id]; match clause { Clause::InstallRoot => (), Clause::Excluded(solvable, reason) => { tracing::trace!("{solvable:?} is excluded"); let solvable = solvable - .as_solvable(&solver.variable_map) + .as_solvable(&state.variable_map) .expect("only solvables can be excluded"); let package_node = Self::add_node(&mut graph, &mut nodes, solvable.into()); @@ -80,7 +79,7 @@ impl Conflict { Clause::Learnt(..) => unreachable!(), &Clause::Requires(package_id, version_set_id) => { let solvable = package_id - .as_solvable_or_root(&solver.variable_map) + .as_solvable_or_root(&state.variable_map) .expect("only solvables can be excluded"); let package_node = Self::add_node(&mut graph, &mut nodes, solvable); @@ -112,10 +111,10 @@ impl Conflict { } &Clause::Lock(locked, forbidden) => { let locked_solvable = locked - .as_solvable(&solver.variable_map) + .as_solvable(&state.variable_map) .expect("only solvables can be excluded"); let forbidden_solvable = forbidden - .as_solvable(&solver.variable_map) + .as_solvable(&state.variable_map) .expect("only solvables can be excluded"); let node2_id = Self::add_node(&mut graph, &mut nodes, forbidden_solvable.into()); @@ -124,12 +123,12 @@ impl Conflict { } &Clause::ForbidMultipleInstances(instance1_id, instance2_id, _) => { let solvable1 = instance1_id - .as_solvable_or_root(&solver.variable_map) + .as_solvable_or_root(&state.variable_map) .expect("only solvables can be excluded"); let node1_id = Self::add_node(&mut graph, &mut nodes, solvable1); let VariableOrigin::ForbidMultiple(name) = - solver.variable_map.origin(instance2_id.variable()) + state.variable_map.origin(instance2_id.variable()) else { unreachable!("expected only forbid variables") }; @@ -145,10 +144,10 @@ impl Conflict { } &Clause::Constrains(package_id, dep_id, version_set_id) => { let package_solvable = package_id - .as_solvable_or_root(&solver.variable_map) + .as_solvable_or_root(&state.variable_map) .expect("only solvables can be excluded"); let dependency_solvable = dep_id - .as_solvable_or_root(&solver.variable_map) + .as_solvable_or_root(&state.variable_map) .expect("only solvables can be excluded"); let package_node = Self::add_node(&mut graph, &mut nodes, package_solvable); @@ -629,42 +628,6 @@ impl Indenter { } } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_indenter_without_top_level_indent() { - let indenter = Indenter::new(false); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), ""); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), "└─ "); - } - - #[test] - fn test_indenter_with_multiple_siblings() { - let indenter = Indenter::new(true); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), "└─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); - assert_eq!(indenter.get_indent(), " ├─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), " │ └─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::Last); - assert_eq!(indenter.get_indent(), " │ └─ "); - - let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); - assert_eq!(indenter.get_indent(), " │ ├─ "); - } -} - /// A struct implementing [`fmt::Display`] that generates a user-friendly /// representation of a conflict graph pub struct DisplayUnsat<'i, I: Interner> { @@ -1052,3 +1015,39 @@ impl<'i, I: Interner> fmt::Display for DisplayUnsat<'i, I> { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_indenter_without_top_level_indent() { + let indenter = Indenter::new(false); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), ""); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), "└─ "); + } + + #[test] + fn test_indenter_with_multiple_siblings() { + let indenter = Indenter::new(true); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), "└─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); + assert_eq!(indenter.get_indent(), " ├─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), " │ └─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::Last); + assert_eq!(indenter.get_indent(), " │ └─ "); + + let indenter = indenter.push_level_with_order(ChildOrder::HasRemainingSiblings); + assert_eq!(indenter.get_indent(), " │ ├─ "); + } +} diff --git a/src/internal/mapping.rs b/src/internal/mapping.rs index d0e8ea7..1e7a676 100644 --- a/src/internal/mapping.rs +++ b/src/internal/mapping.rs @@ -15,7 +15,7 @@ pub struct Mapping { _phantom: PhantomData, } -impl Default for Mapping { +impl Default for Mapping { fn default() -> Self { Self::new() } @@ -124,7 +124,7 @@ impl Mapping { .get_unchecked(chunk) .get_unchecked(offset) .as_ref() - .unwrap() + .unwrap_unchecked() } /// Get a specific value in the mapping without bound checks @@ -139,7 +139,7 @@ impl Mapping { .get_unchecked_mut(chunk) .get_unchecked_mut(offset) .as_mut() - .unwrap() + .unwrap_unchecked() } /// Returns the number of mapped items diff --git a/src/internal/mod.rs b/src/internal/mod.rs index 08a55f4..7f11dea 100644 --- a/src/internal/mod.rs +++ b/src/internal/mod.rs @@ -3,6 +3,3 @@ pub mod frozen_copy_map; pub mod id; pub mod mapping; pub mod small_vec; -mod unwrap_unchecked; - -pub use unwrap_unchecked::debug_expect_unchecked; diff --git a/src/internal/unwrap_unchecked.rs b/src/internal/unwrap_unchecked.rs deleted file mode 100644 index eee6970..0000000 --- a/src/internal/unwrap_unchecked.rs +++ /dev/null @@ -1,13 +0,0 @@ -/// An unsafe method that unwraps an option without checking if it is `None` in -/// release mode but does check the value in debug mode. -#[track_caller] -pub unsafe fn debug_expect_unchecked(opt: Option, _msg: &str) -> T { - #[cfg(debug_assertions)] - { - opt.expect(_msg) - } - #[cfg(not(debug_assertions))] - { - opt.unwrap_unchecked() - } -} diff --git a/src/solver/clause.rs b/src/solver/clause.rs index f034130..4074e5b 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -632,7 +632,7 @@ mod test { #[test] fn test_literal_eval() { - let mut decision_map = DecisionMap::new(); + let mut decision_map = DecisionMap::default(); let literal = VariableId::root().positive(); let negated_literal = VariableId::root().negative(); @@ -653,7 +653,7 @@ mod test { #[test] fn test_requires_with_and_without_conflict() { - let mut decisions = DecisionTracker::new(); + let mut decisions = DecisionTracker::default(); let parent = VariableId::from_usize(1); let candidate1 = VariableId::from_usize(2); @@ -671,17 +671,11 @@ mod test { clause.as_ref().unwrap().watched_literals[0].variable(), parent ); - assert_eq!( - clause.unwrap().watched_literals[1].variable(), - candidate1.into() - ); + assert_eq!(clause.unwrap().watched_literals[1].variable(), candidate1); // No conflict, still one candidate available decisions - .try_add_decision( - Decision::new(candidate1.into(), false, ClauseId::from_usize(0)), - 1, - ) + .try_add_decision(Decision::new(candidate1, false, ClauseId::from_usize(0)), 1) .unwrap(); let (clause, conflict, _kind) = WatchedLiterals::requires( parent, @@ -696,13 +690,13 @@ mod test { ); assert_eq!( clause.as_ref().unwrap().watched_literals[1].variable(), - candidate2.into() + candidate2 ); // Conflict, no candidates available decisions .try_add_decision( - Decision::new(candidate2.into(), false, ClauseId::install_root()), + Decision::new(candidate2, false, ClauseId::install_root()), 1, ) .unwrap(); @@ -719,7 +713,7 @@ mod test { ); assert_eq!( clause.as_ref().unwrap().watched_literals[1].variable(), - candidate1.into() + candidate1 ); // Panic @@ -740,7 +734,7 @@ mod test { #[test] fn test_constrains_with_and_without_conflict() { - let mut decisions = DecisionTracker::new(); + let mut decisions = DecisionTracker::default(); let parent = VariableId::from_usize(1); let forbidden = VariableId::from_usize(2); diff --git a/src/solver/decision_map.rs b/src/solver/decision_map.rs index 7de521b..ccdc585 100644 --- a/src/solver/decision_map.rs +++ b/src/solver/decision_map.rs @@ -37,17 +37,12 @@ impl DecisionAndLevel { } /// A map of the assignments to solvables. +#[derive(Default)] pub(crate) struct DecisionMap { map: Vec, } impl DecisionMap { - pub fn new() -> Self { - Self { - map: Default::default(), - } - } - pub fn reset(&mut self, variable_id: VariableId) { let variable_id = variable_id.to_usize(); if variable_id < self.map.len() { diff --git a/src/solver/decision_tracker.rs b/src/solver/decision_tracker.rs index 8fcc8cd..82b1d4b 100644 --- a/src/solver/decision_tracker.rs +++ b/src/solver/decision_tracker.rs @@ -3,6 +3,7 @@ use crate::solver::{decision::Decision, decision_map::DecisionMap}; /// Tracks the assignments to solvables, keeping a log that can be used to backtrack, and a map that /// can be used to query the current value assigned +#[derive(Default)] pub(crate) struct DecisionTracker { map: DecisionMap, stack: Vec, @@ -10,18 +11,8 @@ pub(crate) struct DecisionTracker { } impl DecisionTracker { - pub(crate) fn new() -> Self { - Self { - map: DecisionMap::new(), - stack: Vec::new(), - propagate_index: 0, - } - } - pub(crate) fn clear(&mut self) { - self.map = DecisionMap::new(); - self.stack = Vec::new(); - self.propagate_index = 0; + *self = Default::default(); } #[inline(always)] diff --git a/src/solver/mod.rs b/src/solver/mod.rs index 8c0e026..92ddd93 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -128,15 +128,16 @@ impl> Problem { #[derive(Default)] pub(crate) struct Clauses { - pub(crate) kinds: Vec, - watched_literals: Vec>, + pub(crate) kinds: Arena, + watched_literals: Mapping, } impl Clauses { pub fn alloc(&mut self, watched_literals: Option, kind: Clause) -> ClauseId { - let id = ClauseId::from_usize(self.kinds.len()); - self.kinds.push(kind); - self.watched_literals.push(watched_literals); + let id = self.kinds.alloc(kind); + if let Some(watched_literals) = watched_literals { + self.watched_literals.insert(id, watched_literals); + } id } } @@ -145,9 +146,27 @@ type RequirementCandidateVariables = Vec>; /// Drives the SAT solving process. pub struct Solver { + /// The runtime to use for async operations. pub(crate) async_runtime: RT, + + /// A cache that stores request to the dependency provider. pub(crate) cache: SolverCache, + /// Holds the current state of the solver. + pub(crate) state: SolverState, + + /// The activity add factor. This is a value that is added to the activity + /// score of each package that is part of a conflict. + activity_add: f32, + + /// The activity decay factor. This is a value between 0 and 1 with which + /// the activity scores of each package are multiplied when a conflict is + /// detected. + activity_decay: f32, +} + +#[derive(Default)] +pub(crate) struct SolverState { pub(crate) clauses: Clauses, requires_clauses: IndexMap, ahash::RandomState>, watches: WatchMap, @@ -171,23 +190,14 @@ pub struct Solver { decision_tracker: DecisionTracker, + /// Activity score per package. + name_activity: Vec, + /// The [`Requirement`]s that must be installed as part of the solution. root_requirements: Vec, /// Additional constraints imposed by the root. root_constraints: Vec, - - /// Activity score per package. - name_activity: Vec, - - /// The activity add factor. This is a value that is added to the activity - /// score of each package that is part of a conflict. - activity_add: f32, - - /// The activity decay factor. This is a value between 0 and 1 with which - /// the activity scores of each package are multiplied when a conflict is - /// detected. - activity_decay: f32, } impl Solver { @@ -197,23 +207,7 @@ impl Solver { Self { cache: SolverCache::new(provider), async_runtime: NowOrNeverRuntime, - clauses: Clauses::default(), - variable_map: VariableMap::default(), - requires_clauses: Default::default(), - requirement_to_sorted_candidates: FrozenMap::default(), - watches: WatchMap::new(), - negative_assertions: Default::default(), - learnt_clauses: Arena::new(), - learnt_why: Mapping::new(), - learnt_clause_ids: Vec::new(), - decision_tracker: DecisionTracker::new(), - root_requirements: Default::default(), - root_constraints: Default::default(), - clauses_added_for_package: Default::default(), - clauses_added_for_solvable: Default::default(), - forbidden_clauses_added: Default::default(), - name_activity: Default::default(), - + state: SolverState::default(), activity_add: 1.0, activity_decay: 0.95, } @@ -277,24 +271,9 @@ impl Solver { Solver { async_runtime: runtime, cache: self.cache, - clauses: self.clauses, - variable_map: self.variable_map, - requires_clauses: self.requires_clauses, - requirement_to_sorted_candidates: self.requirement_to_sorted_candidates, - watches: self.watches, - negative_assertions: self.negative_assertions, - learnt_clauses: self.learnt_clauses, - learnt_why: self.learnt_why, - learnt_clause_ids: self.learnt_clause_ids, - clauses_added_for_package: self.clauses_added_for_package, - clauses_added_for_solvable: self.clauses_added_for_solvable, - forbidden_clauses_added: self.forbidden_clauses_added, - decision_tracker: self.decision_tracker, - root_requirements: self.root_requirements, - root_constraints: self.root_constraints, - name_activity: self.name_activity, - activity_add: self.activity_add, + state: self.state, activity_decay: self.activity_decay, + activity_add: self.activity_add, } } @@ -341,19 +320,18 @@ impl Solver { &mut self, problem: Problem>, ) -> Result, UnsolvableOrCancelled> { - self.decision_tracker.clear(); - self.negative_assertions.clear(); - self.learnt_clauses.clear(); - self.learnt_why = Mapping::new(); - self.clauses = Clauses::default(); - self.root_requirements = problem.requirements; - self.root_constraints = problem.constraints; + // Re-initialize the solver state. + self.state = SolverState { + root_requirements: problem.requirements, + root_constraints: problem.constraints, + ..SolverState::default() + }; // The first clause will always be the install root clause. Here we verify that // this is indeed the case. let root_clause = { let (state, kind) = WatchedLiterals::root(); - self.clauses.alloc(state, kind) + self.state.clauses.alloc(state, kind) }; assert_eq!(root_clause, ClauseId::install_root()); @@ -364,9 +342,10 @@ impl Solver { ); for additional in problem.soft_requirements { - let additional_var = self.variable_map.intern_solvable(additional); + let additional_var = self.state.variable_map.intern_solvable(additional); if self + .state .decision_tracker .assigned_value(additional_var) .is_none() @@ -375,20 +354,7 @@ impl Solver { } } - Ok(self.chosen_solvables().collect()) - } - - /// Returns the solvables that the solver has chosen to include in the - /// solution so far. - fn chosen_solvables(&self) -> impl Iterator + '_ { - self.decision_tracker.stack().filter_map(|d| { - if d.value { - d.variable.as_solvable(&self.variable_map) - } else { - // Ignore things that are set to false - None - } - }) + Ok(self.state.chosen_solvables().collect()) } /// Run the CDCL algorithm to solve the SAT problem @@ -426,10 +392,11 @@ impl Solver { /// returns [`UnsolvableOrCancelled::Cancelled`] as an `Err`. fn run_sat(&mut self, root_solvable: SolvableOrRootId) -> Result { let starting_level = self + .state .decision_tracker .stack() .next_back() - .map(|decision| self.decision_tracker.level(decision.variable)) + .map(|decision| self.state.decision_tracker.level(decision.variable)) .unwrap_or(0); let mut level = starting_level; @@ -456,10 +423,13 @@ impl Solver { "╤══ Install {} at level {level}", root_solvable.display(self.provider()) ); - self.decision_tracker + self.state + .decision_tracker .try_add_decision( Decision::new( - self.variable_map.intern_solvable_or_root(root_solvable), + self.state + .variable_map + .intern_solvable_or_root(root_solvable), true, ClauseId::install_root(), ), @@ -471,15 +441,7 @@ impl Solver { let output = self.async_runtime.block_on(add_clauses_for_solvables( [root_solvable], &self.cache, - &mut self.clauses, - &self.decision_tracker, - &mut self.variable_map, - &mut self.clauses_added_for_solvable, - &mut self.clauses_added_for_package, - &mut self.forbidden_clauses_added, - &mut self.requirement_to_sorted_candidates, - &self.root_requirements, - &self.root_constraints, + &mut self.state, ))?; if let Err(clause_id) = self.process_add_clause_output(output) { return self.run_sat_process_unsolvable( @@ -511,9 +473,9 @@ impl Solver { // The conflict was caused because new clauses have been added dynamically. // We need to start over. tracing::debug!("├─ added clause {clause} introduces a conflict which invalidates the partial solution", - clause=self.clauses.kinds[clause_id.to_usize()].display(&self.variable_map, self.provider())); + clause=self.state.clauses.kinds[clause_id].display(&self.state.variable_map, self.provider())); level = starting_level; - self.decision_tracker.undo_until(starting_level); + self.state.decision_tracker.undo_until(starting_level); continue; } } @@ -536,17 +498,22 @@ impl Solver { // get any dependencies. If we find any such solvable it means we // did not arrive at the full solution yet. let new_solvables: Vec<_> = self + .state .decision_tracker .stack() // Filter only decisions that led to a positive assignment .filter(|d| d.value) // Select solvables for which we do not yet have dependencies .filter(|d| { - let Some(solvable_or_root) = d.variable.as_solvable_or_root(&self.variable_map) + let Some(solvable_or_root) = + d.variable.as_solvable_or_root(&self.state.variable_map) else { return false; }; - !self.clauses_added_for_solvable.contains(&solvable_or_root) + !self + .state + .clauses_added_for_solvable + .contains(&solvable_or_root) }) .map(|d| (d.variable, d.derived_from)) .collect(); @@ -568,9 +535,9 @@ impl Solver { .copied() .format_with("\n- ", |(id, derived_from), f| f(&format_args!( "{} (derived from {})", - id.display(&self.variable_map, self.provider()), - self.clauses.kinds[derived_from.to_usize()] - .display(&self.variable_map, self.provider()), + id.display(&self.state.variable_map, self.provider()), + self.state.clauses.kinds[derived_from] + .display(&self.state.variable_map, self.provider()), ))) ); tracing::debug!("===="); @@ -580,32 +547,25 @@ impl Solver { new_solvables .iter() .filter_map(|(variable, _)| { - self.variable_map + self.state + .variable_map .origin(*variable) .as_solvable() .map(Into::into) }) .collect::>(), &self.cache, - &mut self.clauses, - &self.decision_tracker, - &mut self.variable_map, - &mut self.clauses_added_for_solvable, - &mut self.clauses_added_for_package, - &mut self.forbidden_clauses_added, - &mut self.requirement_to_sorted_candidates, - &self.root_requirements, - &self.root_constraints, + &mut self.state, ))?; // Serially process the outputs, to reduce the need for synchronization for &clause_id in &output.conflicting_clauses { tracing::debug!("├─ Added clause {clause} introduces a conflict which invalidates the partial solution", - clause=self.clauses.kinds[clause_id.to_usize()].display(&self.variable_map, self.provider())); + clause=self.state.clauses.kinds[clause_id].display(&self.state.variable_map, self.provider())); } if let Err(_first_conflicting_clause_id) = self.process_add_clause_output(output) { - self.decision_tracker.undo_until(starting_level); + self.state.decision_tracker.undo_until(starting_level); level = starting_level; } } @@ -630,11 +590,14 @@ impl Solver { self.analyze_unsolvable(clause_id), )) } else { - self.decision_tracker.undo_until(starting_level); - self.decision_tracker + self.state.decision_tracker.undo_until(starting_level); + self.state + .decision_tracker .try_add_decision( Decision::new( - self.variable_map.intern_solvable_or_root(solvable_or_root), + self.state + .variable_map + .intern_solvable_or_root(solvable_or_root), false, ClauseId::install_root(), ), @@ -646,21 +609,25 @@ impl Solver { } fn process_add_clause_output(&mut self, mut output: AddClauseOutput) -> Result<(), ClauseId> { - let watched_literals = &mut self.clauses.watched_literals; + let watched_literals = &mut self.state.clauses.watched_literals; for clause_id in output.clauses_to_watch { - let watched_literals = watched_literals[clause_id.to_usize()] - .as_mut() + let watched_literals = watched_literals + .get_mut(clause_id) .expect("attempting to watch a clause without watches!"); - self.watches.start_watching(watched_literals, clause_id); + self.state + .watches + .start_watching(watched_literals, clause_id); } for (solvable_id, requirement, clause_id) in output.new_requires_clauses { - self.requires_clauses + self.state + .requires_clauses .entry(solvable_id) .or_default() .push((requirement, clause_id)); } - self.negative_assertions + self.state + .negative_assertions .append(&mut output.negative_assertions); if let Some(max_name_idx) = output @@ -669,8 +636,8 @@ impl Solver { .map(|name_id| name_id.to_usize()) .max() { - if self.name_activity.len() <= max_name_idx { - self.name_activity.resize(max_name_idx + 1, 0.0); + if self.state.name_activity.len() <= max_name_idx { + self.state.name_activity.resize(max_name_idx + 1, 0.0); } } @@ -702,9 +669,9 @@ impl Solver { tracing::info!( "╒══ Install {} at level {level} (derived from {})", - candidate.display(&self.variable_map, self.provider()), - self.clauses.kinds[clause_id.to_usize()] - .display(&self.variable_map, self.provider()) + candidate.display(&self.state.variable_map, self.provider()), + self.state.clauses.kinds[clause_id] + .display(&self.state.variable_map, self.provider()) ); // Propagate the decision @@ -767,7 +734,7 @@ impl Solver { } let mut best_decision: Option = None; - for (&solvable_id, requirements) in self.requires_clauses.iter() { + for (&solvable_id, requirements) in self.state.requires_clauses.iter() { let is_explicit_requirement = solvable_id == VariableId::root(); if let Some(best_decision) = &best_decision { // If we already have an explicit requirement, there is no need to evaluate @@ -778,7 +745,7 @@ impl Solver { } // Consider only clauses in which we have decided to install the solvable - if self.decision_tracker.assigned_value(solvable_id) != Some(true) { + if self.state.decision_tracker.assigned_value(solvable_id) != Some(true) { continue; } @@ -786,7 +753,7 @@ impl Solver { let mut candidate = ControlFlow::Break(()); // Get the candidates for the individual version sets. - let version_set_candidates = &self.requirement_to_sorted_candidates[deps]; + let version_set_candidates = &self.state.requirement_to_sorted_candidates[deps]; // Iterate over all version sets in the requirement and find the first version // set that we can act on, or if a single candidate (from any version set) makes @@ -807,7 +774,8 @@ impl Solver { _ => None, }, |first_candidate, &candidate| { - let assigned_value = self.decision_tracker.assigned_value(candidate); + let assigned_value = + self.state.decision_tracker.assigned_value(candidate); ControlFlow::Continue(match assigned_value { Some(true) => { // This candidate has already been assigned so the clause is @@ -843,7 +811,7 @@ impl Solver { None => { // We found the first candidate that has not been assigned // yet. - let package_activity = self.name_activity[self + let package_activity = self.state.name_activity[self .provider() .version_set_name(version_set) .to_usize()]; @@ -924,9 +892,9 @@ impl Solver { { tracing::trace!( "deciding to assign {}, ({}, {} activity score, {} possible candidates)", - candidate.display(&self.variable_map, self.provider()), - self.clauses.kinds[clause_id.to_usize()] - .display(&self.variable_map, self.provider()), + candidate.display(&self.state.variable_map, self.provider()), + self.state.clauses.kinds[*clause_id] + .display(&self.state.variable_map, self.provider()), package_activity, candidate_count, ); @@ -964,7 +932,8 @@ impl Solver { ) -> Result { level += 1; - self.decision_tracker + self.state + .decision_tracker .try_add_decision(Decision::new(solvable, true, clause_id), level) .expect("bug: solvable was already decided!"); @@ -1006,31 +975,31 @@ impl Solver { { tracing::info!( "├┬ Propagation conflicted: could not set {solvable} to {attempted_value}", - solvable = conflicting_solvable.display(&self.variable_map, self.provider()), + solvable = conflicting_solvable.display(&self.state.variable_map, self.provider()), ); tracing::info!( "││ During unit propagation for clause: {}", - self.clauses.kinds[conflicting_clause.to_usize()] - .display(&self.variable_map, self.provider()) + self.state.clauses.kinds[conflicting_clause] + .display(&self.state.variable_map, self.provider()) ); tracing::info!( "││ Previously decided value: {}. Derived from: {}", !attempted_value, - self.clauses.kinds[self + self.state.clauses.kinds[self + .state .decision_tracker .find_clause_for_assignment(conflicting_solvable) - .unwrap() - .to_usize()] - .display(&self.variable_map, self.provider()), + .unwrap()] + .display(&self.state.variable_map, self.provider()), ); } if level == 1 { - for decision in self.decision_tracker.stack() { + for decision in self.state.decision_tracker.stack() { let clause_id = decision.derived_from; - let clause = self.clauses.kinds[clause_id.to_usize()]; - let level = self.decision_tracker.level(decision.variable); + let clause = self.state.clauses.kinds[clause_id]; + let level = self.state.decision_tracker.level(decision.variable); let action = if decision.value { "install" } else { "forbid" }; if let Clause::ForbidMultipleInstances(..) = clause { @@ -1042,9 +1011,9 @@ impl Solver { "* ({level}) {action} {}. Reason: {}", decision .variable - .display(&self.variable_map, self.provider()), - self.clauses.kinds[decision.derived_from.to_usize()] - .display(&self.variable_map, self.provider()), + .display(&self.state.variable_map, self.provider()), + self.state.clauses.kinds[decision.derived_from] + .display(&self.state.variable_map, self.provider()), ); } @@ -1059,7 +1028,8 @@ impl Solver { // Optimization: propagate right now, since we know that the clause is a unit // clause let decision = literal.satisfying_value(); - self.decision_tracker + self.state + .decision_tracker .try_add_decision( Decision::new(literal.variable(), decision, learned_clause_id), level, @@ -1069,7 +1039,7 @@ impl Solver { "│├ Propagate after learn: {} = {decision}", literal .variable() - .display(&self.variable_map, self.provider()), + .display(&self.state.variable_map, self.provider()), ); tracing::info!("│└ Backtracked from {old_level} -> {level}"); @@ -1109,24 +1079,25 @@ impl Solver { // an error is returned. let interner = self.cache.provider(); - let clause_kinds = &self.clauses.kinds; + let clause_kinds = &self.state.clauses.kinds; - while let Some(decision) = self.decision_tracker.next_unpropagated() { + while let Some(decision) = self.state.decision_tracker.next_unpropagated() { let watched_literal = Literal::new(decision.variable, decision.value); debug_assert!( - watched_literal.eval(self.decision_tracker.map()) == Some(false), + watched_literal.eval(self.state.decision_tracker.map()) == Some(false), "we are only watching literals that are turning false" ); // Propagate, iterating through the linked list of clauses that watch this // solvable let mut next_cursor = self + .state .watches - .cursor(&mut self.clauses.watched_literals, watched_literal); + .cursor(&mut self.state.clauses.watched_literals, watched_literal); while let Some(cursor) = next_cursor.take() { let clause_id = cursor.clause_id(); - let clause = &clause_kinds[clause_id.to_usize()]; + let clause = &clause_kinds[clause_id]; let watch_index = cursor.watch_index(); // If the other literal the current clause is watching is already true, we can @@ -1134,14 +1105,14 @@ impl Solver { let watched_literals = cursor.watched_literals(); let other_watched_literal = watched_literals.watched_literals[1 - cursor.watch_index()]; - if other_watched_literal.eval(self.decision_tracker.map()) == Some(true) { + if other_watched_literal.eval(self.state.decision_tracker.map()) == Some(true) { // Continue with the next clause in the linked list. next_cursor = cursor.next(); } else if let Some(literal) = watched_literals.next_unwatched_literal( clause, - &self.learnt_clauses, - &self.requirement_to_sorted_candidates, - self.decision_tracker.map(), + &self.state.learnt_clauses, + &self.state.requirement_to_sorted_candidates, + self.state.decision_tracker.map(), watch_index, ) { // Update the watch to point to the new literal @@ -1150,6 +1121,7 @@ impl Solver { // We could not find another literal to watch, which means the remaining // watched literal must be set to true. let decided = self + .state .decision_tracker .try_add_decision( Decision::new( @@ -1176,9 +1148,9 @@ impl Solver { "├ Propagate {} = {}. {}", other_watched_literal .variable() - .display(&self.variable_map, interner), + .display(&self.state.variable_map, interner), other_watched_literal.satisfying_value(), - clause.display(&self.variable_map, interner) + clause.display(&self.state.variable_map, interner) ); } } @@ -1197,9 +1169,10 @@ impl Solver { /// (assertions are clauses that consist of a single literal, and /// therefore do not have watches). fn decide_assertions(&mut self, level: u32) -> Result<(), PropagationError> { - for &(solvable_id, clause_id) in &self.negative_assertions { + for &(solvable_id, clause_id) in &self.state.negative_assertions { let value = false; let decided = self + .state .decision_tracker .try_add_decision(Decision::new(solvable_id, value, clause_id), level) .map_err(|_| PropagationError::Conflict(solvable_id, value, clause_id))?; @@ -1207,7 +1180,7 @@ impl Solver { if decided { tracing::trace!( "Negative assertions derived from other rules: Propagate assertion {} = {}", - solvable_id.display(&self.variable_map, self.provider()), + solvable_id.display(&self.state.variable_map, self.provider()), value ); } @@ -1218,14 +1191,14 @@ impl Solver { /// Add decisions derived from learnt clauses. fn decide_learned(&mut self, level: u32) -> Result<(), PropagationError> { // Assertions derived from learnt rules - for learn_clause_idx in 0..self.learnt_clause_ids.len() { - let clause_id = self.learnt_clause_ids[learn_clause_idx]; - let clause = self.clauses.kinds[clause_id.to_usize()]; + for learn_clause_idx in 0..self.state.learnt_clause_ids.len() { + let clause_id = self.state.learnt_clause_ids[learn_clause_idx]; + let clause = self.state.clauses.kinds[clause_id]; let Clause::Learnt(learnt_index) = clause else { unreachable!(); }; - let literals = &self.learnt_clauses[learnt_index]; + let literals = &self.state.learnt_clauses[learnt_index]; if literals.len() > 1 { continue; } @@ -1236,6 +1209,7 @@ impl Solver { let decision = literal.satisfying_value(); let decided = self + .state .decision_tracker .try_add_decision( Decision::new(literal.variable(), decision, clause_id), @@ -1248,7 +1222,7 @@ impl Solver { "├─ Propagate assertion {} = {}", literal .variable() - .display(&self.variable_map, self.provider()), + .display(&self.state.variable_map, self.provider()), decision ); } @@ -1261,13 +1235,13 @@ impl Solver { /// Because learnt clauses are not relevant for the user, they are not added /// to the [`Conflict`]. Instead, we report the clauses that caused them. fn analyze_unsolvable_clause( - clauses: &[Clause], + clauses: &Arena, learnt_why: &Mapping>, clause_id: ClauseId, conflict: &mut Conflict, seen: &mut HashSet, ) { - let clause = &clauses[clause_id.to_usize()]; + let clause = &clauses[clause_id]; match clause { Clause::Learnt(learnt_clause_id) => { if !seen.insert(clause_id) { @@ -1288,8 +1262,8 @@ impl Solver { /// Create a [`Conflict`] based on the id of the clause that triggered an /// unrecoverable conflict fn analyze_unsolvable(&mut self, clause_id: ClauseId) -> Conflict { - let last_decision = self.decision_tracker.stack().last().unwrap(); - let highest_level = self.decision_tracker.level(last_decision.variable); + let last_decision = self.state.decision_tracker.stack().last().unwrap(); + let highest_level = self.state.decision_tracker.level(last_decision.variable); debug_assert_eq!(highest_level, 1); let mut conflict = Conflict::default(); @@ -1297,9 +1271,9 @@ impl Solver { tracing::info!("=== ANALYZE UNSOLVABLE"); let mut involved = HashSet::default(); - self.clauses.kinds[clause_id.to_usize()].visit_literals( - &self.learnt_clauses, - &self.requirement_to_sorted_candidates, + self.state.clauses.kinds[clause_id].visit_literals( + &self.state.learnt_clauses, + &self.state.requirement_to_sorted_candidates, |literal| { involved.insert(literal.variable()); }, @@ -1307,14 +1281,14 @@ impl Solver { let mut seen = HashSet::default(); Self::analyze_unsolvable_clause( - &self.clauses.kinds, - &self.learnt_why, + &self.state.clauses.kinds, + &self.state.learnt_why, clause_id, &mut conflict, &mut seen, ); - for decision in self.decision_tracker.stack().rev() { + for decision in self.state.decision_tracker.stack().rev() { if decision.variable.is_root() { continue; } @@ -1328,18 +1302,18 @@ impl Solver { assert_ne!(why, ClauseId::install_root()); Self::analyze_unsolvable_clause( - &self.clauses.kinds, - &self.learnt_why, + &self.state.clauses.kinds, + &self.state.learnt_why, why, &mut conflict, &mut seen, ); - self.clauses.kinds[why.to_usize()].visit_literals( - &self.learnt_clauses, - &self.requirement_to_sorted_candidates, + self.state.clauses.kinds[why].visit_literals( + &self.state.learnt_clauses, + &self.state.requirement_to_sorted_candidates, |literal| { - if literal.eval(self.decision_tracker.map()) == Some(true) { + if literal.eval(self.state.decision_tracker.map()) == Some(true) { assert_eq!(literal.variable(), decision.variable); } else { involved.insert(literal.variable()); @@ -1378,13 +1352,13 @@ impl Solver { let mut s_value; let mut learnt_why = Vec::new(); let mut first_iteration = true; - let clause_kinds = &self.clauses.kinds; + let clause_kinds = &self.state.clauses.kinds; loop { learnt_why.push(clause_id); - clause_kinds[clause_id.to_usize()].visit_literals( - &self.learnt_clauses, - &self.requirement_to_sorted_candidates, + clause_kinds[clause_id].visit_literals( + &self.state.learnt_clauses, + &self.state.requirement_to_sorted_candidates, |literal| { if !first_iteration && literal.variable() == conflicting_solvable { // We are only interested in the causes of the conflict, so we ignore the @@ -1397,13 +1371,14 @@ impl Solver { return; } - let decision_level = self.decision_tracker.level(literal.variable()); + let decision_level = self.state.decision_tracker.level(literal.variable()); if decision_level == current_level { causes_at_current_level += 1; } else if current_level > 1 { let learnt_literal = Literal::new( literal.variable(), - self.decision_tracker + self.state + .decision_tracker .assigned_value(literal.variable()) .unwrap(), ); @@ -1419,7 +1394,7 @@ impl Solver { // Select next literal to look at loop { - let (last_decision, last_decision_level) = self.decision_tracker.undo_last(); + let (last_decision, last_decision_level) = self.state.decision_tracker.undo_last(); conflicting_solvable = last_decision.variable; s_value = last_decision.value; @@ -1447,23 +1422,24 @@ impl Solver { for literal in &learnt { let name_id = literal .variable() - .as_solvable(&self.variable_map) + .as_solvable(&self.state.variable_map) .map(|s| self.provider().solvable_name(s)); if let Some(name_id) = name_id { - self.name_activity[name_id.to_usize()] += self.activity_add; + self.state.name_activity[name_id.to_usize()] += self.activity_add; } } // Add the clause - let learnt_id = self.learnt_clauses.alloc(learnt.clone()); - self.learnt_why.insert(learnt_id, learnt_why); + let learnt_id = self.state.learnt_clauses.alloc(learnt.clone()); + self.state.learnt_why.insert(learnt_id, learnt_why); let (watched_literals, kind) = WatchedLiterals::learnt(learnt_id, &learnt); - let clause_id = self.clauses.alloc(watched_literals, kind); - self.learnt_clause_ids.push(clause_id); - if let Some(watched_literals) = self.clauses.watched_literals[clause_id.to_usize()].as_mut() - { - self.watches.start_watching(watched_literals, clause_id); + let clause_id = self.state.clauses.alloc(watched_literals, kind); + self.state.learnt_clause_ids.push(clause_id); + if let Some(watched_literals) = self.state.clauses.watched_literals.get_mut(clause_id) { + self.state + .watches + .start_watching(watched_literals, clause_id); } tracing::debug!("│├ Learnt disjunction:",); @@ -1471,13 +1447,14 @@ impl Solver { tracing::debug!( "││ - {}{}", if lit.negate() { "NOT " } else { "" }, - lit.variable().display(&self.variable_map, self.provider()), + lit.variable() + .display(&self.state.variable_map, self.provider()), ); } // Should revert at most to the root level let target_level = back_track_to.max(1); - self.decision_tracker.undo_until(target_level); + self.state.decision_tracker.undo_until(target_level); self.decay_activity_scores(); @@ -1487,7 +1464,7 @@ impl Solver { /// Decays the activity scores of all packages in the solver. This function /// is caleld after each conflict. fn decay_activity_scores(&mut self) { - for activity in &mut self.name_activity { + for activity in &mut self.state.name_activity { *activity *= self.activity_decay; } } @@ -1508,19 +1485,7 @@ impl Solver { async fn add_clauses_for_solvables( solvable_ids: impl IntoIterator, cache: &SolverCache, - clauses: &mut Clauses, - decision_tracker: &DecisionTracker, - variable_map: &mut VariableMap, - clauses_added_for_solvable: &mut HashSet, - clauses_added_for_package: &mut HashSet, - forbidden_clauses_added: &mut HashMap>, - requirement_to_sorted_candidates: &mut FrozenMap< - Requirement, - RequirementCandidateVariables, - ahash::RandomState, - >, - root_requirements: &[Requirement], - root_constraints: &[VersionSetId], + state: &mut SolverState, ) -> Result> { let mut output = AddClauseOutput::default(); @@ -1551,7 +1516,7 @@ async fn add_clauses_for_solvables( let mut pending_solvables = vec![]; { for solvable_id in solvable_ids { - if clauses_added_for_solvable.insert(solvable_id) { + if state.clauses_added_for_solvable.insert(solvable_id) { pending_solvables.push(solvable_id); } } @@ -1583,8 +1548,8 @@ async fn add_clauses_for_solvables( ready(Ok(TaskResult::Dependencies { solvable_id: solvable_or_root, dependencies: Dependencies::Known(KnownDependencies { - requirements: root_requirements.to_vec(), - constrains: root_constraints.to_vec(), + requirements: state.root_requirements.to_vec(), + constrains: state.root_constraints.to_vec(), }), })) .right_future() @@ -1611,8 +1576,8 @@ async fn add_clauses_for_solvables( // Allocate a variable for the solvable let variable = match solvable_id.solvable() { - Some(solvable_id) => variable_map.intern_solvable(solvable_id), - None => variable_map.root(), + Some(solvable_id) => state.variable_map.intern_solvable(solvable_id), + None => state.variable_map.root(), }; let (requirements, constrains) = match dependencies { @@ -1621,15 +1586,15 @@ async fn add_clauses_for_solvables( // There is no information about the solvable's dependencies, so we add // an exclusion clause for it - let (state, kind) = WatchedLiterals::exclude(variable, reason); - let clause_id = clauses.alloc(state, kind); + let (watched_literals, kind) = WatchedLiterals::exclude(variable, reason); + let clause_id = state.clauses.alloc(watched_literals, kind); // Exclusions are negative assertions, tracked outside the watcher // system output.negative_assertions.push((variable, clause_id)); // There might be a conflict now - if decision_tracker.assigned_value(variable) == Some(true) { + if state.decision_tracker.assigned_value(variable) == Some(true) { output.conflicting_clauses.push(clause_id); } @@ -1643,7 +1608,7 @@ async fn add_clauses_for_solvables( .chain(constrains.iter().copied()) { let dependency_name = cache.provider().version_set_name(version_set_id); - if clauses_added_for_package.insert(dependency_name) { + if state.clauses_added_for_package.insert(dependency_name) { tracing::trace!( "┝━ Adding clauses for package '{}'", cache.provider().display_name(dependency_name), @@ -1721,15 +1686,17 @@ async fn add_clauses_for_solvables( // If there is a locked solvable, forbid other solvables. if let Some(locked_solvable_id) = package_candidates.locked { - let locked_solvable_var = variable_map.intern_solvable(locked_solvable_id); + let locked_solvable_var = + state.variable_map.intern_solvable(locked_solvable_id); for &other_candidate in candidates { if other_candidate != locked_solvable_id { - let other_candidate_var = variable_map.intern_solvable(other_candidate); + let other_candidate_var = + state.variable_map.intern_solvable(other_candidate); let (watched_literals, kind) = WatchedLiterals::lock(locked_solvable_var, other_candidate_var); - let clause_id = clauses.alloc(watched_literals, kind); + let clause_id = state.clauses.alloc(watched_literals, kind); - debug_assert!(clauses.watched_literals[clause_id.to_usize()].is_some()); + debug_assert!(state.clauses.watched_literals.get(clause_id).is_some()); output.clauses_to_watch.push(clause_id); } } @@ -1737,15 +1704,17 @@ async fn add_clauses_for_solvables( // Add a clause for solvables that are externally excluded. for (solvable, reason) in package_candidates.excluded.iter().copied() { - let solvable_var = variable_map.intern_solvable(solvable); + let solvable_var = state.variable_map.intern_solvable(solvable); let (watched_literals, kind) = WatchedLiterals::exclude(solvable_var, reason); - let clause_id = clauses.alloc(watched_literals, kind); + let clause_id = state.clauses.alloc(watched_literals, kind); // Exclusions are negative assertions, tracked outside the watcher system output.negative_assertions.push((solvable_var, clause_id)); // Conflicts should be impossible here - debug_assert!(decision_tracker.assigned_value(solvable_var) != Some(true)); + debug_assert!( + state.decision_tracker.assigned_value(solvable_var) != Some(true) + ); } } TaskResult::SortedCandidates { @@ -1760,19 +1729,19 @@ async fn add_clauses_for_solvables( // Allocate a variable for the solvable let variable = match solvable_id.solvable() { - Some(solvable_id) => variable_map.intern_solvable(solvable_id), - None => variable_map.root(), + Some(solvable_id) => state.variable_map.intern_solvable(solvable_id), + None => state.variable_map.root(), }; // Intern all the solvables of the candidates. - let version_set_variables = requirement_to_sorted_candidates.insert( + let version_set_variables = state.requirement_to_sorted_candidates.insert( requirement, candidates .iter() .map(|&candidates| { candidates .iter() - .map(|&var| variable_map.intern_solvable(var)) + .map(|&var| state.variable_map.intern_solvable(var)) .collect() }) .collect(), @@ -1794,7 +1763,7 @@ async fn add_clauses_for_solvables( // If the dependencies are already available for the // candidate, queue the candidate for processing. if cache.are_dependencies_available_for(candidate) - && clauses_added_for_solvable.insert(candidate.into()) + && state.clauses_added_for_solvable.insert(candidate.into()) { pending_solvables.push(candidate.into()); } @@ -1803,7 +1772,7 @@ async fn add_clauses_for_solvables( // solvables that have been visited already for the same // version set name. let name_id = cache.provider().solvable_name(candidate); - let other_solvables = forbidden_clauses_added.entry(name_id).or_default(); + let other_solvables = state.forbidden_clauses_added.entry(name_id).or_default(); other_solvables.add( candidate_var, |a, b, positive| { @@ -1812,11 +1781,11 @@ async fn add_clauses_for_solvables( if positive { b.positive() } else { b.negative() }, name_id, ); - let clause_id = clauses.alloc(watched_literals, kind); - debug_assert!(clauses.watched_literals[clause_id.to_usize()].is_some()); + let clause_id = state.clauses.alloc(watched_literals, kind); + debug_assert!(state.clauses.watched_literals.get(clause_id).is_some()); output.clauses_to_watch.push(clause_id); }, - || variable_map.alloc_forbid_multiple_variable(name_id), + || state.variable_map.alloc_forbid_multiple_variable(name_id), ); } @@ -1826,10 +1795,10 @@ async fn add_clauses_for_solvables( variable, requirement, version_set_variables.iter().flatten().copied(), - decision_tracker, + &state.decision_tracker, ); let has_watches = watched_literals.is_some(); - let clause_id = clauses.alloc(watched_literals, kind); + let clause_id = state.clauses.alloc(watched_literals, kind); if has_watches { output.clauses_to_watch.push(clause_id); @@ -1861,21 +1830,22 @@ async fn add_clauses_for_solvables( // Allocate a variable for the solvable let variable = match solvable_id.solvable() { - Some(solvable_id) => variable_map.intern_solvable(solvable_id), - None => variable_map.root(), + Some(solvable_id) => state.variable_map.intern_solvable(solvable_id), + None => state.variable_map.root(), }; // Add forbidden clauses for the candidates for &forbidden_candidate in non_matching_candidates { - let forbidden_candidate_var = variable_map.intern_solvable(forbidden_candidate); - let (state, conflict, kind) = WatchedLiterals::constrains( + let forbidden_candidate_var = + state.variable_map.intern_solvable(forbidden_candidate); + let (watched_literals, conflict, kind) = WatchedLiterals::constrains( variable, forbidden_candidate_var, version_set_id, - decision_tracker, + &state.decision_tracker, ); - let clause_id = clauses.alloc(state, kind); + let clause_id = state.clauses.alloc(watched_literals, kind); output.clauses_to_watch.push(clause_id); if conflict { @@ -1890,3 +1860,18 @@ async fn add_clauses_for_solvables( Ok(output) } + +impl SolverState { + /// Returns the solvables that the solver has chosen to include in the + /// solution so far. + fn chosen_solvables(&self) -> impl Iterator + '_ { + self.decision_tracker.stack().filter_map(|d| { + if d.value { + d.variable.as_solvable(&self.variable_map) + } else { + // Ignore things that are set to false + None + } + }) + } +} diff --git a/src/solver/watch_map.rs b/src/solver/watch_map.rs index adc5c9e..6e9b073 100644 --- a/src/solver/watch_map.rs +++ b/src/solver/watch_map.rs @@ -1,10 +1,11 @@ use crate::{ - internal::{arena::ArenaId, debug_expect_unchecked, id::ClauseId, mapping::Mapping}, + internal::{id::ClauseId, mapping::Mapping}, solver::clause::{Literal, WatchedLiterals}, }; /// A map from literals to the clauses that are watching them. Each literal /// forms a linked list of clauses that are all watching that literal. +#[derive(Default)] pub(crate) struct WatchMap { // Note: the map is to a single clause, but clauses form a linked list, so // it is possible to go from one to the next @@ -12,12 +13,6 @@ pub(crate) struct WatchMap { } impl WatchMap { - pub(crate) fn new() -> Self { - Self { - map: Mapping::new(), - } - } - /// Add the clause to the linked list of the literals that the clause is /// watching. pub(crate) fn start_watching(&mut self, clause: &mut WatchedLiterals, clause_id: ClauseId) { @@ -36,13 +31,11 @@ impl WatchMap { /// literal. pub fn cursor<'a>( &'a mut self, - watches: &'a mut [Option], + watches: &'a mut Mapping, literal: Literal, ) -> Option> { let clause_id = *self.map.get(literal)?; - let watched_literal = watches[clause_id.to_usize()] - .as_ref() - .expect("no watches found for clause"); + let watched_literal = watches.get(clause_id).expect("no watches found for clause"); let watch_index = if watched_literal.watched_literals[0] == literal { 0 } else { @@ -88,7 +81,7 @@ pub struct WatchMapCursor<'a> { watch_map: &'a mut WatchMap, /// The nodes of the linked list. - watches: &'a mut [Option], + watches: &'a mut Mapping, /// The literal who's linked list is being navigated. literal: Literal, @@ -116,8 +109,9 @@ impl<'a> WatchMapCursor<'a> { fn next_node(&self) -> Option { let current_watch = self.watched_literals(); let next_clause_id = current_watch.next_watches[self.current.watch_index]?; - let next_watch = self.watches[next_clause_id.to_usize()] - .as_ref() + let next_watch = self + .watches + .get(next_clause_id) .expect("watches are missing"); let next_clause_watch_index = if next_watch.watched_literals[0] == self.literal { 0 @@ -143,12 +137,7 @@ impl<'a> WatchMapCursor<'a> { /// Returns the watches of the current clause. pub fn watched_literals(&self) -> &WatchedLiterals { // SAFETY: Within the cursor, the current clause is always watching literals. - unsafe { - debug_expect_unchecked( - self.watches[self.current.clause_id.to_usize()].as_ref(), - "clause is not watching literals", - ) - } + unsafe { self.watches.get_unchecked(self.current.clause_id) } } /// Returns the index of the current watch in the current clause. @@ -167,7 +156,7 @@ impl<'a> WatchMapCursor<'a> { "cannot update watch to the same literal" ); - let clause_idx = self.current.clause_id.to_usize(); + let clause_id = self.current.clause_id; let next_node = self.next_node(); // Update the previous node to point to the next node in the linked list @@ -176,12 +165,7 @@ impl<'a> WatchMapCursor<'a> { // If there is a previous node we update that node to point to the next. // SAFETY: Within the cursor, the watches are never unset, so if we have a // previous index there will also be watch literals for that clause. - let previous_watches = unsafe { - debug_expect_unchecked( - self.watches[previous.clause_id.to_usize()].as_mut(), - "previous clause has no watches", - ) - }; + let previous_watches = unsafe { self.watches.get_unchecked_mut(previous.clause_id) }; previous_watches.next_watches[previous.watch_index] = next_node.as_ref().map(|node| node.clause_id); } else if let Some(next_clause_id) = next_node.as_ref().map(|node| node.clause_id) { @@ -192,12 +176,7 @@ impl<'a> WatchMapCursor<'a> { } // Set the new watch for the current clause. - let watch = unsafe { - debug_expect_unchecked( - self.watches[clause_idx].as_mut(), - "clause is not watching literals", - ) - }; + let watch = unsafe { self.watches.get_unchecked_mut(clause_id) }; watch.watched_literals[self.current.watch_index] = new_watch; let previous_clause_id = self.watch_map.map.insert(new_watch, self.current.clause_id); watch.next_watches[self.current.watch_index] = previous_clause_id; diff --git a/tests/solver.rs b/tests/solver.rs index de15d8a..16b84b6 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -123,9 +123,7 @@ impl Spec { pub fn parse_union( spec: &str, ) -> impl Iterator::Err>> + '_ { - spec.split('|') - .map(str::trim) - .map(|dep| Spec::from_str(dep)) + spec.split('|').map(str::trim).map(Spec::from_str) } } @@ -386,7 +384,7 @@ impl DependencyProvider for BundleBoxProvider { candidates .iter() .copied() - .filter(|s| range.contains(&self.pool.resolve_solvable(*s).record) == !inverse) + .filter(|s| range.contains(&self.pool.resolve_solvable(*s).record) != inverse) .collect() } @@ -538,7 +536,7 @@ impl DependencyProvider for BundleBoxProvider { } /// Create a string from a [`Transaction`] -fn transaction_to_string(interner: &impl Interner, solvables: &Vec) -> String { +fn transaction_to_string(interner: &impl Interner, solvables: &[SolvableId]) -> String { use std::fmt::Write; let mut buf = String::new(); for solvable in solvables