Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/decl assign params #1441

Merged
merged 17 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,16 @@
Set.Poly.union_list
[ acc; query_expr acc predicate
; query_initial_demotable_stmt true acc body ]
| Decl {decl_type= Type.Sized st; decl_id; _}
when SizedType.is_complex_type st ->
Set.add acc decl_id
| Decl {decl_type= Type.Sized st; decl_id; initialize; _} ->
let complex_name =
match SizedType.is_complex_type st with
| true -> Set.Poly.singleton decl_id
| false -> Set.Poly.empty in
let init_names =
match initialize with
| Assign e -> query_expr acc e
| _ -> Set.Poly.empty in
Set.union acc (Set.union complex_name init_names)
| Skip | Break | Continue | Decl _ -> acc

(** Look through a statement to see whether the objects used in it need to be
Expand Down Expand Up @@ -419,6 +426,13 @@
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
| true -> Set.add all_rhs_eigen_names assign_name
| false -> Set.Poly.empty)
| Decl {decl_id; initialize= Assign e; _} -> (
let all_rhs_eigen_names = query_var_eigen_names e in
if Set.mem aos_exits decl_id then Set.add all_rhs_eigen_names decl_id
else
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
| true -> Set.add all_rhs_eigen_names decl_id

Check warning on line 434 in src/analysis_and_optimization/Memory_patterns.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Memory_patterns.ml#L434

Added line #L434 was not covered by tests
| false -> Set.Poly.empty)
(* All other statements do not need logic here*)
| _ -> Set.Poly.empty

Expand Down Expand Up @@ -453,7 +467,7 @@
(Fun_kind.StanLib (name, sfx, Mem_pattern.AoS), exprs')
else
( Fun_kind.StanLib (name, sfx, SoA)
, List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs
, List.map ~f:(modify_expr ~force_demotion:false modifiable_set) exprs
)
| UserDefined _ as udf ->
(udf, List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs)
Expand Down Expand Up @@ -543,12 +557,56 @@
let mod_stmt stmt = modify_stmt stmt modifiable_set in
match pattern with
| Stmt.Fixed.Pattern.Decl
({decl_id; decl_type= Type.Sized sized_type; _} as decl) ->
{ decl_id
; decl_adtype
; decl_type= Type.Sized sized_type
; initialize=
Assign
({ pattern= FunApp (CompilerInternal (FnReadParam read_param), args)
; _ } as assigner) } ->
let name = decl_id in
if Set.mem modifiable_set name then
Stmt.Fixed.Pattern.Decl
{ decl_id
; decl_adtype
; decl_type=
Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type)
; initialize=
Assign
{ assigner with
pattern=
FunApp
( CompilerInternal
(FnReadParam {read_param with mem_pattern= AoS})
, List.map ~f:(mod_expr true) args ) } }
else
Stmt.Fixed.Pattern.Decl
{ decl_id
; decl_adtype
; decl_type=
Type.Sized (SizedType.modify_sizedtype_mem SoA sized_type)
; initialize=
Assign
{ assigner with
pattern=
FunApp
( CompilerInternal
(FnReadParam {read_param with mem_pattern= SoA})
, List.map ~f:(mod_expr false) args ) } }
| Stmt.Fixed.Pattern.Decl
({decl_id; decl_type= Type.Sized sized_type; initialize; _} as decl) ->
if Set.mem modifiable_set decl_id then
let init_expr =
match initialize with
| Stmt.Fixed.Pattern.Assign e ->
Stmt.Fixed.Pattern.Assign (mod_expr false e)

Check warning on line 602 in src/analysis_and_optimization/Memory_patterns.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Memory_patterns.ml#L601-L602

