Skip to content

Commit

Permalink
Merge branch 'master' into feature/beta_neg_binomial_lccdf
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Nov 14, 2024
2 parents 48daf70 + 6915754 commit 9745565
Show file tree
Hide file tree
Showing 148 changed files with 3,930 additions and 242 deletions.
5 changes: 5 additions & 0 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e)
, Fun_kind.with_unnormalized_suffix fname
|> Option.value ~default:fname )
| FnLpdf _ -> (Fun_kind.FnLpdf false, fname)
| FnLpmf propto' when propto' && propto ->
( Fun_kind.FnLpmf true
, Fun_kind.with_unnormalized_suffix fname
|> Option.value ~default:fname )
| FnLpmf _ -> (FnLpmf false, fname)
| _ -> (suffix, fname) in
match Map.find fim fname' with
| None ->
Expand Down
7 changes: 5 additions & 2 deletions src/analysis_and_optimization/Pedantic_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ let list_possible_nonlinear (mir : Program.Typed.t) : Location_span.t Set.Poly.t
| Stmt.Fixed.Pattern.TargetPE
{ pattern=
Expr.Fixed.Pattern.FunApp
((StanLib (_, FnLpdf _, _) | UserDefined (_, FnLpdf _)), e :: _)
( ( StanLib (_, (FnLpdf _ | FnLpmf _), _)
| UserDefined (_, (FnLpdf _ | FnLpmf _)) )
, e :: _ )
; _ }
when not (is_linear true e) ->
Set.Poly.singleton stmt.meta
Expand Down Expand Up @@ -307,7 +309,8 @@ let compiletime_value_of_expr
let list_distributions (mir : Program.Typed.t) : dist_info Set.Poly.t =
let take_dist (expr : Expr.Typed.t) =
match expr.pattern with
| Expr.Fixed.Pattern.FunApp (StanLib (fname, FnLpdf true, _), arg_exprs) ->
| Expr.Fixed.Pattern.FunApp
(StanLib (fname, (FnLpdf true | FnLpmf true), _), arg_exprs) ->
let fname = chop_dist_name fname |> Option.value_exn in
let params = parameter_set mir in
let data = data_set mir in
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ type ('e, 's, 'l, 'f) statement =
| Tilde of
{ arg: 'e
; distribution: identifier
; kind: 'f
; args: 'e list
; truncation: 'e truncation }
| Break
Expand Down
26 changes: 14 additions & 12 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -531,17 +531,18 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
NRFunApp (trans_fn_kind fn_kind name, trans_exprs args) |> swrap
| Ast.TargetPE e -> TargetPE (trans_expr e) |> swrap
| Ast.JacobianPE e -> JacobianPE (trans_expr e) |> swrap
| Ast.Tilde {arg; distribution; args; truncation} ->
let suffix =
Stan_math_signatures.dist_name_suffix ud_dists distribution.name in
let name = distribution.name ^ suffix in
let kind =
let possible_names =
List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices
|> String.Set.of_list in
if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists then
Fun_kind.UserDefined (name, FnLpdf true)
else StanLib (name, FnLpdf true, AoS) in
| Ast.Tilde {arg; distribution; args; truncation; kind} ->
let sfx =
match kind with
| UserDefined (FnLpdf _) | StanLib (FnLpdf _) -> "_lpdf"
| UserDefined (FnLpmf _) | StanLib (FnLpmf _) -> "_lpmf"
| _ ->
Common.ICE.internal_compiler_error
[%message
"Impossible: tilde with non-distribution after typechecking"
(distribution : Ast.identifier)
(kind : Ast.fun_kind)] in
let name = distribution.name ^ sfx in
let add_dist =
let adlevel =
if
Expand All @@ -551,7 +552,8 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
else DataOnly in
Stmt.Fixed.Pattern.TargetPE
Expr.
{ Fixed.pattern= FunApp (kind, trans_exprs (arg :: args))
{ Fixed.pattern=
FunApp (trans_fn_kind kind name, trans_exprs (arg :: args))
; meta= Typed.Meta.create ~type_:UReal ~loc:mloc ~adlevel () } in
swrap add_dist @ truncate_dist ud_dists distribution arg args truncation
| Ast.Print ps ->
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Deprecation_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ let expired (major, minor) =
let deprecated_functions = String.Map.of_alist_exn []
let stan_lib_deprecations = deprecated_functions

(* TODO deprecate other pre-variadics like algebra_solver? *)
let deprecated_odes =
String.Map.of_alist_exn
[ ("integrate_ode", ("ode_rk45", (3, 0)))
Expand Down
16 changes: 4 additions & 12 deletions src/frontend/Info.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,10 @@ let rec get_function_calls_stmt ud_dists (funs, distrs) stmt =
| Print _ -> (Set.add funs "print", distrs)
| Reject _ -> (Set.add funs "reject", distrs)
| FatalError _ -> (Set.add funs "fatal_error", distrs)
| Tilde {distribution; _} ->
let possible_names =
List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices
|> String.Set.of_list in
if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists then
(funs, distrs)
else
let suffix =
Stan_math_signatures.dist_name_suffix ud_dists distribution.name
in
let name = distribution.name ^ Utils.unnormalized_suffix suffix in
(funs, Set.add distrs name)
| Tilde {distribution; kind= StanLib (FnLpdf _); _} ->
(funs, Set.add distrs (distribution.name ^ "_lupdf"))
| Tilde {distribution; kind= StanLib (FnLpmf _); _} ->
(funs, Set.add distrs (distribution.name ^ "_lupmf"))
| _ -> (funs, distrs) in
fold_statement get_function_calls_expr
(get_function_calls_stmt ud_dists)
Expand Down
8 changes: 1 addition & 7 deletions src/frontend/Preprocessor.ml
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ let restore_prior_lexbuf () =
lexbuf.lex_start_p <- old_pos;
old_lexbuf

let maybe_remove_quotes str =
let open String in
if is_prefix str ~prefix:"\"" && is_suffix str ~suffix:"\"" then
drop_suffix (drop_prefix str 1) 1
else str

let find_include_fs lookup_paths fname =
let rec loop paths =
match paths with
Expand Down Expand Up @@ -131,7 +125,7 @@ let find_include fname =

let try_get_new_lexbuf fname =
let lexbuf = Stack.top_exn include_stack in
let new_lexbuf, file = find_include (maybe_remove_quotes fname) in
let new_lexbuf, file = find_include fname in
lexer_logger ("opened " ^ file);
new_lexbuf.lex_start_p <-
new_file_start_position file
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/Pretty_printing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ and pp_statement ppf ({stmt= s_content; smeta= {loc}} as ss : untyped_statement)
pf ppf "%a(@[%a);@]" pp_identifier id pp_list_of_expression (es, loc)
| TargetPE e -> pf ppf "target += %a;" pp_expression e
| JacobianPE e -> pf ppf "jacobian += %a;" pp_expression e
| Tilde {arg= e; distribution= id; args= es; truncation= t} ->
| Tilde {arg= e; distribution= id; args= es; truncation= t; kind= _} ->
pf ppf "%a ~ %a(@[%a)@]%a;" pp_expression e pp_identifier id
pp_list_of_expression (es, loc) pp_truncation t
| Break -> pf ppf "break;"
Expand Down
8 changes: 6 additions & 2 deletions src/frontend/SignatureMismatch.ml
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,10 @@ let check_variadic_args ~allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
let wrap_func_error x =
TypeMismatch (minimal_func_type, func_type, Some x) |> wrap_err in
let suffix = Fun_kind.without_propto suffix in
if suffix = FnPlain || (allow_lpdf && suffix = FnLpdf ()) then
if
suffix = FnPlain
|| (allow_lpdf && (suffix = FnLpdf () || suffix = FnLpmf ()))
then
match check_compatible_arguments 1 mandatory mandatory_fun_arg_tys with
| Error x -> wrap_func_error (InputMismatch x)
| Ok _ -> (
Expand All @@ -323,7 +326,8 @@ let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) =
let suffix_str = function
| Fun_kind.FnPlain -> "a pure function"
| FnRng -> "an rng function"
| FnLpdf () -> "a probability density or mass function"
| FnLpdf () -> "a probability density function"
| FnLpmf () -> "a probability mass function"
| FnTarget -> "an _lp function"
| FnJacobian -> "a _jacobian function" in
let index_str = function
Expand Down
75 changes: 32 additions & 43 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ let in_jacobian_function cf =

let in_udf_distribution cf =
match cf.containing_function with
| NonReturning (FnLpdf ()) | Returning (FnLpdf (), _) -> true
| NonReturning (FnLpdf ())
|Returning (FnLpdf (), _)
|NonReturning (FnLpmf ())
|Returning (FnLpmf (), _) ->
true
| _ -> false

let context block =
Expand Down Expand Up @@ -113,18 +117,12 @@ let verify_identifier id : unit =
"Variable name 'jacobian' will be a reserved word starting in Stan 2.38. \
Please rename it!";
if id.name = !model_name then
Semantic_error.ident_is_model_name id.id_loc id.name |> error
else if
Semantic_error.ident_is_model_name id.id_loc id.name |> error;
if
String.is_suffix id.name ~suffix:"__"
|| List.mem reserved_keywords id.name ~equal:String.equal
then Semantic_error.ident_is_keyword id.id_loc id.name |> error

let distribution_name_variants name =
match Utils.split_distribution_suffix name with
| Some (stem, "lpmf") -> [name; stem ^ "_lpdf"]
| Some (stem, "lpdf") -> [name; stem ^ "_lpmf"]
| _ -> [name]

(** verify that the variable being declared is previous unused.
allowed to shadow StanLib *)
let verify_name_fresh_var loc tenv name =
Expand Down Expand Up @@ -160,10 +158,8 @@ let verify_name_fresh_udf loc tenv name =
- is not already in use (for now)
*)
let verify_name_fresh tenv id ~is_udf =
let f =
if is_udf then verify_name_fresh_udf id.id_loc tenv
else verify_name_fresh_var id.id_loc tenv in
List.iter ~f (distribution_name_variants id.name)
if is_udf then verify_name_fresh_udf id.id_loc tenv id.name
else verify_name_fresh_var id.id_loc tenv id.name

let is_of_compatible_return_type rt1 srt2 =
UnsizedType.(
Expand Down Expand Up @@ -293,7 +289,7 @@ let check_id cf loc tenv id =
| {kind= `Variable {origin; _}; type_} :: _ ->
(calculate_autodifftype cf origin type_, type_)
| { kind= `UserDefined | `UserDeclared _
; type_= UFun (args, rt, FnLpdf _, mem_pattern) }
; type_= UFun (args, rt, (FnLpdf _ | FnLpmf _), mem_pattern) }
:: _ ->
let type_ =
UnsizedType.UFun
Expand Down Expand Up @@ -513,10 +509,10 @@ let check_normal_fn ~is_cond_dist loc tenv id es =
let is_known_family s =
List.mem known_families s ~equal:String.equal in
match suffix with
| ("lpmf" | "lumpf") when Env.mem tenv (prefix ^ "_lpdf") ->
| ("lpmf" | "lupmf") when Env.mem tenv (prefix ^ "_lpdf") ->
Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc
(prefix, suffix)
| ("lpdf" | "lumdf") when Env.mem tenv (prefix ^ "_lpmf") ->
| ("lpdf" | "lupdf") when Env.mem tenv (prefix ^ "_lpmf") ->
Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc
(prefix, suffix)
| _ ->
Expand Down Expand Up @@ -582,7 +578,7 @@ let find_matching_first_order_fn tenv matches fname =
| Error None -> SignatureMismatch.SignatureErrors (List.hd_exn errs)

let make_function_variable cf loc id = function
| UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) ->
| UnsizedType.UFun (args, rt, (FnLpdf _ | FnLpmf _), mem_pattern) ->
let type_ =
UnsizedType.UFun
(args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in
Expand Down Expand Up @@ -1283,26 +1279,30 @@ let check_tilde_distribution loc tenv id arguments =
let name = id.name in
let argumenttypes = List.map ~f:arg_type arguments in
let name_w_suffix_dist suffix =
SignatureMismatch.matching_function tenv (name ^ suffix) argumenttypes in
( SignatureMismatch.matching_function tenv (name ^ suffix) argumenttypes
, suffix ) in
let distributions =
List.map ~f:name_w_suffix_dist Utils.distribution_suffices in
match
List.min_elt distributions ~compare:SignatureMismatch.compare_match_results
List.min_elt distributions ~compare:(fun (m1, _) (m2, _) ->
SignatureMismatch.compare_match_results m1 m2)
with
| Some (UniqueMatch (_, _, p)) ->
Promotion.promote_list arguments p
| Some (UniqueMatch (_, f, p), sfx) ->
let suffix =
Fun_kind.suffix_from_name (name ^ Utils.unnormalized_suffix sfx) in
(Promotion.promote_list arguments p, f suffix)
(* real return type is enforced by [verify_fundef_dist_rt] *)
| None | Some (SignatureErrors ([], _)) ->
| None | Some (SignatureErrors ([], _), _) ->
(* Function is non existent *)
Semantic_error.invalid_tilde_no_such_dist loc name
(List.hd_exn argumenttypes |> snd |> UnsizedType.is_int_type)
|> error
| Some (AmbiguousMatch sigs) ->
| Some (AmbiguousMatch sigs, _) ->
Semantic_error.ambiguous_function_promotion loc id.name
(Some (List.map ~f:type_of_expr_typed arguments))
sigs
|> error
| Some (SignatureErrors (l, b)) ->
| Some (SignatureErrors (l, b), _) ->
arguments
|> List.map ~f:(fun e -> e.emeta.type_)
|> Semantic_error.illtyped_fn_app loc id.name (l, b)
Expand Down Expand Up @@ -1349,11 +1349,12 @@ let check_tilde loc cf tenv distribution truncation arg args =
verify_distribution_pdf_pmf distribution;
verify_valid_distribution_pos loc cf;
verify_distribution_cdf_ccdf loc distribution;
let promoted_args =
let promoted_args, kind =
check_tilde_distribution loc tenv distribution (te :: tes) in
let te, tes = (List.hd_exn promoted_args, List.tl_exn promoted_args) in
verify_distribution_cdf_defined loc tenv distribution ttrunc tes;
let stmt = Tilde {arg= te; distribution; args= tes; truncation= ttrunc} in
let stmt =
Tilde {arg= te; distribution; args= tes; truncation= ttrunc; kind} in
mk_typed_statement ~stmt ~loc ~return_type:Incomplete

(* Break and continue only occur in loops. *)
Expand Down Expand Up @@ -1692,9 +1693,7 @@ and check_var_decl loc cf tenv sized_ty trans

(* function definitions *)
and exists_matching_fn_declared tenv id arg_tys rt =
let options =
List.concat_map ~f:(Env.find tenv) (distribution_name_variants id.name)
in
let options = Env.find tenv id.name in
let f = function
| Env.{kind= `UserDeclared _; type_= UFun (listedtypes, rt', _, _)}
when arg_tys = listedtypes && rt = rt' ->
Expand All @@ -1703,9 +1702,7 @@ and exists_matching_fn_declared tenv id arg_tys rt =
List.exists ~f options

and verify_unique_signature tenv loc id arg_tys rt =
let existing =
List.concat_map ~f:(Env.find tenv) (distribution_name_variants id.name)
in
let existing = Env.find tenv id.name in
let same_args = function
| Env.{type_= UFun (listedtypes, _, _, _); _}
when List.map ~f:snd arg_tys = List.map ~f:snd listedtypes ->
Expand Down Expand Up @@ -1822,16 +1819,8 @@ and check_fundef loc cf tenv return_ty id args body =
are not modified in function. (passed by const ref) *)
(`Variable {origin; readonly= true; global= false})) in
let context =
let is_udf_dist name =
List.exists
~f:(fun suffix -> String.is_suffix name ~suffix)
Utils.distribution_suffices in
let kind =
if is_udf_dist id.name then Fun_kind.FnLpdf ()
else if String.is_suffix id.name ~suffix:"_rng" then FnRng
else if String.is_suffix id.name ~suffix:"_lp" then FnTarget
else if String.is_suffix id.name ~suffix:"_jacobian" then FnJacobian
else FnPlain in
Fun_kind.suffix_from_name id.name |> Fun_kind.forget_normalization in
{ cf with
containing_function=
UnsizedType.returntype_to_type_opt return_ty
Expand All @@ -1850,12 +1839,12 @@ and check_statement (cf : context_flags_record) (tenv : Env.t)
(s : Ast.untyped_statement) : Env.t * typed_statement =
let loc = s.smeta.loc in
match s.stmt with
| NRFunApp (_, id, es) -> (tenv, check_nr_fn_app loc cf tenv id es)
| NRFunApp ((), id, es) -> (tenv, check_nr_fn_app loc cf tenv id es)
| Assignment {assign_lhs; assign_op; assign_rhs} ->
(tenv, check_assignment loc cf tenv assign_lhs assign_op assign_rhs)
| TargetPE e -> (tenv, check_target_pe loc cf tenv e)
| JacobianPE e -> (tenv, check_jacobian_pe loc cf tenv e)
| Tilde {arg; distribution; args; truncation} ->
| Tilde {arg; distribution; args; truncation; kind= ()} ->
(tenv, check_tilde loc cf tenv distribution truncation arg args)
| Break -> (tenv, check_break loc cf)
| Continue -> (tenv, check_continue loc cf)
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ atomic_statement:
RPAREN ot=option(truncation) SEMICOLON
{ grammar_logger "tilde_statement" ;
let t = match ot with Some tt -> tt | None -> NoTruncate in
Tilde {arg= e; distribution= id; args= es; truncation= t }
Tilde {arg= e; distribution= id; args= es; truncation= t; kind=() }
}
| TARGET PLUSASSIGN e=expression SEMICOLON
{ grammar_logger "targetpe_statement" ; TargetPE e }
Expand Down
Loading

0 comments on commit 9745565

Please sign in to comment.