Skip to content

Commit f89c0b2

Browse files
authored
refactor: reduce watchmap memory size (#92)
1 parent bebb6b4 commit f89c0b2

File tree

5 files changed

+85
-70
lines changed

5 files changed

+85
-70
lines changed

src/internal/id.rs

+24-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use std::fmt::{Display, Formatter};
1+
use std::{
2+
fmt::{Display, Formatter},
3+
num::NonZeroU32,
4+
};
25

36
use crate::{internal::arena::ArenaId, Interner};
47

@@ -165,32 +168,24 @@ impl From<SolvableId> for u32 {
165168

166169
#[repr(transparent)]
167170
#[derive(Copy, Clone, PartialOrd, Ord, Eq, PartialEq, Debug, Hash)]
168-
pub(crate) struct ClauseId(u32);
171+
pub(crate) struct ClauseId(NonZeroU32);
169172

170173
impl ClauseId {
171-
/// There is a guarentee that ClauseId(0) will always be
174+
/// There is a guarentee that ClauseId(1) will always be
172175
/// "Clause::InstallRoot". This assumption is verified by the solver.
173176
pub(crate) fn install_root() -> Self {
174-
Self(0)
175-
}
176-
177-
pub(crate) fn is_null(self) -> bool {
178-
self.0 == u32::MAX
179-
}
180-
181-
pub(crate) fn null() -> ClauseId {
182-
ClauseId(u32::MAX)
177+
Self(unsafe { NonZeroU32::new_unchecked(1) })
183178
}
184179
}
185180

186181
impl ArenaId for ClauseId {
187182
fn from_usize(x: usize) -> Self {
188-
assert!(x < u32::MAX as usize, "clause id too big");
189-
Self(x as u32)
183+
// SAFETY: Safe because we always add 1 to the index
184+
Self(unsafe { NonZeroU32::new_unchecked((x + 1).try_into().expect("clause id too big")) })
190185
}
191186

192187
fn to_usize(self) -> usize {
193-
self.0 as usize
188+
(self.0.get() - 1) as usize
194189
}
195190
}
196191

@@ -236,3 +231,17 @@ impl ArenaId for DependenciesId {
236231
self.0 as usize
237232
}
238233
}
234+
235+
#[cfg(test)]
236+
mod tests {
237+
use super::*;
238+
239+
#[test]
240+
fn test_clause_id_size() {
241+
// Verify that the size of a ClauseId is the same as an Option<ClauseId>.
242+
assert_eq!(
243+
std::mem::size_of::<ClauseId>(),
244+
std::mem::size_of::<Option<ClauseId>>()
245+
);
246+
}
247+
}

src/internal/mapping.rs

+15
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,21 @@ impl<TId: ArenaId, TValue> Mapping<TId, TValue> {
6565
previous_value
6666
}
6767

68+
/// Unset a specific value in the mapping, returns the previous value.
69+
pub fn unset(&mut self, id: TId) -> Option<TValue> {
70+
let idx = id.to_usize();
71+
let (chunk, offset) = Self::chunk_and_offset(idx);
72+
if chunk >= self.chunks.len() {
73+
return None;
74+
}
75+
76+
let previous_value = self.chunks[chunk][offset].take();
77+
if previous_value.is_some() {
78+
self.len -= 1;
79+
}
80+
previous_value
81+
}
82+
6883
/// Get a specific value in the mapping with bound checks
6984
pub fn get(&self, id: TId) -> Option<&TValue> {
7085
let (chunk, offset) = Self::chunk_and_offset(id.to_usize());

src/solver/clause.rs

+31-31
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ pub(crate) struct ClauseState {
324324
// The ids of the solvables this clause is watching
325325
pub watched_literals: [Literal; 2],
326326
// The ids of the next clause in each linked list that this clause is part of
327-
pub(crate) next_watches: [ClauseId; 2],
327+
pub(crate) next_watches: [Option<ClauseId>; 2],
328328
}
329329

330330
impl ClauseState {
@@ -417,15 +417,15 @@ impl ClauseState {
417417

418418
let clause = Self {
419419
watched_literals,
420-
next_watches: [ClauseId::null(), ClauseId::null()],
420+
next_watches: [None, None],
421421
};
422422

423423
debug_assert!(!clause.has_watches() || watched_literals[0] != watched_literals[1]);
424424

425425
clause
426426
}
427427

428-
pub fn link_to_clause(&mut self, watch_index: usize, linked_clause: ClauseId) {
428+
pub fn link_to_clause(&mut self, watch_index: usize, linked_clause: Option<ClauseId>) {
429429
self.next_watches[watch_index] = linked_clause;
430430
}
431431

@@ -444,7 +444,7 @@ impl ClauseState {
444444
}
445445

446446
#[inline]
447-
pub fn next_watched_clause(&self, solvable_id: InternalSolvableId) -> ClauseId {
447+
pub fn next_watched_clause(&self, solvable_id: InternalSolvableId) -> Option<ClauseId> {
448448
if solvable_id == self.watched_literals[0].solvable_id() {
449449
self.next_watches[0]
450450
} else {
@@ -650,7 +650,7 @@ mod test {
650650
use super::*;
651651
use crate::{internal::arena::ArenaId, solver::decision::Decision};
652652

653-
fn clause(next_clauses: [ClauseId; 2], watch_literals: [Literal; 2]) -> ClauseState {
653+
fn clause(next_clauses: [Option<ClauseId>; 2], watch_literals: [Literal; 2]) -> ClauseState {
654654
ClauseState {
655655
watched_literals: watch_literals,
656656
next_watches: next_clauses,
@@ -691,21 +691,24 @@ mod test {
691691
#[test]
692692
fn test_unlink_clause_different() {
693693
let clause1 = clause(
694-
[ClauseId::from_usize(2), ClauseId::from_usize(3)],
694+
[
695+
ClauseId::from_usize(2).into(),
696+
ClauseId::from_usize(3).into(),
697+
],
695698
[
696699
InternalSolvableId::from_usize(1596).negative(),
697700
InternalSolvableId::from_usize(1211).negative(),
698701
],
699702
);
700703
let clause2 = clause(
701-
[ClauseId::null(), ClauseId::from_usize(3)],
704+
[None, ClauseId::from_usize(3).into()],
702705
[
703706
InternalSolvableId::from_usize(1596).negative(),
704707
InternalSolvableId::from_usize(1208).negative(),
705708
],
706709
);
707710
let clause3 = clause(
708-
[ClauseId::null(), ClauseId::null()],
711+
[None, None],
709712
[
710713
InternalSolvableId::from_usize(1211).negative(),
711714
InternalSolvableId::from_usize(42).negative(),
@@ -723,10 +726,7 @@ mod test {
723726
InternalSolvableId::from_usize(1211).negative()
724727
]
725728
);
726-
assert_eq!(
727-
clause1.next_watches,
728-
[ClauseId::null(), ClauseId::from_usize(3)]
729-
)
729+
assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(3).into()])
730730
}
731731

732732
// Unlink 1
@@ -740,24 +740,24 @@ mod test {
740740
InternalSolvableId::from_usize(1211).negative()
741741
]
742742
);
743-
assert_eq!(
744-
clause1.next_watches,
745-
[ClauseId::from_usize(2), ClauseId::null()]
746-
)
743+
assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None])
747744
}
748745
}
749746

750747
#[test]
751748
fn test_unlink_clause_same() {
752749
let clause1 = clause(
753-
[ClauseId::from_usize(2), ClauseId::from_usize(2)],
750+
[
751+
ClauseId::from_usize(2).into(),
752+
ClauseId::from_usize(2).into(),
753+
],
754754
[
755755
InternalSolvableId::from_usize(1596).negative(),
756756
InternalSolvableId::from_usize(1211).negative(),
757757
],
758758
);
759759
let clause2 = clause(
760-
[ClauseId::null(), ClauseId::null()],
760+
[None, None],
761761
[
762762
InternalSolvableId::from_usize(1596).negative(),
763763
InternalSolvableId::from_usize(1211).negative(),
@@ -775,10 +775,7 @@ mod test {
775775
InternalSolvableId::from_usize(1211).negative()
776776
]
777777
);
778-
assert_eq!(
779-
clause1.next_watches,
780-
[ClauseId::null(), ClauseId::from_usize(2)]
781-
)
778+
assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(2).into()])
782779
}
783780

784781
// Unlink 1
@@ -792,10 +789,7 @@ mod test {
792789
InternalSolvableId::from_usize(1211).negative()
793790
]
794791
);
795-
assert_eq!(
796-
clause1.next_watches,
797-
[ClauseId::from_usize(2), ClauseId::null()]
798-
)
792+
assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None])
799793
}
800794
}
801795

