forked from stan-dev/stanc3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSemantic_check.ml
1904 lines (1755 loc) · 70.3 KB
/
Semantic_check.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
(** Semantic validation of AST*)
(* Idea: check many of things related to identifiers that are hard to check
during parsing and are in fact irrelevant for building up the parse tree *)
open Core_kernel
open Symbol_table
open Middle
open Ast
open Errors
module Validate = Common.Validation.Make (Semantic_error)
(* There is a semantic checking function for each AST node that calls
the checking functions for its children left to right. *)
(* Top level function semantic_check_program declares the AST while operating
on (1) a global symbol table vm, and (2) structure of type context_flags_record
to communicate information down the AST. *)
let check_of_compatible_return_type rt1 srt2 =
UnsizedType.(
match (rt1, srt2) with
| Void, NoReturnType
|Void, Incomplete Void
|Void, Complete Void
|Void, AnyReturnType ->
true
| ReturnType UReal, Complete (ReturnType UInt) -> true
| ReturnType rt1, Complete (ReturnType rt2) -> rt1 = rt2
| ReturnType _, AnyReturnType -> true
| _ -> false)
(** Origin blocks, to keep track of where variables are declared *)
type originblock =
| MathLibrary
| Functions
| Data
| TData
| Param
| TParam
| Model
| GQuant
(** Print all the signatures of a stan math operator, for the purposes of error messages. *)
let check_that_all_functions_have_definition = ref true
let model_name = ref ""
let vm = Symbol_table.initialize ()
(* Record structure holding flags and other markers about context to be
used for error reporting. *)
type context_flags_record =
{ current_block: originblock
; in_toplevel_decl: bool
; in_fun_def: bool
; in_returning_fun_def: bool
; in_rng_fun_def: bool
; in_lp_fun_def: bool
; in_udf_dist_def: bool
; loop_depth: int }
(* Some helper functions *)
let dup_exists l =
match List.find_a_dup ~compare:String.compare l with
| Some _ -> true
| None -> false
let type_of_expr_typed ue = ue.emeta.type_
let calculate_autodifftype cf at ut =
match at with
| (Param | TParam | Model | Functions)
when not (UnsizedType.contains_int ut || cf.current_block = GQuant) ->
UnsizedType.AutoDiffable
| _ -> DataOnly
let has_int_type ue = ue.emeta.type_ = UInt
let has_int_array_type ue = ue.emeta.type_ = UArray UInt
let has_int_or_real_type ue =
match ue.emeta.type_ with UInt | UReal -> true | _ -> false
let probability_distribution_name_variants id =
let name = id.name in
let open String in
List.map
~f:(fun n -> {name= n; id_loc= id.id_loc})
( if name = "multiply_log" || name = "binomial_coefficient_log" then [name]
else if is_suffix ~suffix:"_lpmf" name then
[name; drop_suffix name 5 ^ "_lpdf"; drop_suffix name 5 ^ "_log"]
else if is_suffix ~suffix:"_lpdf" name then
[name; drop_suffix name 5 ^ "_lpmf"; drop_suffix name 5 ^ "_log"]
else if is_suffix ~suffix:"_lcdf" name then
[name; drop_suffix name 5 ^ "_cdf_log"]
else if is_suffix ~suffix:"_lccdf" name then
[name; drop_suffix name 6 ^ "_ccdf_log"]
else if is_suffix ~suffix:"_cdf_log" name then
[name; drop_suffix name 8 ^ "_lcdf"]
else if is_suffix ~suffix:"_ccdf_log" name then
[name; drop_suffix name 9 ^ "_lccdf"]
else if is_suffix ~suffix:"_log" name then
[name; drop_suffix name 4 ^ "_lpmf"; drop_suffix name 4 ^ "_lpdf"]
else [name] )
let lub_rt loc rt1 rt2 =
match (rt1, rt2) with
| UnsizedType.ReturnType UReal, UnsizedType.ReturnType UInt
|ReturnType UInt, ReturnType UReal ->
Validate.ok (UnsizedType.ReturnType UReal)
| _, _ when rt1 = rt2 -> Validate.ok rt2
| _ -> Semantic_error.mismatched_return_types loc rt1 rt2 |> Validate.error
(*
Checks that a variable/function name:
- if UDF that it does not match a Stan Math function
- a function/identifier does not have the _lupdf/_lupmf suffix
- is not already in use
*)
let check_fresh_variable_basic id is_udf =
Validate.(
if
is_udf
&& ( Stan_math_signatures.is_stan_math_function_name id.name
(* variadic functions are currently not in math sigs *)
|| Stan_math_signatures.is_reduce_sum_fn id.name
|| Stan_math_signatures.is_variadic_ode_fn id.name )
then Semantic_error.ident_is_stanmath_name id.id_loc id.name |> error
else if Utils.is_unnormalized_distribution id.name then
if is_udf then
Semantic_error.udf_is_unnormalized_fn id.id_loc id.name |> error
else
Semantic_error.ident_has_unnormalized_suffix id.id_loc id.name |> error
else
match Symbol_table.look vm id.name with
| Some _ -> Semantic_error.ident_in_use id.id_loc id.name |> error
| None -> ok ())
let check_fresh_variable id is_udf =
List.fold ~init:(Validate.ok ())
~f:(fun v0 name ->
check_fresh_variable_basic name is_udf |> Validate.apply_const v0 )
(probability_distribution_name_variants id)
(* == SEMANTIC CHECK OF PROGRAM ELEMENTS ==================================== *)
(* Probably nothing to do here *)
let semantic_check_assignmentoperator op = Validate.ok op
(* Probably nothing to do here *)
let semantic_check_autodifftype at = Validate.ok at
(* Probably nothing to do here *)
let rec semantic_check_unsizedtype : UnsizedType.t -> unit Validate.t =
function
| UFun (l, rt) ->
(* fold over argument types accumulating errors with initial state
given by validating the return type *)
List.fold
~f:(fun v0 (at, ut) ->
Validate.(
apply_const
(apply_const v0 (semantic_check_autodifftype at))
(semantic_check_unsizedtype ut)) )
~init:(semantic_check_returntype rt)
l
| UArray ut -> semantic_check_unsizedtype ut
| _ -> Validate.ok ()
and semantic_check_returntype : UnsizedType.returntype -> unit Validate.t =
function
| Void -> Validate.ok ()
| ReturnType ut -> semantic_check_unsizedtype ut
(* -- Indentifiers ---------------------------------------------------------- *)
let reserved_keywords =
[ "true"; "false"; "repeat"; "until"; "then"; "var"; "fvar"; "STAN_MAJOR"
; "STAN_MINOR"; "STAN_PATCH"; "STAN_MATH_MAJOR"; "STAN_MATH_MINOR"
; "STAN_MATH_PATCH"; "alignas"; "alignof"; "and"; "and_eq"; "asm"; "auto"
; "bitand"; "bitor"; "bool"; "break"; "case"; "catch"; "char"; "char16_t"
; "char32_t"; "class"; "compl"; "const"; "constexpr"; "const_cast"
; "continue"; "decltype"; "default"; "delete"; "do"; "double"; "dynamic_cast"
; "else"; "enum"; "explicit"; "export"; "extern"; "false"; "float"; "for"
; "friend"; "goto"; "if"; "inline"; "int"; "long"; "mutable"; "namespace"
; "new"; "noexcept"; "not"; "not_eq"; "nullptr"; "operator"; "or"; "or_eq"
; "private"; "protected"; "public"; "register"; "reinterpret_cast"; "return"
; "short"; "signed"; "sizeof"; "static"; "static_assert"; "static_cast"
; "struct"; "switch"; "template"; "this"; "thread_local"; "throw"; "true"
; "try"; "typedef"; "typeid"; "typename"; "union"; "unsigned"; "using"
; "virtual"; "void"; "volatile"; "wchar_t"; "while"; "xor"; "xor_eq" ]
let semantic_check_identifier id =
Validate.(
if id.name = !model_name then
Semantic_error.ident_is_model_name id.id_loc id.name |> error
else if
String.is_suffix id.name ~suffix:"__"
|| List.exists ~f:(fun str -> str = id.name) reserved_keywords
then Semantic_error.ident_is_keyword id.id_loc id.name |> error
else ok ())
(* -- Operators ------------------------------------------------------------- *)
let semantic_check_operator _ = Validate.ok ()
(* == Expressions =========================================================== *)
let arg_type x = (x.emeta.ad_level, x.emeta.type_)
let get_arg_types = List.map ~f:arg_type
(* -- Function application -------------------------------------------------- *)
let semantic_check_fn_map_rect ~loc id es =
Validate.(
match (id.name, es) with
| "map_rect", {expr= Variable arg1; _} :: _
when String.(
is_suffix arg1.name ~suffix:"_lp"
|| is_suffix arg1.name ~suffix:"_rng") ->
Semantic_error.invalid_map_rect_fn loc arg1.name |> error
| _ -> ok ())
let semantic_check_fn_conditioning ~loc id =
Validate.(
if
List.exists
~f:(fun suffix -> String.is_suffix id.name ~suffix)
Utils.conditioning_suffices
then Semantic_error.conditioning_required loc |> error
else ok ())
(** `Target+=` can only be used in model and functions
with right suffix (same for tilde etc)
*)
let semantic_check_fn_target_plus_equals cf ~loc id =
Validate.(
if
String.is_suffix id.name ~suffix:"_lp"
&& not (cf.in_lp_fun_def || cf.current_block = Model)
then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error
else ok ())
(** Rng functions cannot be used in Tp or Model and only
in function defs with the right suffix
*)
let semantic_check_fn_rng cf ~loc id =
Validate.(
if String.is_suffix id.name ~suffix:"_rng" && cf.in_toplevel_decl then
Semantic_error.invalid_decl_rng_fn loc |> error
else if
String.is_suffix id.name ~suffix:"_rng"
&& ( (cf.in_fun_def && not cf.in_rng_fun_def)
|| cf.current_block = TParam || cf.current_block = Model )
then Semantic_error.invalid_rng_fn loc |> error
else ok ())
(** unnormalized _lpdf/_lpmf functions can only be used in _lpdf/_lpmf/_lp udfs
or the model block
*)
let semantic_check_unnormalized cf ~loc id =
Validate.(
if
Utils.is_unnormalized_distribution id.name
&& not
( (cf.in_fun_def && (cf.in_udf_dist_def || cf.in_lp_fun_def))
|| cf.current_block = Model )
then Semantic_error.invalid_unnormalized_fn loc |> error
else ok ())
let mk_fun_app ~is_cond_dist (x, y, z) =
if is_cond_dist then CondDistApp (x, y, z) else FunApp (x, y, z)
(* Regular function application *)
let semantic_check_fn_normal ~is_cond_dist ~loc id es =
Validate.(
match Symbol_table.look vm (Utils.normalized_name id.name) with
| Some (_, UnsizedType.UFun (_, Void)) ->
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
|> error
| Some (_, UFun (listedtypes, rt))
when not
(UnsizedType.check_compatible_arguments_mod_conv id.name
listedtypes (get_arg_types es)) ->
es
|> List.map ~f:type_of_expr_typed
|> Semantic_error.illtyped_userdefined_fn_app loc id.name listedtypes
rt
|> error
| Some (_, UFun (_, ReturnType ut)) ->
mk_typed_expression
~expr:(mk_fun_app ~is_cond_dist (UserDefined, id, es))
~ad_level:(expr_ad_lub es) ~type_:ut ~loc
|> ok
| Some _ ->
(* Check that Funaps are actually functions *)
Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error
| None ->
Semantic_error.returning_fn_expected_undeclaredident_found loc id.name
|> error)
(* Stan-Math function application *)
let semantic_check_fn_stan_math ~is_cond_dist ~loc id es =
match
Stan_math_signatures.stan_math_returntype id.name (get_arg_types es)
with
| Some UnsizedType.Void ->
Semantic_error.returning_fn_expected_nonreturning_found loc id.name
|> Validate.error
| Some (UnsizedType.ReturnType ut) ->
mk_typed_expression
~expr:(mk_fun_app ~is_cond_dist (StanLib, id, es))
~ad_level:(expr_ad_lub es) ~type_:ut ~loc
|> Validate.ok
| _ ->
es
|> List.map ~f:(fun e -> e.emeta.type_)
|> Semantic_error.illtyped_stanlib_fn_app loc id.name
|> Validate.error
let arg_match (x_ad, x_t) y =
UnsizedType.check_of_same_type_mod_conv "" x_t y.emeta.type_
&& UnsizedType.autodifftype_can_convert x_ad y.emeta.ad_level
let args_match a b =
List.length a = List.length b && List.for_all2_exn ~f:arg_match a b
let semantic_check_reduce_sum ~is_cond_dist ~loc id es =
match es with
| { emeta=
{ type_=
UnsizedType.UFun
( ((_, sliced_arg_fun_type) as sliced_arg_fun)
:: (_, UInt) :: (_, UInt) :: fun_args
, ReturnType UReal ); _ }; _ }
:: sliced :: {emeta= {type_= UInt; _}; _} :: args
when arg_match sliced_arg_fun sliced
&& List.mem Stan_math_signatures.reduce_sum_slice_types
sliced.emeta.type_ ~equal:( = )
&& List.mem Stan_math_signatures.reduce_sum_slice_types
sliced_arg_fun_type ~equal:( = ) ->
if args_match fun_args args then
mk_typed_expression
~expr:(mk_fun_app ~is_cond_dist (StanLib, id, es))
~ad_level:(expr_ad_lub es) ~type_:UnsizedType.UReal ~loc
|> Validate.ok
else
Semantic_error.illtyped_reduce_sum loc id.name
(List.map ~f:type_of_expr_typed es)
(sliced_arg_fun :: fun_args)
|> Validate.error
| _ ->
es
|> List.map ~f:type_of_expr_typed
|> Semantic_error.illtyped_reduce_sum_generic loc id.name
|> Validate.error
let semantic_check_variadic_ode ~is_cond_dist ~loc id es =
let optional_tol_mandatory_args =
if Stan_math_signatures.is_variadic_ode_tol_fn id.name then
Stan_math_signatures.variadic_ode_tol_arg_types
else []
in
let mandatory_arg_types =
Stan_math_signatures.variadic_ode_mandatory_arg_types
@ optional_tol_mandatory_args
in
let generic_variadic_ode_semantic_error =
Semantic_error.illtyped_variadic_ode loc id.name
(List.map ~f:type_of_expr_typed es)
[]
|> Validate.error
in
let fun_arg_match (x_ad, x_t) (y_ad, y_t) =
UnsizedType.check_of_same_type_mod_conv "" x_t y_t
&& UnsizedType.autodifftype_can_convert x_ad y_ad
in
let fun_args_match a b =
List.length a = List.length b && List.for_all2_exn ~f:fun_arg_match a b
in
match es with
| {emeta= {type_= UnsizedType.UFun (fun_args, ReturnType return_type); _}; _}
:: args ->
let num_of_mandatory_args =
if Stan_math_signatures.is_variadic_ode_tol_fn id.name then 6 else 3
in
let mandatory_args, variadic_args =
List.split_n args num_of_mandatory_args
in
let mandatory_fun_args, variadic_fun_args = List.split_n fun_args 2 in
if
fun_args_match mandatory_fun_args
Stan_math_signatures.variadic_ode_mandatory_fun_args
&& UnsizedType.check_of_same_type_mod_conv "" return_type
Stan_math_signatures.variadic_ode_fun_return_type
&& args_match mandatory_arg_types mandatory_args
then
if args_match variadic_fun_args variadic_args then
mk_typed_expression
~expr:(mk_fun_app ~is_cond_dist (StanLib, id, es))
~ad_level:(expr_ad_lub es)
~type_:Stan_math_signatures.variadic_ode_return_type ~loc
|> Validate.ok
else
Semantic_error.illtyped_variadic_ode loc id.name
(List.map ~f:type_of_expr_typed es)
fun_args
|> Validate.error
else generic_variadic_ode_semantic_error
| _ -> generic_variadic_ode_semantic_error
let fn_kind_from_application id es =
(* We need to check an application here, rather than a mere name of the
function because, technically, user defined functions can shadow
constants in StanLib. *)
if
Stan_math_signatures.stan_math_returntype id.name
(List.map ~f:(fun x -> (x.emeta.ad_level, x.emeta.type_)) es)
<> None
|| Symbol_table.look vm id.name = None
&& Stan_math_signatures.is_stan_math_function_name id.name
then StanLib
else UserDefined
(** Determines the function kind based on the identifier and performs the
corresponding semantic check
*)
let semantic_check_fn ~is_cond_dist ~loc id es =
match fn_kind_from_application id es with
| StanLib when Stan_math_signatures.is_reduce_sum_fn id.name ->
semantic_check_reduce_sum ~is_cond_dist ~loc id es
| StanLib when Stan_math_signatures.is_variadic_ode_fn id.name ->
semantic_check_variadic_ode ~is_cond_dist ~loc id es
| StanLib -> semantic_check_fn_stan_math ~is_cond_dist ~loc id es
| UserDefined -> semantic_check_fn_normal ~is_cond_dist ~loc id es
(* -- Ternary If ------------------------------------------------------------ *)
let semantic_check_ternary_if loc (pe, te, fe) =
Validate.(
let err =
Semantic_error.illtyped_ternary_if loc pe.emeta.type_ te.emeta.type_
fe.emeta.type_
in
if pe.emeta.type_ = UInt then
match UnsizedType.common_type (te.emeta.type_, fe.emeta.type_) with
| Some type_ ->
mk_typed_expression
~expr:(TernaryIf (pe, te, fe))
~ad_level:(expr_ad_lub [pe; te; fe])
~type_ ~loc
|> ok
| None -> error err
else error err)
(* -- Binary (Infix) Operators ---------------------------------------------- *)
let semantic_check_binop loc op (le, re) =
Validate.(
let err =
Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_
in
[le; re] |> List.map ~f:arg_type
|> Stan_math_signatures.operator_stan_math_return_type op
|> Option.value_map ~default:(error err) ~f:(function
| ReturnType type_ ->
mk_typed_expression
~expr:(BinOp (le, op, re))
~ad_level:(expr_ad_lub [le; re])
~type_ ~loc
|> ok
| Void -> error err ))
let to_exn v =
v |> Validate.to_result
|> Result.map_error ~f:Fmt.(to_to_string @@ list ~sep:cut Semantic_error.pp)
|> Result.ok_or_failwith
let semantic_check_binop_exn loc op (le, re) =
semantic_check_binop loc op (le, re) |> to_exn
(* -- Prefix Operators ------------------------------------------------------ *)
let semantic_check_prefixop loc op e =
Validate.(
let err = Semantic_error.illtyped_prefix_op loc op e.emeta.type_ in
Stan_math_signatures.operator_stan_math_return_type op [arg_type e]
|> Option.value_map ~default:(error err) ~f:(function
| ReturnType type_ ->
mk_typed_expression
~expr:(PrefixOp (op, e))
~ad_level:(expr_ad_lub [e])
~type_ ~loc
|> ok
| Void -> error err ))
(* -- Postfix operators ----------------------------------------------------- *)
let semantic_check_postfixop loc op e =
Validate.(
let err = Semantic_error.illtyped_postfix_op loc op e.emeta.type_ in
Stan_math_signatures.operator_stan_math_return_type op [arg_type e]
|> Option.value_map ~default:(error err) ~f:(function
| ReturnType type_ ->
mk_typed_expression
~expr:(PostfixOp (e, op))
~ad_level:(expr_ad_lub [e])
~type_ ~loc
|> ok
| Void -> error err ))
(* -- Variables ------------------------------------------------------------- *)
let semantic_check_variable cf loc id =
Validate.(
match Symbol_table.look vm (Utils.stdlib_distribution_name id.name) with
| None when not (Stan_math_signatures.is_stan_math_function_name id.name)
->
Semantic_error.ident_not_in_scope loc id.name |> error
| None ->
mk_typed_expression ~expr:(Variable id)
~ad_level:
(calculate_autodifftype cf MathLibrary UMathLibraryFunction)
~type_:UMathLibraryFunction ~loc
|> ok
| Some ((Param | TParam | GQuant), _) when cf.in_toplevel_decl ->
Semantic_error.non_data_variable_size_decl loc |> error
| Some _
when Utils.is_unnormalized_distribution id.name
&& not
( (cf.in_fun_def && (cf.in_udf_dist_def || cf.in_lp_fun_def))
|| cf.current_block = Model ) ->
Semantic_error.invalid_unnormalized_fn loc |> error
| Some (originblock, type_) ->
mk_typed_expression ~expr:(Variable id)
~ad_level:(calculate_autodifftype cf originblock type_)
~type_ ~loc
|> ok)
(* -- Conditioned Distribution Application ---------------------------------- *)
let semantic_check_conddist_name ~loc id =
Validate.(
if
List.exists
~f:(fun x -> String.is_suffix id.name ~suffix:x)
Utils.conditioning_suffices
then ok ()
else Semantic_error.conditional_notation_not_allowed loc |> error)
(* -- Array Expressions ----------------------------------------------------- *)
let check_consistent_types ad_level type_ es =
let f state e =
match state with
| Error e -> Error e
| Ok (ad, ty) -> (
let ad =
if UnsizedType.autodifftype_can_convert e.emeta.ad_level ad then
e.emeta.ad_level
else ad
in
match UnsizedType.common_type (ty, e.emeta.type_) with
| Some ty -> Ok (ad, ty)
| None -> Error (ty, e.emeta) )
in
List.fold ~init:(Ok (ad_level, type_)) ~f es
let semantic_check_array_expr ~loc es =
Validate.(
match es with
| [] -> Semantic_error.empty_array loc |> error
| {emeta= {ad_level; type_; _}; _} :: elements -> (
match check_consistent_types ad_level type_ elements with
| Error (ty, meta) ->
Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error
| Ok (ad_level, type_) ->
let type_ = UnsizedType.UArray type_ in
mk_typed_expression ~expr:(ArrayExpr es) ~ad_level ~type_ ~loc |> ok
))
(* -- Row Vector Expresssion ------------------------------------------------ *)
let semantic_check_rowvector ~loc es =
Validate.(
match es with
| {emeta= {ad_level; type_= UnsizedType.URowVector; _}; _} :: elements -> (
match check_consistent_types ad_level URowVector elements with
| Ok (ad_level, _) ->
mk_typed_expression ~expr:(RowVectorExpr es) ~ad_level ~type_:UMatrix
~loc
|> ok
| Error (_, meta) ->
Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error )
| _ -> (
match check_consistent_types DataOnly UReal es with
| Ok (ad_level, _) ->
mk_typed_expression ~expr:(RowVectorExpr es) ~ad_level
~type_:URowVector ~loc
|> ok
| Error (_, meta) ->
Semantic_error.invalid_row_vector_types meta.loc meta.type_ |> error
))
(* -- Indexed Expressions --------------------------------------------------- *)
let tuple2 a b = (a, b)
let tuple3 a b c = (a, b, c)
let index_with_type idx =
match idx with
| Single e -> (idx, e.emeta.type_)
| _ -> (idx, UnsizedType.UInt)
let inferred_unsizedtype_of_indexed ~loc ut indices =
let rec aux k ut xs =
match (ut, xs) with
| UnsizedType.UMatrix, [(All, _); (Single _, UnsizedType.UInt)]
|UMatrix, [(Upfrom _, _); (Single _, UInt)]
|UMatrix, [(Downfrom _, _); (Single _, UInt)]
|UMatrix, [(Between _, _); (Single _, UInt)]
|UMatrix, [(Single _, UArray UInt); (Single _, UInt)] ->
k @@ Validate.ok UnsizedType.UVector
| _, [] -> k @@ Validate.ok ut
| _, next :: rest -> (
match next with
| Single _, UInt -> (
match ut with
| UArray inner_ty -> aux k inner_ty rest
| UVector | URowVector -> aux k UReal rest
| UMatrix -> aux k URowVector rest
| _ -> Semantic_error.not_indexable loc ut |> Validate.error )
| _ -> (
match ut with
| UArray inner_ty ->
let k' =
Fn.compose k (Validate.map ~f:(fun t -> UnsizedType.UArray t))
in
aux k' inner_ty rest
| UVector | URowVector | UMatrix -> aux k ut rest
| _ -> Semantic_error.not_indexable loc ut |> Validate.error ) )
in
aux Fn.id ut (List.map ~f:index_with_type indices)
let inferred_unsizedtype_of_indexed_exn ~loc ut indices =
inferred_unsizedtype_of_indexed ~loc ut indices |> to_exn
let inferred_ad_type_of_indexed at uindices =
UnsizedType.lub_ad_type
( at
:: List.map
~f:(function
| All -> UnsizedType.DataOnly
| Single ue1 | Upfrom ue1 | Downfrom ue1 ->
UnsizedType.lub_ad_type [at; ue1.emeta.ad_level]
| Between (ue1, ue2) ->
UnsizedType.lub_ad_type
[at; ue1.emeta.ad_level; ue2.emeta.ad_level])
uindices )
let rec semantic_check_indexed ~loc ~cf e indices =
Validate.(
indices
|> List.map ~f:(semantic_check_index cf)
|> sequence
|> liftA2 tuple2 (semantic_check_expression cf e)
>>= fun (ue, uindices) ->
let at = inferred_ad_type_of_indexed ue.emeta.ad_level uindices in
uindices
|> inferred_unsizedtype_of_indexed ~loc ue.emeta.type_
|> map ~f:(fun ut ->
mk_typed_expression
~expr:(Indexed (ue, uindices))
~ad_level:at ~type_:ut ~loc ))
and semantic_check_index cf = function
| All -> Validate.ok All
(* Check that indexes have int (container) type *)
| Single e ->
Validate.(
semantic_check_expression cf e
>>= fun ue ->
if has_int_type ue || has_int_array_type ue then ok @@ Single ue
else
Semantic_error.int_intarray_or_range_expected ue.emeta.loc
ue.emeta.type_
|> error)
| Upfrom e ->
semantic_check_expression_of_int_type cf e "Range bound"
|> Validate.map ~f:(fun e -> Upfrom e)
| Downfrom e ->
semantic_check_expression_of_int_type cf e "Range bound"
|> Validate.map ~f:(fun e -> Downfrom e)
| Between (e1, e2) ->
let le = semantic_check_expression_of_int_type cf e1 "Range bound"
and ue = semantic_check_expression_of_int_type cf e2 "Range bound" in
Validate.liftA2 (fun l u -> Between (l, u)) le ue
(* -- Top-level expressions ------------------------------------------------- *)
and semantic_check_expression cf ({emeta; expr} : Ast.untyped_expression) :
Ast.typed_expression Validate.t =
match expr with
| TernaryIf (e1, e2, e3) ->
let pe = semantic_check_expression cf e1
and te = semantic_check_expression cf e2
and fe = semantic_check_expression cf e3 in
Validate.(liftA3 tuple3 pe te fe >>= semantic_check_ternary_if emeta.loc)
| BinOp (e1, op, e2) ->
let le = semantic_check_expression cf e1
and re = semantic_check_expression cf e2
and warn_int_division (x, y) =
match (x.emeta.type_, y.emeta.type_, op) with
| UInt, UInt, Divide ->
let hint ppf () =
match (x.expr, y.expr) with
| IntNumeral x, _ ->
Fmt.pf ppf "%s.0 / %a" x Pretty_printing.pp_expression y
| _, Ast.IntNumeral y ->
Fmt.pf ppf "%a / %s.0" Pretty_printing.pp_expression x y
| _ ->
Fmt.pf ppf "%a * 1.0 / %a" Pretty_printing.pp_expression x
Pretty_printing.pp_expression y
in
Fmt.pr
"@[<v>@[<hov 0>Info: Found int division at %s:@]@ @[<hov \
2>%a@]@,@[<hov>%a@]@ @[<hov 2>%a@]@,@[<hov>%a@]@]"
(Location_span.to_string x.emeta.loc)
Pretty_printing.pp_expression {expr; emeta} Fmt.text
"Values will be rounded towards zero. If rounding is not \
desired you can write the division as"
hint () Fmt.text
"If rounding is intended please use the integer division \
operator %/%." ;
(x, y)
| _ -> (x, y)
in
Validate.(
liftA2 tuple2 le re |> map ~f:warn_int_division
|> apply_const (semantic_check_operator op)
>>= semantic_check_binop emeta.loc op)
| PrefixOp (op, e) ->
Validate.(
semantic_check_expression cf e
|> apply_const (semantic_check_operator op)
>>= semantic_check_prefixop emeta.loc op)
| PostfixOp (e, op) ->
Validate.(
semantic_check_expression cf e
|> apply_const (semantic_check_operator op)
>>= semantic_check_postfixop emeta.loc op)
| Variable id ->
semantic_check_variable cf emeta.loc id
|> Validate.apply_const (semantic_check_identifier id)
| IntNumeral s -> (
match float_of_string_opt s with
| Some i when i < 2_147_483_648.0 ->
mk_typed_expression ~expr:(IntNumeral s) ~ad_level:DataOnly ~type_:UInt
~loc:emeta.loc
|> Validate.ok
| _ -> Semantic_error.bad_int_literal emeta.loc |> Validate.error )
| RealNumeral s ->
mk_typed_expression ~expr:(RealNumeral s) ~ad_level:DataOnly ~type_:UReal
~loc:emeta.loc
|> Validate.ok
| FunApp (_, id, es) ->
semantic_check_funapp ~is_cond_dist:false id es cf emeta
| CondDistApp (_, id, es) ->
semantic_check_funapp ~is_cond_dist:true id es cf emeta
| GetLP ->
(* Target+= can only be used in model and functions with right suffix (same for tilde etc) *)
if
not
( cf.in_lp_fun_def || cf.current_block = Model
|| cf.current_block = TParam )
then
Semantic_error.target_plusequals_outisde_model_or_logprob emeta.loc
|> Validate.error
else
mk_typed_expression ~expr:GetLP
~ad_level:(calculate_autodifftype cf cf.current_block UReal)
~type_:UReal ~loc:emeta.loc
|> Validate.ok
| GetTarget ->
(* Target+= can only be used in model and functions with right suffix (same for tilde etc) *)
if
not
( cf.in_lp_fun_def || cf.current_block = Model
|| cf.current_block = TParam )
then
Semantic_error.target_plusequals_outisde_model_or_logprob emeta.loc
|> Validate.error
else
mk_typed_expression ~expr:GetTarget
~ad_level:(calculate_autodifftype cf cf.current_block UReal)
~type_:UReal ~loc:emeta.loc
|> Validate.ok
| ArrayExpr es ->
Validate.(
es
|> List.map ~f:(semantic_check_expression cf)
|> sequence
>>= fun ues -> semantic_check_array_expr ~loc:emeta.loc ues)
| RowVectorExpr es ->
Validate.(
es
|> List.map ~f:(semantic_check_expression cf)
|> sequence
>>= semantic_check_rowvector ~loc:emeta.loc)
| Paren e ->
semantic_check_expression cf e
|> Validate.map ~f:(fun ue ->
mk_typed_expression ~expr:(Paren ue) ~ad_level:ue.emeta.ad_level
~type_:ue.emeta.type_ ~loc:emeta.loc )
| Indexed (e, indices) -> semantic_check_indexed ~loc:emeta.loc ~cf e indices
and semantic_check_funapp ~is_cond_dist id es cf emeta =
let name_check =
if is_cond_dist then semantic_check_conddist_name
else semantic_check_fn_conditioning
in
Validate.(
es
|> List.map ~f:(semantic_check_expression cf)
|> sequence
>>= fun ues ->
semantic_check_fn ~is_cond_dist ~loc:emeta.loc id ues
|> apply_const (semantic_check_identifier id)
|> apply_const (semantic_check_fn_map_rect ~loc:emeta.loc id ues)
|> apply_const (name_check ~loc:emeta.loc id)
|> apply_const (semantic_check_fn_target_plus_equals cf ~loc:emeta.loc id)
|> apply_const (semantic_check_fn_rng cf ~loc:emeta.loc id)
|> apply_const (semantic_check_unnormalized cf ~loc:emeta.loc id))
and semantic_check_expression_of_int_type cf e name =
Validate.(
semantic_check_expression cf e
>>= fun ue ->
if has_int_type ue then ok ue
else Semantic_error.int_expected ue.emeta.loc name ue.emeta.type_ |> error)
and semantic_check_expression_of_int_or_real_type cf e name =
Validate.(
semantic_check_expression cf e
>>= fun ue ->
if has_int_or_real_type ue then ok ue
else
Semantic_error.int_or_real_expected ue.emeta.loc name ue.emeta.type_
|> error)
let semantic_check_expression_of_scalar_or_type cf t e name =
Validate.(
semantic_check_expression cf e
>>= fun ue ->
if UnsizedType.is_scalar_type ue.emeta.type_ || ue.emeta.type_ = t then
ok ue
else
Semantic_error.scalar_or_type_expected ue.emeta.loc name t ue.emeta.type_
|> error)
(* -- Sized Types ----------------------------------------------------------- *)
let rec semantic_check_sizedtype cf = function
| SizedType.SInt -> Validate.ok SizedType.SInt
| SReal -> Validate.ok SizedType.SReal
| SVector e ->
semantic_check_expression_of_int_type cf e "Vector sizes"
|> Validate.map ~f:(fun ue -> SizedType.SVector ue)
| SRowVector e ->
semantic_check_expression_of_int_type cf e "Row vector sizes"
|> Validate.map ~f:(fun ue -> SizedType.SRowVector ue)
| SMatrix (e1, e2) ->
let ue1 = semantic_check_expression_of_int_type cf e1 "Matrix sizes"
and ue2 = semantic_check_expression_of_int_type cf e2 "Matrix sizes" in
Validate.liftA2 (fun ue1 ue2 -> SizedType.SMatrix (ue1, ue2)) ue1 ue2
| SArray (st, e) ->
let ust = semantic_check_sizedtype cf st
and ue = semantic_check_expression_of_int_type cf e "Array sizes" in
Validate.liftA2 (fun ust ue -> SizedType.SArray (ust, ue)) ust ue
(* -- Transformations ------------------------------------------------------- *)
let semantic_check_transformation cf ut = function
| Program.Identity -> Validate.ok Program.Identity
| Lower e ->
semantic_check_expression_of_scalar_or_type cf ut e "Lower bound"
|> Validate.map ~f:(fun ue -> Program.Lower ue)
| Upper e ->
semantic_check_expression_of_scalar_or_type cf ut e "Upper bound"
|> Validate.map ~f:(fun ue -> Program.Upper ue)
| LowerUpper (e1, e2) ->
let ue1 =
semantic_check_expression_of_scalar_or_type cf ut e1 "Lower bound"
and ue2 =
semantic_check_expression_of_scalar_or_type cf ut e2 "Upper bound"
in
Validate.liftA2 (fun ue1 ue2 -> Program.LowerUpper (ue1, ue2)) ue1 ue2
| Offset e ->
semantic_check_expression_of_scalar_or_type cf ut e "Offset"
|> Validate.map ~f:(fun ue -> Program.Offset ue)
| Multiplier e ->
semantic_check_expression_of_scalar_or_type cf ut e "Multiplier"
|> Validate.map ~f:(fun ue -> Program.Multiplier ue)
| OffsetMultiplier (e1, e2) ->
let ue1 = semantic_check_expression_of_scalar_or_type cf ut e1 "Offset"
and ue2 =
semantic_check_expression_of_scalar_or_type cf ut e2 "Multiplier"
in
Validate.liftA2
(fun ue1 ue2 -> Program.OffsetMultiplier (ue1, ue2))
ue1 ue2
| Ordered -> Validate.ok Program.Ordered
| PositiveOrdered -> Validate.ok Program.PositiveOrdered
| Simplex -> Validate.ok Program.Simplex
| UnitVector -> Validate.ok Program.UnitVector
| CholeskyCorr -> Validate.ok Program.CholeskyCorr
| CholeskyCov -> Validate.ok Program.CholeskyCov
| Correlation -> Validate.ok Program.Correlation
| Covariance -> Validate.ok Program.Covariance
(* -- Printables ------------------------------------------------------------ *)
let semantic_check_printable cf = function
| PString s -> Validate.ok @@ PString s
(* Print/reject expressions cannot be of function type. *)
| PExpr e -> (
Validate.(
semantic_check_expression cf e
>>= fun ue ->
match ue.emeta.type_ with
| UFun _ | UMathLibraryFunction ->
Semantic_error.not_printable ue.emeta.loc |> error
| _ -> ok @@ PExpr ue) )
(* -- Truncations ----------------------------------------------------------- *)
let semantic_check_truncation cf = function
| NoTruncate -> Validate.ok NoTruncate
| TruncateUpFrom e ->
semantic_check_expression_of_int_or_real_type cf e "Truncation bound"
|> Validate.map ~f:(fun ue -> TruncateUpFrom ue)
| TruncateDownFrom e ->
semantic_check_expression_of_int_or_real_type cf e "Truncation bound"
|> Validate.map ~f:(fun ue -> TruncateDownFrom ue)
| TruncateBetween (e1, e2) ->
let ue1 =
semantic_check_expression_of_int_or_real_type cf e1 "Truncation bound"
and ue2 =
semantic_check_expression_of_int_or_real_type cf e2 "Truncation bound"
in
Validate.liftA2 (fun ue1 ue2 -> TruncateBetween (ue1, ue2)) ue1 ue2
(* == Statements ============================================================ *)
(* -- Non-returning function application ------------------------------------ *)
let semantic_check_nrfn_target ~loc ~cf id =
Validate.(
if
String.is_suffix id.name ~suffix:"_lp"
&& not (cf.in_lp_fun_def || cf.current_block = Model)
then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error
else ok ())
let semantic_check_nrfn_normal ~loc id es =
Validate.(
match Symbol_table.look vm id.name with
| Some (_, UFun (listedtypes, Void))
when UnsizedType.check_compatible_arguments_mod_conv id.name listedtypes
(get_arg_types es) ->
mk_typed_statement
~stmt:(NRFunApp (UserDefined, id, es))
~return_type:NoReturnType ~loc
|> ok
| Some (_, UFun (listedtypes, Void)) ->
es
|> List.map ~f:type_of_expr_typed
|> Semantic_error.illtyped_userdefined_fn_app loc id.name listedtypes
Void
|> error
| Some (_, UFun (_, ReturnType _)) ->
Semantic_error.nonreturning_fn_expected_returning_found loc id.name
|> error
| Some _ ->
Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name
|> error
| None ->
Semantic_error.nonreturning_fn_expected_undeclaredident_found loc
id.name
|> error)
let semantic_check_nrfn_stan_math ~loc id es =
Validate.(
match
Stan_math_signatures.stan_math_returntype id.name (get_arg_types es)
with
| Some UnsizedType.Void ->
mk_typed_statement
~stmt:(NRFunApp (StanLib, id, es))
~return_type:NoReturnType ~loc
|> ok
| Some (UnsizedType.ReturnType _) ->
Semantic_error.nonreturning_fn_expected_returning_found loc id.name
|> error
| None ->
es
|> List.map ~f:type_of_expr_typed
|> Semantic_error.illtyped_stanlib_fn_app loc id.name