Skip to content

Commit

Permalink
Merge pull request #1441 from stan-dev/feature/decl-assign-params
Browse files Browse the repository at this point in the history
Feature/decl assign params
  • Loading branch information
SteveBronder authored Jul 25, 2024
2 parents 606a395 + 36271ca commit 4f0d26c
Show file tree
Hide file tree
Showing 31 changed files with 10,203 additions and 13,258 deletions.
70 changes: 64 additions & 6 deletions src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,16 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
Set.Poly.union_list
[ acc; query_expr acc predicate
; query_initial_demotable_stmt true acc body ]
| Decl {decl_type= Type.Sized st; decl_id; _}
when SizedType.is_complex_type st ->
Set.add acc decl_id
| Decl {decl_type= Type.Sized st; decl_id; initialize; _} ->
let complex_name =
match SizedType.is_complex_type st with
| true -> Set.Poly.singleton decl_id
| false -> Set.Poly.empty in
let init_names =
match initialize with
| Assign e -> query_expr acc e
| _ -> Set.Poly.empty in
Set.union acc (Set.union complex_name init_names)
| Skip | Break | Continue | Decl _ -> acc

(** Look through a statement to see whether the objects used in it need to be
Expand Down Expand Up @@ -419,6 +426,13 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t)
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
| true -> Set.add all_rhs_eigen_names assign_name
| false -> Set.Poly.empty)
| Decl {decl_id; initialize= Assign e; _} -> (
let all_rhs_eigen_names = query_var_eigen_names e in
if Set.mem aos_exits decl_id then Set.add all_rhs_eigen_names decl_id
else
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
| true -> Set.add all_rhs_eigen_names decl_id
| false -> Set.Poly.empty)
(* All other statements do not need logic here*)
| _ -> Set.Poly.empty

