diff --git a/reopt.cabal b/reopt.cabal index e64f6d8c..39ae9585 100644 --- a/reopt.cabal +++ b/reopt.cabal @@ -54,6 +54,7 @@ library attoparsec-aeson, base64, bytestring, + composition-extra, containers, directory, elf-edit >= 0.40, diff --git a/reopt/Main_reopt.hs b/reopt/Main_reopt.hs index dd6c6104..a3366624 100644 --- a/reopt/Main_reopt.hs +++ b/reopt/Main_reopt.hs @@ -207,6 +207,9 @@ data Args = Args -- ^ Trace unification of the type inference solver , traceConstraintOrigins :: Bool -- ^ Trace the origin of constraints in the type inference solver + , maxRestarts :: Maybe Int + -- ^ How many time the type constraint solver should restart before giving up. `Nothing` means + -- infinite restarts. } deriving (Generic) @@ -247,6 +250,7 @@ defaultArgs = , performRecovery = False , traceTypeUnification = False , traceConstraintOrigins = False + , maxRestarts = Nothing } ------------------------------------------------------------------------ @@ -299,6 +303,13 @@ traceConstraintOriginsP = long "trace-constraint-origins" <> help "Trace the origins of constraints in the type inference engine" +maxRestartsP :: Parser (Maybe Int) +maxRestartsP = + optional $ option auto $ + long "max-restarts" + <> metavar "NUMBER" + <> help "Number of times the type constraint solver should restart before giving up" + outputPathP :: Parser String outputPathP = strOption $ @@ -623,6 +634,7 @@ arguments = <*> performRecoveryP <*> traceTypeUnificationP <*> traceConstraintOriginsP + <*> maxRestartsP -- | Parser to set the path to the binary to analyze. programPathP :: Parser String @@ -668,6 +680,7 @@ argsReoptOptions args = do , roDiscoveryOptions = args ^. #discOpts , roDynDepPaths = dynDepPath args , roDynDepDebugPaths = dynDepDebugPath args ++ gdbDebugDirs + , roMaxRestarts = maxRestarts args , roTraceUnification = traceTypeUnification args , roTraceConstraintOrigins = traceConstraintOrigins args } @@ -722,7 +735,8 @@ showConstraints args elfPath = do doRecoverX86 funPrefix sysp symAddrMap debugTypeMap discState Map.empty let recMod = recoveredModule recoverX86Output - pure $ genModuleConstraints recMod (Macaw.memory discState) (traceTypeUnification args) (traceConstraintOrigins args) + pure $ genModuleConstraints recMod (Macaw.memory discState) + (maxRestarts args) (traceTypeUnification args) (traceConstraintOrigins args) mc <- handleEitherWithExit mr diff --git a/src/Reopt.hs b/src/Reopt.hs index 5475d11c..11fceb63 100644 --- a/src/Reopt.hs +++ b/src/Reopt.hs @@ -481,6 +481,7 @@ data ReoptOptions = ReoptOptions , roDynDepPaths :: ![FilePath] -- ^ Additional paths to search for dynamic dependencies. , roDynDepDebugPaths :: ![FilePath] + , roMaxRestarts :: Maybe Int -- ^ Additional paths to search for debug versions of dynamic dependencies. , roTraceUnification :: !Bool -- ^ Trace unification in the solver @@ -499,6 +500,7 @@ defaultReoptOptions = , roDiscoveryOptions = reoptDefaultDiscoveryOptions , roDynDepPaths = [] , roDynDepDebugPaths = [] + , roMaxRestarts = Nothing , roTraceUnification = False , roTraceConstraintOrigins = False } @@ -2745,6 +2747,7 @@ reoptRecoveryLoop symAddrMap rOpts funPrefix sysp debugTypeMap firstDiscState = genModuleConstraints recMod (Macaw.memory discState') + (roMaxRestarts rOpts) (roTraceUnification rOpts) (roTraceConstraintOrigins rOpts) diff --git a/src/Reopt/TypeInference/ConstraintGen.hs b/src/Reopt/TypeInference/ConstraintGen.hs index a1bc18e3..f857c9d1 100644 --- a/src/Reopt/TypeInference/ConstraintGen.hs +++ b/src/Reopt/TypeInference/ConstraintGen.hs @@ -114,6 +114,7 @@ import Reopt.TypeInference.Solver.Constraints ( ConstraintProvenance (..), FnRepProvenance (..), ) +import Reopt.TypeInference.Solver.Monad (ConstraintSolvingReader (..)) import Reopt.TypeInference.Solver.Types (TyF (..)) -- This algorithm proceeds in stages: @@ -241,20 +242,29 @@ inSolverM = CGenM . lift . lift runCGenM :: Memory (ArchAddrWidth arch) -> + Maybe Int -> Bool -> Bool -> CGenM CGenGlobalContext arch a -> a -runCGenM mem traceWanted orig (CGenM m) = runSolverM traceWanted orig ptrWidth $ do - let segs = memSegments mem - -- Allocate a row variable for each memory segment - memRows <- Map.fromList <$> mapM (\seg -> (,) seg <$> S.freshRowVar) segs - let ctxt0 = - CGenGlobalContext - { _cgenMemory = mem - , _cgenMemoryRegions = memRows +runCGenM mem maxRestarts traceWanted orig (CGenM m) = do + let initReader = + ConstraintSolvingReader + { rMaxNumberOfRestarts = maxRestarts + , rPtrWidth = ptrWidth + , rTraceConstraintOrigins = orig + , rTraceUnification = traceWanted } - evalStateT (Reader.runReaderT m ctxt0) st0 + runSolverM initReader $ do + let segs = memSegments mem + -- Allocate a row variable for each memory segment + memRows <- Map.fromList <$> mapM (\seg -> (,) seg <$> S.freshRowVar) segs + let ctxt0 = + CGenGlobalContext + { _cgenMemory = mem + , _cgenMemoryRegions = memRows + } + evalStateT (Reader.runReaderT m ctxt0) st0 where ptrWidth = widthVal (memWidth mem) @@ -981,10 +991,11 @@ genModuleConstraints :: FoldableFC (ArchFn arch) => RecoveredModule arch -> Memory (ArchAddrWidth arch) -> + Maybe Int -> Bool -> Bool -> ModuleConstraints arch -genModuleConstraints m mem traceWanted orig = runCGenM mem traceWanted orig $ do +genModuleConstraints m mem maxRestarts traceWanted orig = runCGenM mem maxRestarts traceWanted orig $ do -- allocate type variables for functions without types -- FIXME: we currently ignore hints diff --git a/src/Reopt/TypeInference/Solver/Monad.hs b/src/Reopt/TypeInference/Solver/Monad.hs index d0add9f4..4387efdb 100644 --- a/src/Reopt/TypeInference/Solver/Monad.hs +++ b/src/Reopt/TypeInference/Solver/Monad.hs @@ -14,6 +14,7 @@ module Reopt.TypeInference.Solver.Monad ( Conditional (..), Conditional', Conjunction (..), + ConstraintSolvingReader (..), ConstraintSolvingState (..), defineRowVar, defineTyVar, @@ -43,15 +44,16 @@ module Reopt.TypeInference.Solver.Monad ( withFresh, ) where -import Control.Lens (Lens', use, (%%=), (%=), (<<+=)) -import Control.Monad.State (MonadState, State, evalState) +import Control.Lens (Lens', view, (%%=), (%=), (<<+=)) import Data.Foldable (asum) +import Data.Function.Slip (slipr) import Data.Generics.Labels () import Data.Map.Strict (Map) import Data.Map.Strict qualified as Map import GHC.Generics (Generic) import Prettyprinter qualified as PP +import Control.Monad.RWS.Strict import Reopt.TypeInference.Solver.Constraints ( ConstraintProvenance, EqC (EqC), @@ -81,6 +83,16 @@ import Reopt.TypeInference.Solver.UnionFindMap qualified as UM type Conditional' = Conditional ([EqC], [EqRowC]) +data ConstraintSolvingReader = ConstraintSolvingReader + { rMaxNumberOfRestarts :: Maybe Int + , rPtrWidth :: Int + -- ^ The width of a pointer, in bits. This can go away when tyvars have an associated size, it is + -- only used for PtrAddC solving. + , rTraceUnification :: Bool + , rTraceConstraintOrigins :: Bool + } + deriving (Generic) + data ConstraintSolvingState = ConstraintSolvingState { ctxEqCs :: [EqC] , ctxEqRowCs :: [EqRowC] @@ -90,22 +102,16 @@ data ConstraintSolvingState = ConstraintSolvingState , nextTraceId :: Int , nextRowVar :: Int , nextTyVar :: Int - , ptrWidth :: Int - -- ^ The width of a pointer, in bits. This can go away when - -- tyvars have an associated size, it is only used for PtrAddC - -- solving. , ctxTyVars :: UnionFindMap TyVar TyVar ITy' -- ^ The union-find data-structure mapping each tyvar onto its -- representative tv. If no mapping exists, it is a self-mapping. , ctxRowVars :: UnionFindMap RowVar RowInfo (FieldMap TyVar) - , -- Debugging - ctxTraceUnification :: Bool - , ctxTraceConstraintOrigins :: Bool + , ctxNumberOfRestarts :: Int } deriving (Generic) -emptyContext :: Int -> Bool -> Bool -> ConstraintSolvingState -emptyContext w trace orig = +emptyConstraintSolvingState :: ConstraintSolvingState +emptyConstraintSolvingState = ConstraintSolvingState { ctxEqCs = [] , ctxEqRowCs = [] @@ -115,20 +121,29 @@ emptyContext w trace orig = , nextTraceId = 0 , nextRowVar = 0 , nextTyVar = 0 - , ptrWidth = w , ctxTyVars = UM.empty , ctxRowVars = UM.empty - , ctxTraceUnification = trace - , ctxTraceConstraintOrigins = orig + , ctxNumberOfRestarts = 0 } newtype SolverM a = SolverM - { getSolverM :: State ConstraintSolvingState a + { getSolverM :: RWS ConstraintSolvingReader () ConstraintSolvingState a } - deriving (Applicative, Functor, Monad, MonadState ConstraintSolvingState) - -runSolverM :: Bool -> Bool -> Int -> SolverM a -> a -runSolverM b o w = flip evalState (emptyContext w b o) . getSolverM + deriving + ( Applicative + , Functor + , Monad + , MonadState ConstraintSolvingState + , MonadReader ConstraintSolvingReader + ) + +runSolverM :: + ConstraintSolvingReader -> + SolverM a -> + a +runSolverM initReader = fst . slipr evalRWS initReader initState . getSolverM + where + initState = emptyConstraintSolvingState -------------------------------------------------------------------------------- -- Adding constraints @@ -272,13 +287,13 @@ unsafeUnifyTyVars root leaf = #ctxTyVars %= UM.unify root leaf -- Other stuff ptrWidthNumTy :: SolverM ITy' -ptrWidthNumTy = NumTy <$> use #ptrWidth +ptrWidthNumTy = NumTy <$> view #rPtrWidth traceUnification :: SolverM Bool -traceUnification = use #ctxTraceUnification +traceUnification = view #rTraceUnification traceConstraintOrigins :: SolverM Bool -traceConstraintOrigins = use #ctxTraceConstraintOrigins +traceConstraintOrigins = view #rTraceConstraintOrigins -------------------------------------------------------------------------------- -- Conditional constraints diff --git a/src/Reopt/TypeInference/Solver/Solver.hs b/src/Reopt/TypeInference/Solver/Solver.hs index c3a0ed69..1abe6a54 100644 --- a/src/Reopt/TypeInference/Solver/Solver.hs +++ b/src/Reopt/TypeInference/Solver/Solver.hs @@ -27,6 +27,7 @@ import Data.Set qualified as Set import Debug.Trace (trace) import Prettyprinter qualified as PP +import Control.Monad.RWS (asks) import Control.Monad.Trans (lift) import Reopt.TypeInference.Solver.Constraints ( ConstraintProvenance (..), @@ -43,6 +44,7 @@ import Reopt.TypeInference.Solver.Finalize ( import Reopt.TypeInference.Solver.Monad ( Conditional (..), Conditional', + ConstraintSolvingReader (rMaxNumberOfRestarts, rPtrWidth), ConstraintSolvingState (..), SolverM, addEqC, @@ -180,17 +182,19 @@ solveHeadReset fld doit = do put resetSt -- Forget everything we know in resetSt about the eqvs for tv - let eqs = eqvClasses (ctxTyVars oldSt) - eqsTv = Map.findWithDefault [] tv eqs + let + eqs = eqvClasses (ctxTyVars oldSt) + eqsTv = Map.findWithDefault [] tv eqs traverse_ undefineTyVar eqsTv -- FIXME: gross - defineTyVar tv (ConflictTy (ptrWidth resetSt)) + ptrWidth <- asks rPtrWidth + defineTyVar tv (ConflictTy ptrWidth) -- FIXME: this could cause problems if we allocate tyvars after -- we start solving. Because we don't, this should work. mapM_ (addTyVarEq' ConflictProv tv) eqsTv -- retain eqv class for conflict var. get - put resetSt' + put $ resetSt'{ctxNumberOfRestarts = ctxNumberOfRestarts resetSt + 1} solveFirst :: Lens' ConstraintSolvingState [a] -> @@ -227,8 +231,9 @@ _solveAll fld solve = do go acc progd [] = restore acc $> progd -- finished here, we didn't so anything. go acc progd (c : cs) = do (retain, progress) <- solve c - let acc' = if retain == Retain then c : acc else acc - progd' = progd || madeProgress progress + let + acc' = if retain == Retain then c : acc else acc + progd' = progd || madeProgress progress go acc' progd' cs -- | @preprocess l f# just pre-processes the element at @l@, and so @@ -246,11 +251,17 @@ preprocess fld f = fld %= (<> r) solverLoop :: SolverM () -solverLoop = evalStateT go =<< get +solverLoop = do + maxRestarts <- asks rMaxNumberOfRestarts + evalStateT (go (exceeds maxRestarts)) =<< get where - go = do + exceeds (Just maxRestarts) r = r > maxRestarts + exceeds Nothing _ = False + + go isTooMany = do + tooManyRestarts <- gets $ isTooMany . ctxNumberOfRestarts keepGoing <- orM solvers - when keepGoing go + when (keepGoing && not tooManyRestarts) (go isTooMany) solvers = [ solveHeadReset #ctxEqCs solveEqC @@ -371,14 +382,16 @@ solveConditional c = traceContext' "solveConditional" c $ do solveEqRowC :: EqRowC -> SolverM () solveEqRowC eqc = traceContext' "solveEqRowC" eqc $ do (le, m_lfm) <- lookupRowExpr (eqRowLHS eqc) - let lo = rowExprShift le - lv = rowExprVar le - lfm = fromMaybe emptyFieldMap m_lfm + let + lo = rowExprShift le + lv = rowExprVar le + lfm = fromMaybe emptyFieldMap m_lfm (re, m_rfm) <- lookupRowExpr (eqRowRHS eqc) - let ro = rowExprShift re - rv = rowExprVar re - rfm = fromMaybe emptyFieldMap m_rfm + let + ro = rowExprShift re + rv = rowExprVar re + rfm = fromMaybe emptyFieldMap m_rfm case () of _ @@ -390,8 +403,9 @@ solveEqRowC eqc = traceContext' "solveEqRowC" eqc $ do unify delta lowv lowfm highv highfm = do undefineRowVar highv unsafeUnifyRowVars (RowExprShift delta lowv) highv - let highfm' = shiftFieldMap delta highfm - (lowfm', newEqs) = unifyFieldMaps lowfm highfm' + let + highfm' = shiftFieldMap delta highfm + (lowfm', newEqs) = unifyFieldMaps lowfm highfm' defineRowVar lowv lowfm' traverse_ (uncurry (addTyVarEq' FromEqRowCProv)) newEqs @@ -437,7 +451,7 @@ solveEqC eqc = do -- when the type variable is conflicted. unifyTypes :: TyVar -> ITy' -> ITy' -> SolverM (Maybe TyVar) unifyTypes tv ty1 ty2 = do - pW <- gets ptrWidth + pW <- asks rPtrWidth case (ty1, ty2) of _ | ty1 == ty2 -> pure Nothing (NumTy i, NumTy i') diff --git a/tests/TyConstraintTests.hs b/tests/TyConstraintTests.hs index 16e6c8ac..bcaea307 100644 --- a/tests/TyConstraintTests.hs +++ b/tests/TyConstraintTests.hs @@ -1,46 +1,64 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RankNTypes #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} -module TyConstraintTests - ( constraintTests +module TyConstraintTests ( + constraintTests, -- utility - , ghciTest - ) + ghciTest, +) where -import Data.Bifunctor (bimap) -import Data.Foldable (traverse_) -import qualified Data.Map as Map -import qualified Prettyprinter as PP -import qualified Test.Tasty as T -import qualified Test.Tasty.HUnit as T - -import Reopt.TypeInference.Solver - (TyVar, - -- RowVar(..), - Ty(..), - Ty, FTy, - unifyConstraints, - eqTC, SolverM, runSolverM, freshTyVar, numTy, varTy, ConstraintSolution (..) - , pattern FPtrTy, pattern FNumTy, ptrTy, pattern FUnknownTy - , pattern FStructTy, pattern FNamedStruct - , pattern FConflictTy - , ptrAddTC, ptrSubTC, OperandClass (OCOffset), ptrTC, ptrTy') +import Data.Bifunctor (bimap) +import Data.Foldable (traverse_) +import Data.Map qualified as Map +import Prettyprinter qualified as PP +import Test.Tasty qualified as T +import Test.Tasty.HUnit qualified as T + +import Reopt.TypeInference.Solver ( + -- RowVar(..), + + ConstraintSolution (..), + FTy, + OperandClass (OCOffset), + SolverM, + Ty (..), + TyVar, + eqTC, + freshTyVar, + numTy, + ptrAddTC, + ptrSubTC, + ptrTC, + ptrTy, + ptrTy', + runSolverM, + unifyConstraints, + varTy, + pattern FConflictTy, + pattern FNamedStruct, + pattern FNumTy, + pattern FPtrTy, + pattern FStructTy, + pattern FUnknownTy, + ) import Reopt.TypeInference.Solver.Constraints (ConstraintProvenance (..)) -import Reopt.TypeInference.Solver.RowVariables (Offset, singletonFieldMap, FieldMap, fieldMapFromList) -import Reopt.TypeInference.Solver.Monad (withFresh) +import Reopt.TypeInference.Solver.Monad (ConstraintSolvingReader (..), withFresh) +import Reopt.TypeInference.Solver.RowVariables (FieldMap, Offset, fieldMapFromList, singletonFieldMap) constraintTests :: T.TestTree -constraintTests = T.testGroup "Type Constraint Tests" - [ eqCTests - , ptrCTests - , recursiveTests - , conflictTests - ] +constraintTests = + T.testGroup + "Type Constraint Tests" + [ eqCTests + , ptrCTests + , recursiveTests + , conflictTests + ] num8, num32, num64 :: Ty num8 = numTy 8 @@ -58,7 +76,6 @@ tv = freshTyVar Nothing Nothing tvEq :: TyVar -> TyVar -> SolverM () tvEq v v' = eqTC prov (varTy v) (varTy v') - tvEqs :: [(TyVar, TyVar)] -> SolverM () tvEqs = traverse_ (uncurry tvEq) @@ -73,260 +90,298 @@ frecTy = FStructTy . fieldMapFromList -- Simple tests having to do with equality constraints eqCTests :: T.TestTree -eqCTests = T.testGroup "Equality Constraint Tests" - [ mkTest "Single eqTC var left" (do { x0 <- tv; eqTC prov (varTy x0) num64; pure [(x0, fnum64)] }) - , mkTest "Single eqTC var right" (do { x0 <- tv; eqTC prov num64 (varTy x0); pure [(x0, fnum64)] }) - , mkTest "Multiple eqTCs" $ do - x0 <- tv - x1 <- tv - eqTC prov (varTy x0) num64 - eqTC prov (varTy x1) (ptrTy' num64) - pure [(x0, fnum64), (x1, fptrTy' fnum64)] - - , mkTest "eqTC simple transitivity 1" $ do - x0 <- tv - x1 <- tv - tvEq x0 x1 - eqTC prov (varTy x1) num64 - pure [(x0, fnum64), (x1, fnum64)] - - , mkTest "eqTC simple Transitivity 2" $ do - x0 <- tv - x1 <- tv - eqTC prov (varTy x1) num64 - tvEq x0 x1 - pure [(x0, fnum64), (x1, fnum64)] - - , mkTest "eqTC simple transitivity 3" $ do - x0 <- tv; x1 <- tv; x2 <- tv - tvEqs [(x1, x2), (x0, x1)] - eqTC prov (varTy x2) num64 - pure [(x0, fnum64), (x1, fnum64), (x2, fnum64)] - - , mkTest "eqTC with pointers 1" $ do - x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv - tvEqs [(x1, x2), (x0, x1) ] - eqTC prov (varTy x2) num64 - eqTC prov (varTy x3) (ptrTy' num64) - pure [(x0, fnum64), (x1, fnum64), (x2, fnum64), (x3, fptrTy' fnum64)] - - , mkTest "eqTC with pointers 2" $ do - x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv; x4 <- tv - tvEqs [(x0, x2)] - eqTC prov (varTy x1) (ptrTy' (varTy x2)) - eqTC prov (varTy x2) num64 - eqTC prov (varTy x3) num64 - eqTC prov (varTy x4) (ptrTy' (varTy x3)) - - pure [(x0, fnum64), (x1, fptrTy' fnum64), (x2, fnum64), (x3, fnum64), (x4, fptrTy' fnum64)] - - , mkTest "eqTC with pointers 3" $ do - x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv; x4 <- tv - tvEqs [(x0, x2)] - eqTC prov (varTy x1) (ptrTy' (varTy x2)) - eqTC prov (varTy x3) num64 - eqTC prov (varTy x4) (ptrTy' (varTy x3)) - - pure [(x0, FUnknownTy), (x1, fptrTy' FUnknownTy), (x2, FUnknownTy), (x3, fnum64), (x4, fptrTy' fnum64)] - - , mkTest "eqTC with records 1" $ do - x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv - tvEqs [ (x1, x2), (x0, x1) ] - eqTC prov (varTy x2) num64 - eqTC prov (varTy x3) (ptrTy' (varTy x0)) - pure [(x0, fnum64), (x1, fnum64), (x2, fnum64), (x3, FPtrTy $ frecTy [(0, fnum64)])] - - , mkTest "eqTC with records 2" $ do - x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv - tvEqs [ (x1, x2) ] - eqTC prov (varTy x0) (ptrTy' (varTy x1)) - eqTC prov (varTy x2) num64 - eqTC prov (varTy x3) (ptrTy $ recTy [(0, varTy x0), (8, varTy x1)]) - pure [ (x0, fptrTy' fnum64), (x1, fnum64), (x2, fnum64) - , (x3, FPtrTy $ frecTy [(0, fptrTy' fnum64), (8, fnum64)]) - ] - - -- These next tests check that record constraints are unified properly - -- during constraint solving when there are possible unknown other fields - -- (i.e., the row variables). This arises when we want to combine - -- different atomic facts describing offsets from a single memory - -- location. E.g., if we separately learn (1) at `p` there is a `num` and - -- (2) at `p+8` there is an `ptr(num)`, these statements about `p` can be - -- described via the following two atomic constraints: `p = {0 : num|ρ}` - -- and `p = {8 : ptr(num)|ρ'}`. Our unification should then combine these - -- constraints on `p` into `p = {0 : num, 8 : ptr(num)}`. - - , mkTest "eqTC with records+rows 1" $ do - x0 <- tv - let x0Ty = varTy x0 - eqTC prov x0Ty (ptrTy $ recTy [(0, num64)]) - eqTC prov x0Ty (ptrTy $ recTy [(8, ptrTy' num64)]) - pure [(x0, FPtrTy $ frecTy [(0, fnum64), (8, fptrTy' fnum64)])] - - -- , mkTest "eqTC with records+rows 2" $ do - -- x0 <- tv; x1 <- tv; x2 <- tv - -- let x0Ty = varTy x0 - -- x1Ty = varTy x1 - -- x2Ty = varTy x2 - - -- r0 <- freshRowVar - -- r1 <- freshRowVar - - -- eqTC x0Ty (recTy [(0, num64)] r0) - -- eqTC x0Ty (recTy [(8, ptrTy num64)] r1) - -- eqTC x1Ty (recTy mempty r0) - -- eqTC x2Ty (recTy mempty r1) - -- pure [ (x0, frecTy [(0, fnum64), (8, FPtrTy fnum64)]) - -- , (x1, frecTy [(8, FPtrTy fnum64)]) - -- , (x2, frecTy [(0, fnum64)]) - -- ] - - -- , mkTest "eqTC with records+rows 3" $ do - -- x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv - -- let x0Ty = varTy x0 - -- x1Ty = varTy x1 - -- x2Ty = varTy x2 - -- x3Ty = varTy x3 - - -- r0 <- freshRowVar - -- r1 <- freshRowVar - - -- eqTC x0Ty (recTy [(0, num64), (8, ptrTy x1Ty)] r0) - -- eqTC x0Ty (recTy [(8, ptrTy num64), (16, num64)] r1) - -- eqTC x2Ty (recTy mempty r0) - -- eqTC x3Ty (recTy mempty r1) - -- pure [ (x0, frecTy [(0, fnum64), (8, FPtrTy fnum64), (16, fnum64)]) - -- , (x1, fnum64) - -- , (x2, frecTy [(16, fnum64)]) - -- , (x3, frecTy [(0, fnum64)])] - ] +eqCTests = + T.testGroup + "Equality Constraint Tests" + [ mkTest "Single eqTC var left" (do x0 <- tv; eqTC prov (varTy x0) num64; pure [(x0, fnum64)]) + , mkTest "Single eqTC var right" (do x0 <- tv; eqTC prov num64 (varTy x0); pure [(x0, fnum64)]) + , mkTest "Multiple eqTCs" $ do + x0 <- tv + x1 <- tv + eqTC prov (varTy x0) num64 + eqTC prov (varTy x1) (ptrTy' num64) + pure [(x0, fnum64), (x1, fptrTy' fnum64)] + , mkTest "eqTC simple transitivity 1" $ do + x0 <- tv + x1 <- tv + tvEq x0 x1 + eqTC prov (varTy x1) num64 + pure [(x0, fnum64), (x1, fnum64)] + , mkTest "eqTC simple Transitivity 2" $ do + x0 <- tv + x1 <- tv + eqTC prov (varTy x1) num64 + tvEq x0 x1 + pure [(x0, fnum64), (x1, fnum64)] + , mkTest "eqTC simple transitivity 3" $ do + x0 <- tv + x1 <- tv + x2 <- tv + tvEqs [(x1, x2), (x0, x1)] + eqTC prov (varTy x2) num64 + pure [(x0, fnum64), (x1, fnum64), (x2, fnum64)] + , mkTest "eqTC with pointers 1" $ do + x0 <- tv + x1 <- tv + x2 <- tv + x3 <- tv + tvEqs [(x1, x2), (x0, x1)] + eqTC prov (varTy x2) num64 + eqTC prov (varTy x3) (ptrTy' num64) + pure [(x0, fnum64), (x1, fnum64), (x2, fnum64), (x3, fptrTy' fnum64)] + , mkTest "eqTC with pointers 2" $ do + x0 <- tv + x1 <- tv + x2 <- tv + x3 <- tv + x4 <- tv + tvEqs [(x0, x2)] + eqTC prov (varTy x1) (ptrTy' (varTy x2)) + eqTC prov (varTy x2) num64 + eqTC prov (varTy x3) num64 + eqTC prov (varTy x4) (ptrTy' (varTy x3)) + + pure [(x0, fnum64), (x1, fptrTy' fnum64), (x2, fnum64), (x3, fnum64), (x4, fptrTy' fnum64)] + , mkTest "eqTC with pointers 3" $ do + x0 <- tv + x1 <- tv + x2 <- tv + x3 <- tv + x4 <- tv + tvEqs [(x0, x2)] + eqTC prov (varTy x1) (ptrTy' (varTy x2)) + eqTC prov (varTy x3) num64 + eqTC prov (varTy x4) (ptrTy' (varTy x3)) + + pure [(x0, FUnknownTy), (x1, fptrTy' FUnknownTy), (x2, FUnknownTy), (x3, fnum64), (x4, fptrTy' fnum64)] + , mkTest "eqTC with records 1" $ do + x0 <- tv + x1 <- tv + x2 <- tv + x3 <- tv + tvEqs [(x1, x2), (x0, x1)] + eqTC prov (varTy x2) num64 + eqTC prov (varTy x3) (ptrTy' (varTy x0)) + pure [(x0, fnum64), (x1, fnum64), (x2, fnum64), (x3, FPtrTy $ frecTy [(0, fnum64)])] + , mkTest "eqTC with records 2" $ do + x0 <- tv + x1 <- tv + x2 <- tv + x3 <- tv + tvEqs [(x1, x2)] + eqTC prov (varTy x0) (ptrTy' (varTy x1)) + eqTC prov (varTy x2) num64 + eqTC prov (varTy x3) (ptrTy $ recTy [(0, varTy x0), (8, varTy x1)]) + pure + [ (x0, fptrTy' fnum64) + , (x1, fnum64) + , (x2, fnum64) + , (x3, FPtrTy $ frecTy [(0, fptrTy' fnum64), (8, fnum64)]) + ] + , -- These next tests check that record constraints are unified properly + -- during constraint solving when there are possible unknown other fields + -- (i.e., the row variables). This arises when we want to combine + -- different atomic facts describing offsets from a single memory + -- location. E.g., if we separately learn (1) at `p` there is a `num` and + -- (2) at `p+8` there is an `ptr(num)`, these statements about `p` can be + -- described via the following two atomic constraints: `p = {0 : num|ρ}` + -- and `p = {8 : ptr(num)|ρ'}`. Our unification should then combine these + -- constraints on `p` into `p = {0 : num, 8 : ptr(num)}`. + + mkTest "eqTC with records+rows 1" $ do + x0 <- tv + let x0Ty = varTy x0 + eqTC prov x0Ty (ptrTy $ recTy [(0, num64)]) + eqTC prov x0Ty (ptrTy $ recTy [(8, ptrTy' num64)]) + pure [(x0, FPtrTy $ frecTy [(0, fnum64), (8, fptrTy' fnum64)])] + + -- , mkTest "eqTC with records+rows 2" $ do + -- x0 <- tv; x1 <- tv; x2 <- tv + -- let x0Ty = varTy x0 + -- x1Ty = varTy x1 + -- x2Ty = varTy x2 + + -- r0 <- freshRowVar + -- r1 <- freshRowVar + + -- eqTC x0Ty (recTy [(0, num64)] r0) + -- eqTC x0Ty (recTy [(8, ptrTy num64)] r1) + -- eqTC x1Ty (recTy mempty r0) + -- eqTC x2Ty (recTy mempty r1) + -- pure [ (x0, frecTy [(0, fnum64), (8, FPtrTy fnum64)]) + -- , (x1, frecTy [(8, FPtrTy fnum64)]) + -- , (x2, frecTy [(0, fnum64)]) + -- ] + + -- , mkTest "eqTC with records+rows 3" $ do + -- x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv + -- let x0Ty = varTy x0 + -- x1Ty = varTy x1 + -- x2Ty = varTy x2 + -- x3Ty = varTy x3 + + -- r0 <- freshRowVar + -- r1 <- freshRowVar + + -- eqTC x0Ty (recTy [(0, num64), (8, ptrTy x1Ty)] r0) + -- eqTC x0Ty (recTy [(8, ptrTy num64), (16, num64)] r1) + -- eqTC x2Ty (recTy mempty r0) + -- eqTC x3Ty (recTy mempty r1) + -- pure [ (x0, frecTy [(0, fnum64), (8, FPtrTy fnum64), (16, fnum64)]) + -- , (x1, fnum64) + -- , (x2, frecTy [(16, fnum64)]) + -- , (x3, frecTy [(0, fnum64)])] + ] ptrCTests :: T.TestTree -ptrCTests = T.testGroup "Pointer Arith Constraint Tests" - [ mkTest "Constrained by arg" $ do - x0 <- tv; x1 <- tv; x2 <- tv - let x0Ty = varTy x0 +ptrCTests = + T.testGroup + "Pointer Arith Constraint Tests" + [ mkTest "Constrained by arg" $ do + x0 <- tv + x1 <- tv + x2 <- tv + let + x0Ty = varTy x0 x1Ty = varTy x1 x2Ty = varTy x2 - eqTC prov x1Ty (ptrTy $ recTy [(0, num64), (8, ptrTy' num64)]) - ptrAddTC prov x0Ty x1Ty x2Ty (OCOffset 8) - pure [(x0, fptrTy' (fptrTy' fnum64))] - - , mkTest "Constrained by result 1" $ do - x0 <- tv; x1 <- tv; x2 <- tv - let x0Ty = varTy x0 + eqTC prov x1Ty (ptrTy $ recTy [(0, num64), (8, ptrTy' num64)]) + ptrAddTC prov x0Ty x1Ty x2Ty (OCOffset 8) + pure [(x0, fptrTy' (fptrTy' fnum64))] + , mkTest "Constrained by result 1" $ do + x0 <- tv + x1 <- tv + x2 <- tv + let + x0Ty = varTy x0 x1Ty = varTy x1 x2Ty = varTy x2 - eqTC prov x0Ty (ptrTy $ recTy [(0, num64), (8, ptrTy' num64)]) + eqTC prov x0Ty (ptrTy $ recTy [(0, num64), (8, ptrTy' num64)]) - ptrAddTC prov x0Ty x1Ty x2Ty (OCOffset 8) - pure [(x1, FPtrTy $ frecTy [(8, fnum64), (16, fptrTy' fnum64)])] - , mkTest "Constrained by result 2" $ do - x0 <- tv; x1 <- tv; x2 <- tv - let x0Ty = varTy x0 + ptrAddTC prov x0Ty x1Ty x2Ty (OCOffset 8) + pure [(x1, FPtrTy $ frecTy [(8, fnum64), (16, fptrTy' fnum64)])] + , mkTest "Constrained by result 2" $ do + x0 <- tv + x1 <- tv + x2 <- tv + let + x0Ty = varTy x0 x1Ty = varTy x1 x2Ty = varTy x2 - eqTC prov x0Ty (ptrTy $ recTy [(0, num64), (8, ptrTy' num64)]) - eqTC prov x1Ty (ptrTy $ recTy [(0, ptrTy' num64)]) - - ptrAddTC prov x0Ty x1Ty x2Ty (OCOffset 8) - pure [(x1, FPtrTy $ frecTy [(0, fptrTy' fnum64), (8, fnum64), (16, fptrTy' fnum64)])] - - , mkTest "Accessing pointer members from a struct pointer" $ do - x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv - let x0Ty = varTy x0 + eqTC prov x0Ty (ptrTy $ recTy [(0, num64), (8, ptrTy' num64)]) + eqTC prov x1Ty (ptrTy $ recTy [(0, ptrTy' num64)]) + + ptrAddTC prov x0Ty x1Ty x2Ty (OCOffset 8) + pure [(x1, FPtrTy $ frecTy [(0, fptrTy' fnum64), (8, fnum64), (16, fptrTy' fnum64)])] + , mkTest "Accessing pointer members from a struct pointer" $ do + x0 <- tv + x1 <- tv + x2 <- tv + x3 <- tv + let + x0Ty = varTy x0 x1Ty = varTy x1 x2Ty = varTy x2 x3Ty = varTy x3 - eqTC prov x0Ty (ptrTy (recTy [(0, num8)])) - eqTC prov x1Ty (ptrTy (recTy [(0, num64)])) - eqTC prov x2Ty (ptrTy (recTy [(0, num32)])) - eqTC prov x3Ty (ptrTy (recTy [])) - ptrAddTC prov x0Ty x3Ty num64 (OCOffset 0) - ptrAddTC prov x1Ty x3Ty num64 (OCOffset 8) - ptrAddTC prov x2Ty x3Ty num64 (OCOffset 72) - pure [ (x0, FPtrTy (frecTy [(0, fnum8), (8, fnum64), (72, fnum32)])) - , (x1, FPtrTy (frecTy [(0, fnum64), (64, fnum32)])) - , (x2, FPtrTy (frecTy [(0, fnum32)])) - , (x3, FPtrTy (frecTy [(0, fnum8), (8, fnum64), (72, fnum32)]))] - - - , mkTest "Simple cycle test (liveness)" $ do - x0 <- tv; x1 <- tv; x2 <- tv - let x0Ty = varTy x0 + eqTC prov x0Ty (ptrTy (recTy [(0, num8)])) + eqTC prov x1Ty (ptrTy (recTy [(0, num64)])) + eqTC prov x2Ty (ptrTy (recTy [(0, num32)])) + eqTC prov x3Ty (ptrTy (recTy [])) + ptrAddTC prov x0Ty x3Ty num64 (OCOffset 0) + ptrAddTC prov x1Ty x3Ty num64 (OCOffset 8) + ptrAddTC prov x2Ty x3Ty num64 (OCOffset 72) + pure + [ (x0, FPtrTy (frecTy [(0, fnum8), (8, fnum64), (72, fnum32)])) + , (x1, FPtrTy (frecTy [(0, fnum64), (64, fnum32)])) + , (x2, FPtrTy (frecTy [(0, fnum32)])) + , (x3, FPtrTy (frecTy [(0, fnum8), (8, fnum64), (72, fnum32)])) + ] + , mkTest "Simple cycle test (liveness)" $ do + x0 <- tv + x1 <- tv + x2 <- tv + let + x0Ty = varTy x0 x1Ty = varTy x1 x2Ty = varTy x2 - -- x0 = ptr (recTy {0 -> x1} (freshRow)) - ptrTC prov x1Ty x0Ty - - -- x1 is a byte - eqTC prov x1Ty (numTy 8) - - eqTC prov x2Ty num64 - ptrAddTC prov x0Ty x0Ty x2Ty (OCOffset 1) - pure [] -- Liveness - - , mkTest "Nested Cycle test (liveness)" $ do - x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv - let x0Ty = varTy x0 + -- x0 = ptr (recTy {0 -> x1} (freshRow)) + ptrTC prov x1Ty x0Ty + + -- x1 is a byte + eqTC prov x1Ty (numTy 8) + + eqTC prov x2Ty num64 + ptrAddTC prov x0Ty x0Ty x2Ty (OCOffset 1) + pure [] -- Liveness + , mkTest "Nested Cycle test (liveness)" $ do + x0 <- tv + x1 <- tv + x2 <- tv + x3 <- tv + let + x0Ty = varTy x0 x1Ty = varTy x1 x2Ty = varTy x2 x3Ty = varTy x3 - -- x0 = ptr (recTy {0 -> x1} (freshRow)) - ptrTC prov x1Ty x0Ty - ptrTC prov x1Ty x3Ty - - -- x1 is a byte - eqTC prov x1Ty (numTy 8) - - eqTC prov x2Ty num64 - ptrAddTC prov x3Ty x0Ty x2Ty (OCOffset 1) - ptrAddTC prov x0Ty x3Ty x2Ty (OCOffset 1) - -- If we get here, then we have succeeded, although we may want - -- to return a value once array stride detection produces a - -- reasonable type. - pure [] - - , mkTest "Nested cycle test 2 (liveness)" $ do - x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv - let x0Ty = varTy x0 + -- x0 = ptr (recTy {0 -> x1} (freshRow)) + ptrTC prov x1Ty x0Ty + ptrTC prov x1Ty x3Ty + + -- x1 is a byte + eqTC prov x1Ty (numTy 8) + + eqTC prov x2Ty num64 + ptrAddTC prov x3Ty x0Ty x2Ty (OCOffset 1) + ptrAddTC prov x0Ty x3Ty x2Ty (OCOffset 1) + -- If we get here, then we have succeeded, although we may want + -- to return a value once array stride detection produces a + -- reasonable type. + pure [] + , mkTest "Nested cycle test 2 (liveness)" $ do + x0 <- tv + x1 <- tv + x2 <- tv + x3 <- tv + let + x0Ty = varTy x0 x1Ty = varTy x1 x2Ty = varTy x2 x3Ty = varTy x3 - -- x0 = ptr (recTy {0 -> x1} (freshRow)) - ptrTC prov x1Ty x0Ty - ptrTC prov x1Ty x3Ty - - -- x1 is a byte - eqTC prov x1Ty (numTy 8) - - eqTC prov x2Ty num64 - -- Inverted constraint order from above - ptrAddTC prov x0Ty x3Ty x2Ty (OCOffset 1) - ptrAddTC prov x3Ty x0Ty x2Ty (OCOffset 1) - -- If we get here, then we have succeeded, although we may want - -- to return a value once array stride detection produces a - -- reasonable type. - pure [] - - - , mkTest "Pointer offset subtraction" $ do - x0 <- tv; x1 <- tv; x2 <- tv - let x0Ty = varTy x0 + -- x0 = ptr (recTy {0 -> x1} (freshRow)) + ptrTC prov x1Ty x0Ty + ptrTC prov x1Ty x3Ty + + -- x1 is a byte + eqTC prov x1Ty (numTy 8) + + eqTC prov x2Ty num64 + -- Inverted constraint order from above + ptrAddTC prov x0Ty x3Ty x2Ty (OCOffset 1) + ptrAddTC prov x3Ty x0Ty x2Ty (OCOffset 1) + -- If we get here, then we have succeeded, although we may want + -- to return a value once array stride detection produces a + -- reasonable type. + pure [] + , mkTest "Pointer offset subtraction" $ do + x0 <- tv + x1 <- tv + x2 <- tv + let + x0Ty = varTy x0 x1Ty = varTy x1 x2Ty = varTy x2 - eqTC prov x1Ty (ptrTy $ recTy [(0, num64), (8, ptrTy' num64)]) - ptrSubTC prov x0Ty x1Ty x2Ty (OCOffset 8) - pure [(x0, FPtrTy $ frecTy [(8, fnum64), (16, fptrTy' fnum64)])] - ] + eqTC prov x1Ty (ptrTy $ recTy [(0, num64), (8, ptrTy' num64)]) + ptrSubTC prov x0Ty x1Ty x2Ty (OCOffset 8) + pure [(x0, FPtrTy $ frecTy [(8, fnum64), (16, fptrTy' fnum64)])] + ] -- t4 = do -- x0 <- tv; x1 <- tv; x2 <- tv; x3 <- tv @@ -338,30 +393,42 @@ ptrCTests = T.testGroup "Pointer Arith Constraint Tests" -- r0 <- freshRowVar -- eqTC prov x0Ty (recTy [(0, ptrTy x0Ty)] r0) - recursiveTests :: T.TestTree -recursiveTests = T.testGroup "Recursive Type Tests" - [ mkTest "Recursive linked list" $ do - x0 <- tv - let x0Ty = varTy x0 - eqTC prov x0Ty (ptrTy $ recTy [(0, ptrTy' x0Ty)]) - pure [(x0, FPtrTy $ FNamedStruct "struct.reopt.t1")] - ] +recursiveTests = + T.testGroup + "Recursive Type Tests" + [ mkTest "Recursive linked list" $ do + x0 <- tv + let x0Ty = varTy x0 + eqTC prov x0Ty (ptrTy $ recTy [(0, ptrTy' x0Ty)]) + pure [(x0, FPtrTy $ FNamedStruct "struct.reopt.t1")] + ] conflictTests :: T.TestTree -conflictTests = T.testGroup "Conflict Type Tests" - [ mkTest "Simple conflict" $ withFresh $ \x0 -> do - let x0Ty = varTy x0 - eqTC prov x0Ty num64 - eqTC prov x0Ty (ptrTy' num64) - -- eqTC x0Ty (ptrTy' num64) - -- eqTC x0Ty (ptrTy' (ptrTy' num64)) - - pure [(x0, FConflictTy 64)] - ] +conflictTests = + T.testGroup + "Conflict Type Tests" + [ mkTest "Simple conflict" $ withFresh $ \x0 -> do + let x0Ty = varTy x0 + eqTC prov x0Ty num64 + eqTC prov x0Ty (ptrTy' num64) + -- eqTC x0Ty (ptrTy' num64) + -- eqTC x0Ty (ptrTy' (ptrTy' num64)) + + pure [(x0, FConflictTy 64)] + ] + +testReader :: Bool -> ConstraintSolvingReader +testReader doTrace = + ConstraintSolvingReader + { rMaxNumberOfRestarts = Nothing + , rPtrWidth = 64 + , rTraceUnification = doTrace + , rTraceConstraintOrigins = False + } ghciTest :: Bool -> SolverM a -> PP.Doc d -ghciTest doTrace t = PP.pretty . runSolverM doTrace False 64 $ do +ghciTest doTrace t = PP.pretty . runSolverM (testReader doTrace) $ do t >> unifyConstraints newtype TypeEnv = TypeEnv [(TyVar, FTy)] @@ -374,14 +441,16 @@ instance Show TypeEnv where -- tyEnv = TypeEnv . sortBy (compare `on` fst) mkTest :: String -> SolverM [(TyVar, FTy)] -> T.TestTree -mkTest name m = T.testCase name (runSolverM False False 64 test) - where - test = do - expected <- m - res <- unifyConstraints - let actual = [ (k, Map.findWithDefault FUnknownTy k (csTyVars res)) - | (k, _) <- expected ] - pure (TypeEnv actual T.@?= TypeEnv expected) +mkTest name m = T.testCase name (runSolverM (testReader False) test) + where + test = do + expected <- m + res <- unifyConstraints + let actual = + [ (k, Map.findWithDefault FUnknownTy k (csTyVars res)) + | (k, _) <- expected + ] + pure (TypeEnv actual T.@?= TypeEnv expected) prov :: ConstraintProvenance prov = TestingProv