Skip to content

Commit

Permalink
Insert extra let bindings to avoid evaluating expressions multiple ti…
Browse files Browse the repository at this point in the history
…mes during pattern matching
  • Loading branch information
Kmeakin committed Dec 12, 2022
1 parent 0dd795c commit d24c4f9
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 53 deletions.
24 changes: 24 additions & 0 deletions fathom/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,20 @@ impl<'arena> Term<'arena> {
),
}
}

pub fn is_trivial(&self) -> bool {
match self {
Term::ItemVar(_, _)
| Term::LocalVar(_, _)
| Term::MetaVar(_, _)
| Term::InsertedMeta(_, _, _)
| Term::Universe(_)
| Term::Prim(_, _)
| Term::ConstLit(_, _) => true,
Term::RecordProj(_, head, _) => head.is_trivial(),
_ => false,
}
}
}

/// Simple patterns that have had some initial elaboration performed on them
Expand Down Expand Up @@ -431,6 +445,16 @@ impl<'arena> CheckedPattern<'arena> {
CheckedPattern::ConstLit(_, _) | CheckedPattern::RecordLit(_, _, _) => false,
}
}

pub fn is_trivial(&self) -> bool {
match self {
CheckedPattern::ReportedError(_)
| CheckedPattern::Placeholder(_)
| CheckedPattern::Binder(_, _)
| CheckedPattern::ConstLit(_, _) => true,
CheckedPattern::RecordLit(_, _, _) => false,
}
}
}

