Skip to content

Commit

Permalink
Improve compare for IntSet and IntMap (#1086)
Browse files Browse the repository at this point in the history
Compare the trees directly instead of converting to lists.
The implementation follows broadly the same approach as the previous
attempt in commit 7aff529.
  • Loading branch information
meooow25 authored Jan 30, 2025
1 parent a9e0297 commit 0d85628
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 6 deletions.
1 change: 1 addition & 0 deletions containers-tests/benchmarks/IntMap.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ main = do
, bench "split" $ whnf (M.split key_mid) m
, bench "splitLookup" $ whnf (M.splitLookup key_mid) m
, bench "eq" $ whnf (\m' -> m' == m') m -- worst case, compares everything
, bench "compare" $ whnf (\m' -> compare m' m') m -- worst case, compares everything
, bgroup "folds" $ foldBenchmarks M.foldr M.foldl M.foldr' M.foldl' foldMap m
, bgroup "folds with key" $
foldWithKeyBenchmarks M.foldrWithKey M.foldlWithKey M.foldrWithKey' M.foldlWithKey' M.foldMapWithKey m
Expand Down
2 changes: 2 additions & 0 deletions containers-tests/benchmarks/IntSet.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ main = do
, bench "splitMember:dense" $ whnf (IS.splitMember elem_mid) s
, bench "splitMember:sparse" $ whnf (IS.splitMember elem_sparse_mid) s_sparse
, bench "eq" $ whnf (\s' -> s' == s') s -- worst case, compares everything
, bench "compare:dense" $ whnf (\s' -> compare s' s') s -- worst case, compares everything
, bench "compare:sparse" $ whnf (\s' -> compare s' s') s_sparse -- worst case, compares everything
, bgroup "folds:dense" $ foldBenchmarks IS.foldr IS.foldl IS.foldr' IS.foldl' IS.foldMap s
, bgroup "folds:sparse" $ foldBenchmarks IS.foldr IS.foldl IS.foldr' IS.foldl' IS.foldMap s_sparse
]
Expand Down
6 changes: 5 additions & 1 deletion containers-tests/tests/intmap-properties.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import Test.Tasty
import Test.Tasty.HUnit
import Test.Tasty.QuickCheck
import Test.QuickCheck.Function (apply)
import Test.QuickCheck.Poly (A, B, C)
import Test.QuickCheck.Poly (A, B, C, OrdA)

default (Int)

Expand Down Expand Up @@ -247,6 +247,7 @@ main = defaultMain $ testGroup "intmap-properties"
, testProperty "mapAccumRWithKey" prop_mapAccumRWithKey
, testProperty "mapKeysWith" prop_mapKeysWith
, testProperty "mapKeysMonotonic" prop_mapKeysMonotonic
, testProperty "compare" prop_compare
]

{--------------------------------------------------------------------
Expand Down Expand Up @@ -1980,3 +1981,6 @@ prop_mapKeysMonotonic (Positive a) b m =
fromIntegral (minBound :: Int) <= y && y <= fromIntegral (maxBound :: Int)
where
y = fromIntegral a * fromIntegral x + fromIntegral b :: Integer

prop_compare :: IntMap OrdA -> IntMap OrdA -> Property
prop_compare m1 m2 = compare m1 m2 === compare (toList m1) (toList m2)
2 changes: 2 additions & 0 deletions containers/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
including `insert` and `delete`, by inlining part of the balancing
routine. (Soumik Sarkar)

* Improved performance for `IntSet` and `IntMap`'s `Ord` instances.

## Unreleased with `@since` annotation for 0.7.1:

### Additions
Expand Down
69 changes: 66 additions & 3 deletions containers/src/Data/IntMap/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ import Data.IntSet.Internal.IntTreeCommons
, TreeTreeBranch(..)
, treeTreeBranch
, i2w
, Order(..)
)
import Utils.Containers.Internal.BitUtil (shiftLL, shiftRL, iShiftRL)
import Utils.Containers.Internal.StrictPair
Expand Down Expand Up @@ -3487,12 +3488,74 @@ instance Eq1 IntMap where
--------------------------------------------------------------------}

instance Ord a => Ord (IntMap a) where
compare m1 m2 = compare (toList m1) (toList m2)
compare m1 m2 = liftCmp compare m1 m2
{-# INLINABLE compare #-}

-- | @since 0.5.9
instance Ord1 IntMap where
liftCompare cmp m n =
liftCompare (liftCompare cmp) (toList m) (toList n)
liftCompare = liftCmp

liftCmp :: (a -> b -> Ordering) -> IntMap a -> IntMap b -> Ordering
liftCmp cmp m1 m2 = case (splitSign m1, splitSign m2) of
((l1, r1), (l2, r2)) -> case go l1 l2 of
A_LT_B -> LT
A_Prefix_B -> if null r1 then LT else GT
A_EQ_B -> case go r1 r2 of
A_LT_B -> LT
A_Prefix_B -> LT
A_EQ_B -> EQ
B_Prefix_A -> GT
A_GT_B -> GT
B_Prefix_A -> if null r2 then GT else LT
A_GT_B -> GT
where
go t1@(Bin p1 l1 r1) t2@(Bin p2 l2 r2) = case treeTreeBranch p1 p2 of
ABL -> case go l1 t2 of
A_Prefix_B -> A_GT_B
A_EQ_B -> B_Prefix_A
o -> o
ABR -> A_LT_B
BAL -> case go t1 l2 of
A_EQ_B -> A_Prefix_B
B_Prefix_A -> A_LT_B
o -> o
BAR -> A_GT_B
EQL -> case go l1 l2 of
A_Prefix_B -> A_GT_B
A_EQ_B -> go r1 r2
B_Prefix_A -> A_LT_B
o -> o
NOM -> if unPrefix p1 < unPrefix p2 then A_LT_B else A_GT_B
go (Bin _ l1 _) (Tip k2 x2) = case lookupMinSure l1 of
KeyValue k1 x1 -> case compare k1 k2 <> cmp x1 x2 of
LT -> A_LT_B
EQ -> B_Prefix_A
GT -> A_GT_B
go (Tip k1 x1) (Bin _ l2 _) = case lookupMinSure l2 of
KeyValue k2 x2 -> case compare k1 k2 <> cmp x1 x2 of
LT -> A_LT_B
EQ -> A_Prefix_B
GT -> A_GT_B
go (Tip k1 x1) (Tip k2 x2) = case compare k1 k2 <> cmp x1 x2 of
LT -> A_LT_B
EQ -> A_EQ_B
GT -> A_GT_B
go Nil Nil = A_EQ_B
go Nil _ = A_Prefix_B
go _ Nil = B_Prefix_A
{-# INLINE liftCmp #-}

-- Split into negative and non-negative
splitSign :: IntMap a -> (IntMap a, IntMap a)
splitSign t@(Bin p l r)
| signBranch p = (r, l)
| unPrefix p < 0 = (t, Nil)
| otherwise = (Nil, t)
splitSign t@(Tip k _)
| k < 0 = (t, Nil)
| otherwise = (Nil, t)
splitSign Nil = (Nil, Nil)
{-# INLINE splitSign #-}

{--------------------------------------------------------------------
Functor
Expand Down
91 changes: 89 additions & 2 deletions containers/src/Data/IntSet/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ import Data.IntSet.Internal.IntTreeCommons
, TreeTreeBranch(..)
, treeTreeBranch
, i2w
, Order(..)
)

#if __GLASGOW_HASKELL__
Expand Down Expand Up @@ -1486,8 +1487,94 @@ equal _ _ = False
--------------------------------------------------------------------}

instance Ord IntSet where
compare s1 s2 = compare (toAscList s1) (toAscList s2)
-- tentative implementation. See if more efficient exists.
compare = compareIntSets

compareIntSets :: IntSet -> IntSet -> Ordering
compareIntSets s1 s2 = case (splitSign s1, splitSign s2) of
((l1, r1), (l2, r2)) -> case go l1 l2 of
A_LT_B -> LT
A_Prefix_B -> if null r1 then LT else GT
A_EQ_B -> case go r1 r2 of
A_LT_B -> LT
A_Prefix_B -> LT
A_EQ_B -> EQ
B_Prefix_A -> GT
A_GT_B -> GT
B_Prefix_A -> if null r2 then GT else LT
A_GT_B -> GT
where
go t1@(Bin p1 l1 r1) t2@(Bin p2 l2 r2) = case treeTreeBranch p1 p2 of
ABL -> case go l1 t2 of
A_Prefix_B -> A_GT_B
A_EQ_B -> B_Prefix_A
o -> o
ABR -> A_LT_B
BAL -> case go t1 l2 of
A_EQ_B -> A_Prefix_B
B_Prefix_A -> A_LT_B
o -> o
BAR -> A_GT_B
EQL -> case go l1 l2 of
A_Prefix_B -> A_GT_B
A_EQ_B -> go r1 r2
B_Prefix_A -> A_LT_B
o -> o
NOM -> if unPrefix p1 < unPrefix p2 then A_LT_B else A_GT_B
go (Bin _ l1 _) (Tip k2 bm2) = case leftmostTipSure l1 of
Tip' k1 bm1 -> case orderTips k1 bm1 k2 bm2 of
A_Prefix_B -> A_GT_B
A_EQ_B -> B_Prefix_A
o -> o
go (Tip k1 bm1) (Bin _ l2 _) = case leftmostTipSure l2 of
Tip' k2 bm2 -> case orderTips k1 bm1 k2 bm2 of
A_EQ_B -> A_Prefix_B
B_Prefix_A -> A_LT_B
o -> o
go (Tip k1 bm1) (Tip k2 bm2) = orderTips k1 bm1 k2 bm2
go Nil Nil = A_EQ_B
go Nil _ = A_Prefix_B
go _ Nil = B_Prefix_A

-- This type allows GHC to return unboxed ints from leftmostTipSure, as
-- $wleftmostTipSure :: IntSet -> (# Int#, Word# #)
-- On a modern enough GHC (>=9.4) this is unnecessary, we could use StrictPair
-- instead and get the same Core.
data Tip' = Tip' {-# UNPACK #-} !Int {-# UNPACK #-} !BitMap

leftmostTipSure :: IntSet -> Tip'
leftmostTipSure (Bin _ l _) = leftmostTipSure l
leftmostTipSure (Tip k bm) = Tip' k bm
leftmostTipSure Nil = error "leftmostTipSure: Nil"

orderTips :: Int -> BitMap -> Int -> BitMap -> Order
orderTips k1 bm1 k2 bm2 = case compare k1 k2 of
LT -> A_LT_B
EQ | bm1 == bm2 -> A_EQ_B
| otherwise ->
-- To lexicographically compare the elements of two BitMaps,
-- - Find the lowest bit where they differ.
-- - For the BitMap with this bit 0, check if all higher bits are also
-- 0. If yes it is a prefix, otherwise it is greater.
let diff = bm1 `xor` bm2
lowestDiff = diff .&. negate diff
highMask = negate lowestDiff
in if bm1 .&. lowestDiff == 0
then (if bm1 .&. highMask == 0 then A_Prefix_B else A_GT_B)
else (if bm2 .&. highMask == 0 then B_Prefix_A else A_LT_B)
GT -> A_GT_B
{-# INLINE orderTips #-}

-- Split into negative and non-negative
splitSign :: IntSet -> (IntSet, IntSet)
splitSign t@(Bin p l r)
| signBranch p = (r, l)
| unPrefix p < 0 = (t, Nil)
| otherwise = (Nil, t)
splitSign t@(Tip k _)
| k < 0 = (t, Nil)
| otherwise = (Nil, t)
splitSign Nil = (Nil, Nil)
{-# INLINE splitSign #-}

{--------------------------------------------------------------------
Show
Expand Down
9 changes: 9 additions & 0 deletions containers/src/Data/IntSet/Internal/IntTreeCommons.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ module Data.IntSet.Internal.IntTreeCommons
, mask
, branchMask
, i2w
, Order(..)
) where

import Data.Bits (Bits(..), countLeadingZeros)
Expand Down Expand Up @@ -161,6 +162,14 @@ i2w :: Int -> Word
i2w = fromIntegral
{-# INLINE i2w #-}

-- Used to compare IntSets and IntMaps
data Order
= A_LT_B -- holds for [0,3,4] [0,3,5,1]
| A_Prefix_B -- holds for [0,3,4] [0,3,4,5]
| A_EQ_B -- holds for [0,3,4] [0,3,4]
| B_Prefix_A -- holds for [0,3,4] [0,3]
| A_GT_B -- holds for [0,3,4] [0,2,5]

{--------------------------------------------------------------------
Notes
--------------------------------------------------------------------}
Expand Down

0 comments on commit 0d85628

Please sign in to comment.