Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add max-restarts CLI arguments to stop long constraint solving loops #317

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions reopt.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ library
attoparsec-aeson,
base64,
bytestring,
composition-extra,
containers,
directory,
elf-edit >= 0.40,
Expand Down
16 changes: 15 additions & 1 deletion reopt/Main_reopt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ data Args = Args
-- ^ Trace unification of the type inference solver
, traceConstraintOrigins :: Bool
-- ^ Trace the origin of constraints in the type inference solver
, maxRestarts :: Maybe Int
-- ^ How many time the type constraint solver should restart before giving up. `Nothing` means
-- infinite restarts.
}
deriving (Generic)

Expand Down Expand Up @@ -247,6 +250,7 @@ defaultArgs =
, performRecovery = False
, traceTypeUnification = False
, traceConstraintOrigins = False
, maxRestarts = Nothing
}

------------------------------------------------------------------------
Expand Down Expand Up @@ -299,6 +303,13 @@ traceConstraintOriginsP =
long "trace-constraint-origins"
<> help "Trace the origins of constraints in the type inference engine"

maxRestartsP :: Parser (Maybe Int)
maxRestartsP =
optional $ option auto $
long "max-restarts"
<> metavar "NUMBER"
<> help "Number of times the type constraint solver should restart before giving up"

outputPathP :: Parser String
outputPathP =
strOption $
Expand Down Expand Up @@ -623,6 +634,7 @@ arguments =
<*> performRecoveryP
<*> traceTypeUnificationP
<*> traceConstraintOriginsP
<*> maxRestartsP

-- | Parser to set the path to the binary to analyze.
programPathP :: Parser String
Expand Down Expand Up @@ -668,6 +680,7 @@ argsReoptOptions args = do
, roDiscoveryOptions = args ^. #discOpts
, roDynDepPaths = dynDepPath args
, roDynDepDebugPaths = dynDepDebugPath args ++ gdbDebugDirs
, roMaxRestarts = maxRestarts args
, roTraceUnification = traceTypeUnification args
, roTraceConstraintOrigins = traceConstraintOrigins args
}
Expand Down Expand Up @@ -722,7 +735,8 @@ showConstraints args elfPath = do
doRecoverX86 funPrefix sysp symAddrMap debugTypeMap discState Map.empty

let recMod = recoveredModule recoverX86Output
pure $ genModuleConstraints recMod (Macaw.memory discState) (traceTypeUnification args) (traceConstraintOrigins args)
pure $ genModuleConstraints recMod (Macaw.memory discState)
(maxRestarts args) (traceTypeUnification args) (traceConstraintOrigins args)

mc <- handleEitherWithExit mr