Added lines #L601 - L602 were not covered by tests
| Default -> Default
| Uninit -> Uninit in
Stmt.Fixed.Pattern.Decl
{ decl with
decl_type=
Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type) }
Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type)
; initialize= init_expr }
else
Decl
{ decl with
Expand Down
11 changes: 9 additions & 2 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
[free_vars_expr e1; free_vars_expr e2; free_vars_stmt b.pattern]
| Profile (_, l) | Block l | SList l ->
Set.Poly.union_list (List.map ~f:(fun s -> free_vars_stmt s.pattern) l)
| Decl {initialize= Assign e; _} -> free_vars_expr e

Check warning on line 76 in src/analysis_and_optimization/Monotone_framework.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Monotone_framework.ml#L76

Added line #L76 was not covered by tests
| Decl _ | Break | Continue | Return None | Skip -> Set.Poly.empty

(** A variation on free_vars_stmt, where we do not recursively count free
Expand All @@ -81,6 +82,7 @@
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
(s : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) =
match s with
| Decl {initialize= Assign e; _} -> free_vars_expr e
| Assignment _ | Return _ | TargetPE _ | JacobianPE _ | NRFunApp _ | Decl _
|Break | Continue | Skip ->
free_vars_stmt
Expand Down Expand Up @@ -472,6 +474,7 @@
match s with
| Assignment (lhs, _, _) ->
Set.Poly.singleton (Middle.Stmt.Helpers.lhs_variable lhs)
| Decl {decl_id; initialize= Assign _; _} -> Set.Poly.singleton decl_id
| TargetPE _ | JacobianPE _ -> Set.Poly.singleton "target"
| NRFunApp
( ( UserDefined (_, (FnTarget | FnJacobian))
Expand Down Expand Up @@ -622,7 +625,9 @@
let rec used_expressions_stmt_help f
(s : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) =
match s with
| TargetPE e | JacobianPE e | Return (Some e) -> f e
| TargetPE e | JacobianPE e | Return (Some e) | Decl {initialize= Assign e; _}

Check warning on line 628 in src/analysis_and_optimization/Monotone_framework.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Monotone_framework.ml#L628

Added line #L628 was not covered by tests
->
f e
| Assignment (l, _, e) -> Set.union (f e) (used_expressions_lval f l)
| IfElse (e, b1, Some b2) ->
Expr.Typed.Set.union_list
Expand Down Expand Up @@ -656,7 +661,9 @@
let top_used_expressions_stmt_help f
(s : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) =
match s with
| TargetPE e | JacobianPE e | Return (Some e) -> f e
| TargetPE e | JacobianPE e | Return (Some e) | Decl {initialize= Assign e; _}

Check warning on line 664 in src/analysis_and_optimization/Monotone_framework.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Monotone_framework.ml#L664

Added line #L664 was not covered by tests
->
f e
| Assignment (l, _, e) -> Set.union (f e) (used_expressions_lval f l)
| While (e, _) | IfElse (e, _, _) -> f e
| NRFunApp (k, l) ->
Expand Down
20 changes: 15 additions & 5 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
{ decl_adtype= DataOnly
; decl_id= returned
; decl_type= Sized SInt
; initialize= true }
; initialize= Default }
; meta= Location_span.empty }
; Stmt.Fixed.
{ pattern=
Expand Down Expand Up @@ -294,7 +294,7 @@
(Type.to_unsized decl_type)
; decl_id= inline_return_name
; decl_type
; initialize= false } ]
; initialize= Uninit } ]
(* We should minimize the code that's having its variables
replaced to avoid conflict with the (two) new dummy
variables introduced by inlining *)
Expand Down Expand Up @@ -464,6 +464,10 @@
Block (List.map l ~f:(inline_function_statement propto adt fim))
| SList l ->
SList (List.map l ~f:(inline_function_statement propto adt fim))
| Decl {decl_adtype; decl_id; decl_type; initialize= Assign expr} ->
let d, s, e = inline_function_expression propto adt fim expr in
slist_concat_no_loc (d @ s)
(Decl {decl_adtype; decl_id; decl_type; initialize= Assign e})
| Decl r -> Decl r
| Skip -> Skip
| Break -> Break
Expand Down Expand Up @@ -752,6 +756,9 @@
remove an assignment to a variable
due to side effects. *)
(* TODO: maybe we should revisit that. *)
| Decl ({decl_id; initialize= Assign e; _} as decl) ->
if Set.mem live_variables_s decl_id || cannot_remove_expr e then stmt
else Decl {decl with initialize= Uninit}

