diff --git a/Makefile b/Makefile index ffda2b2..fbd873c 100644 --- a/Makefile +++ b/Makefile @@ -52,6 +52,7 @@ sources = \ src/cps/dead-code.sml \ src/cps/uncurry.sml \ src/cps/loop.sml \ + src/cps/ref-cell.sml \ src/cps/inline.sml \ src/cps/decompose-recursive.sml \ src/cps/unpack-record-parameter.sml \ diff --git a/src/cps.sml b/src/cps.sml index 70d247c..63fbaa7 100644 --- a/src/cps.sml +++ b/src/cps.sml @@ -946,7 +946,6 @@ structure CpsSimplify :> sig end = struct local structure F = FSyntax structure C = CSyntax - structure P = Primitives datatype frequency = datatype CpsUsageAnalysis.frequency in type Context = { nextVId : int ref @@ -1144,9 +1143,6 @@ and alphaConvert (ctx : Context, subst : C.Value TypedSyntax.VIdMap.map, csubst } end | alphaConvert (_, _, _, e as C.Unreachable) = e -datatype simplify_result = VALUE of C.Value - | SIMPLE_EXP of C.SimpleExp - | NOT_SIMPLIFIED type value_info = { exp : C.SimpleExp option, isDiscardableFunction : bool } fun isDiscardableDec (dec, env : value_info TypedSyntax.VIdMap.map) = case dec of diff --git a/src/cps/dead-code.sml b/src/cps/dead-code.sml index 3bced1e..30d6590 100644 --- a/src/cps/dead-code.sml +++ b/src/cps/dead-code.sml @@ -78,12 +78,7 @@ local structure CpsUsageAnalysis :> sig datatype frequency = NEVER | ONCE | MANY type usage = { call : frequency - , project : frequency - , ref_read : frequency - , ref_write : frequency , other : frequency - , returnConts : CSyntax.CVarSet.set - , labels : (string option) Syntax.LabelMap.map } type cont_usage = { direct : frequency, indirect : frequency } val neverUsed : usage @@ -106,21 +101,11 @@ local | oneMore ONCE = MANY | oneMore (many as MANY) = many type usage = { call : frequency - , project : frequency - , ref_read : frequency - , ref_write : frequency , other : frequency - , returnConts : CSyntax.CVarSet.set - , labels : (string option) Syntax.LabelMap.map } type cont_usage = { direct : frequency, indirect : frequency } val neverUsed : usage = { call = NEVER - , project = NEVER - , ref_read = NEVER - , ref_write = NEVER , other = NEVER - , returnConts = CSyntax.CVarSet.empty - , labels = Syntax.LabelMap.empty } val neverUsedCont : cont_usage = { direct = NEVER, indirect = NEVER } type usage_table = (usage ref) TypedSyntax.VIdTable.hash_table @@ -128,14 +113,14 @@ local fun getValueUsage (table : usage_table, v) = case TypedSyntax.VIdTable.find table v of SOME r => !r - | NONE => { call = MANY, project = MANY, ref_read = MANY, ref_write = MANY, other = MANY, returnConts = CSyntax.CVarSet.empty, labels = Syntax.LabelMap.empty } (* unknown *) + | NONE => { call = MANY, other = MANY } (* unknown *) fun getContUsage (table : cont_usage_table, c) = case CSyntax.CVarTable.find table c of SOME r => !r | NONE => { direct = MANY, indirect = MANY } (* unknown *) fun useValue env (C.Var v) = (case TypedSyntax.VIdTable.find env v of - SOME r => let val { call, project, ref_read, ref_write, other, returnConts, labels } = !r - in r := { call = call, project = project, ref_read = ref_read, ref_write = ref_write, other = oneMore other, returnConts = returnConts, labels = labels } + SOME r => let val { call, other } = !r + in r := { call = call, other = oneMore other } end | NONE => () ) @@ -148,75 +133,22 @@ local | useValue _ (C.Char16Const _) = () | useValue _ (C.StringConst _) = () | useValue _ (C.String16Const _) = () - fun useValueAsCallee (env, cont, C.Var v) + fun useValueAsCallee (env, C.Var v) = (case TypedSyntax.VIdTable.find env v of - SOME r => let val { call, project, ref_read, ref_write, other, returnConts, labels } = !r - in r := { call = oneMore call, project = project, ref_read = ref_read, ref_write = ref_write, other = other, returnConts = C.CVarSet.add (returnConts, cont), labels = labels } + SOME r => let val { call, other } = !r + in r := { call = oneMore call, other = other } end | NONE => () ) - | useValueAsCallee (_, _, C.Unit) = () - | useValueAsCallee (_, _, C.Nil) = () - | useValueAsCallee (_, _, C.BoolConst _) = () - | useValueAsCallee (_, _, C.IntConst _) = () - | useValueAsCallee (_, _, C.WordConst _) = () - | useValueAsCallee (_, _, C.CharConst _) = () - | useValueAsCallee (_, _, C.Char16Const _) = () - | useValueAsCallee (_, _, C.StringConst _) = () - | useValueAsCallee (_, _, C.String16Const _) = () - fun useValueAsRecord (env, label, result, C.Var v) - = (case TypedSyntax.VIdTable.find env v of - SOME r => let val { call, project, ref_read, ref_write, other, returnConts, labels } = !r - val result' = case result of - SOME (TypedSyntax.MkVId (name, _)) => SOME name - | NONE => NONE - fun mergeOption (x as SOME _, _) = x - | mergeOption (NONE, y) = y - in r := { call = call, project = oneMore project, ref_read = ref_read, ref_write = ref_write, other = other, returnConts = returnConts, labels = Syntax.LabelMap.insertWith mergeOption (labels, label, result') } - end - | NONE => () - ) - | useValueAsRecord (_, _, _, C.Unit) = () - | useValueAsRecord (_, _, _, C.Nil) = () - | useValueAsRecord (_, _, _, C.BoolConst _) = () - | useValueAsRecord (_, _, _, C.IntConst _) = () - | useValueAsRecord (_, _, _, C.WordConst _) = () - | useValueAsRecord (_, _, _, C.CharConst _) = () - | useValueAsRecord (_, _, _, C.Char16Const _) = () - | useValueAsRecord (_, _, _, C.StringConst _) = () - | useValueAsRecord (_, _, _, C.String16Const _) = () - fun useValueAsRefRead (env, C.Var v) - = (case TypedSyntax.VIdTable.find env v of - SOME r => let val { call, project, ref_read, ref_write, other, returnConts, labels } = !r - in r := { call = call, project = project, ref_read = oneMore ref_read, ref_write = ref_write, other = other, returnConts = returnConts, labels = labels } - end - | NONE => () - ) - | useValueAsRefRead (_, C.Unit) = () - | useValueAsRefRead (_, C.Nil) = () - | useValueAsRefRead (_, C.BoolConst _) = () - | useValueAsRefRead (_, C.IntConst _) = () - | useValueAsRefRead (_, C.WordConst _) = () - | useValueAsRefRead (_, C.CharConst _) = () - | useValueAsRefRead (_, C.Char16Const _) = () - | useValueAsRefRead (_, C.StringConst _) = () - | useValueAsRefRead (_, C.String16Const _) = () - fun useValueAsRefWrite (env, C.Var v) - = (case TypedSyntax.VIdTable.find env v of - SOME r => let val { call, project, ref_read, ref_write, other, returnConts, labels } = !r - in r := { call = call, project = project, ref_read = ref_read, ref_write = oneMore ref_write, other = other, returnConts = returnConts, labels = labels } - end - | NONE => () - ) - | useValueAsRefWrite (_, C.Unit) = () - | useValueAsRefWrite (_, C.Nil) = () - | useValueAsRefWrite (_, C.BoolConst _) = () - | useValueAsRefWrite (_, C.IntConst _) = () - | useValueAsRefWrite (_, C.WordConst _) = () - | useValueAsRefWrite (_, C.CharConst _) = () - | useValueAsRefWrite (_, C.Char16Const _) = () - | useValueAsRefWrite (_, C.StringConst _) = () - | useValueAsRefWrite (_, C.String16Const _) = () + | useValueAsCallee (_, C.Unit) = () + | useValueAsCallee (_, C.Nil) = () + | useValueAsCallee (_, C.BoolConst _) = () + | useValueAsCallee (_, C.IntConst _) = () + | useValueAsCallee (_, C.WordConst _) = () + | useValueAsCallee (_, C.CharConst _) = () + | useValueAsCallee (_, C.Char16Const _) = () + | useValueAsCallee (_, C.StringConst _) = () + | useValueAsCallee (_, C.String16Const _) = () fun useContVarIndirect cenv (v : C.CVar) = (case C.CVarTable.find cenv v of SOME r => let val { direct, indirect } = !r in r := { direct = direct, indirect = oneMore indirect } @@ -239,23 +171,18 @@ local else C.CVarTable.insert cenv (v, ref neverUsedCont) in - fun goSimpleExp (env, _, _, _, _, C.PrimOp { primOp = FSyntax.PrimCall Primitives.Ref_set, tyargs = _, args = [r, v] }) = (useValueAsRefWrite (env, r); useValue env v) - | goSimpleExp (env, _, _, _, _, C.PrimOp { primOp = FSyntax.PrimCall Primitives.Ref_read, tyargs = _, args = [r] }) = useValueAsRefRead (env, r) - | goSimpleExp (env, _, _, _, _, C.PrimOp { primOp = _, tyargs = _, args }) = List.app (useValue env) args - | goSimpleExp (env, _, _, _, _, C.Record fields) = Syntax.LabelMap.app (useValue env) fields - | goSimpleExp (_, _, _, _, _, C.ExnTag { name = _, payloadTy = _ }) = () - | goSimpleExp (env, _, _, _, results, C.Projection { label, record, fieldTypes = _ }) = (case results of - [result] => useValueAsRecord (env, label, result, record) - | _ => () (* should not occur *) - ) - | goSimpleExp (env, renv, cenv, crenv, _, C.Abs { contParam, params, body, attr = _ }) + fun goSimpleExp (env, _, _, _, C.PrimOp { primOp = _, tyargs = _, args }) = List.app (useValue env) args + | goSimpleExp (env, _, _, _, C.Record fields) = Syntax.LabelMap.app (useValue env) fields + | goSimpleExp (_, _, _, _, C.ExnTag { name = _, payloadTy = _ }) = () + | goSimpleExp (env, _, _, _, C.Projection { label = _, record, fieldTypes = _ }) = useValue env record + | goSimpleExp (env, renv, cenv, crenv, C.Abs { contParam, params, body, attr = _ }) = ( List.app (fn p => add (env, p)) params ; addC (cenv, contParam) ; goCExp (env, renv, cenv, crenv, body) ) and goDec (env, renv, cenv, crenv) = fn C.ValDec { exp, results } => - ( goSimpleExp (env, renv, cenv, crenv, results, exp) + ( goSimpleExp (env, renv, cenv, crenv, exp) ; List.app (fn SOME result => add (env, result) | NONE => () ) results @@ -295,7 +222,7 @@ local ; goCExp (env, renv, cenv, crenv, cont) ) | C.App { applied, cont, args, attr = _ } => - ( useValueAsCallee (env, cont, applied) + ( useValueAsCallee (env, applied) ; useContVarIndirect cenv cont ; List.app (useValue env) args ) @@ -331,9 +258,7 @@ in structure CpsDeadCodeElimination : sig val goCExp : CpsSimplify.Context * CSyntax.CExp -> CSyntax.CExp end = struct -local structure F = FSyntax - structure C = CSyntax - structure P = Primitives +local structure C = CSyntax datatype frequency = datatype CpsUsageAnalysis.frequency in type Context = { base : CpsSimplify.Context @@ -343,7 +268,6 @@ type Context = { base : CpsSimplify.Context , cont_rec_usage : CpsUsageAnalysis.cont_usage_table , dead_code_analysis : CpsDeadCodeAnalysis.usage } -datatype param_transform = KEEP | ELIMINATE | UNPACK of (C.Var * Syntax.Label) list fun simplifyDec (ctx : Context, appliedCont : C.CVar option) (dec, (env, cenv, subst, csubst, acc : C.Dec list)) = case dec of C.ValDec { exp, results } => @@ -357,17 +281,17 @@ fun simplifyDec (ctx : Context, appliedCont : C.CVar option) (dec, (env, cenv, s in case (exp, results) of (C.Abs { contParam, params, body, attr }, [SOME result]) => (case CpsUsageAnalysis.getValueUsage (#usage ctx, result) of - { call = NEVER, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, returnConts = _, labels = _ } => + { call = NEVER, other = NEVER } => ( #simplificationOccurred (#base ctx) := true ; (env, cenv, subst, csubst, acc) ) - | { call = ONCE, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, returnConts = _, labels = _ } => + | { call = ONCE, other = NEVER } => let val body = simplifyCExp (ctx, env, cenv, subst, csubst, body) val env = TypedSyntax.VIdMap.insert (env, result, { exp = SOME (C.Abs { contParam = contParam, params = params, body = body, attr = attr }), isDiscardableFunction = CpsSimplify.isDiscardableExp (env, body) }) val () = #simplificationOccurred (#base ctx) := true in (env, cenv, subst, csubst, acc) end - | u => let val body = simplifyCExp (ctx, env, cenv, subst, csubst, body) + | _ => let val body = simplifyCExp (ctx, env, cenv, subst, csubst, body) val exp = C.Abs { contParam = contParam, params = params, body = body, attr = attr } val env = TypedSyntax.VIdMap.insert (env, result, { exp = NONE, isDiscardableFunction = CpsSimplify.isDiscardableExp (env, body) }) val dec = C.ValDec { exp = exp @@ -417,7 +341,7 @@ fun simplifyDec (ctx : Context, appliedCont : C.CVar option) (dec, (env, cenv, s else let val body = simplifyCExp (ctx, env, cenv, subst, csubst, body) val params = List.map (fn SOME p => (case CpsUsageAnalysis.getValueUsage (#usage ctx, p) of - { call = NEVER, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, ... } => NONE + { call = NEVER, other = NEVER, ... } => NONE | _ => SOME p ) | NONE => NONE @@ -486,9 +410,9 @@ and simplifyCExp (ctx : Context, env : CpsSimplify.value_info TypedSyntax.VIdMap val subst = ListPair.foldlEq (fn (p, a, subst) => TypedSyntax.VIdMap.insert (subst, p, a)) subst (params, args) val csubst = C.CVarMap.insert (csubst, contParam, cont) val canOmitAlphaConversion = case CpsUsageAnalysis.getValueUsage (#usage ctx, applied) of - { call = ONCE, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, returnConts = _, labels = _ } => + { call = ONCE, other = NEVER } => (case CpsUsageAnalysis.getValueUsage (#rec_usage ctx, applied) of - { call = NEVER, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, returnConts = _, labels = _ } => true + { call = NEVER, other = NEVER } => true | _ => false ) | _ => false diff --git a/src/cps/decompose-recursive.sml b/src/cps/decompose-recursive.sml index 3e9e6d0..c971812 100644 --- a/src/cps/decompose-recursive.sml +++ b/src/cps/decompose-recursive.sml @@ -9,7 +9,7 @@ local structure C = CSyntax in fun goDec ctx (dec, acc) = case dec of - C.ValDec { exp = C.Abs { contParam, params, body, attr }, results as [SOME name] } => + C.ValDec { exp = C.Abs { contParam, params, body, attr }, results as [SOME _] } => C.ValDec { exp = C.Abs { contParam = contParam, params = params, body = goCExp (ctx, body), attr = attr }, results = results } :: acc | C.ValDec { exp = _, results = _ } => dec :: acc | C.RecDec defs => diff --git a/src/cps/loop.sml b/src/cps/loop.sml index ddebc9c..1e0c4b0 100644 --- a/src/cps/loop.sml +++ b/src/cps/loop.sml @@ -44,17 +44,17 @@ local else TypedSyntax.VIdTable.insert env (v, ref neverUsed) in - fun goSimpleExp (env, _, _, C.PrimOp { primOp = _, tyargs = _, args }) = () - | goSimpleExp (env, _, _, C.Record fields) = () - | goSimpleExp (_, _, _, C.ExnTag { name = _, payloadTy = _ }) = () - | goSimpleExp (env, _, results, C.Projection { label, record, fieldTypes = _ }) = () - | goSimpleExp (env, renv, _, C.Abs { contParam, params, body, attr = _ }) + fun goSimpleExp (_, _, C.PrimOp _) = () + | goSimpleExp (_, _, C.Record _) = () + | goSimpleExp (_, _, C.ExnTag _) = () + | goSimpleExp (_, _, C.Projection _) = () + | goSimpleExp (env, renv, C.Abs { contParam = _, params, body, attr = _ }) = ( List.app (fn p => add (env, p)) params ; goCExp (env, renv, body) ) and goDec (env, renv) = fn C.ValDec { exp, results } => - ( goSimpleExp (env, renv, results, exp) + ( goSimpleExp (env, renv, exp) ; List.app (fn SOME result => add (env, result) | NONE => () ) results @@ -70,7 +70,7 @@ local ; TypedSyntax.VIdMap.appi (fn (f, v) => TypedSyntax.VIdTable.insert renv (f, v)) recursiveEnv ; List.app (fn { name, ... } => TypedSyntax.VIdTable.insert env (name, ref neverUsed)) defs end - | C.ContDec { name, params, body } => + | C.ContDec { name = _, params, body } => ( List.app (Option.app (fn p => add (env, p))) params ; goCExp (env, renv, body) ) @@ -86,15 +86,15 @@ local ( List.app (goDec (env, renv)) decs ; goCExp (env, renv, cont) ) - | C.App { applied, cont, args, attr = _ } => + | C.App { applied, cont, args = _, attr = _ } => ( useValueAsCallee (env, cont, applied) ) - | C.AppCont { applied, args } => () - | C.If { cond, thenCont, elseCont } => + | C.AppCont { applied = _, args = _ } => () + | C.If { cond = _, thenCont, elseCont } => ( goCExp (env, renv, thenCont) ; goCExp (env, renv, elseCont) ) - | C.Handle { body, handler = (e, h), successfulExitIn, successfulExitOut } => + | C.Handle { body, handler = (e, h), successfulExitIn = _, successfulExitOut = _ } => ( goCExp (env, renv, body) ; add (env, e) ; goCExp (env, renv, h) @@ -117,7 +117,7 @@ in type Context = { base : CpsSimplify.Context , rec_usage : CpsUsageAnalysis.usage_table } -fun simplifyDec (ctx : Context, appliedCont : C.CVar option) (dec, acc : C.Dec list) +fun simplifyDec (ctx : Context) (dec, acc : C.Dec list) = case dec of C.ValDec { exp, results } => (case (exp, results) of @@ -172,10 +172,7 @@ fun simplifyDec (ctx : Context, appliedCont : C.CVar option) (dec, acc : C.Dec l and simplifyCExp (ctx : Context, e) = case e of C.Let { decs, cont } => - let val appliedCont = case cont of - C.AppCont { applied, args = _ } => SOME applied - | _ => NONE - val revDecs = List.foldl (simplifyDec (ctx, appliedCont)) [] decs + let val revDecs = List.foldl (simplifyDec ctx) [] decs in CpsTransform.prependRevDecs (revDecs, simplifyCExp (ctx, cont)) end | C.App _ => e diff --git a/src/cps/ref-cell.sml b/src/cps/ref-cell.sml new file mode 100644 index 0000000..0e18570 --- /dev/null +++ b/src/cps/ref-cell.sml @@ -0,0 +1,297 @@ +(* + * Copyright (c) 2024 ARATA Mizuki + * This file is part of LunarML. + *) +(* + * Eliminate constant ref-cells. + *) +local + structure CpsUsageAnalysis :> sig + datatype frequency = NEVER | ONCE | MANY + type usage = { ref_read : frequency + , ref_write : frequency + , other : frequency + } + val neverUsed : usage + type usage_table + val getValueUsage : usage_table * TypedSyntax.VId -> usage + val analyze : CSyntax.CExp -> { usage : usage_table + , rec_usage : usage_table + } + end = struct + local structure C = CSyntax + in + datatype frequency = NEVER | ONCE | MANY + fun oneMore NEVER = ONCE + | oneMore ONCE = MANY + | oneMore (many as MANY) = many + type usage = { ref_read : frequency + , ref_write : frequency + , other : frequency + } + val neverUsed : usage = { ref_read = NEVER + , ref_write = NEVER + , other = NEVER + } + type usage_table = (usage ref) TypedSyntax.VIdTable.hash_table + fun getValueUsage (table : usage_table, v) + = case TypedSyntax.VIdTable.find table v of + SOME r => !r + | NONE => { ref_read = MANY, ref_write = MANY, other = MANY } (* unknown *) + fun useValue env (C.Var v) = (case TypedSyntax.VIdTable.find env v of + SOME r => let val { ref_read, ref_write, other } = !r + in r := { ref_read = ref_read, ref_write = ref_write, other = oneMore other } + end + | NONE => () + ) + | useValue _ C.Unit = () + | useValue _ C.Nil = () + | useValue _ (C.BoolConst _) = () + | useValue _ (C.IntConst _) = () + | useValue _ (C.WordConst _) = () + | useValue _ (C.CharConst _) = () + | useValue _ (C.Char16Const _) = () + | useValue _ (C.StringConst _) = () + | useValue _ (C.String16Const _) = () + fun useValueAsRefRead (env, C.Var v) + = (case TypedSyntax.VIdTable.find env v of + SOME r => let val { ref_read, ref_write, other } = !r + in r := { ref_read = oneMore ref_read, ref_write = ref_write, other = other } + end + | NONE => () + ) + | useValueAsRefRead (_, C.Unit) = () + | useValueAsRefRead (_, C.Nil) = () + | useValueAsRefRead (_, C.BoolConst _) = () + | useValueAsRefRead (_, C.IntConst _) = () + | useValueAsRefRead (_, C.WordConst _) = () + | useValueAsRefRead (_, C.CharConst _) = () + | useValueAsRefRead (_, C.Char16Const _) = () + | useValueAsRefRead (_, C.StringConst _) = () + | useValueAsRefRead (_, C.String16Const _) = () + fun useValueAsRefWrite (env, C.Var v) + = (case TypedSyntax.VIdTable.find env v of + SOME r => let val { ref_read, ref_write, other } = !r + in r := { ref_read = ref_read, ref_write = oneMore ref_write, other = other } + end + | NONE => () + ) + | useValueAsRefWrite (_, C.Unit) = () + | useValueAsRefWrite (_, C.Nil) = () + | useValueAsRefWrite (_, C.BoolConst _) = () + | useValueAsRefWrite (_, C.IntConst _) = () + | useValueAsRefWrite (_, C.WordConst _) = () + | useValueAsRefWrite (_, C.CharConst _) = () + | useValueAsRefWrite (_, C.Char16Const _) = () + | useValueAsRefWrite (_, C.StringConst _) = () + | useValueAsRefWrite (_, C.String16Const _) = () + local + fun add (env, v) = if TypedSyntax.VIdTable.inDomain env v then + raise Fail ("goCExp: duplicate name in AST: " ^ TypedSyntax.print_VId v) + else + TypedSyntax.VIdTable.insert env (v, ref neverUsed) + in + fun goSimpleExp (env, _, C.PrimOp { primOp = FSyntax.PrimCall Primitives.Ref_set, tyargs = _, args = [r, v] }) = (useValueAsRefWrite (env, r); useValue env v) + | goSimpleExp (env, _, C.PrimOp { primOp = FSyntax.PrimCall Primitives.Ref_read, tyargs = _, args = [r] }) = useValueAsRefRead (env, r) + | goSimpleExp (env, _, C.PrimOp { primOp = _, tyargs = _, args }) = List.app (useValue env) args + | goSimpleExp (env, _, C.Record fields) = Syntax.LabelMap.app (useValue env) fields + | goSimpleExp (_, _, C.ExnTag { name = _, payloadTy = _ }) = () + | goSimpleExp (env, _, C.Projection { label = _, record, fieldTypes = _ }) = useValue env record + | goSimpleExp (env, renv, C.Abs { contParam = _, params, body, attr = _ }) + = ( List.app (fn p => add (env, p)) params + ; goCExp (env, renv, body) + ) + and goDec (env, renv) + = fn C.ValDec { exp, results } => + ( goSimpleExp (env, renv, exp) + ; List.app (fn SOME result => add (env, result) + | NONE => () + ) results + ) + | C.RecDec defs => + let val recursiveEnv = List.foldl (fn ({ name, ... }, m) => TypedSyntax.VIdMap.insert (m, name, ref neverUsed)) TypedSyntax.VIdMap.empty defs + in TypedSyntax.VIdMap.appi (fn (f, v) => TypedSyntax.VIdTable.insert env (f, v)) recursiveEnv + ; List.app (fn { contParam = _, params, body, ... } => + ( List.app (fn p => add (env, p)) params + ; goCExp (env, renv, body) + ) + ) defs + ; TypedSyntax.VIdMap.appi (fn (f, v) => TypedSyntax.VIdTable.insert renv (f, v)) recursiveEnv + ; List.app (fn { name, ... } => TypedSyntax.VIdTable.insert env (name, ref neverUsed)) defs + end + | C.ContDec { name = _, params, body } => + ( List.app (Option.app (fn p => add (env, p))) params + ; goCExp (env, renv, body) + ) + | C.RecContDec defs => + List.app (fn (_, params, body) => ( List.app (Option.app (fn p => add (env, p))) params + ; goCExp (env, renv, body) + ) + ) defs + | C.ESImportDec { pure = _, specs, moduleName = _ } => List.app (fn (_, vid) => add (env, vid)) specs + and goCExp (env : (usage ref) TypedSyntax.VIdTable.hash_table, renv, cexp) + = case cexp of + C.Let { decs, cont } => + ( List.app (goDec (env, renv)) decs + ; goCExp (env, renv, cont) + ) + | C.App { applied, cont = _, args, attr = _ } => + ( useValue env applied + ; List.app (useValue env) args + ) + | C.AppCont { applied = _, args } => + List.app (useValue env) args + | C.If { cond, thenCont, elseCont } => + ( useValue env cond + ; goCExp (env, renv, thenCont) + ; goCExp (env, renv, elseCont) + ) + | C.Handle { body, handler = (e, h), successfulExitIn = _, successfulExitOut = _ } => + ( goCExp (env, renv, body) + ; add (env, e) + ; goCExp (env, renv, h) + ) + | C.Unreachable => () + end (* local *) + fun analyze exp = let val usage = TypedSyntax.VIdTable.mkTable (1, Fail "usage table lookup failed") + val rusage = TypedSyntax.VIdTable.mkTable (1, Fail "rusage table lookup failed") + in goCExp (usage, rusage, exp) + ; { usage = usage, rec_usage = rusage } + end + end (* local *) + end (* strucuture CpsUsageAnalysis *) +in +structure CpsConstantRefCell : sig + val goCExp : CpsSimplify.Context * CSyntax.CExp -> CSyntax.CExp + end = struct +local structure F = FSyntax + structure C = CSyntax + structure P = Primitives + datatype frequency = datatype CpsUsageAnalysis.frequency +in +type Context = { base : CpsSimplify.Context + , usage : CpsUsageAnalysis.usage_table + , rec_usage : CpsUsageAnalysis.usage_table + } +datatype simplify_result = VALUE of C.Value + | SIMPLE_EXP of C.SimpleExp + | NOT_SIMPLIFIED +(*: val simplifySimpleExp : CpsUsageAnalysis.usage_table * { exp : C.SimpleExp option } TypedSyntax.VIdMap.map * C.SimpleExp -> simplify_result *) +fun simplifySimpleExp (usage, env, C.PrimOp { primOp, tyargs = _, args }) + = (case (primOp, args) of + (F.PrimCall P.Ref_read, [C.Var v]) => + let val u = CpsUsageAnalysis.getValueUsage (usage, v) + in case (#ref_write u, #other u) of + (CpsUsageAnalysis.NEVER, CpsUsageAnalysis.NEVER) => + (case TypedSyntax.VIdMap.find (env, v) of + SOME { exp = SOME (C.PrimOp { primOp = F.PrimCall P.Ref_ref, tyargs = _, args = [initialValue] }) } => + VALUE initialValue + | _ => NOT_SIMPLIFIED + ) + | _ => NOT_SIMPLIFIED + end + | _ => NOT_SIMPLIFIED + ) + | simplifySimpleExp (_, _, _) = NOT_SIMPLIFIED +and simplifyDec (ctx : Context) (dec, (env, subst, acc : C.Dec list)) + = case dec of + C.ValDec { exp, results } => + let val exp = CpsSimplify.substSimpleExp (subst, C.CVarMap.empty, exp) + in case simplifySimpleExp (#usage ctx, env, exp) of + VALUE v => let val () = #simplificationOccurred (#base ctx) := true + val subst = case results of + [SOME result] => TypedSyntax.VIdMap.insert (subst, result, v) + | [NONE] => subst + | _ => subst (* should not occur *) + in (env, subst, acc) + end + | simplified => + let val () = case simplified of + SIMPLE_EXP _ => #simplificationOccurred (#base ctx) := true + | VALUE _ => #simplificationOccurred (#base ctx) := true (* shoud not occur *) + | NOT_SIMPLIFIED => () + val exp = case simplified of + SIMPLE_EXP exp => exp + | _ => exp + in case (exp, results) of + (C.Abs { contParam, params, body, attr }, [SOME result]) => + let val body = simplifyCExp (ctx, env, subst, body) + val exp = C.Abs { contParam = contParam, params = params, body = body, attr = attr } + val env = TypedSyntax.VIdMap.insert (env, result, { exp = NONE }) + val dec = C.ValDec { exp = exp + , results = [SOME result] + } + in (env, subst, dec :: acc) + end + | _ => (case (C.isDiscardable exp, results) of + (true, [NONE]) => (env, subst, acc) + | (_, [SOME result]) => let val dec = C.ValDec { exp = exp + , results = [SOME result] + } + val env = TypedSyntax.VIdMap.insert (env, result, { exp = SOME exp }) + in (env, subst, dec :: acc) + end + | _ => let val dec = C.ValDec { exp = exp + , results = results + } + in (env, subst, dec :: acc) + end + ) + end + end + | C.RecDec defs => + let val defs = List.map (fn { name, contParam, params, body, attr } => { name = name, contParam = contParam, params = params, body = simplifyCExp (ctx, env, subst, body), attr = attr }) defs + val decs = C.RecDec defs :: acc + in (env, subst, decs) + end + | C.ContDec { name, params, body } => + let val body = simplifyCExp (ctx, env, subst, body) + val dec = C.ContDec { name = name + , params = params + , body = body + } + in (env, subst, dec :: acc) + end + | C.RecContDec defs => + let val dec = C.RecContDec (List.map (fn (name, params, body) => (name, params, simplifyCExp (ctx, env, subst, body))) defs) + in (env, subst, dec :: acc) + end + | C.ESImportDec _ => (env, subst, dec :: acc) +and simplifyCExp (ctx : Context, env : { exp : CSyntax.SimpleExp option } TypedSyntax.VIdMap.map, subst : C.Value TypedSyntax.VIdMap.map, e) + = case e of + C.Let { decs, cont } => + let val (env, subst, revDecs) = List.foldl (simplifyDec ctx) (env, subst, []) decs + in CpsTransform.prependRevDecs (revDecs, simplifyCExp (ctx, env, subst, cont)) + end + | C.App { applied, cont, args, attr } => + let val applied = CpsSimplify.substValue subst applied + val args = List.map (CpsSimplify.substValue subst) args + in C.App { applied = applied, cont = cont, args = args, attr = attr } + end + | C.AppCont { applied, args } => + let val args = List.map (CpsSimplify.substValue subst) args + in C.AppCont { applied = applied, args = args } + end + | C.If { cond, thenCont, elseCont } => + C.If { cond = cond + , thenCont = simplifyCExp (ctx, env, subst, thenCont) + , elseCont = simplifyCExp (ctx, env, subst, elseCont) + } + | C.Handle { body, handler = (e, h), successfulExitIn, successfulExitOut } => + C.Handle { body = simplifyCExp (ctx, env, subst, body) + , handler = (e, simplifyCExp (ctx, env, subst, h)) + , successfulExitIn = successfulExitIn + , successfulExitOut = successfulExitOut + } + | C.Unreachable => e +fun goCExp (ctx : CpsSimplify.Context, exp) + = let val usage = CpsUsageAnalysis.analyze exp + val ctx' = { base = ctx + , usage = #usage usage + , rec_usage = #rec_usage usage + } + in simplifyCExp (ctx', TypedSyntax.VIdMap.empty, TypedSyntax.VIdMap.empty, exp) + end +end (* local *) +end (* structure CpsConstantRefCell *) +end; (* local *) diff --git a/src/cps/uncurry.sml b/src/cps/uncurry.sml index 37cc1d2..4bc1823 100644 --- a/src/cps/uncurry.sml +++ b/src/cps/uncurry.sml @@ -8,7 +8,7 @@ structure CpsUncurry : sig local structure C = CSyntax in (*: val tryUncurry : C.SimpleExp -> ((C.Var list) list * C.CVar * C.CExp) option *) -fun tryUncurry (exp as C.Abs { contParam, params, body as C.Let { decs, cont = C.AppCont { applied = k, args = [C.Var v] } }, attr = { isWrapper = false } }) +fun tryUncurry (C.Abs { contParam, params, body as C.Let { decs, cont = C.AppCont { applied = k, args = [C.Var v] } }, attr = { isWrapper = false } }) = (case decs of [C.ValDec { exp, results = [SOME f] }] => if contParam = k andalso v = f then diff --git a/src/cps/unpack-record-parameter.sml b/src/cps/unpack-record-parameter.sml index 458954d..0378e61 100644 --- a/src/cps/unpack-record-parameter.sml +++ b/src/cps/unpack-record-parameter.sml @@ -69,22 +69,22 @@ local | useValue _ (C.Char16Const _) = () | useValue _ (C.StringConst _) = () | useValue _ (C.String16Const _) = () - fun useValueAsCallee (env, cont, C.Var v) + fun useValueAsCallee (env, C.Var v) = (case TypedSyntax.VIdTable.find env v of SOME r => let val { call, project, other, labels } = !r in r := { call = oneMore call, project = project, other = other, labels = labels } end | NONE => () ) - | useValueAsCallee (_, _, C.Unit) = () - | useValueAsCallee (_, _, C.Nil) = () - | useValueAsCallee (_, _, C.BoolConst _) = () - | useValueAsCallee (_, _, C.IntConst _) = () - | useValueAsCallee (_, _, C.WordConst _) = () - | useValueAsCallee (_, _, C.CharConst _) = () - | useValueAsCallee (_, _, C.Char16Const _) = () - | useValueAsCallee (_, _, C.StringConst _) = () - | useValueAsCallee (_, _, C.String16Const _) = () + | useValueAsCallee (_, C.Unit) = () + | useValueAsCallee (_, C.Nil) = () + | useValueAsCallee (_, C.BoolConst _) = () + | useValueAsCallee (_, C.IntConst _) = () + | useValueAsCallee (_, C.WordConst _) = () + | useValueAsCallee (_, C.CharConst _) = () + | useValueAsCallee (_, C.Char16Const _) = () + | useValueAsCallee (_, C.StringConst _) = () + | useValueAsCallee (_, C.String16Const _) = () fun useValueAsRecord (env, label, result, C.Var v) = (case TypedSyntax.VIdTable.find env v of SOME r => let val { call, project, other, labels } = !r @@ -112,7 +112,7 @@ local end | NONE => () ) - fun useContVarDirect cenv (v : C.CVar) = () + fun useContVarDirect _ (_ : C.CVar) = () local fun add (env, v) = if TypedSyntax.VIdTable.inDomain env v then raise Fail ("goCExp: duplicate name in AST: " ^ TypedSyntax.print_VId v) @@ -177,7 +177,7 @@ local ; goCExp (env, renv, cenv, crenv, cont) ) | C.App { applied, cont, args, attr = _ } => - ( useValueAsCallee (env, cont, applied) + ( useValueAsCallee (env, applied) ; useContVarIndirect cenv cont ; List.app (useValue env) args ) @@ -212,9 +212,7 @@ in structure CpsUnpackRecordParameter : sig val goCExp : CpsSimplify.Context * CSyntax.CExp -> CSyntax.CExp end = struct -local structure F = FSyntax - structure C = CSyntax - structure P = Primitives +local structure C = CSyntax datatype frequency = datatype CpsUsageAnalysis.frequency in type Context = { base : CpsSimplify.Context @@ -273,33 +271,34 @@ fun simplifyDec (ctx : Context) (dec, acc : C.Dec list) | (_, ELIMINATE, acc) => acc | (_, UNPACK fields, acc) => List.map #1 fields @ acc ) [] (params, paramTransforms) - val decs = ListPair.foldrEq (fn (p, UNPACK fields, decs) => C.ValDec { exp = C.Record (List.foldl (fn ((fieldVar, label), map) => Syntax.LabelMap.insert (map, label, C.Var fieldVar)) Syntax.LabelMap.empty fields), results = [SOME p] } :: decs - | (_, KEEP, decs) => decs - | (_, ELIMINATE, decs) => decs - ) [] (params, paramTransforms) - val body = case decs of - [] => body - | _ => C.Let { decs = decs, cont = body } - val body = simplifyCExp (ctx, body) - val exp = C.Abs { contParam = contParam, params = params', body = body, attr = attr } - val result' = CpsSimplify.renewVId (#base ctx, result) + val workerDecs = ListPair.foldrEq (fn (p, UNPACK fields, decs) => C.ValDec { exp = C.Record (List.foldl (fn ((fieldVar, label), map) => Syntax.LabelMap.insert (map, label, C.Var fieldVar)) Syntax.LabelMap.empty fields), results = [SOME p] } :: decs + | (_, KEEP, decs) => decs + | (_, ELIMINATE, decs) => decs + ) [] (params, paramTransforms) + val workerBody = case workerDecs of + [] => body + | _ => C.Let { decs = workerDecs, cont = body } + val workerBody = simplifyCExp (ctx, body) + val worker = C.Abs { contParam = contParam, params = params', body = workerBody, attr = attr } + val workerName = CpsSimplify.renewVId (#base ctx, result) val wrapperBody = let val k = CpsSimplify.genContSym (#base ctx) - val params' = List.map (fn p => CpsSimplify.renewVId (#base ctx, p)) params + val params'' = List.map (fn p => CpsSimplify.renewVId (#base ctx, p)) params val (decs, args) = ListPair.foldrEq (fn (p, KEEP, (decs, args)) => (decs, C.Var p :: args) | (_, ELIMINATE, acc) => acc | (p, UNPACK fields, (decs, args)) => List.foldr (fn ((v, label), (decs, args)) => - let val dec = C.ValDec { exp = C.Projection { label = label, record = C.Var p, fieldTypes = Syntax.LabelMap.empty (* dummy *) } + let val v = CpsSimplify.renewVId (#base ctx, v) + val dec = C.ValDec { exp = C.Projection { label = label, record = C.Var p, fieldTypes = Syntax.LabelMap.empty (* dummy *) } , results = [SOME v] } in (dec :: decs, C.Var v :: args) end ) (decs, args) fields - ) ([], []) (params', paramTransforms) + ) ([], []) (params'', paramTransforms) in C.Abs { contParam = k - , params = params' + , params = params'' , body = C.Let { decs = decs - , cont = C.App { applied = C.Var result' + , cont = C.App { applied = C.Var workerName , cont = k , args = args , attr = {} @@ -308,9 +307,9 @@ fun simplifyDec (ctx : Context) (dec, acc : C.Dec list) , attr = { isWrapper = true } } end - val dec1 = C.ValDec { exp = wrapperBody, results = [SOME result] } - val dec2 = C.ValDec { exp = exp, results = [SOME result'] } - in dec2 :: dec1 :: acc + val workerDec = C.ValDec { exp = worker, results = [SOME workerName] } + val wrapperDec = C.ValDec { exp = wrapperBody, results = [SOME result] } + in wrapperDec :: workerDec :: acc end | NONE => let val body = simplifyCExp (ctx, body) val exp = C.Abs { contParam = contParam, params = params, body = body, attr = attr } @@ -387,10 +386,10 @@ fun simplifyDec (ctx : Context) (dec, acc : C.Dec list) let val body = if TypedSyntax.VIdMap.isEmpty wrappers then body else - C.recurseCExp (fn e as C.App { applied = C.Var applied, cont, args, attr } => + C.recurseCExp (fn e as C.App { applied = C.Var applied, cont, args, attr = _ } => (case TypedSyntax.VIdMap.find (wrappers, applied) of NONE => e - | SOME { contParam, params, body, attr } => + | SOME { contParam, params, body, attr = _ } => let val subst = ListPair.foldlEq (fn (p, a, subst) => TypedSyntax.VIdMap.insert (subst, p, a)) TypedSyntax.VIdMap.empty (params, args) val csubst = C.CVarMap.singleton (contParam, cont) in CpsSimplify.alphaConvert (#base ctx, subst, csubst, body) @@ -402,8 +401,7 @@ fun simplifyDec (ctx : Context) (dec, acc : C.Dec list) in { name = name, contParam = contParam, params = params, body = body, attr = attr } end ) defs - val decs = C.RecDec defs :: acc - in decs + in TypedSyntax.VIdMap.foldli (fn (name, abs, acc) => C.ValDec { exp = C.Abs abs, results = [SOME name] } :: acc) (C.RecDec defs :: acc) wrappers end | C.ContDec { name, params, body } => let val shouldTransformParams = if #indirect (CpsUsageAnalysis.getContUsage (#cont_usage ctx, name)) = NEVER then @@ -523,10 +521,25 @@ fun simplifyDec (ctx : Context) (dec, acc : C.Dec list) | NONE => { origName = name, body = body, newName = name, newParams = params, inline = (params, NONE) } end val defs' = List.map transform defs - val dec = C.RecContDec (List.map (fn { newName, newParams, body, ... } => (newName, newParams, simplifyCExp (ctx, body))) defs') - in dec :: acc + val dec = C.RecContDec (List.map (fn { newName, newParams, body, ... } => + let val body = C.recurseCExp (fn e as C.AppCont { applied, args } => + (case List.find (fn { origName, ... } => origName = applied) defs' of + SOME { inline = (params, SOME wrapperBody), ... } => + let val subst = ListPair.foldlEq (fn (SOME p, a, subst) => TypedSyntax.VIdMap.insert (subst, p, a) | (NONE, _, subst) => subst) TypedSyntax.VIdMap.empty (params, args) + in CpsSimplify.alphaConvert (#base ctx, subst, C.CVarMap.empty, wrapperBody) + end + | _ => e + ) + | e => e + ) body + in (newName, newParams, simplifyCExp (ctx, body)) + end + ) defs') + in List.foldl (fn ({ origName, inline = (params, SOME wrapperBody), ... }, acc) => C.ContDec { name = origName, params = params, body = wrapperBody } :: acc + | (_, acc) => acc + ) (dec :: acc) defs' end - | C.ESImportDec { pure, specs, moduleName } => dec :: acc + | C.ESImportDec _ => dec :: acc and simplifyCExp (ctx : Context, e) = case e of C.Let { decs, cont } => diff --git a/src/lunarml-common.mlb b/src/lunarml-common.mlb index c61cdd6..b4f89c3 100644 --- a/src/lunarml-common.mlb +++ b/src/lunarml-common.mlb @@ -43,6 +43,7 @@ cps.sml cps/dead-code.sml cps/uncurry.sml cps/loop.sml +cps/ref-cell.sml cps/inline.sml cps/decompose-recursive.sml cps/unpack-record-parameter.sml diff --git a/src/main.sml b/src/main.sml index e002640..f14bfbd 100644 --- a/src/main.sml +++ b/src/main.sml @@ -90,13 +90,14 @@ fun optimizeCps (_ : { nextVId : int ref, printTimings : bool }) cexp 0 = cexp val ctx' = { nextVId = #nextVId ctx , simplificationOccurred = ref false } - (* val cexp = CpsDeadCodeElimination.goCExp (ctx', cexp) *) + val cexp = CpsDeadCodeElimination.goCExp (ctx', cexp) val cexp = CpsUncurry.goCExp (ctx', cexp) (* val cexp = CpsUnpackRecordParameter.goCExp (ctx', cexp) *) - (* val cexp = CpsLoopOptimization.goCExp (ctx', cexp) *) + val cexp = CpsLoopOptimization.goCExp (ctx', cexp) val cexp = CpsDecomposeRecursive.goCExp (ctx', cexp) + val cexp = CpsConstantRefCell.goCExp (ctx', cexp) val cexp = CpsInline.goCExp (ctx', cexp) - val cexp = CpsMiscOptimization.goCExp (ctx', cexp) + (* val cexp = CpsMiscOptimization.goCExp (ctx', cexp) *) in if #printTimings ctx then print (" " ^ LargeInt.toString (Time.toMicroseconds (#usr (Timer.checkCPUTimer timer))) ^ " us\n") else