Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds inlined recursion with pretty conservative guard #61

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
41 changes: 41 additions & 0 deletions backend-es/test/snapshots-out/Snapshot.RecursionInlined01.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// @inline Snapshot.RecursionInlined01.append always
import * as $runtime from "../runtime.js";
const $List = (tag, _1, _2) => ({tag, _1, _2});
const Nil = /* #__PURE__ */ $List("Nil");
const Cons = value0 => value1 => $List("Cons", value0, value1);
const append = v => v1 => {
if (v.tag === "Nil") { return v1; }
if (v.tag === "Cons") { return $List("Cons", v._1, append(v._2)(v1)); }
$runtime.fail();
};
const test1 = /* #__PURE__ */ $List(
"Cons",
"a",
/* #__PURE__ */ $List(
"Cons",
"b",
/* #__PURE__ */ $List(
"Cons",
"c",
/* #__PURE__ */ $List("Cons", "d", /* #__PURE__ */ $List("Cons", "e", /* #__PURE__ */ $List("Cons", "f", /* #__PURE__ */ $List("Cons", "g", Nil))))
)
)
);
const test2 = z => $List(
"Cons",
"a",
$List(
"Cons",
"b",
$List(
"Cons",
"c",
(() => {
if (z.tag === "Nil") { return $List("Cons", "d", $List("Cons", "e", $List("Cons", "f", $List("Cons", "g", Nil)))); }
if (z.tag === "Cons") { return $List("Cons", z._1, append(z._2)($List("Cons", "d", $List("Cons", "e", $List("Cons", "f", $List("Cons", "g", Nil)))))); }
$runtime.fail();
})()
)
)
);
export {$List, Cons, Nil, append, test1, test2};
26 changes: 26 additions & 0 deletions backend-es/test/snapshots-out/Snapshot.RecursionInlined02.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// @inline Snapshot.RecursionInlined02.addStuff always
const $List = (tag, _1, _2) => ({tag, _1, _2});
const Nil = /* #__PURE__ */ $List("Nil");
const Cons = value0 => value1 => $List("Cons", value0, value1);
const addStuff = x => ys => {
if (x > 0) { return 1 + addStuff(x - 1 | 0)(ys) | 0; }
if (x < 0) { return -1 + addStuff(x + 1 | 0)(ys) | 0; }
return ys;
};
const test1 = 42;
const test2 = z => 2 + addStuff(1)((() => {
if (z > 0) {
const $0 = z - 1 | 0;
if ($0 > 0) { return 2 + addStuff($0 - 1 | 0)(5) | 0; }
if ($0 < 0) { return 0 + addStuff($0 + 1 | 0)(5) | 0; }
return 6;
}
if (z < 0) {
const $0 = z + 1 | 0;
if ($0 > 0) { return 0 + addStuff($0 - 1 | 0)(5) | 0; }
if ($0 < 0) { return -2 + addStuff($0 + 1 | 0)(5) | 0; }
return 4;
}
return 5;
})()) | 0;
export {$List, Cons, Nil, addStuff, test1, test2};
48 changes: 48 additions & 0 deletions backend-es/test/snapshots-out/Snapshot.RecursionInlined03.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// @inline Snapshot.RecursionInlined03.append always
// This doesn't quite work yet, because the inlining of append depends
// on the analysis of the local `b`, which we don't have (easy) access to.
// If we can somehow grab its usage, we'd likely see that access + case == total
import * as Partial from "../Partial/index.js";
const List = x => x;
const nil = {type: "nil", value: undefined};
const cons = head => tail => ({type: "cons", value: {head, tail}});
const append = v => b => {
if (v.type === "cons") { return {type: "cons", value: {head: v.value.head, tail: append(v.value.tail)(b)}}; }
if (v.type === "nil") { return b; }
return Partial._crashWith("Data.Variant: pattern match failure [" + v.type + "]");
};
const test1 = {
type: "cons",
value: {
head: "a",
tail: {
type: "cons",
value: {
head: "b",
tail: /* #__PURE__ */ append({type: "cons", value: {head: "c", tail: nil}})({
type: "cons",
value: {head: "d", tail: {type: "cons", value: {head: "e", tail: {type: "cons", value: {head: "f", tail: {type: "cons", value: {head: "g", tail: nil}}}}}}}
})
}
}
}
};
const test2 = z => (
{
type: "cons",
value: {
head: "a",
tail: {
type: "cons",
value: {
head: "b",
tail: append({type: "cons", value: {head: "c", tail: z}})({
type: "cons",
value: {head: "d", tail: {type: "cons", value: {head: "e", tail: {type: "cons", value: {head: "f", tail: {type: "cons", value: {head: "g", tail: nil}}}}}}}
})
}
}
}
}
);
export {List, append, cons, nil, test1, test2};
17 changes: 17 additions & 0 deletions backend-es/test/snapshots-out/Snapshot.RecursionInlinedBroken.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// @inline Snapshot.RecursionInlinedBroken.addStuffBroken always
// This will recurse out of control and stop when it hits the recursion limit
const $List = (tag, _1, _2) => ({tag, _1, _2});
const Nil = /* #__PURE__ */ $List("Nil");
const Cons = value0 => value1 => $List("Cons", value0, value1);
const addStuffBroken = v => v1 => {
if (v === 0) { return v1; }
return 1 + addStuffBroken(v - 1 | 0)(v1) | 0;
};
const test1 = v => 2 + addStuffBroken(-5)(4) | 0;
const test2 = z => 2 + addStuffBroken(-5)((() => {
if (z === 0) { return 5; }
const $0 = z - 1 | 0;
if ($0 === 0) { return 6; }
return 2 + addStuffBroken($0 - 1 | 0)(5) | 0;
})()) | 0;
export {$List, Cons, Nil, addStuffBroken, test1, test2};
16 changes: 16 additions & 0 deletions backend-es/test/snapshots/Snapshot.RecursionInlined01.purs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- @inline Snapshot.RecursionInlined01.append always
module Snapshot.RecursionInlined01 where

