forked from stan-dev/stanc3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Transform_Mir.ml
1195 lines (1152 loc) · 46.1 KB
/
Transform_Mir.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
open Core
open Core.Poly
open Middle
open Mangle
let use_opencl = ref false
let translate_funapps_and_kwrds e =
let open Expr.Fixed in
let f ({pattern; _} as expr) =
match pattern with
| FunApp (UserDefined (fname, suffix), args) ->
{ expr with
pattern= FunApp (UserDefined (add_prefix_to_kwrds fname, suffix), args)
}
| Var s -> {expr with pattern= Var (add_prefix_to_kwrds s)}
| _ -> expr in
rewrite_bottom_up ~f e
let rec change_kwrds_stmts s =
let open Stmt.Fixed.Pattern in
let pattern =
match s.Stmt.Fixed.pattern with
| Decl e -> Decl {e with decl_id= add_prefix_to_kwrds e.decl_id}
| NRFunApp (UserDefined (s, sfx), e) ->
NRFunApp (UserDefined (add_prefix_to_kwrds s, sfx), e)
| Assignment (lhs, t, e2) ->
Assignment
(Stmt.Helpers.map_lhs_variable ~f:add_prefix_to_kwrds lhs, t, e2)
| For e ->
For
{ e with
loopvar= add_prefix_to_kwrds e.loopvar
; body= change_kwrds_stmts e.body }
| x -> map Fn.id change_kwrds_stmts x in
{s with pattern}
(** A list of functions which return an Eigen block expression *)
let eigen_block_expr_fns =
["head"; "tail"; "segment"; "col"; "row"; "block"; "sub_row"; "sub_col"]
|> String.Set.of_list
(** Eval indexed eigen types in UDF calls to prevent
infinite template expansion if the call is recursive
Infinite expansion can happen only when the call graph is cyclic.
The strategy here is to build the call graph one edge at the time
and check if adding that edge creates a cycle. If it does, insert
an [eval()] to stop template expansion.
All relevant function calls are recorded transitively in [callgraph],
meaning if [A] calls [B] and [B] calls [C] then [callgraph[A] = {B,C}].
In the worst case every function calls every other, [callgraph] has
size O(n^2) and this algorithm is O(n^3) so it's important to track
only the function calls that really can propagate eigen templates.
*)
let break_eigen_cycles functions_block =
let callgraph = String.Table.create () in
let eval_eigen_cycles fun_args calls (f : _ Program.fun_def) =
let open Expr.Fixed in
let rec is_potentially_recursive = function
| {pattern= Var name; _} -> Set.mem fun_args name
| {pattern= Indexed (e, _); _} -> is_potentially_recursive e
| {pattern= FunApp (StanLib (fname, _, _), e :: _); _} ->
Set.mem eigen_block_expr_fns fname && is_potentially_recursive e
| _ -> false in
let rec map_args name args =
let args = List.map ~f:rewrite_expr args in
let can_recurse, eval_args =
List.fold_map ~init:false
~f:(fun is_rec e ->
if is_potentially_recursive e then
( true
, {e with pattern= FunApp (StanLib ("eval", FnPlain, AoS), [e])}
)
else (is_rec, e))
args in
if not can_recurse then args
else if name = f.fdname then eval_args
else
match Hashtbl.find callgraph name with
| Some nested when Hash_set.mem nested f.fdname -> eval_args
| Some nested ->
(* [calls] records all functions reachable from the current function *)
Hash_set.add calls name;
Hash_set.iter nested ~f:(Hash_set.add calls);
args
| None ->
Hash_set.add calls name;
args
and rewrite_expr : Expr.Typed.t -> Expr.Typed.t = function
| {pattern= FunApp ((UserDefined (name, _) as kind), args); _} as e ->
{e with pattern= FunApp (kind, map_args name args)}
| { pattern=
FunApp
( (StanLib (_, _, _) as kind)
, ({pattern= Var name; meta= {type_= UFun _; _}} as f) :: args )
; _ } as e ->
(* higher-order function -- just pretend it's a direct call *)
{e with pattern= FunApp (kind, f :: map_args name args)}
| e -> {e with pattern= Pattern.map rewrite_expr e.pattern} in
let rec rewrite_stmt s =
let open Stmt.Fixed in
match s with
| {pattern= Pattern.NRFunApp ((UserDefined (name, _) as kind), args); _}
as s ->
{s with pattern= NRFunApp (kind, map_args name args)}
| s -> {s with pattern= Pattern.map rewrite_expr rewrite_stmt s.pattern}
in
Program.map_fun_def rewrite_stmt f in
let break_cycles (Program.{fdname; fdargs; _} as fd) =
let fun_args =
List.filter_map fdargs ~f:(fun (_, n, t) ->
if UnsizedType.is_eigen_type t then Some n else None)
|> String.Set.of_list in
if Set.is_empty fun_args then fd
else
let calls = String.Hash_set.create () in
let fndef = eval_eigen_cycles fun_args calls fd in
if not (Hash_set.is_empty calls) then (
(* update [callgraph] with the call paths going through the current function *)
Hashtbl.map_inplace callgraph ~f:(fun x ->
if Hash_set.mem x fdname then Hash_set.union calls x else x);
Hashtbl.update callgraph fdname
~f:(Option.value_map ~f:(Hash_set.union calls) ~default:calls));
fndef in
List.map ~f:break_cycles functions_block
let opencl_trigger_restrictions =
String.Map.of_alist_exn
[ ( "bernoulli_lpmf"
, [ [ (0, UnsizedType.DataOnly, UnsizedType.UArray UnsizedType.UInt)
; (1, UnsizedType.DataOnly, UnsizedType.UReal) ] ] )
; ( "bernoulli_logit_glm_lpmf"
, [ (* Array of conditions under which we do not want to move to OpenCL *)
[(1, UnsizedType.DataOnly, UnsizedType.URowVector)]
(* Argument 1 (0-based indexing) is a row vector *) ] )
; ( "categorical_logit_glm_lpmf"
, [[(1, UnsizedType.DataOnly, UnsizedType.URowVector)]] )
; ( "exponential_lpdf"
, [ [ (0, UnsizedType.AutoDiffable, UnsizedType.UVector)
; (1, UnsizedType.DataOnly, UnsizedType.UReal) ] ] )
; ( "neg_binomial_2_log_glm_lpmf"
, [[(1, UnsizedType.DataOnly, UnsizedType.URowVector)]] )
; ( "normal_id_glm_lpdf"
, [[(1, UnsizedType.DataOnly, UnsizedType.URowVector)]] )
; ( "ordered_logistic_glm_lpmf"
, [[(1, UnsizedType.DataOnly, UnsizedType.URowVector)]] )
; ( "poisson_log_glm_lpmf"
, [[(1, UnsizedType.DataOnly, UnsizedType.URowVector)]] )
; ("std_normal_lpdf", [[(0, UnsizedType.AutoDiffable, UnsizedType.UVector)]])
; ( "uniform_lpdf"
, [ [ (0, UnsizedType.AutoDiffable, UnsizedType.UVector)
; (1, UnsizedType.DataOnly, UnsizedType.UReal)
; (1, UnsizedType.DataOnly, UnsizedType.UReal) ] ] ) ]
let opencl_supported_functions =
[ "bernoulli_lpmf"; "bernoulli_logit_lpmf"; "bernoulli_logit_glm_lpmf"
; "beta_lpdf"; "beta_proportion_lpdf"; "binomial_lpmf"
; "categorical_logit_glm_lpmf"; "cauchy_lpdf"; "chi_square_lpdf"
; "double_exponential_lpdf"; "exp_mod_normal_lpdf"; "exponential_lpdf"
; "frechet_lpdf"; "gamma_lpdf"; "gumbel_lpdf"; "inv_chi_square_lpdf"
; "inv_gamma_lpdf"; "logistic_lpdf"; "lognormal_lpdf"; "neg_binomial_lpmf"
; "neg_binomial_2_lpmf"; "neg_binomial_2_log_lpmf"
; "neg_binomial_2_log_glm_lpmf"; "normal_lpdf"; "normal_id_glm_lpdf"
; "ordered_logistic_glm_lpmf"; "pareto_lpdf"; "pareto_type_2_lpdf"
; "poisson_lpmf"; "poisson_log_lpmf"; "poisson_log_glm_lpmf"; "rayleigh_lpdf"
; "scaled_inv_chi_square_lpdf"; "skew_normal_lpdf"; "std_normal_lpdf"
; "student_t_lpdf"; "uniform_lpdf"; "weibull_lpdf"; "binomial_logit_lpmf"
; "binomial_logit_glm_lpmf" ]
|> String.Set.of_list
let opencl_suffix = "_opencl__"
let to_matrix_cl e =
Expr.Fixed.
{e with pattern= FunApp (StanLib ("to_matrix_cl", FnPlain, AoS), [e])}
let rec switch_expr_to_opencl available_cl_vars (Expr.Fixed.{pattern; _} as e) =
let is_avail = List.mem available_cl_vars ~equal:( = ) in
let to_cl (Expr.Fixed.{pattern; meta= {Expr.Typed.Meta.type_; _}} as e) =
match (pattern, type_) with
| Var s, _ when is_avail s ->
Expr.Fixed.{e with pattern= Var (s ^ opencl_suffix)}
| _, UnsizedType.(UInt | UReal) -> e
| _, _ -> to_matrix_cl e in
let check_type args (i, ad, t) =
let arg = List.nth_exn args i in
Expr.Typed.type_of arg = t
&& UnsizedType.autodifftype_can_convert (Expr.Typed.adlevel_of arg) ad in
let is_restricted args = List.exists ~f:(List.for_all ~f:(check_type args)) in
let maybe_map_args args req_args =
match req_args with
| Some x when is_restricted args x -> args
| None | Some _ -> List.map args ~f:to_cl in
let is_fn_opencl_supported f = Set.mem opencl_supported_functions f in
match pattern with
| FunApp (StanLib (f, sfx, mem_pattern), args) when is_fn_opencl_supported f
->
let trigger = Map.find opencl_trigger_restrictions f in
{ e with
pattern=
FunApp (StanLib (f, sfx, mem_pattern), maybe_map_args args trigger) }
| x ->
{ e with
pattern=
Expr.Fixed.Pattern.map (switch_expr_to_opencl available_cl_vars) x }
let rec base_type = function
| SizedType.SArray (t, _) -> base_type t
| SVector _ | SRowVector _ | SMatrix _ -> UnsizedType.UReal
| SComplexVector _ | SComplexRowVector _ | SComplexMatrix _ -> UComplex
| x -> SizedType.to_unsized x
let pos = "pos__"
let meta_from_sizedtype st =
let type_ = SizedType.to_unsized st in
{ Expr.Typed.Meta.empty with
type_
; adlevel= UnsizedType.fill_adtype_for_type DataOnly type_ }
let munge_tuple_name name =
Str.global_replace (Str.regexp_string ".") "_dot_" name
let make_tuple_temp name = munge_tuple_name name ^ "_temp__"
(** This function is essentially copied from [var_context_read],
but rather than calling ReadDataFn, this indexes
into the flattened versions of the tuple data
created by [var_context_read] when it encounters an array of tuples
@param enclosing_tuple_name The name (in the sense of [Stmt.Helpers.get_lhs_name])
of the element of the tuple this recursive call is handling. This is
used to generate the appropriate [_flat__] variable to pull from
@param origin_type The type of the flat variable for this call, if one exists.
In situations where this is an array of tuples still, this type is unused.
*)
let rec var_context_read_inside_tuple enclosing_tuple_name origin_type
((decl_id_lval : 'a Stmt.Fixed.Pattern.lvalue), _, st) =
let smeta =
(* avoid a bunch of redundant current_statement assigns *)
Location_span.empty in
let unsized = SizedType.to_unsized st in
let scalar = base_type st in
let flat_type = UnsizedType.UArray scalar in
let decl_id = Stmt.Helpers.get_lhs_name decl_id_lval in
let decl_var =
{ Expr.Fixed.pattern= Var decl_id
; meta= Expr.Typed.Meta.{loc= smeta; type_= unsized; adlevel= DataOnly} }
in
let swrap stmt = {Stmt.Fixed.pattern= stmt; meta= smeta} in
let pos_var = {Expr.Fixed.pattern= Var pos; meta= Expr.Typed.Meta.empty} in
let flat_name decl_id = munge_tuple_name decl_id ^ "_flat__" in
let enclosing_tuple_flat, enclosing_tuple_pos =
let name = munge_tuple_name enclosing_tuple_name in
(name ^ "_flat__", name ^ "_flat__pos__") in
let origin_name =
let var = Expr.Helpers.variable enclosing_tuple_flat in
{ var with
meta=
{ var.meta with
type_= origin_type
; adlevel= UnsizedType.fill_adtype_for_type DataOnly origin_type } }
in
let type_size =
Expr.Helpers.(
binop (variable enclosing_tuple_pos) Plus (SizedType.io_size st)) in
let end_position = Expr.Helpers.(binop type_size Minus loop_bottom) in
let origin =
match unsized with
| UInt | UReal | UComplex ->
(* Scalars get one index *)
Expr.Helpers.add_int_index origin_name
(Index.Single (Expr.Helpers.variable enclosing_tuple_pos))
| _ ->
Expr.Helpers.add_int_index origin_name
(Index.Between
(Expr.Helpers.variable enclosing_tuple_pos, end_position)) in
let incr_tuple_pos =
Stmt.Fixed.Pattern.Assignment
(Stmt.Helpers.lvariable enclosing_tuple_pos, UInt, type_size)
|> swrap in
match st with
| SInt | SReal | SComplex ->
[Assignment (decl_id_lval, unsized, origin) |> swrap; incr_tuple_pos]
| SArray ((SInt | SReal), _) ->
[Assignment (decl_id_lval, flat_type, origin) |> swrap; incr_tuple_pos]
| STuple subtypes ->
let elements =
List.mapi
~f:(fun iter x ->
( (Stmt.Fixed.Pattern.LTupleProjection (decl_id_lval, iter + 1), [])
, smeta
, x ))
subtypes in
let enclosing_names =
List.mapi
~f:(fun i _ -> enclosing_tuple_name ^ "." ^ string_of_int (i + 1))
subtypes in
List.map2_exn
~f:(fun name projection ->
var_context_read_inside_tuple name origin_type projection)
enclosing_names elements
|> List.concat
| SArray _ when SizedType.contains_tuple st ->
let tupl, dims = SizedType.get_array_dims st in
let tuple_component_names, tuple_types =
match tupl with
| STuple subtypes ->
( List.mapi
~f:(fun i _ ->
enclosing_tuple_name ^ "." ^ string_of_int (i + 1))
subtypes
, subtypes )
| _ -> ([], []) in
let temps =
List.map2_exn
~f:(fun name t ->
Stmt.Fixed.Pattern.Decl
{ decl_adtype=
UnsizedType.fill_adtype_for_type DataOnly
(SizedType.to_unsized t)
; decl_id= make_tuple_temp name
; decl_type= Sized t
; initialize= Default }
|> swrap)
tuple_component_names tuple_types in
let loop =
let final_assignment loopvars =
let assign_lval =
let lbase, idxs = decl_id_lval in
( lbase
, idxs @ List.map ~f:(fun e -> Index.Single e) (List.rev loopvars)
) in
[ Stmt.Fixed.Pattern.Assignment
( assign_lval
, unsized
, Expr.Helpers.tuple_expr
(List.map2_exn
~f:(fun n st ->
Expr.Fixed.
{ pattern= Var (make_tuple_temp n)
; meta= meta_from_sizedtype st })
tuple_component_names tuple_types) )
|> swrap ] in
[ Stmt.Helpers.mk_nested_for (List.rev dims)
(fun loopvars ->
Stmt.Fixed.
{ meta= smeta
; pattern=
SList
((List.map2_exn
~f:(fun io_name st ->
let temp_name = make_tuple_temp io_name in
var_context_read_inside_tuple io_name
(UnsizedType.wind_array_type
(SizedType.to_unsized st, List.length dims))
(Stmt.Helpers.lvariable temp_name, smeta, st))
tuple_component_names tuple_types
|> List.concat)
@ final_assignment loopvars) })
smeta ] in
[Block (temps @ loop) |> swrap]
| SVector _ | SRowVector _ | SMatrix _ | SComplexMatrix _
|SComplexRowVector _ | SComplexVector _ | SArray _ ->
let decl, assign, flat_var =
let decl_id_flat = flat_name decl_id in
( Stmt.Fixed.Pattern.Decl
{ decl_adtype= AutoDiffable
; decl_id= decl_id_flat
; decl_type= Unsized flat_type
; initialize= Default }
|> swrap
, Assignment (Stmt.Helpers.lvariable decl_id_flat, flat_type, origin)
|> swrap
, { Expr.Fixed.pattern= Var decl_id_flat
; meta=
Expr.Typed.Meta.{loc= smeta; type_= flat_type; adlevel= DataOnly}
} ) in
let bodyfn _ var =
let pos_increment =
[ Assignment
( Stmt.Helpers.lvariable pos
, UInt
, Expr.Helpers.(binop pos_var Plus one) )
|> swrap ] in
let read_indexed _ =
{ Expr.Fixed.pattern= Indexed (flat_var, [Single pos_var])
; meta= Expr.Typed.Meta.{flat_var.meta with type_= scalar} } in
SList
(Stmt.Helpers.assign_indexed (SizedType.to_unsized st) decl_id_lval
smeta read_indexed var
:: pos_increment)
|> swrap in
let pos_reset =
Stmt.Fixed.Pattern.Assignment
(Stmt.Helpers.lvariable pos, UInt, Expr.Helpers.loop_bottom)
|> swrap in
[ Block
[ decl; assign; pos_reset
; Stmt.Helpers.for_scalar_inv st bodyfn decl_var smeta; incr_tuple_pos
]
|> swrap ]
let rec var_context_read_internal
((decl_id_lval : 'a Stmt.Fixed.Pattern.lvalue), smeta, st) =
let unsized = SizedType.to_unsized st in
let scalar = base_type st in
let flat_type = UnsizedType.UArray scalar in
let decl_id = Stmt.Helpers.get_lhs_name decl_id_lval in
let decl_var =
{ Expr.Fixed.pattern= Var decl_id
; meta= Expr.Typed.Meta.{loc= smeta; type_= unsized; adlevel= DataOnly} }
in
let swrap stmt = {Stmt.Fixed.pattern= stmt; meta= smeta} in
let swrap_noloc stmt =
(* not strictly necessary, but lets us cut down on the number of
curent_statement__ = X lines in the generated code *)
{Stmt.Fixed.pattern= stmt; meta= Location_span.empty} in
let pos_var = {Expr.Fixed.pattern= Var pos; meta= Expr.Typed.Meta.empty} in
let flat_name decl_id = munge_tuple_name decl_id ^ "_flat__" in
let readfnapp decl_id flat_type =
Expr.Helpers.internal_funapp FnReadData
[{decl_var with pattern= Lit (Str, remove_prefix decl_id)}]
Expr.Typed.Meta.{decl_var.meta with type_= flat_type} in
match st with
| SInt | SReal | SComplex ->
let e =
{ Expr.Fixed.pattern=
Indexed
(readfnapp decl_id flat_type, [Single Expr.Helpers.loop_bottom])
; meta= {decl_var.meta with type_= unsized} } in
[Assignment (decl_id_lval, unsized, e) |> swrap]
| SArray ((SInt | SReal), _) ->
[ Assignment (decl_id_lval, flat_type, readfnapp decl_id flat_type)
|> swrap ]
| STuple subtypes ->
let sub_sts =
List.mapi
~f:(fun iter x ->
( (Stmt.Fixed.Pattern.LTupleProjection (decl_id_lval, iter + 1), [])
, (if iter = 0 then smeta
(* don't repeat locations in inner loops *)
else Location_span.empty)
, x ))
subtypes in
List.concat_map ~f:var_context_read_internal sub_sts
| SArray _ when SizedType.contains_tuple st ->
(* The IO format for tuples is complicated in this case.
Therefore, we need to do the following
1. Make "_flat__" decls for everything
2. Declare a temp for each item of this tuple
3. in a loop:
i. call [var_context_read_inside_tuple] with the temp variable as the destination
this function does essentially the same things recursively, but it doesn't create
more "_flat__" variables for deeper nested arrays-of-tuples.
ii. assign those temps (forwarding as tuple) to this variable, properly indexed.
*)
let tupl, dims = SizedType.get_array_dims st in
let flat_decls =
(* Here we need to go recursively all the way down the tuple *)
let flat_io_names =
UnsizedType.enumerate_tuple_names_io decl_id
(SizedType.to_unsized tupl) in
let flat_vars = List.map ~f:flat_name flat_io_names in
let flat_types = SizedType.flatten_tuple_io tupl in
List.map3_exn
~f:(fun variable_name io_name st ->
let typ = SizedType.to_unsized st in
let scalar_type = UnsizedType.internal_scalar typ in
let array_type = UnsizedType.UArray scalar_type in
[ Stmt.Fixed.Pattern.Decl
{ decl_adtype= AutoDiffable
; decl_id= variable_name
; decl_type= Unsized array_type
; initialize= Default }
|> swrap_noloc
; Assignment
( Stmt.Helpers.lvariable variable_name
, typ
, readfnapp io_name array_type )
|> swrap
; Stmt.Fixed.Pattern.Decl
{ decl_adtype= DataOnly
; decl_id= variable_name ^ "pos__"
; decl_type= Unsized UInt
; initialize= Default }
|> swrap_noloc
; Stmt.Fixed.Pattern.Assignment
( Stmt.Helpers.lvariable (variable_name ^ "pos__")
, UInt
, Expr.Helpers.loop_bottom )
|> swrap_noloc ])
flat_vars flat_io_names flat_types
|> List.concat in
(* from now on, we only care about things at this level,
calling [var_context_read_inside_tuple] *)
let tuple_component_names, tuple_types =
match tupl with
| STuple subtypes ->
( List.mapi
~f:(fun i _ -> decl_id ^ "." ^ string_of_int (i + 1))
subtypes
, subtypes )
| _ -> (* impossible by above pattern patch *) ([], []) in
let temps =
List.map2_exn
~f:(fun name t ->
Stmt.Fixed.Pattern.Decl
{ decl_adtype=
UnsizedType.fill_adtype_for_type DataOnly
(SizedType.to_unsized t)
; decl_id= make_tuple_temp name
; decl_type= Sized t
; initialize= Default }
|> swrap_noloc)
tuple_component_names tuple_types in
let loop =
let final_assignment loopvars =
let assign_lval =
let lbase, idxs = decl_id_lval in
( lbase
, idxs @ List.map ~f:(fun e -> Index.Single e) (List.rev loopvars)
) in
[ Stmt.Fixed.Pattern.Assignment
( assign_lval
, unsized
, Expr.Helpers.tuple_expr
(List.map2_exn
~f:(fun n st ->
Expr.Fixed.
{ pattern= Var (make_tuple_temp n)
; meta= meta_from_sizedtype st })
tuple_component_names tuple_types) )
|> swrap_noloc ] in
[ Stmt.Helpers.mk_nested_for (List.rev dims)
(fun loopvars ->
SList
((List.map2_exn
~f:(fun io_name st ->
let temp_name = make_tuple_temp io_name in
var_context_read_inside_tuple io_name
(UnsizedType.wind_array_type
(SizedType.to_unsized st, List.length dims))
( Stmt.Helpers.lvariable temp_name
, Location_span.empty
, st ))
tuple_component_names tuple_types
|> List.concat)
@ final_assignment loopvars)
|> swrap_noloc)
Location_span.empty ] in
[Block (flat_decls @ temps @ loop) |> swrap]
| SVector _ | SRowVector _ | SMatrix _ | SComplexMatrix _
|SComplexRowVector _ | SComplexVector _ | SArray _ ->
let decl, assign, flat_var =
let decl_id_flat = flat_name decl_id in
( Stmt.Fixed.Pattern.Decl
{ decl_adtype= AutoDiffable
; decl_id= decl_id_flat
; decl_type= Unsized flat_type
; initialize= Uninit }
|> swrap
, Assignment
( Stmt.Helpers.lvariable decl_id_flat
, flat_type
, readfnapp decl_id flat_type )
|> swrap
, { Expr.Fixed.pattern= Var decl_id_flat
; meta=
Expr.Typed.Meta.{loc= smeta; type_= flat_type; adlevel= DataOnly}
} ) in
let bodyfn _ var =
let pos_increment =
[ Assignment
( Stmt.Helpers.lvariable pos
, UInt
, Expr.Helpers.(binop pos_var Plus one) )
|> swrap_noloc ] in
let read_indexed _ =
{ Expr.Fixed.pattern= Indexed (flat_var, [Single pos_var])
; meta= Expr.Typed.Meta.{flat_var.meta with type_= scalar} } in
SList
(Stmt.Helpers.assign_indexed (SizedType.to_unsized st) decl_id_lval
Location_span.empty read_indexed var
:: pos_increment)
|> swrap_noloc in
let pos_reset =
Stmt.Fixed.Pattern.Assignment
(Stmt.Helpers.lvariable pos, UInt, Expr.Helpers.loop_bottom)
|> swrap_noloc in
[ Block
[ decl; assign; pos_reset
; Stmt.Helpers.for_scalar_inv st bodyfn decl_var Location_span.empty
]
|> swrap ]
let var_context_read p =
(* this never uses the declare-define fast path at the moment *)
(var_context_read_internal p, None)
(*
Get the dimension expressions that are expected by constrain/unconstrain
functions for a sized type.
For constrains that return square / lower triangular matrices the C++
only wants one of the matrix dimensions.
*)
let read_constrain_dims constrain_transform st =
let rec constrain_get_dims st =
match st with
| SizedType.SInt | SReal | SComplex | STuple _ -> []
| SArray (t, dim) -> dim :: constrain_get_dims t
| SVector (_, d)
|SRowVector (_, d)
|SComplexVector d
|SComplexRowVector d ->
[d]
| SMatrix (_, _, dim2) | SComplexMatrix (_, dim2) -> [dim2] in
match constrain_transform with
| Transformation.CholeskyCorr | Correlation | Covariance ->
constrain_get_dims st
| _ -> SizedType.get_dims st
let plain_deserializer_read loc out_constrained_st =
let ut = SizedType.to_unsized out_constrained_st in
let dims = SizedType.get_dims out_constrained_st in
let emeta = Expr.Typed.Meta.create ~loc ~type_:ut ~adlevel:AutoDiffable () in
Expr.(
Helpers.(
internal_funapp FnReadDeserializer dims Typed.Meta.{emeta with type_= ut}))
let param_deserializer_read
( decl_id_lval
, smeta
, Program.{out_constrained_st= cst; out_block; out_trans; _} ) =
if not (out_block = Parameters) then ([], None)
else
let basic_read (cst, out_trans) =
let ut = SizedType.to_unsized cst in
let emeta =
Expr.Typed.Meta.create ~loc:smeta ~type_:ut
~adlevel:(UnsizedType.fill_adtype_for_type AutoDiffable ut)
() in
let dims = read_constrain_dims out_trans cst in
Expr.Helpers.internal_funapp
(FnReadParam
{ constrain= out_trans
; dims
; mem_pattern= SizedType.get_mem_pattern cst })
[] emeta in
let rec read_stmt (lval, cst, out_trans) =
match cst with
| SizedType.SArray _ when SizedType.contains_tuple cst ->
let tupl, array_dims = SizedType.get_array_dims cst in
( [ Stmt.Helpers.mk_nested_for (List.rev array_dims)
(fun loopvars ->
Stmt.Fixed.
{ meta= smeta
; pattern=
SList
(fst
(read_stmt
(let lbase, idxs = lval in
( ( lbase
, idxs
@ List.map
~f:(fun e -> Index.Single e)
(List.rev loopvars) )
, tupl
, out_trans )))) })
smeta ]
, None )
| SizedType.STuple _ ->
let subtys =
Utils.(zip_stuple_trans_exn cst (tuple_trans_exn out_trans)) in
let sub_sts =
List.mapi
~f:(fun iter (st, trans) ->
( (Stmt.Fixed.Pattern.LTupleProjection (lval, iter + 1), [])
, st
, trans ))
subtys in
(List.concat_map ~f:(Fn.compose fst read_stmt) sub_sts, None)
| _ -> (
let read = basic_read (cst, out_trans) in
( [ Stmt.Fixed.
{ pattern=
Pattern.Assignment (lval, SizedType.to_unsized cst, read)
; meta= smeta } ]
, (* if we're assigning to a top level variable, we can opt into to the declare-define *)
match lval with
| Stmt.Fixed.Pattern.LVariable _, [] -> Some read
| _ -> None )) in
read_stmt (decl_id_lval, cst, out_trans)
let escape_name str =
str
|> String.substr_replace_all ~pattern:"." ~with_:"_"
|> String.substr_replace_all ~pattern:"-" ~with_:"_"
(** Make sure that all if-while-and-for bodies are safely wrapped in a block in
such a way that we can insert a location update before. The blocks make sure
that the program with the inserted location update is still well-formed C++ though.
*)
let rec ensure_body_in_block (Stmt.Fixed.{pattern; _} as stmt) =
let in_block stmt =
let pattern =
Stmt.Fixed.(
match stmt.pattern with
| Block l | SList l -> Pattern.Block l
| _ -> Block [stmt]) in
{stmt with pattern} in
let ensure_body_in_block_base pattern =
Stmt.Fixed.Pattern.(
match pattern with
| IfElse (_, _, _) | While (_, _) | For _ -> map Fn.id in_block pattern
| _ -> pattern) in
let pattern =
ensure_body_in_block_base
Stmt.Fixed.(Pattern.map Fn.id ensure_body_in_block pattern) in
{stmt with pattern}
let rec flatten_slists_list ls =
let flatten_slist stmt =
Stmt.Fixed.(match stmt.pattern with SList ls -> ls | _ -> [stmt]) in
let rec flatten_slists_stmt stmt =
let pattern =
Stmt.Fixed.(
match stmt.pattern with
| Block ls ->
Pattern.Block
(List.concat_map
~f:(Fn.compose flatten_slist flatten_slists_stmt)
ls)
| pattern -> Pattern.map Fn.id flatten_slists_stmt pattern) in
{stmt with pattern} in
List.concat_map ls ~f:(fun stmt ->
Stmt.Fixed.(
match stmt.pattern with
| SList ls -> flatten_slists_list ls
| _ -> [stmt]))
|> List.map ~f:flatten_slists_stmt
let%expect_test "Flatten slists" =
let e pattern = Expr.Fixed.{meta= (); pattern} in
let s pattern = Stmt.Fixed.{meta= (); pattern} in
let stmt =
Stmt.Fixed.Pattern.(
[ SList
[ Block
[ SList
[ While (e (Var "hi"), Block [SList [Break |> s] |> s] |> s)
|> s ]
|> s ]
|> s ]
|> s ]
|> flatten_slists_list) in
print_s [%sexp (stmt : (unit, unit) Stmt.Fixed.t list)];
[%expect
{|
(((pattern
(Block
(((pattern
(While ((pattern (Var hi)) (meta ()))
((pattern (Block (((pattern Break) (meta ()))))) (meta ()))))
(meta ())))))
(meta ()))) |}]
let add_reads vars mkread stmts =
let vars = List.map ~f:(fun (id, l, outvar) -> (id, (l, outvar))) vars in
let var_names = String.Map.of_alist_exn vars in
let add_read_to_decl (Stmt.Fixed.{pattern; _} as stmt) =
match pattern with
| Decl ({decl_id; _} as decl_rec) when Map.mem var_names decl_id -> (
let loc, out = Map.find_exn var_names decl_id in
let param_reader, op_assign =
mkread (Stmt.Helpers.lvariable decl_id, loc, out) in
match op_assign with
| Some e ->
[{stmt with pattern= Decl {decl_rec with initialize= Assign e}}]
| None -> stmt :: param_reader)
| _ -> [stmt] in
List.concat_map ~f:add_read_to_decl stmts
let param_serializer_write ?(unconstrain = false)
(decl_id, Program.{out_constrained_st; out_trans; _}) =
let rec write (var, st, trans) =
match (unconstrain, st, trans) with
| ( true
, SizedType.STuple subtypes
, Transformation.TupleTransformation transforms ) ->
let tuple_elements =
subtypes
|> List.mapi ~f:(fun iter x ->
(Expr.Helpers.add_tuple_index var (iter + 1), x))
|> List.map2_exn ~f:(fun t (v, st) -> (v, st, t)) transforms in
List.concat_map ~f:write tuple_elements
| true, SArray _, TupleTransformation _ ->
let tupl, array_dims = SizedType.get_array_dims st in
[ Stmt.Helpers.mk_nested_for (List.rev array_dims)
(fun loopvars ->
Stmt.Fixed.
{ meta= Location_span.empty
; pattern=
SList
(write
( List.fold ~f:Expr.Helpers.add_int_index ~init:var
(List.map
~f:(fun e -> Index.Single e)
(List.rev loopvars))
, tupl
, trans )) })
Location_span.empty ]
| true, _, _ ->
[ Stmt.Helpers.internal_nrfunapp
(FnWriteParam {unconstrain_opt= Some trans; var})
[] Location_span.empty ]
| false, _, _ ->
[ Stmt.Helpers.internal_nrfunapp
(FnWriteParam {unconstrain_opt= None; var})
[] Location_span.empty ] in
let decl_var =
{ Expr.Fixed.pattern= Var decl_id
; meta=
Expr.Typed.Meta.
{ loc= Location_span.empty
; type_= SizedType.to_unsized out_constrained_st
; adlevel= DataOnly } } in
write (decl_var, out_constrained_st, out_trans)
(**
Generate write instructions for unconstrained types. For scalars,
matrices, vectors, and arrays with one dimension we can write
these directly, but for arrays of arrays/vectors/matrices we
need to write them in "column major order"
*)
let param_unconstrained_serializer_write
(decl_id, smeta, Program.{out_constrained_st; _}) =
let rec write (var, st) =
match st with
| SizedType.STuple subtypes ->
let elements =
List.mapi
~f:(fun iter x -> (Expr.Helpers.add_tuple_index var (iter + 1), x))
subtypes in
List.concat_map ~f:write elements
| _ when SizedType.is_recursive_container st ->
let nonarray_st, array_dims = SizedType.get_scalar_and_dims st in
[ Stmt.Helpers.mk_nested_for (List.rev array_dims)
(fun loopvars ->
Stmt.Fixed.
{ meta= Location_span.empty
; pattern=
SList
(write
( List.fold ~f:Expr.Helpers.add_int_index ~init:var
(List.map
~f:(fun e -> Index.Single e)
(List.rev loopvars))
, nonarray_st )) })
smeta ]
| _ ->
[ Stmt.Helpers.internal_nrfunapp
(FnWriteParam {unconstrain_opt= None; var})
[] Location_span.empty ] in
let var =
{ Expr.Fixed.pattern= Var decl_id
; meta=
Expr.Typed.Meta.
{ loc= Location_span.empty
; type_= SizedType.to_unsized out_constrained_st
; adlevel= DataOnly } } in
write (var, out_constrained_st)
(** Reads in parameters from a var_context, the same way as is done in the constructor,
and then writes out the unconstrained versions *)
let var_context_unconstrain_transform (decl_id, smeta, outvar) =
let st = outvar.Program.out_constrained_st in
Stmt.Fixed.
{ pattern=
Decl
{ decl_adtype=
UnsizedType.fill_adtype_for_type AutoDiffable
(SizedType.to_unsized st)
; decl_id
; decl_type= Type.Sized st
; initialize= Default }
; meta= smeta }
:: var_context_read_internal (Stmt.Helpers.lvariable decl_id, smeta, st)
@ param_serializer_write ~unconstrain:true (decl_id, outvar)
(** Reads in parameters from a serializer and then writes out the unconstrained versions *)
let array_unconstrain_transform (decl_id, smeta, outvar) =
let decl =
Stmt.Fixed.
{ pattern=
Decl
{ decl_adtype=
UnsizedType.fill_adtype_for_type AutoDiffable
(SizedType.to_unsized outvar.Program.out_constrained_st)
; decl_id
; decl_type= Type.Sized outvar.Program.out_constrained_st
; initialize= Default }
; meta= smeta } in
let rec read (lval, st) =
match st with
| SizedType.STuple subtypes ->
let elements =
List.mapi
~f:(fun iter x ->
((Stmt.Fixed.Pattern.LTupleProjection (lval, iter + 1), []), x))
subtypes in
List.concat_map ~f:read elements
| _ when SizedType.contains_tuple st ->
let tupl, array_dims = SizedType.get_scalar_and_dims st in
[ Stmt.Helpers.mk_nested_for (List.rev array_dims)
(fun loopvars ->
Stmt.Fixed.
{ meta= Location_span.empty
; pattern=
SList
(read
(let lbase, idxs = lval in
( ( lbase
, idxs
@ List.map
~f:(fun e -> Index.Single e)
(List.rev loopvars) )
, tupl ))) })
smeta ]
| _ when SizedType.is_recursive_container st ->
(* non-tuple containing array *)
let nonarray_st, array_dims = SizedType.get_scalar_and_dims st in
[ Stmt.Helpers.mk_nested_for (List.rev array_dims)
(fun loopvars ->
let assign_lval =
let lbase, idxs = lval in
( lbase
, idxs
@ List.map ~f:(fun e -> Index.Single e) (List.rev loopvars) )
in
Stmt.Fixed.
{ meta= smeta
; pattern=
Assignment
( assign_lval
, SizedType.to_unsized nonarray_st
, plain_deserializer_read smeta
(SizedType.internal_scalar nonarray_st) ) })
smeta ]
| _ ->
[ Stmt.Fixed.
{ meta= smeta
; pattern=
Assignment
( lval
, SizedType.to_unsized st
, plain_deserializer_read smeta st ) } ] in
decl
:: read (Stmt.Helpers.lvariable decl_id, outvar.Program.out_constrained_st)
@ param_serializer_write ~unconstrain:true (decl_id, outvar)
let rec contains_var_expr is_vident accum Expr.Fixed.{pattern; _} =
accum
||
match pattern with
| Var v when is_vident v -> true
| pattern ->
Expr.Fixed.Pattern.fold (contains_var_expr is_vident) false pattern
let rec insert_before f to_insert = function
| [] -> to_insert
| hd :: tl ->
if f hd then to_insert @ (hd :: tl)
else hd :: insert_before f to_insert tl
let is_opencl_var = String.is_suffix ~suffix:opencl_suffix
let rec collect_vars_expr is_target accum Expr.Fixed.{pattern; _} =
Set.union accum
(match pattern with
| Var s when is_target s -> String.Set.of_list [s]
| x ->
Expr.Fixed.Pattern.fold (collect_vars_expr is_target) String.Set.empty x)
let collect_opencl_vars s =
let rec go accum s =
Stmt.Fixed.(
Pattern.fold (collect_vars_expr is_opencl_var) go accum s.pattern) in
go String.Set.empty s
let%expect_test "collect vars expr" =
let mkvar s = Expr.{Fixed.pattern= Var s; meta= Typed.Meta.empty} in
let args = List.map ~f:mkvar ["y"; "x_opencl__"; "z"; "w_opencl__"] in
let fnapp =
Expr.
{ Fixed.pattern= FunApp (StanLib ("print", FnPlain, AoS), args)
; meta= Typed.Meta.empty } in