Skip to content

Commit

Permalink
Merge pull request #856 from rybern/constraint-refactor-2
Browse files Browse the repository at this point in the history
Call new deserializer backend for constrained reads
  • Loading branch information
seantalts authored Apr 3, 2021
2 parents 036d930 + b8e666e commit 282439e
Show file tree
Hide file tree
Showing 35 changed files with 7,686 additions and 16,118 deletions.
1 change: 0 additions & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import org.stan.Utils

def utils = new org.stan.Utils()
def skipExpressionTests = false

/* Functions that runs a sh command and returns the stdout */
def runShell(String command){
def output = sh (returnStdout: true, script: "${command}").trim()
Expand Down
8 changes: 4 additions & 4 deletions src/analysis_and_optimization/Factor_graph.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ let extract_factors_statement stmt =
match stmt with
| Stmt.Fixed.Pattern.TargetPE e ->
List.map (summation_terms e) ~f:(fun x -> TargetTerm x)
| NRFunApp (_, f, _) when Internal_fun.of_string_opt f = Some FnReject ->
[Reject]
| NRFunApp (_, s, args) when String.suffix s 3 = "_lp" ->
| NRFunApp (CompilerInternal FnReject, _) -> [Reject]
| NRFunApp ((UserDefined s | StanLib s), args) when String.suffix s 3 = "_lp"
->
[LPFunction (s, args)]
| Assignment (_, _)
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand Down
14 changes: 7 additions & 7 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ let rec num_expr_value (v : Expr.Typed.t) : (float * string) option =
| {pattern= Fixed.Pattern.Lit (Real, str); _}
|{pattern= Fixed.Pattern.Lit (Int, str); _} ->
Some (float_of_string str, str)
| {pattern= Fixed.Pattern.FunApp (StanLib, "PMinus__", [v]); _} -> (
| {pattern= Fixed.Pattern.FunApp (StanLib "PMinus__", [v]); _} -> (
match num_expr_value v with
| Some (v, s) -> Some (-.v, "-" ^ s)
| None -> None )
Expand Down Expand Up @@ -252,7 +252,7 @@ let rec expr_var_set Expr.Fixed.({pattern; meta}) =
match pattern with
| Var s -> Set.Poly.singleton (VVar s, meta)
| Lit _ -> Set.Poly.empty
| FunApp (_, _, exprs) -> union_recur exprs
| FunApp (_, exprs) -> union_recur exprs
| TernaryIf (expr1, expr2, expr3) -> union_recur [expr1; expr2; expr3]
| Indexed (expr, ix) ->
Set.Poly.union_list (expr_var_set expr :: List.map ix ~f:index_var_set)
Expand All @@ -270,7 +270,7 @@ and index_var_set ix =
let stmt_rhs stmt =
match stmt with
| Stmt.Fixed.Pattern.For vars -> Set.Poly.of_list [vars.lower; vars.upper]
| NRFunApp (_, _, exprs) -> Set.Poly.of_list exprs
| NRFunApp (_, exprs) -> Set.Poly.of_list exprs
| IfElse (rhs, _, _)
|While (rhs, _)
|Assignment (_, rhs)
Expand All @@ -296,7 +296,7 @@ let expr_assigned_var Expr.Fixed.({pattern; _}) =
(** See interface file *)
let rec summation_terms (Expr.Fixed.({pattern; _}) as rhs) =
match pattern with
| FunApp (_, "Plus__", [e1; e2]) ->
| FunApp (StanLib "Plus__", [e1; e2]) ->
List.append (summation_terms e1) (summation_terms e2)
| _ -> [rhs]

Expand Down Expand Up @@ -356,7 +356,7 @@ let expr_subst_stmt m = map_rec_stmt_loc (expr_subst_stmt_base m)
let rec expr_depth Expr.Fixed.({pattern; _}) =
match pattern with
| Var _ | Lit (_, _) -> 0
| FunApp (_, _, l) ->
| FunApp (_, l) ->
1
+ Option.value ~default:0
(List.max_elt ~compare:compare_int (List.map ~f:expr_depth l))
Expand Down Expand Up @@ -394,9 +394,9 @@ let rec update_expr_ad_levels autodiffable_variables
Expr.Typed.{e with meta= Meta.{e.meta with adlevel= AutoDiffable}}
else {e with meta= {e.meta with adlevel= DataOnly}}
| Lit (_, _) -> {e with meta= {e.meta with adlevel= DataOnly}}
| FunApp (o, f, l) ->
| FunApp (kind, l) ->
let l = List.map ~f:(update_expr_ad_levels autodiffable_variables) l in
{pattern= FunApp (o, f, l); meta= {e.meta with adlevel= ad_level_sup l}}
{pattern= FunApp (kind, l); meta= {e.meta with adlevel= ad_level_sup l}}
| TernaryIf (e1, e2, e3) ->
let e1 = update_expr_ad_levels autodiffable_variables e1 in
let e2 = update_expr_ad_levels autodiffable_variables e2 in
Expand Down
38 changes: 22 additions & 16 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ let rec free_vars_expr (e : Expr.Typed.t) =
match e.pattern with
| Var x -> Set.Poly.singleton x
| Lit (_, _) -> Set.Poly.empty
| FunApp (_, f, l) ->
Set.Poly.union_list (Set.Poly.singleton f :: List.map ~f:free_vars_expr l)
| FunApp (kind, l) -> free_vars_fnapp kind l
| TernaryIf (e1, e2, e3) ->
Set.Poly.union_list (List.map ~f:free_vars_expr [e1; e2; e3])
| Indexed (e, l) ->
Expand All @@ -45,6 +44,13 @@ and free_vars_idx (i : Expr.Typed.t Index.t) =
| Single e | Upfrom e | MultiIndex e -> free_vars_expr e
| Between (e1, e2) -> Set.Poly.union (free_vars_expr e1) (free_vars_expr e2)

and free_vars_fnapp kind l =
let arg_vars = List.map ~f:free_vars_expr l in
match kind with
| Fun_kind.UserDefined f ->
Set.Poly.union_list (Set.Poly.singleton f :: List.map ~f:free_vars_expr l)
| _ -> Set.Poly.union_list arg_vars

(** Calculate the free (non-bound) variables in a statement *)
let rec free_vars_stmt
(s : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) =
Expand All @@ -53,8 +59,7 @@ let rec free_vars_stmt
free_vars_expr e
| Assignment ((_, _, l), e) ->
Set.Poly.union_list (free_vars_expr e :: List.map ~f:free_vars_idx l)
| NRFunApp (_, f, l) ->
Set.Poly.union_list (Set.Poly.singleton f :: List.map ~f:free_vars_expr l)
| NRFunApp (kind, l) -> free_vars_fnapp kind l
| IfElse (e, b1, Some b2) ->
Set.Poly.union_list
[free_vars_expr e; free_vars_stmt b1.pattern; free_vars_stmt b2.pattern]
Expand Down Expand Up @@ -314,7 +319,7 @@ let constant_propagation_transfer
| Decl {decl_id= s; _} | Assignment ((s, _, _ :: _), _) ->
Map.remove m s
| TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand Down Expand Up @@ -373,7 +378,7 @@ let expression_propagation_transfer
in
Set.Poly.fold kills ~init:m ~f:kill_var
| TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand Down Expand Up @@ -414,7 +419,7 @@ let copy_propagation_transfer (globals : string Set.Poly.t)
in
Set.Poly.fold kills ~init:m ~f:kill_var
| TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand All @@ -435,11 +440,11 @@ let assigned_vars_stmt (s : (Expr.Typed.t, 'a) Stmt.Fixed.Pattern.t) =
match s with
| Assignment ((x, _, _), _) -> Set.Poly.singleton x
| TargetPE _ -> Set.Poly.singleton "target"
| NRFunApp (_, s, _) when String.suffix s 3 = "_lp" ->
| NRFunApp ((UserDefined s | StanLib s), _) when String.suffix s 3 = "_lp" ->
Set.Poly.singleton "target"
| For {loopvar= x; _} -> Set.Poly.singleton x
| Decl {decl_id= _; _}
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand Down Expand Up @@ -478,9 +483,10 @@ let reaching_definitions_transfer
|For {loopvar= x; _} ->
Set.filter p ~f:(fun (y, _) -> y = x)
| TargetPE _ -> Set.filter p ~f:(fun (y, _) -> y = "target")
| NRFunApp (_, s, _) when String.suffix s 3 = "_lp" ->
| NRFunApp ((UserDefined s | StanLib s), _)
when String.suffix s 3 = "_lp" ->
Set.filter p ~f:(fun (y, _) -> y = "target")
| NRFunApp (_, _, _)
| NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand Down Expand Up @@ -523,7 +529,7 @@ let live_variables_transfer (never_kill : string Set.Poly.t)
| Assignment ((x, _, []), _) | Decl {decl_id= x; _} ->
Set.Poly.singleton x
| TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip
|IfElse (_, _, _)
|While (_, _)
Expand All @@ -542,7 +548,7 @@ let rec used_subexpressions_expr (e : Expr.Typed.t) =
(Expr.Typed.Set.singleton e)
( match e.pattern with
| Var _ | Lit (_, _) -> Expr.Typed.Set.empty
| FunApp (_, _, l) ->
| FunApp (_, l) ->
Expr.Typed.Set.union_list (List.map ~f:used_subexpressions_expr l)
| TernaryIf (e1, e2, e3) ->
Expr.Typed.Set.union_list
Expand Down Expand Up @@ -580,7 +586,7 @@ let rec used_expressions_stmt_help f
[ f e
; used_expressions_stmt_help f b1.pattern
; used_expressions_stmt_help f b2.pattern ]
| NRFunApp (_, _, l) -> Expr.Typed.Set.union_list (List.map ~f l)
| NRFunApp (_, l) -> Expr.Typed.Set.union_list (List.map ~f l)
| Decl _ | Return None | Break | Continue | Skip -> Expr.Typed.Set.empty
| IfElse (e, b, None) | While (e, b) ->
Expr.Typed.Set.union (f e) (used_expressions_stmt_help f b.pattern)
Expand Down Expand Up @@ -614,7 +620,7 @@ let top_used_expressions_stmt_help f
(Expr.Typed.Set.union_list
(List.map ~f:(used_expressions_idx_help f) l))
| While (e, _) | IfElse (e, _, _) -> f e
| NRFunApp (_, _, l) -> Expr.Typed.Set.union_list (List.map ~f l)
| NRFunApp (_, l) -> Expr.Typed.Set.union_list (List.map ~f l)
| Profile _ | Block _ | SList _ | Decl _
|Return None
|Break | Continue | Skip ->
Expand Down Expand Up @@ -899,7 +905,7 @@ let rec declared_variables_stmt
| Decl {decl_id= x; _} -> Set.Poly.singleton x
| Assignment (_, _)
|TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip ->
Set.Poly.empty
| IfElse (_, b1, Some b2) ->
Expand Down
114 changes: 63 additions & 51 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ let rec inline_function_expression propto adt fim
match pattern with
| Var _ -> ([], [], e)
| Lit (_, _) -> ([], [], e)
| FunApp (t, s, es) -> (
| FunApp (kind, es) -> (
let dse_list =
List.map ~f:(inline_function_expression propto adt fim) es
in
Expand All @@ -231,30 +231,45 @@ let rec inline_function_expression propto adt fim
List.concat (List.rev (List.map ~f:(function _, x, _ -> x) dse_list))
in
let es = List.map ~f:(function _, _, x -> x) dse_list in
let s = if propto then s else Middle.Utils.stdlib_distribution_name s in
match Map.find fim s with
| None -> (d_list, s_list, {e with pattern= FunApp (t, s, es)})
| Some (rt, args, b) ->
let x = Gensym.generate ~prefix:"inline_" () in
let handle = handle_early_returns (Some x) in
let d_list2, s_list2, (e : Expr.Typed.t) =
( [ Stmt.Fixed.Pattern.Decl
{decl_adtype= adt; decl_id= x; decl_type= Option.value_exn rt}
]
(* We should minimize the code that's having its variables
match kind with
| CompilerInternal _ ->
(d_list, s_list, {e with pattern= FunApp (kind, es)})
| UserDefined fname | StanLib fname -> (
let fname =
if propto then fname
else Middle.Utils.stdlib_distribution_name fname
in
match Map.find fim fname with
| None ->
let fun_kind =
match kind with
| Fun_kind.UserDefined _ -> Fun_kind.UserDefined fname
| _ -> StanLib fname
in
(d_list, s_list, {e with pattern= FunApp (fun_kind, es)})
| Some (rt, args, b) ->
let x = Gensym.generate ~prefix:"inline_" () in
let handle = handle_early_returns (Some x) in
let d_list2, s_list2, (e : Expr.Typed.t) =
( [ Stmt.Fixed.Pattern.Decl
{ decl_adtype= adt
; decl_id= x
; decl_type= Option.value_exn rt } ]
(* We should minimize the code that's having its variables
replaced to avoid conflict with the (two) new dummy
variables introduced by inlining *)
, [handle (replace_fresh_local_vars (subst_args_stmt args es b))]
, { pattern= Var x
; meta=
Expr.Typed.Meta.
{ type_= Type.to_unsized (Option.value_exn rt)
; adlevel= adt
; loc= Location_span.empty } } )
in
let d_list = d_list @ d_list2 in
let s_list = s_list @ s_list2 in
(d_list, s_list, e) )
, [ handle
(replace_fresh_local_vars (subst_args_stmt args es b)) ]
, { pattern= Var x
; meta=
Expr.Typed.Meta.
{ type_= Type.to_unsized (Option.value_exn rt)
; adlevel= adt
; loc= Location_span.empty } } )
in
let d_list = d_list @ d_list2 in
let s_list = s_list @ s_list2 in
(d_list, s_list, e) ) )
| TernaryIf (e1, e2, e3) ->
let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in
let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in
Expand Down Expand Up @@ -347,7 +362,7 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.({pattern; meta}) =
| TargetPE e ->
let d, s, e = inline_function_expression propto adt fim e in
slist_concat_no_loc (d @ s) (TargetPE e)
| NRFunApp (t, s, es) ->
| NRFunApp (kind, es) ->
let dse_list =
List.map ~f:(inline_function_expression propto adt fim) es
in
Expand All @@ -362,14 +377,17 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.({pattern; meta}) =
in
let es = List.map ~f:(function _, _, x -> x) dse_list in
slist_concat_no_loc (d_list @ s_list)
( match Map.find fim s with
| None -> NRFunApp (t, s, es)
| Some (_, args, b) ->
let b = replace_fresh_local_vars b in
let b = handle_early_returns None b in
(subst_args_stmt args es
{pattern= b; meta= Location_span.empty})
.pattern )
( match kind with
| CompilerInternal _ -> NRFunApp (kind, es)
| UserDefined s | StanLib s -> (
match Map.find fim s with
| None -> NRFunApp (kind, es)
| Some (_, args, b) ->
let b = replace_fresh_local_vars b in
let b = handle_early_returns None b in
(subst_args_stmt args es
{pattern= b; meta= Location_span.empty})
.pattern ) )
| Return e -> (
match e with
| None -> Return None
Expand Down Expand Up @@ -499,7 +517,7 @@ let rec contains_top_break_or_continue Stmt.Fixed.({pattern; _}) =
| Break | Continue -> true
| Assignment (_, _)
|TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Return _ | Decl _
|While (_, _)
|For _ | Skip ->
Expand Down Expand Up @@ -565,7 +583,7 @@ let unroll_loop_one_step_statement _ =
else
IfElse
( Expr.Fixed.
{lower with pattern= FunApp (StanLib, "Geq__", [upper; lower])}
{lower with pattern= FunApp (StanLib "Geq__", [upper; lower])}
, { pattern=
(let body_unrolled =
subst_args_stmt [loopvar] [lower]
Expand All @@ -581,8 +599,7 @@ let unroll_loop_one_step_statement _ =
{ lower with
pattern=
FunApp
( StanLib
, "Plus__"
( StanLib "Plus__"
, [lower; Expr.Helpers.loop_bottom] ) } }
; meta= Location_span.empty }
in
Expand Down Expand Up @@ -666,26 +683,21 @@ and accum_any pred b e = b || expr_any pred e

let can_side_effect_top_expr (e : Expr.Typed.t) =
match e.pattern with
| FunApp (t, f, _) ->
String.suffix f 3 = "_lp"
|| (t = CompilerInternal && f = Internal_fun.to_string FnReadParam)
|| (t = CompilerInternal && f = Internal_fun.to_string FnReadData)
|| (t = CompilerInternal && f = Internal_fun.to_string FnWriteParam)
|| (t = CompilerInternal && f = Internal_fun.to_string FnConstrain)
|| (t = CompilerInternal && f = Internal_fun.to_string FnValidateSize)
|| (t = CompilerInternal && f = Internal_fun.to_string FnValidateSize)
|| t = CompilerInternal
&& f = Internal_fun.to_string FnValidateSizeSimplex
|| t = CompilerInternal
&& f = Internal_fun.to_string FnValidateSizeUnitVector
|| (t = CompilerInternal && f = Internal_fun.to_string FnUnconstrain)
| FunApp ((UserDefined f | StanLib f), _) -> String.suffix f 3 = "_lp"
| FunApp
( CompilerInternal
( FnReadParam _ | FnReadData | FnWriteParam | FnConstrain _
| FnValidateSize | FnValidateSizeSimplex | FnValidateSizeUnitVector
| FnUnconstrain _ )
, _ ) ->
true
| _ -> false

let cannot_duplicate_expr (e : Expr.Typed.t) =
let pred e =
can_side_effect_top_expr e
|| ( match e.pattern with
| FunApp (_, f, _) -> String.suffix f 4 = "_rng"
| FunApp ((UserDefined f | StanLib f), _) -> String.suffix f 4 = "_rng"
| _ -> false )
|| (preserve_stability && UnsizedType.is_autodiffable e.meta.type_)
in
Expand Down Expand Up @@ -746,7 +758,7 @@ let dead_code_elimination (mir : Program.Typed.t) =
due to side effects. *)
(* TODO: maybe we should revisit that. *)
| Decl _ | TargetPE _
|NRFunApp (_, _, _)
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip ->
stmt
| IfElse (e, b1, b2) -> (
Expand Down
Loading

0 comments on commit 282439e

Please sign in to comment.