data List a = Nil | Cons a (List a)
infixr 5 Cons as :

append :: forall a. List a -> List a -> List a
append Nil ys = ys
append (Cons x xs) ys = Cons x (append xs ys)

infixr 4 append as <>

test1 :: List String
test1 = ("a" : "b" : "c" : Nil) <> ("d" : "e" : "f" : "g" : Nil)
test2 :: List String -> List String
test2 z = ("a" : "b" : "c" : z) <> ("d" : "e" : "f" : "g" : Nil)
20 changes: 20 additions & 0 deletions backend-es/test/snapshots/Snapshot.RecursionInlined02.purs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- @inline Snapshot.RecursionInlined02.addStuff always
module Snapshot.RecursionInlined02 where

import Prelude

data List a = Nil | Cons a (List a)
infixr 5 Cons as :

addStuff :: Int -> Int -> Int
addStuff x ys
| x > 0 = 1 + addStuff (x - 1) ys
| x < 0 = (-1) + addStuff (x + 1) ys
| otherwise = ys

infixr 4 addStuff as ++

test1 :: Int
test1 = 38 ++ 4
test2 :: Int -> Int
test2 z = 3 ++ z ++ 5
31 changes: 31 additions & 0 deletions backend-es/test/snapshots/Snapshot.RecursionInlined03.purs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
-- @inline Snapshot.RecursionInlined03.append always
-- This doesn't quite work yet, because the inlining of append depends
-- on the analysis of the local `b`, which we don't have (easy) access to.
-- If we can somehow grab its usage, we'd likely see that access + case == total
module Snapshot.RecursionInlined03 where

import Prelude

import Data.Variant (Variant, inj, case_, on)
import Type.Proxy (Proxy(..))

newtype List a = List (Variant (nil :: Unit, cons :: { head :: a, tail :: List a }))

cons :: forall a. a -> List a -> List a
cons head tail = List $ inj (Proxy :: _ "cons") {head, tail}
nil :: forall a. List a
nil = List $ inj (Proxy :: _ "nil") unit
infixr 5 cons as :