Expand Down Expand Up @@ -453,7 +467,7 @@ let rec modify_kind ?force_demotion:(force = false)
(Fun_kind.StanLib (name, sfx, Mem_pattern.AoS), exprs')
else
( Fun_kind.StanLib (name, sfx, SoA)
, List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs
, List.map ~f:(modify_expr ~force_demotion:false modifiable_set) exprs
)
| UserDefined _ as udf ->
(udf, List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs)
Expand Down Expand Up @@ -543,12 +557,56 @@ let rec modify_stmt_pattern
let mod_stmt stmt = modify_stmt stmt modifiable_set in
match pattern with
| Stmt.Fixed.Pattern.Decl
({decl_id; decl_type= Type.Sized sized_type; _} as decl) ->
{ decl_id
; decl_adtype
; decl_type= Type.Sized sized_type
; initialize=
Assign
({ pattern= FunApp (CompilerInternal (FnReadParam read_param), args)
; _ } as assigner) } ->
let name = decl_id in
if Set.mem modifiable_set name then
Stmt.Fixed.Pattern.Decl
{ decl_id
; decl_adtype
; decl_type=
Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type)
; initialize=
Assign
{ assigner with
pattern=
FunApp
( CompilerInternal
(FnReadParam {read_param with mem_pattern= AoS})
, List.map ~f:(mod_expr true) args ) } }
else
Stmt.Fixed.Pattern.Decl
{ decl_id
; decl_adtype
; decl_type=
Type.Sized (SizedType.modify_sizedtype_mem SoA sized_type)
; initialize=
Assign
{ assigner with
pattern=
FunApp
( CompilerInternal
(FnReadParam {read_param with mem_pattern= SoA})
, List.map ~f:(mod_expr false) args ) } }
| Stmt.Fixed.Pattern.Decl
({decl_id; decl_type= Type.Sized sized_type; initialize; _} as decl) ->
if Set.mem modifiable_set decl_id then
let init_expr =
match initialize with
| Stmt.Fixed.Pattern.Assign e ->
Stmt.Fixed.Pattern.Assign (mod_expr false e)
| Default -> Default
| Uninit -> Uninit in
Stmt.Fixed.Pattern.Decl
{ decl with
decl_type=
Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type) }
Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type)
; initialize= init_expr }
else
Decl
{ decl with
Expand Down
11 changes: 9 additions & 2 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ let rec free_vars_stmt (s : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t)
[free_vars_expr e1; free_vars_expr e2; free_vars_stmt b.pattern]
| Profile (_, l) | Block l | SList l ->
Set.Poly.union_list (List.map ~f:(fun s -> free_vars_stmt s.pattern) l)
| Decl {initialize= Assign e; _} -> free_vars_expr e
| Decl _ | Break | Continue | Return None | Skip -> Set.Poly.empty

(** A variation on free_vars_stmt, where we do not recursively count free
Expand All @@ -81,6 +82,7 @@ let top_free_vars_stmt
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
(s : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) =
match s with
| Decl {initialize= Assign e; _} -> free_vars_expr e
| Assignment _ | Return _ | TargetPE _ | JacobianPE _ | NRFunApp _ | Decl _
|Break | Continue | Skip ->
free_vars_stmt
Expand Down Expand Up @@ -472,6 +474,7 @@ let assigned_vars_stmt (s : (Expr.Typed.t, 'a) Stmt.Fixed.Pattern.t) =
match s with
| Assignment (lhs, _, _) ->
Set.Poly.singleton (Middle.Stmt.Helpers.lhs_variable lhs)
| Decl {decl_id; initialize= Assign _; _} -> Set.Poly.singleton decl_id
| TargetPE _ | JacobianPE _ -> Set.Poly.singleton "target"
| NRFunApp
( ( UserDefined (_, (FnTarget | FnJacobian))
Expand Down Expand Up @@ -622,7 +625,9 @@ let used_expressions_expr e = Expr.Typed.Set.singleton e
let rec used_expressions_stmt_help f
(s : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) =
match s with
| TargetPE e | JacobianPE e | Return (Some e) -> f e
| TargetPE e | JacobianPE e | Return (Some e) | Decl {initialize= Assign e; _}
->
f e
| Assignment (l, _, e) -> Set.union (f e) (used_expressions_lval f l)
| IfElse (e, b1, Some b2) ->
Expr.Typed.Set.union_list
Expand Down Expand Up @@ -656,7 +661,9 @@ let used_expressions_stmt = used_expressions_stmt_help used_expressions_expr
let top_used_expressions_stmt_help f
(s : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) =
match s with
| TargetPE e | JacobianPE e | Return (Some e) -> f e
| TargetPE e | JacobianPE e | Return (Some e) | Decl {initialize= Assign e; _}
->
f e
| Assignment (l, _, e) -> Set.union (f e) (used_expressions_lval f l)
| While (e, _) | IfElse (e, _, _) -> f e
| NRFunApp (k, l) ->
Expand Down
20 changes: 15 additions & 5 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ let handle_early_returns (fname : string) opt_var stmt =
{ decl_adtype= DataOnly
; decl_id= returned
; decl_type= Sized SInt
; initialize= true }
; initialize= Default }
; meta= Location_span.empty }
; Stmt.Fixed.
{ pattern=
Expand Down Expand Up @@ -294,7 +294,7 @@ let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e)
(Type.to_unsized decl_type)
; decl_id= inline_return_name
; decl_type
; initialize= false } ]
; initialize= Uninit } ]
(* We should minimize the code that's having its variables
replaced to avoid conflict with the (two) new dummy
variables introduced by inlining *)
Expand Down Expand Up @@ -464,6 +464,10 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.{pattern; meta} =
Block (List.map l ~f:(inline_function_statement propto adt fim))
| SList l ->
SList (List.map l ~f:(inline_function_statement propto adt fim))
| Decl {decl_adtype; decl_id; decl_type; initialize= Assign expr} ->
let d, s, e = inline_function_expression propto adt fim expr in
slist_concat_no_loc (d @ s)
(Decl {decl_adtype; decl_id; decl_type; initialize= Assign e})
| Decl r -> Decl r
| Skip -> Skip
| Break -> Break
Expand Down Expand Up @@ -752,6 +756,9 @@ let dead_code_elimination (mir : Program.Typed.t) =
remove an assignment to a variable
due to side effects. *)
(* TODO: maybe we should revisit that. *)
| Decl ({decl_id; initialize= Assign e; _} as decl) ->
if Set.mem live_variables_s decl_id || cannot_remove_expr e then stmt
else Decl {decl with initialize= Uninit}
| Decl _ | TargetPE _ | JacobianPE _
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip ->
Expand Down Expand Up @@ -828,14 +835,16 @@ let rec find_assignment_idx (name : string) Stmt.Fixed.{pattern; _} =
and unenforce_initialize (lst : Stmt.Located.t list) =
let rec unenforce_initialize_patt (Stmt.Fixed.{pattern; _} as stmt) sub_lst =
match pattern with
| Stmt.Fixed.Pattern.Decl ({decl_id; _} as decl_pat) -> (
| Stmt.Fixed.Pattern.Decl ({decl_id; initialize= Default; _} as decl_pat)
-> (
match List.hd sub_lst with
| Some next_stmt -> (
match find_assignment_idx decl_id next_stmt with
| Some ([] | [Index.All] | [Index.All; Index.All]) ->
{ stmt with
pattern=
Stmt.Fixed.Pattern.Decl {decl_pat with initialize= false} }
Stmt.Fixed.Pattern.Decl {decl_pat with initialize= Uninit}
}
| None | Some _ -> stmt)
| None -> stmt)
| Block block_lst ->
Expand Down Expand Up @@ -972,7 +981,7 @@ let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) =
{ decl_adtype= Expr.Typed.adlevel_of key
; decl_id= data
; decl_type= Type.Unsized (Expr.Typed.type_of key)
; initialize= true }
; initialize= Default }
; meta= Location_span.empty }
:: accum) in
let lazy_code_motion_base i stmt =
Expand Down Expand Up @@ -1003,6 +1012,7 @@ let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) =
let f stmt =
match stmt with
| Stmt.Fixed.Pattern.Assignment ((LVariable x, []), _, e')
|Decl {decl_id= x; initialize= Assign e'; _}
when Map.mem m e'
&& Expr.Typed.equal {e' with pattern= Var x}
(Map.find_exn m e') ->
Expand Down
9 changes: 5 additions & 4 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,8 @@ let create_decl_with_assign decl_id declc decl_type initial_value transform
() } in
let decl =
Stmt.
{ Fixed.pattern= Decl {decl_adtype; decl_id; decl_type; initialize= true}
{ Fixed.pattern=
Decl {decl_adtype; decl_id; decl_type; initialize= Default}
; meta= smeta } in
let rhs_assignment =
Option.map
Expand Down Expand Up @@ -583,7 +584,7 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
{ decl_adtype= Expr.Typed.adlevel_of iteratee'
; decl_id= loopvar.name
; decl_type= Unsized decl_type
; initialize= true } } in
; initialize= Default } } in
let assignment var =
Stmt.Fixed.
{ pattern=
Expand Down Expand Up @@ -629,7 +630,7 @@ and trans_packed_assign loc trans_stmt lvals rhs assign_op =
{ decl_adtype= rhs.emeta.ad_level
; decl_id= sym
; decl_type= Unsized rhs_type
; initialize= false }
; initialize= Uninit }
; meta= rhs.emeta.loc } in
let assign =
{ temp with
Expand Down Expand Up @@ -743,7 +744,7 @@ let rec trans_sizedtype_decl declc tr name st =
{ decl_type= Sized SInt
; decl_id
; decl_adtype= DataOnly
; initialize= true }
; initialize= Default }
; meta= e.meta.loc } in
let assign =
{ Stmt.Fixed.pattern=
Expand Down
18 changes: 13 additions & 5 deletions src/middle/Stmt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ module Fixed = struct
{ decl_adtype: UnsizedType.autodifftype
; decl_id: string
; decl_type: 'a Type.t
; initialize: bool }
; initialize: 'a decl_init }
[@@deriving sexp, hash, map, fold, compare]

and 'e lvalue = 'e lbase * 'e Index.t list
Expand All @@ -34,6 +34,9 @@ module Fixed = struct
and 'e lbase = LVariable of string | LTupleProjection of 'e lvalue * int
[@@deriving sexp, hash, map, compare, fold]

and 'a decl_init = Uninit | Default | Assign of 'a
[@@deriving sexp, hash, map, fold, compare]

let rec pp_lvalue pp_e ppf (lbase, idcs) =
match lbase with
| LVariable v -> Fmt.pf ppf "%s%a" v (Index.pp_indices pp_e) idcs
Expand Down Expand Up @@ -70,9 +73,14 @@ module Fixed = struct
| Block stmts ->
Fmt.pf ppf "{@;<1 2>@[<v>%a@]@;}" Fmt.(list pp_s ~sep:cut) stmts
| SList stmts -> Fmt.(list pp_s ~sep:cut |> vbox) ppf stmts
| Decl {decl_adtype; decl_id; decl_type; _} ->
Fmt.pf ppf "@[<hov 2>%a%a@ %s;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id
| Decl {decl_adtype; decl_id; decl_type; initialize} -> (
match initialize with
| Assign e ->
Fmt.pf ppf "@[<hov 2>%a%a@ %s = %a;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id pp_e e
| Uninit | Default ->
Fmt.pf ppf "@[<hov 2>%a%a@ %s;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id)

include Foldable.Make2 (struct
type nonrec ('a, 'b) t = ('a, 'b) t
Expand Down Expand Up @@ -143,7 +151,7 @@ module Helpers = struct
{ decl_adtype= Expr.Typed.adlevel_of e
; decl_id= sym
; decl_type= Unsized (Expr.Typed.type_of e)
; initialize= true }
; initialize= Default }
; meta= e.meta.loc } in
let assign =
{ decl with
Expand Down
5 changes: 4 additions & 1 deletion src/middle/Stmt.mli
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ module Fixed : sig
{ decl_adtype: UnsizedType.autodifftype
; decl_id: string
; decl_type: 'a Type.t
; initialize: bool }
; initialize: 'a decl_init }
[@@deriving sexp, hash, compare]

and 'e lvalue = 'e lbase * 'e Index.t list
Expand All @@ -32,6 +32,9 @@ module Fixed : sig
and 'e lbase = LVariable of string | LTupleProjection of 'e lvalue * int
[@@deriving sexp, hash, map, compare, fold]

and 'a decl_init = Uninit | Default | Assign of 'a
[@@deriving sexp, hash, map, fold, compare]

include Pattern.S2 with type ('a, 'b) t := ('a, 'b) t
end

Expand Down
Loading

0 comments on commit 4f0d26c

Please sign in to comment.