Expand Down
3 changes: 3 additions & 0 deletions src/Reopt.hs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ data ReoptOptions = ReoptOptions
, roDynDepPaths :: ![FilePath]
-- ^ Additional paths to search for dynamic dependencies.
, roDynDepDebugPaths :: ![FilePath]
, roMaxRestarts :: Maybe Int
-- ^ Additional paths to search for debug versions of dynamic dependencies.
, roTraceUnification :: !Bool
-- ^ Trace unification in the solver
Expand All @@ -499,6 +500,7 @@ defaultReoptOptions =
, roDiscoveryOptions = reoptDefaultDiscoveryOptions
, roDynDepPaths = []
, roDynDepDebugPaths = []
, roMaxRestarts = Nothing
, roTraceUnification = False
, roTraceConstraintOrigins = False
}
Expand Down Expand Up @@ -2745,6 +2747,7 @@ reoptRecoveryLoop symAddrMap rOpts funPrefix sysp debugTypeMap firstDiscState =
genModuleConstraints
recMod
(Macaw.memory discState')
(roMaxRestarts rOpts)
(roTraceUnification rOpts)
(roTraceConstraintOrigins rOpts)

Expand Down
31 changes: 21 additions & 10 deletions src/Reopt/TypeInference/ConstraintGen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ import Reopt.TypeInference.Solver.Constraints (
ConstraintProvenance (..),
FnRepProvenance (..),
)
import Reopt.TypeInference.Solver.Monad (ConstraintSolvingReader (..))
import Reopt.TypeInference.Solver.Types (TyF (..))

-- This algorithm proceeds in stages:
Expand Down Expand Up @@ -241,20 +242,29 @@ inSolverM = CGenM . lift . lift

runCGenM ::
Memory (ArchAddrWidth arch) ->
Maybe Int ->
Bool ->
Bool ->
CGenM CGenGlobalContext arch a ->
a
runCGenM mem traceWanted orig (CGenM m) = runSolverM traceWanted orig ptrWidth $ do
let segs = memSegments mem
-- Allocate a row variable for each memory segment
memRows <- Map.fromList <$> mapM (\seg -> (,) seg <$> S.freshRowVar) segs
let ctxt0 =
CGenGlobalContext
{ _cgenMemory = mem
, _cgenMemoryRegions = memRows
runCGenM mem maxRestarts traceWanted orig (CGenM m) = do
let initReader =
ConstraintSolvingReader
{ rMaxNumberOfRestarts = maxRestarts
, rPtrWidth = ptrWidth
, rTraceConstraintOrigins = orig
, rTraceUnification = traceWanted
}
evalStateT (Reader.runReaderT m ctxt0) st0
runSolverM initReader $ do
let segs = memSegments mem
-- Allocate a row variable for each memory segment
memRows <- Map.fromList <$> mapM (\seg -> (,) seg <$> S.freshRowVar) segs
let ctxt0 =
CGenGlobalContext
{ _cgenMemory = mem
, _cgenMemoryRegions = memRows
}
evalStateT (Reader.runReaderT m ctxt0) st0
where
ptrWidth = widthVal (memWidth mem)

Expand Down Expand Up @@ -981,10 +991,11 @@ genModuleConstraints ::
FoldableFC (ArchFn arch) =>
RecoveredModule arch ->
Memory (ArchAddrWidth arch) ->
Maybe Int ->
Bool ->
Bool ->
ModuleConstraints arch
genModuleConstraints m mem traceWanted orig = runCGenM mem traceWanted orig $ do
genModuleConstraints m mem maxRestarts traceWanted orig = runCGenM mem maxRestarts traceWanted orig $ do
-- allocate type variables for functions without types
-- FIXME: we currently ignore hints

Expand Down
59 changes: 37 additions & 22 deletions src/Reopt/TypeInference/Solver/Monad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module Reopt.TypeInference.Solver.Monad (
Conditional (..),
Conditional',
Conjunction (..),
ConstraintSolvingReader (..),
ConstraintSolvingState (..),
defineRowVar,
defineTyVar,
Expand Down Expand Up @@ -43,15 +44,16 @@ module Reopt.TypeInference.Solver.Monad (
withFresh,
) where

import Control.Lens (Lens', use, (%%=), (%=), (<<+=))
import Control.Monad.State (MonadState, State, evalState)
import Control.Lens (Lens', view, (%%=), (%=), (<<+=))
import Data.Foldable (asum)
import Data.Function.Slip (slipr)
import Data.Generics.Labels ()
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map
import GHC.Generics (Generic)
import Prettyprinter qualified as PP

import Control.Monad.RWS.Strict
import Reopt.TypeInference.Solver.Constraints (
ConstraintProvenance,
EqC (EqC),
Expand Down Expand Up @@ -81,6 +83,16 @@ import Reopt.TypeInference.Solver.UnionFindMap qualified as UM

type Conditional' = Conditional ([EqC], [EqRowC])

data ConstraintSolvingReader = ConstraintSolvingReader
{ rMaxNumberOfRestarts :: Maybe Int
, rPtrWidth :: Int
-- ^ The width of a pointer, in bits. This can go away when tyvars have an associated size, it is
-- only used for PtrAddC solving.
, rTraceUnification :: Bool
, rTraceConstraintOrigins :: Bool
}
deriving (Generic)

data ConstraintSolvingState = ConstraintSolvingState
{ ctxEqCs :: [EqC]
, ctxEqRowCs :: [EqRowC]
Expand All @@ -90,22 +102,16 @@ data ConstraintSolvingState = ConstraintSolvingState
, nextTraceId :: Int
, nextRowVar :: Int
, nextTyVar :: Int
, ptrWidth :: Int
-- ^ The width of a pointer, in bits. This can go away when
-- tyvars have an associated size, it is only used for PtrAddC
-- solving.
, ctxTyVars :: UnionFindMap TyVar TyVar ITy'
-- ^ The union-find data-structure mapping each tyvar onto its
-- representative tv. If no mapping exists, it is a self-mapping.
, ctxRowVars :: UnionFindMap RowVar RowInfo (FieldMap TyVar)
, -- Debugging
ctxTraceUnification :: Bool
, ctxTraceConstraintOrigins :: Bool
, ctxNumberOfRestarts :: Int
}
deriving (Generic)

emptyContext :: Int -> Bool -> Bool -> ConstraintSolvingState
emptyContext w trace orig =
emptyConstraintSolvingState :: ConstraintSolvingState
emptyConstraintSolvingState =
ConstraintSolvingState
{ ctxEqCs = []
, ctxEqRowCs = []
Expand All @@ -115,20 +121,29 @@ emptyContext w trace orig =
, nextTraceId = 0
, nextRowVar = 0
, nextTyVar = 0
, ptrWidth = w
, ctxTyVars = UM.empty
, ctxRowVars = UM.empty
, ctxTraceUnification = trace
, ctxTraceConstraintOrigins = orig
, ctxNumberOfRestarts = 0
}

newtype SolverM a = SolverM
{ getSolverM :: State ConstraintSolvingState a
{ getSolverM :: RWS ConstraintSolvingReader () ConstraintSolvingState a
}
deriving (Applicative, Functor, Monad, MonadState ConstraintSolvingState)

runSolverM :: Bool -> Bool -> Int -> SolverM a -> a
runSolverM b o w = flip evalState (emptyContext w b o) . getSolverM
deriving
( Applicative
, Functor
, Monad
, MonadState ConstraintSolvingState
, MonadReader ConstraintSolvingReader
)

runSolverM ::
ConstraintSolvingReader ->
SolverM a ->
a
runSolverM initReader = fst . slipr evalRWS initReader initState . getSolverM
where
initState = emptyConstraintSolvingState

--------------------------------------------------------------------------------
-- Adding constraints
Expand Down Expand Up @@ -272,13 +287,13 @@ unsafeUnifyTyVars root leaf = #ctxTyVars %= UM.unify root leaf
-- Other stuff

ptrWidthNumTy :: SolverM ITy'
ptrWidthNumTy = NumTy <$> use #ptrWidth
ptrWidthNumTy = NumTy <$> view #rPtrWidth

traceUnification :: SolverM Bool
traceUnification = use #ctxTraceUnification
traceUnification = view #rTraceUnification

traceConstraintOrigins :: SolverM Bool
traceConstraintOrigins = use #ctxTraceConstraintOrigins
traceConstraintOrigins = view #rTraceConstraintOrigins

--------------------------------------------------------------------------------
-- Conditional constraints
Expand Down
50 changes: 32 additions & 18 deletions src/Reopt/TypeInference/Solver/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Data.Set qualified as Set
import Debug.Trace (trace)
import Prettyprinter qualified as PP

import Control.Monad.RWS (asks)
import Control.Monad.Trans (lift)
import Reopt.TypeInference.Solver.Constraints (
ConstraintProvenance (..),
Expand All @@ -43,6 +44,7 @@ import Reopt.TypeInference.Solver.Finalize (
import Reopt.TypeInference.Solver.Monad (
Conditional (..),
Conditional',
ConstraintSolvingReader (rMaxNumberOfRestarts, rPtrWidth),
ConstraintSolvingState (..),
SolverM,
addEqC,
Expand Down Expand Up @@ -180,17 +182,19 @@ solveHeadReset fld doit = do
put resetSt

-- Forget everything we know in resetSt about the eqvs for tv
let eqs = eqvClasses (ctxTyVars oldSt)
eqsTv = Map.findWithDefault [] tv eqs
let
eqs = eqvClasses (ctxTyVars oldSt)
eqsTv = Map.findWithDefault [] tv eqs
traverse_ undefineTyVar eqsTv

-- FIXME: gross
defineTyVar tv (ConflictTy (ptrWidth resetSt))
ptrWidth <- asks rPtrWidth
defineTyVar tv (ConflictTy ptrWidth)
-- FIXME: this could cause problems if we allocate tyvars after
-- we start solving. Because we don't, this should work.
mapM_ (addTyVarEq' ConflictProv tv) eqsTv -- retain eqv class for conflict var.
get
put resetSt'
put $ resetSt'{ctxNumberOfRestarts = ctxNumberOfRestarts resetSt + 1}

solveFirst ::
Lens' ConstraintSolvingState [a] ->
Expand Down Expand Up @@ -227,8 +231,9 @@ _solveAll fld solve = do
go acc progd [] = restore acc $> progd -- finished here, we didn't so anything.
go acc progd (c : cs) = do
(retain, progress) <- solve c
let acc' = if retain == Retain then c : acc else acc
progd' = progd || madeProgress progress
let
acc' = if retain == Retain then c : acc else acc
progd' = progd || madeProgress progress
go acc' progd' cs

-- | @preprocess l f# just pre-processes the element at @l@, and so
Expand All @@ -246,11 +251,17 @@ preprocess fld f =
fld %= (<> r)

solverLoop :: SolverM ()
solverLoop = evalStateT go =<< get
solverLoop = do
maxRestarts <- asks rMaxNumberOfRestarts
evalStateT (go (exceeds maxRestarts)) =<< get
where
go = do
exceeds (Just maxRestarts) r = r > maxRestarts
exceeds Nothing _ = False

go isTooMany = do
tooManyRestarts <- gets $ isTooMany . ctxNumberOfRestarts
keepGoing <- orM solvers
when keepGoing go
when (keepGoing && not tooManyRestarts) (go isTooMany)

solvers =
[ solveHeadReset #ctxEqCs solveEqC
Expand Down Expand Up @@ -371,14 +382,16 @@ solveConditional c = traceContext' "solveConditional" c $ do
solveEqRowC :: EqRowC -> SolverM ()
solveEqRowC eqc = traceContext' "solveEqRowC" eqc $ do
(le, m_lfm) <- lookupRowExpr (eqRowLHS eqc)
let lo = rowExprShift le
lv = rowExprVar le
lfm = fromMaybe emptyFieldMap m_lfm
let
lo = rowExprShift le
lv = rowExprVar le
lfm = fromMaybe emptyFieldMap m_lfm

(re, m_rfm) <- lookupRowExpr (eqRowRHS eqc)
let ro = rowExprShift re
rv = rowExprVar re
rfm = fromMaybe emptyFieldMap m_rfm
let
ro = rowExprShift re
rv = rowExprVar re
rfm = fromMaybe emptyFieldMap m_rfm

case () of
_
Expand All @@ -390,8 +403,9 @@ solveEqRowC eqc = traceContext' "solveEqRowC" eqc $ do
unify delta lowv lowfm highv highfm = do
undefineRowVar highv
unsafeUnifyRowVars (RowExprShift delta lowv) highv
let highfm' = shiftFieldMap delta highfm
(lowfm', newEqs) = unifyFieldMaps lowfm highfm'
let
highfm' = shiftFieldMap delta highfm
(lowfm', newEqs) = unifyFieldMaps lowfm highfm'
defineRowVar lowv lowfm'
traverse_ (uncurry (addTyVarEq' FromEqRowCProv)) newEqs

Expand Down Expand Up @@ -437,7 +451,7 @@ solveEqC eqc = do
-- when the type variable is conflicted.
unifyTypes :: TyVar -> ITy' -> ITy' -> SolverM (Maybe TyVar)
unifyTypes tv ty1 ty2 = do
pW <- gets ptrWidth
pW <- asks rPtrWidth
case (ty1, ty2) of
_ | ty1 == ty2 -> pure Nothing
(NumTy i, NumTy i')
Expand Down
Loading
Loading