append :: forall a. List a -> List a -> List a
append (List a) b = (case_
# on (Proxy :: _ "nil") (\_ -> b)
# on (Proxy :: _ "cons") (\{head, tail} -> cons head (append tail b))) a

infixr 4 append as <>

test1 :: List String
test1 = ("a" : "b" : "c" : nil) <> ("d" : "e" : "f" : "g" : nil)

test2 :: List String -> List String
test2 z = ("a" : "b" : "c" : z) <> ("d" : "e" : "f" : "g" : nil)
22 changes: 22 additions & 0 deletions backend-es/test/snapshots/Snapshot.RecursionInlinedBroken.purs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- @inline Snapshot.RecursionInlinedBroken.addStuffBroken always
-- This will recurse out of control and stop when it hits the recursion limit
module Snapshot.RecursionInlinedBroken where

import Prelude

data List a = Nil | Cons a (List a)
infixr 5 Cons as :

addStuffBroken :: Int -> Int -> Int
addStuffBroken 0 ys = ys
addStuffBroken x ys = 1 + addStuffBroken (x - 1) ys

infixr 4 addStuffBroken as ++

-- we make this Unit -> Int so that the module can be loaded with an
-- import statement. otherwise, it will execute the broken (-3) ++ 4 and the
-- tests will fail
test1 :: Unit -> Int
test1 _ = (-3) ++ 4
test2 :: Int -> Int
test2 z = (-3) ++ z ++ 5
2 changes: 1 addition & 1 deletion src/PureScript/Backend/Optimizer/Convert.purs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ toBackendTopLevelBindingGroup env = case _ of

toTopLevelBackendBinding :: Array (Qualified Ident) -> ConvertEnv -> Binding Ann -> Accum ConvertEnv (Tuple Ident (WithDeps NeutralExpr))
toTopLevelBackendBinding group env (Binding _ ident cfn) = do
let evalEnv = Env { currentModule: env.currentModule, evalExtern: makeExternEval env, locals: [], directives: env.directives }
let evalEnv = Env { currentModule: env.currentModule, evalExtern: makeExternEval env, locals: [], punt: Set.empty, directives: env.directives, blockNextRecursion: false }
let backendExpr = toBackendExpr cfn env
let Tuple impl expr' = toExternImpl env group (optimize (getCtx env) evalEnv (Qualified (Just env.currentModule) ident) env.rewriteLimit backendExpr)
{ accum: env
Expand Down
50 changes: 40 additions & 10 deletions src/PureScript/Backend/Optimizer/Semantics.purs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ data BackendSemantics
| NeutLocal (Maybe Ident) Level
| NeutVar (Qualified Ident)
| NeutStop (Qualified Ident)
| RecurseWithRecklessAbandon (Qualified Ident)
| NeutData (Qualified Ident) ConstructorType ProperName Ident (Array (Tuple String BackendSemantics))
| NeutCtorDef (Qualified Ident) ConstructorType ProperName Ident (Array String)
| NeutApp BackendSemantics (Spine BackendSemantics)
Expand Down Expand Up @@ -86,6 +87,7 @@ data BackendRewrite
| RewriteLetAssoc (Array (LetBindingAssoc BackendExpr)) BackendExpr
| RewriteEffectBindAssoc (Array (EffectBindingAssoc BackendExpr)) BackendExpr
| RewriteStop (Qualified Ident)
| RewriteRecurse (Qualified Ident)
| RewriteUnpackOp (Maybe Ident) Level UnpackOp BackendExpr
| RewriteDistBranchesLet (Maybe Ident) Level (NonEmptyArray (Pair BackendExpr)) BackendExpr BackendExpr
| RewriteDistBranchesOp (NonEmptyArray (Pair BackendExpr)) BackendExpr DistOp
Expand Down Expand Up @@ -164,6 +166,8 @@ newtype Env = Env
, evalExtern :: Env -> Qualified Ident -> Array ExternSpine -> Maybe BackendSemantics
, locals :: Array (LocalBinding BackendSemantics)
, directives :: InlineDirectiveMap
, punt :: Set.Set (Qualified Ident)
, blockNextRecursion :: Boolean
}

derive instance Newtype Env _
Expand All @@ -183,6 +187,16 @@ insertDirective ref acc dir = Map.alter
Just $ Map.singleton acc dir
ref

recursable :: Array BackendSemantics -> Boolean
recursable = go
where
go arr
| Just { head, tail } <- Array.uncons arr = case head of
NeutData _ _ _ _ _ -> go tail
NeutLit _ -> go tail
_ -> false
| otherwise = true

addStop :: Env -> EvalRef -> InlineAccessor -> Env
addStop (Env env) ref acc = Env env
{ directives = Map.alter
Expand All @@ -195,22 +209,31 @@ addStop (Env env) ref acc = Env env
env.directives
}

puntMe :: Env -> Array (Qualified Ident) -> Env
puntMe (Env env) quals = Env env
{ punt = Set.union env.punt (Set.fromFoldable quals)
}

class Eval f where
eval :: Env -> f -> BackendSemantics

instance Eval f => Eval (BackendSyntax f) where
eval env = case _ of
Var qual ->
evalExtern env qual []
eval env@(Env envx@{ punt }) = case _ of
Var qual
| qual `Set.member` punt -> RecurseWithRecklessAbandon qual
| otherwise -> evalExtern env qual []
Local ident lvl ->
case lookupLocal env lvl of
Just (One sem) -> sem
Just (Group group) | Just sem <- flip Tuple.lookup group =<< ident ->
force sem
_ ->
unsafeCrashWith $ "Unbound local at level " <> show (unwrap lvl)
App hd tl ->
evalApp env (eval env hd) (NonEmptyArray.toArray (eval env <$> tl))
App hd tl -> do
let tailEvaled = NonEmptyArray.toArray (eval env <$> tl)
let canRecurse = recursable tailEvaled
let headEvaled = eval (Env envx { blockNextRecursion = not canRecurse }) hd
evalApp env headEvaled tailEvaled
UncurriedApp hd tl ->
evalUncurriedApp env (eval env hd) (eval env <$> tl)
UncurriedAbs idents body -> do
Expand Down Expand Up @@ -274,11 +297,12 @@ instance Eval f => Eval (BackendSyntax f) where
guardFailOver snd (map (eval env) <$> fields) $ NeutData qual ct ty tag

instance Eval BackendExpr where
eval = go
eval (Env e@{ blockNextRecursion }) = go (Env e { blockNextRecursion = false })
where
go env = case _ of
ExprRewrite _ rewrite ->
case rewrite of
RewriteRecurse ident -> if blockNextRecursion then mkSemExtern ident [] else eval env (Var ident :: BackendSyntax BackendExpr)
RewriteInline _ _ binding body ->
go (bindLocal env (One (eval env binding))) body
RewriteUncurry ident _ args binding body ->
Expand Down Expand Up @@ -826,9 +850,12 @@ primOpOrdNot = case _ of
OpGt -> OpLte
OpGte -> OpLt

mkSemExtern :: Qualified Ident -> Array ExternSpine -> BackendSemantics
mkSemExtern qual spine = SemExtern qual spine (defer \_ -> neutralSpine (NeutVar qual) spine)

evalExtern :: Env -> Qualified Ident -> Array ExternSpine -> BackendSemantics
evalExtern env@(Env e) qual spine = case e.evalExtern env qual spine of
Nothing -> SemExtern qual spine (defer \_ -> neutralSpine (NeutVar qual) spine)
Nothing -> mkSemExtern qual spine
Just sem -> sem

envForGroup :: Env -> EvalRef -> InlineAccessor -> Array (Qualified Ident) -> Env
Expand All @@ -846,15 +873,15 @@ evalExternFromImpl env@(Env e) qual (Tuple analysis impl) spine = case spine of
Just InlineNever ->
Just $ NeutStop qual
Just InlineAlways ->
Just $ eval (envForGroup env ref InlineRef group) expr
Just $ eval (puntMe env group) expr
Just (InlineArity _) ->
Nothing
_ ->
case expr of
NeutralExpr (Lit lit) | shouldInlineExternLiteral lit ->
Just $ eval (envForGroup env ref InlineRef group) expr
Just $ eval (puntMe env group) expr
_ | shouldInlineExternReference qual analysis expr ->
Just $ eval (envForGroup env ref InlineRef group) expr
Just $ eval (puntMe env group) expr
_ ->
Nothing
ExternCtor _ ct ty tag [] ->
Expand Down Expand Up @@ -1056,6 +1083,7 @@ quote = go
foldr (buildBranchCond ctx) (quote ctx <<< force $ def) branches'

-- Non-block constructors
RecurseWithRecklessAbandon ident -> ExprRewrite (withRewrite (analyzeDefault ctx (Var ident))) $ RewriteRecurse ident
SemExtern _ _ sem ->
go ctx (force sem)
SemLam ident k -> do
Expand Down Expand Up @@ -1529,6 +1557,8 @@ freeze init = Tuple (analysisOf init) (go init)
NeutralExpr $ Let ident level (NeutralExpr (Abs args (go binding))) (go body)
RewriteStop qual ->
NeutralExpr $ Var qual
RewriteRecurse qual ->
NeutralExpr $ Var qual
RewriteLetAssoc bindings body ->
case NonEmptyArray.fromArray bindings of
Just bindings' -> do
Expand Down