@@ -820,7 +814,10 @@ mod test {
820814

821815
// No conflict, still one candidate available
822816
decisions
823-
.try_add_decision(Decision::new(candidate1.into(), false, ClauseId::null()), 1)
817+
.try_add_decision(
818+
Decision::new(candidate1.into(), false, ClauseId::from_usize(0)),
819+
1,
820+
)
824821
.unwrap();
825822
let (clause, conflict, _kind) = ClauseState::requires(
826823
parent,
@@ -834,7 +831,10 @@ mod test {
834831

835832
// Conflict, no candidates available
836833
decisions
837-
.try_add_decision(Decision::new(candidate2.into(), false, ClauseId::null()), 1)
834+
.try_add_decision(
835+
Decision::new(candidate2.into(), false, ClauseId::install_root()),
836+
1,
837+
)
838838
.unwrap();
839839
let (clause, conflict, _kind) = ClauseState::requires(
840840
parent,
@@ -848,7 +848,7 @@ mod test {
848848

849849
// Panic
850850
decisions
851-
.try_add_decision(Decision::new(parent, false, ClauseId::null()), 1)
851+
.try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1)
852852
.unwrap();
853853
let panicked = std::panic::catch_unwind(|| {
854854
ClauseState::requires(
@@ -878,7 +878,7 @@ mod test {
878878

879879
// Conflict, forbidden package installed
880880
decisions
881-
.try_add_decision(Decision::new(forbidden, true, ClauseId::null()), 1)
881+
.try_add_decision(Decision::new(forbidden, true, ClauseId::install_root()), 1)
882882
.unwrap();
883883
let (clause, conflict, _kind) =
884884
ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions);
@@ -888,7 +888,7 @@ mod test {
888888

889889
// Panic
890890
decisions
891-
.try_add_decision(Decision::new(parent, false, ClauseId::null()), 1)
891+
.try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1)
892892
.unwrap();
893893
let panicked = std::panic::catch_unwind(|| {
894894
ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions)

src/solver/mod.rs

+8-12
Original file line numberDiff line numberDiff line change
@@ -1435,11 +1435,8 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
14351435
// solvable
14361436
let mut old_predecessor_clause_id: Option<ClauseId>;
14371437
let mut predecessor_clause_id: Option<ClauseId> = None;
1438-
let mut clause_id = self
1439-
.watches
1440-
.first_clause_watching_literal(watched_literal)
1441-
.unwrap_or(ClauseId::null());
1442-
while !clause_id.is_null() {
1438+
let mut next_clause_id = self.watches.first_clause_watching_literal(watched_literal);
1439+
while let Some(clause_id) = next_clause_id {
14431440
debug_assert!(
14441441
predecessor_clause_id != Some(clause_id),
14451442
"Linked list is circular!"
@@ -1466,8 +1463,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
14661463
predecessor_clause_id = Some(clause_id);
14671464

14681465
// Configure the next clause to visit
1469-
let this_clause_id = clause_id;
1470-
clause_id = clause_state.next_watched_clause(watched_literal.solvable_id());
1466+
next_clause_id = clause_state.next_watched_clause(watched_literal.solvable_id());
14711467

14721468
// Determine which watch turned false.
14731469
let (watch_index, other_watch_index) = if clause_state.watched_literals[0]
@@ -1492,7 +1488,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
14921488
// If the other watch is already true, we can simply skip
14931489
// this clause.
14941490
} else if let Some(variable) = clause_state.next_unwatched_literal(
1495-
&clauses[this_clause_id.to_usize()],
1491+
&clauses[clause_id.to_usize()],
14961492
&self.learnt_clauses,
14971493
&self.cache.requirement_to_sorted_candidates,
14981494
self.decision_tracker.map(),
@@ -1501,7 +1497,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
15011497
self.watches.update_watched(
15021498
predecessor_clause_state,
15031499
clause_state,
1504-
this_clause_id,
1500+
clause_id,
15051501
watch_index,
15061502
watched_literal,
15071503
variable,
@@ -1527,20 +1523,20 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
15271523
Decision::new(
15281524
remaining_watch.solvable_id(),
15291525
remaining_watch.satisfying_value(),
1530-
this_clause_id,
1526+
clause_id,
15311527
),
15321528
level,
15331529
)
15341530
.map_err(|_| {
15351531
PropagationError::Conflict(
15361532
remaining_watch.solvable_id(),
15371533
true,
1538-
this_clause_id,
1534+
clause_id,
15391535
)
15401536
})?;
15411537

15421538
if decided {
1543-
let clause = &clauses[this_clause_id.to_usize()];
1539+
let clause = &clauses[clause_id.to_usize()];
15441540
match clause {
15451541
// Skip logging for ForbidMultipleInstances, which is so noisy
15461542
Clause::ForbidMultipleInstances(..) => {}

src/solver/watch_map.rs

+7-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
use crate::solver::clause::Literal;
21
use crate::{
32
internal::{id::ClauseId, mapping::Mapping},
4-
solver::clause::ClauseState,
3+
solver::clause::{ClauseState, Literal},
54
};
65

76
/// A map from solvables to the clauses that are watching them
@@ -20,9 +19,7 @@ impl WatchMap {
2019

2120
pub(crate) fn start_watching(&mut self, clause: &mut ClauseState, clause_id: ClauseId) {
2221
for (watch_index, watched_literal) in clause.watched_literals.into_iter().enumerate() {
23-
let already_watching = self
24-
.first_clause_watching_literal(watched_literal)
25-
.unwrap_or(ClauseId::null());
22+
let already_watching = self.first_clause_watching_literal(watched_literal);
2623
clause.link_to_clause(watch_index, already_watching);
2724
self.watch_literal(watched_literal, clause_id);
2825
}
@@ -42,18 +39,16 @@ impl WatchMap {
4239
if let Some(predecessor_clause) = predecessor_clause {
4340
// Unlink the clause
4441
predecessor_clause.unlink_clause(clause, previous_watch.solvable_id(), watch_index);
45-
} else {
42+
} else if let Some(next_watch) = clause.next_watches[watch_index] {
4643
// This was the first clause in the chain
47-
self.map
48-
.insert(previous_watch, clause.next_watches[watch_index]);
44+
self.map.insert(previous_watch, next_watch);
45+
} else {
46+
self.map.unset(previous_watch);
4947
}
5048

5149
// Set the new watch
5250
clause.watched_literals[watch_index] = new_watch;
53-
let previous_clause_id = self
54-
.map
55-
.insert(new_watch, clause_id)
56-
.unwrap_or(ClauseId::null());
51+
let previous_clause_id = self.map.insert(new_watch, clause_id);
5752
clause.next_watches[watch_index] = previous_clause_id;
5853
}
5954

0 commit comments

Comments
 (0)