Check warning on line 761 in src/analysis_and_optimization/Optimize.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Optimize.ml#L760-L761

Added lines #L760 - L761 were not covered by tests
| Decl _ | TargetPE _ | JacobianPE _
|NRFunApp (_, _)
|Break | Continue | Return _ | Skip ->
Expand Down Expand Up @@ -828,14 +835,16 @@
and unenforce_initialize (lst : Stmt.Located.t list) =
let rec unenforce_initialize_patt (Stmt.Fixed.{pattern; _} as stmt) sub_lst =
match pattern with
| Stmt.Fixed.Pattern.Decl ({decl_id; _} as decl_pat) -> (
| Stmt.Fixed.Pattern.Decl ({decl_id; initialize= Default; _} as decl_pat)
-> (
match List.hd sub_lst with
| Some next_stmt -> (
match find_assignment_idx decl_id next_stmt with
| Some ([] | [Index.All] | [Index.All; Index.All]) ->
{ stmt with
pattern=
Stmt.Fixed.Pattern.Decl {decl_pat with initialize= false} }
Stmt.Fixed.Pattern.Decl {decl_pat with initialize= Uninit}
}
| None | Some _ -> stmt)
| None -> stmt)
| Block block_lst ->
Expand Down Expand Up @@ -972,7 +981,7 @@
{ decl_adtype= Expr.Typed.adlevel_of key
; decl_id= data
; decl_type= Type.Unsized (Expr.Typed.type_of key)
; initialize= true }
; initialize= Default }
; meta= Location_span.empty }
:: accum) in
let lazy_code_motion_base i stmt =
Expand Down Expand Up @@ -1003,6 +1012,7 @@
let f stmt =
match stmt with
| Stmt.Fixed.Pattern.Assignment ((LVariable x, []), _, e')
|Decl {decl_id= x; initialize= Assign e'; _}
when Map.mem m e'
&& Expr.Typed.equal {e' with pattern= Var x}
(Map.find_exn m e') ->
Expand Down
9 changes: 5 additions & 4 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,8 @@ let create_decl_with_assign decl_id declc decl_type initial_value transform
() } in
let decl =
Stmt.
{ Fixed.pattern= Decl {decl_adtype; decl_id; decl_type; initialize= true}
{ Fixed.pattern=
Decl {decl_adtype; decl_id; decl_type; initialize= Default}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should use the logic on the next line which inspects the rhs_assignment to set Initialize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this code it's not modifying the params. Making changes in places that do no modify params is out of the scope of this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good -- I think we would really want to fix #1295 before doing this -- probably by writing a generic-ish check_size function in Stan which we can call after the assignment is completed.

; meta= smeta } in
let rhs_assignment =
Option.map
Expand Down Expand Up @@ -583,7 +584,7 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
{ decl_adtype= Expr.Typed.adlevel_of iteratee'
; decl_id= loopvar.name
; decl_type= Unsized decl_type
; initialize= true } } in
; initialize= Default } } in
let assignment var =
Stmt.Fixed.
{ pattern=
Expand Down Expand Up @@ -629,7 +630,7 @@ and trans_packed_assign loc trans_stmt lvals rhs assign_op =
{ decl_adtype= rhs.emeta.ad_level
; decl_id= sym
; decl_type= Unsized rhs_type
; initialize= false }
; initialize= Uninit }
; meta= rhs.emeta.loc } in
let assign =
{ temp with
Expand Down Expand Up @@ -743,7 +744,7 @@ let rec trans_sizedtype_decl declc tr name st =
{ decl_type= Sized SInt
; decl_id
; decl_adtype= DataOnly
; initialize= true }
; initialize= Default }
; meta= e.meta.loc } in
let assign =
{ Stmt.Fixed.pattern=
Expand Down
18 changes: 13 additions & 5 deletions src/middle/Stmt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
{ decl_adtype: UnsizedType.autodifftype
; decl_id: string
; decl_type: 'a Type.t
; initialize: bool }
; initialize: 'a decl_init }

Check warning on line 28 in src/middle/Stmt.ml

View check run for this annotation

Codecov / codecov/patch

src/middle/Stmt.ml#L28

Added line #L28 was not covered by tests
[@@deriving sexp, hash, map, fold, compare]

and 'e lvalue = 'e lbase * 'e Index.t list
Expand All @@ -34,6 +34,9 @@
and 'e lbase = LVariable of string | LTupleProjection of 'e lvalue * int
[@@deriving sexp, hash, map, compare, fold]

and 'a decl_init = Uninit | Default | Assign of 'a

Check warning on line 37 in src/middle/Stmt.ml

View check run for this annotation

Codecov / codecov/patch

src/middle/Stmt.ml#L37

Added line #L37 was not covered by tests
[@@deriving sexp, hash, map, fold, compare]

let rec pp_lvalue pp_e ppf (lbase, idcs) =
match lbase with
| LVariable v -> Fmt.pf ppf "%s%a" v (Index.pp_indices pp_e) idcs
Expand Down Expand Up @@ -70,9 +73,14 @@
| Block stmts ->
Fmt.pf ppf "{@;<1 2>@[<v>%a@]@;}" Fmt.(list pp_s ~sep:cut) stmts
| SList stmts -> Fmt.(list pp_s ~sep:cut |> vbox) ppf stmts
| Decl {decl_adtype; decl_id; decl_type; _} ->
Fmt.pf ppf "@[<hov 2>%a%a@ %s;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id
| Decl {decl_adtype; decl_id; decl_type; initialize} -> (
match initialize with
| Assign e ->

Check warning on line 78 in src/middle/Stmt.ml

View check run for this annotation

Codecov / codecov/patch

src/middle/Stmt.ml#L78

Added line #L78 was not covered by tests
Fmt.pf ppf "@[<hov 2>%a%a@ %s = %a;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id pp_e e

Check warning on line 80 in src/middle/Stmt.ml

View check run for this annotation

Codecov / codecov/patch

src/middle/Stmt.ml#L80

Added line #L80 was not covered by tests
| Uninit | Default ->
Fmt.pf ppf "@[<hov 2>%a%a@ %s;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id)

include Foldable.Make2 (struct
type nonrec ('a, 'b) t = ('a, 'b) t
Expand Down Expand Up @@ -143,7 +151,7 @@
{ decl_adtype= Expr.Typed.adlevel_of e
; decl_id= sym
; decl_type= Unsized (Expr.Typed.type_of e)
; initialize= true }
; initialize= Default }
; meta= e.meta.loc } in
let assign =
{ decl with
Expand Down
5 changes: 4 additions & 1 deletion src/middle/Stmt.mli
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ module Fixed : sig
{ decl_adtype: UnsizedType.autodifftype
; decl_id: string
; decl_type: 'a Type.t
; initialize: bool }
; initialize: 'a decl_init }
[@@deriving sexp, hash, compare]

and 'e lvalue = 'e lbase * 'e Index.t list
Expand All @@ -32,6 +32,9 @@ module Fixed : sig
and 'e lbase = LVariable of string | LTupleProjection of 'e lvalue * int
[@@deriving sexp, hash, map, compare, fold]

and 'a decl_init = Uninit | Default | Assign of 'a
[@@deriving sexp, hash, map, fold, compare]

include Pattern.S2 with type ('a, 'b) t := ('a, 'b) t
end

Expand Down
Loading