macro_rules! def_prims {
Expand Down
4 changes: 2 additions & 2 deletions fathom/src/surface/distillation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,11 @@ impl<'interner, 'arena, 'env> Context<'interner, 'arena, 'env> {
match core_term {
core::Term::ItemVar(_span, var) => match self.get_item_name(*var) {
Some(name) => Term::Name((), name),
None => todo!("misbound variable"), // TODO: error?
None => panic!("misbound item variable: {var:?}"),
},
core::Term::LocalVar(_span, var) => match self.get_local_name(*var) {
Some(name) => Term::Name((), name),
None => todo!("misbound variable"), // TODO: error?
None => panic!("Unbound local variable: {var:?}"),
},
core::Term::MetaVar(_span, var) => match self.get_hole_name(*var) {
Some(name) => Term::Hole((), name),
Expand Down
114 changes: 105 additions & 9 deletions fathom/src/surface/elaboration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,21 +804,52 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
match (surface_term, expected_type.as_ref()) {
(Term::Let(range, def_pattern, def_type, def_expr, body_expr), _) => {
let (def_pattern, def_type_value) = self.synth_ann_pattern(def_pattern, *def_type);
let scrut = self.check_scrutinee(def_expr, def_type_value.clone());
let mut scrut = self.check_scrutinee(def_expr, def_type_value.clone());
let value = self.eval_env().eval(scrut.expr);

// Bind the scrut to a fresh variable if it is unsafe to evaluate multiple times,
// and may be evaluated multiple times by the pattern match compiler
let extra_def = match (scrut.expr.is_trivial(), def_pattern.is_trivial()) {
(false, false) => {
let def_name = None; // TODO: generate a fresh name
let def_type = self.quote_env().quote(self.scope, &scrut.r#type);
let def_expr = scrut.expr.clone();

let var = core::Term::LocalVar(def_expr.span(), env::Index::last());
scrut.expr = self.scope.to_scope(var);
(self.local_env).push_def(def_name, value.clone(), scrut.r#type.clone());
Some((def_name, def_type, def_expr))
}
_ => None,
};

let initial_len = self.local_env.len();
self.push_local_def(&def_pattern, value, scrut.r#type.clone());
let body_expr = self.check(body_expr, &expected_type);
self.local_env.truncate(initial_len);

let matrix = PatMatrix::singleton(scrut, def_pattern);
self.elab_match(
let expr = self.elab_match(
matrix,
&[body_expr],
*range,
def_expr.range(),
PatternMode::Let,
)
);
let expr = match extra_def {
None => expr,
Some((def_name, def_type, def_expr)) => {
self.local_env.pop();
core::Term::Let(
range.into(),
def_name,
self.scope.to_scope(def_type),
self.scope.to_scope(def_expr),
self.scope.to_scope(expr),
)
}
};
expr
}
(Term::If(range, cond_expr, then_expr, else_expr), _) => {
let cond_expr = self.check(cond_expr, &self.bool_type.clone());
Expand Down Expand Up @@ -1110,9 +1141,25 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
}
Term::Let(range, def_pattern, def_type, def_expr, body_expr) => {
let (def_pattern, def_type_value) = self.synth_ann_pattern(def_pattern, *def_type);
let scrut = self.check_scrutinee(def_expr, def_type_value.clone());
let mut scrut = self.check_scrutinee(def_expr, def_type_value.clone());
let value = self.eval_env().eval(scrut.expr);

// Bind the scrut to a fresh variable if it is unsafe to evaluate multiple times,
// and may be evaluated multiple times by the pattern match compiler
let extra_def = match (scrut.expr.is_trivial(), def_pattern.is_trivial()) {
(false, false) => {
let def_name = None; // TODO: generate a fresh name
let def_type = self.quote_env().quote(self.scope, &scrut.r#type);
let def_expr = scrut.expr.clone();

let var = core::Term::LocalVar(def_expr.span(), env::Index::last());
scrut.expr = self.scope.to_scope(var);
(self.local_env).push_def(def_name, value.clone(), scrut.r#type.clone());
Some((def_name, def_type, def_expr))
}
_ => None,
};

let initial_len = self.local_env.len();
self.push_local_def(&def_pattern, value, scrut.r#type.clone());
let (body_expr, body_type) = self.synth(body_expr);
Expand All @@ -1126,6 +1173,19 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
def_expr.range(),
PatternMode::Let,
);
let expr = match extra_def {
None => expr,
Some((def_name, def_type, def_expr)) => {
self.local_env.pop();
core::Term::Let(
range.into(),
def_name,
self.scope.to_scope(def_type),
self.scope.to_scope(def_expr),
self.scope.to_scope(expr),
)
}
};
(expr, body_type)
}
Term::If(range, cond_expr, then_expr, else_expr) => {
Expand Down Expand Up @@ -1817,15 +1877,37 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
expected_type: &ArcValue<'arena>,
) -> core::Term<'arena> {
let expected_type = self.elim_env().force(expected_type);
let scrut = self.synth_scrutinee(scrutinee_expr);
let mut scrut = self.synth_scrutinee(scrutinee_expr);
let value = self.eval_env().eval(scrut.expr);

let patterns: Vec<_> = equations
.iter()
.map(|(pat, _)| self.check_pattern(pat, &scrut.r#type))
.collect();

// Bind the scrut to a fresh variable if it is unsafe to evaluate multiple times,
// and may be evaluated multiple times by the pattern match compiler
let extra_def = match (
scrut.expr.is_trivial(),
patterns.iter().all(|pat| pat.is_trivial()),
) {
(false, false) => {
let def_name = None; // TODO: generate a fresh name
let def_type = self.quote_env().quote(self.scope, &scrut.r#type);
let def_expr = scrut.expr.clone();

let var = core::Term::LocalVar(def_expr.span(), env::Index::last());
scrut.expr = self.scope.to_scope(var);
(self.local_env).push_def(def_name, value.clone(), scrut.r#type.clone());
Some((def_name, def_type, def_expr))
}
_ => None,
};

let mut rows = Vec::with_capacity(equations.len());
let mut exprs = Vec::with_capacity(equations.len());

for (pat, expr) in equations {
let pattern = self.check_pattern(pat, &scrut.r#type);

for (pattern, (_, expr)) in patterns.into_iter().zip(equations) {
let initial_len = self.local_env.len();
self.push_pattern(
&pattern,
Expand All @@ -1841,7 +1923,21 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
}

let matrix = patterns::PatMatrix::new(rows);
self.elab_match(matrix, &exprs, range, scrut.range, PatternMode::Match)
let expr = self.elab_match(matrix, &exprs, range, scrut.range, PatternMode::Match);
let expr = match extra_def {
None => expr,
Some((def_name, def_type, def_expr)) => {
self.local_env.pop();
core::Term::Let(
range.into(),
def_name,
self.scope.to_scope(def_type),
self.scope.to_scope(def_expr),
self.scope.to_scope(expr),
)
}
};
expr
}

fn synth_scrutinee(&mut self, scrutinee_expr: &Term<'_, ByteRange>) -> Scrutinee<'arena> {
Expand Down
13 changes: 9 additions & 4 deletions tests/succeed/record-patterns/let-check.snap
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
stdout = '''
let _ : () = ();
let x : Bool = (false, true)._0;
let y : Bool = (false, true)._1;
let a : Bool = (false, true)._0;
let b : Bool = (false, true)._1;
let _ : () = ();
let _ : (Bool, Bool) = (false, true);
let x : Bool = _._0;
let y : Bool = _._1;
let _ : (Bool, Bool) = (false, true);
let a : Bool = _._0;
let _ : (Bool, Bool) = (false, true);
let b : Bool = _._1;
let _ : (Bool, Bool) = (false, true);
let _ : (Bool, Bool) = (false, true);
() : ()
'''
Expand Down
13 changes: 9 additions & 4 deletions tests/succeed/record-patterns/let-synth.snap
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
stdout = '''
let _ : () = ();
let x : Bool = (false, true)._0;
let y : Bool = (false, true)._1;
let a : Bool = (false, true)._0;
let b : Bool = (false, true)._1;
let _ : () = ();
let _ : (Bool, Bool) = (false, true);
let x : Bool = _._0;
let y : Bool = _._1;
let _ : (Bool, Bool) = (false, true);
let a : Bool = _._0;
let _ : (Bool, Bool) = (false, true);
let b : Bool = _._1;
let _ : (Bool, Bool) = (false, true);
let _ : (Bool, Bool) = (false, true);
() : ()
'''
Expand Down
52 changes: 18 additions & 34 deletions tests/succeed/record-patterns/match-bool-pairs.snap
Original file line number Diff line number Diff line change
@@ -1,38 +1,22 @@
stdout = '''
let and1 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
then if (x, y)._1 then true else false
else false;
let and2 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
then if (x, y)._1 then true else false
else false;
let and3 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
then if (x, y)._1 then true else false
else if (x, y)._1 then false
else false;
let or1 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
then true
else if (x, y)._1 then true
else false;
let or2 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
then true
else if (x, y)._1 then true
else false;
let or3 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
then if (x, y)._1 then true else true
else if (x, y)._1 then true
else false;
let xor1 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
then if (x, y)._1 then false else true
else if (x, y)._1 then true
else false;
let xor2 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
then if (x, y)._1 then false else true
else if (x, y)._1 then true
else false;
let xor3 : Bool -> Bool -> Bool = fun x y => if (x, y)._0
then if (x, y)._1 then false else true
else if (x, y)._1 then true
else false;
let and1 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
if _._0 then if _._1 then true else false else false;
let and2 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
if _._0 then if _._1 then true else false else false;
let and3 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
if _._0 then if _._1 then true else false else if _._1 then false else false;
let or1 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
if _._0 then true else if _._1 then true else false;
let or2 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
if _._0 then true else if _._1 then true else false;
let or3 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
if _._0 then if _._1 then true else true else if _._1 then true else false;
let xor1 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
if _._0 then if _._1 then false else true else if _._1 then true else false;
let xor2 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
if _._0 then if _._1 then false else true else if _._1 then true else false;
let xor3 : Bool -> Bool -> Bool = fun x y => let _ : (Bool, Bool) = (x, y);
if _._0 then if _._1 then false else true else if _._1 then true else false;
() : ()
'''
stderr = ''

0 comments on commit d24c4f9

Please sign in to comment.