Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/decl assign params #1441

Merged
merged 17 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,16 @@
Set.Poly.union_list
[ acc; query_expr acc predicate
; query_initial_demotable_stmt true acc body ]
| Decl {decl_type= Type.Sized st; decl_id; _}
when SizedType.is_complex_type st ->
Set.add acc decl_id
| Decl {decl_type= Type.Sized st; decl_id; initialize; _} ->
let complex_name =
match SizedType.is_complex_type st with
| true -> Set.Poly.singleton decl_id
| false -> Set.Poly.empty in
let init_names =
match initialize with
| Assign e -> query_expr acc e
| _ -> Set.Poly.empty in
Set.union acc (Set.union complex_name init_names)
| Skip | Break | Continue | Decl _ -> acc

(** Look through a statement to see whether the objects used in it need to be
Expand Down Expand Up @@ -419,6 +426,13 @@
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
| true -> Set.add all_rhs_eigen_names assign_name
| false -> Set.Poly.empty)
| Decl {decl_id; initialize= Assign e; _} -> (
let all_rhs_eigen_names = query_var_eigen_names e in
if Set.mem aos_exits decl_id then Set.add all_rhs_eigen_names decl_id
else
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
| true -> Set.add all_rhs_eigen_names decl_id

Check warning on line 434 in src/analysis_and_optimization/Memory_patterns.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Memory_patterns.ml#L434

Added line #L434 was not covered by tests
| false -> Set.Poly.empty)
(* All other statements do not need logic here*)
| _ -> Set.Poly.empty

Expand Down Expand Up @@ -453,7 +467,7 @@
(Fun_kind.StanLib (name, sfx, Mem_pattern.AoS), exprs')
else
( Fun_kind.StanLib (name, sfx, SoA)
, List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs
, List.map ~f:(modify_expr ~force_demotion:false modifiable_set) exprs
)
| UserDefined _ as udf ->
(udf, List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs)
Expand Down Expand Up @@ -543,12 +557,56 @@
let mod_stmt stmt = modify_stmt stmt modifiable_set in
match pattern with
| Stmt.Fixed.Pattern.Decl
({decl_id; decl_type= Type.Sized sized_type; _} as decl) ->
{ decl_id
; decl_adtype
; decl_type= Type.Sized sized_type
; initialize=
Assign
({ pattern= FunApp (CompilerInternal (FnReadParam read_param), args)
; _ } as assigner) } ->
let name = decl_id in
if Set.mem modifiable_set name then
Stmt.Fixed.Pattern.Decl
{ decl_id
; decl_adtype
; decl_type=
Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type)
; initialize=
Assign
{ assigner with
pattern=
FunApp
( CompilerInternal
(FnReadParam {read_param with mem_pattern= AoS})
, List.map ~f:(mod_expr true) args ) } }
else
Stmt.Fixed.Pattern.Decl
{ decl_id
; decl_adtype
; decl_type=
Type.Sized (SizedType.modify_sizedtype_mem SoA sized_type)
; initialize=
Assign
{ assigner with
pattern=
FunApp
( CompilerInternal
(FnReadParam {read_param with mem_pattern= SoA})
, List.map ~f:(mod_expr false) args ) } }
| Stmt.Fixed.Pattern.Decl
({decl_id; decl_type= Type.Sized sized_type; initialize; _} as decl) ->
if Set.mem modifiable_set decl_id then
let init_expr =
match initialize with
| Stmt.Fixed.Pattern.Assign e ->
Stmt.Fixed.Pattern.Assign (mod_expr false e)

Check warning on line 602 in src/analysis_and_optimization/Memory_patterns.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Memory_patterns.ml#L601-L602

Added lines #L601 - L602 were not covered by tests
| Default -> Default
| Uninit -> Uninit in
Stmt.Fixed.Pattern.Decl
{ decl with
decl_type=
Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type) }
Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type)
; initialize= init_expr }
else
Decl
{ decl with
Expand Down
10 changes: 8 additions & 2 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
[free_vars_expr e1; free_vars_expr e2; free_vars_stmt b.pattern]
| Profile (_, l) | Block l | SList l ->
Set.Poly.union_list (List.map ~f:(fun s -> free_vars_stmt s.pattern) l)
| Decl {initialize= Assign e; _} -> free_vars_expr e

Check warning on line 76 in src/analysis_and_optimization/Monotone_framework.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Monotone_framework.ml#L76

Added line #L76 was not covered by tests
| Decl _ | Break | Continue | Return None | Skip -> Set.Poly.empty

(** A variation on free_vars_stmt, where we do not recursively count free
Expand All @@ -81,6 +82,7 @@
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
(s : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) =
match s with
| Decl {initialize= Assign e; _} -> free_vars_expr e
| Assignment _ | Return _ | TargetPE _ | JacobianPE _ | NRFunApp _ | Decl _
|Break | Continue | Skip ->
free_vars_stmt
Expand Down Expand Up @@ -622,7 +624,9 @@
let rec used_expressions_stmt_help f
(s : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) =
match s with
| TargetPE e | JacobianPE e | Return (Some e) -> f e
| TargetPE e | JacobianPE e | Return (Some e) | Decl {initialize= Assign e; _}

Check warning on line 627 in src/analysis_and_optimization/Monotone_framework.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Monotone_framework.ml#L627

Added line #L627 was not covered by tests
->
f e
| Assignment (l, _, e) -> Set.union (f e) (used_expressions_lval f l)
| IfElse (e, b1, Some b2) ->
Expr.Typed.Set.union_list
Expand Down Expand Up @@ -656,7 +660,9 @@
let top_used_expressions_stmt_help f
(s : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) =
match s with
| TargetPE e | JacobianPE e | Return (Some e) -> f e
| TargetPE e | JacobianPE e | Return (Some e) | Decl {initialize= Assign e; _}

Check warning on line 663 in src/analysis_and_optimization/Monotone_framework.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Monotone_framework.ml#L663

Added line #L663 was not covered by tests
->
f e
| Assignment (l, _, e) -> Set.union (f e) (used_expressions_lval f l)
| While (e, _) | IfElse (e, _, _) -> f e
| NRFunApp (k, l) ->
Expand Down
17 changes: 13 additions & 4 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
{ decl_adtype= DataOnly
; decl_id= returned
; decl_type= Sized SInt
; initialize= true }
; initialize= Default }
; meta= Location_span.empty }
; Stmt.Fixed.
{ pattern=
Expand Down Expand Up @@ -294,7 +294,7 @@
(Type.to_unsized decl_type)
; decl_id= inline_return_name
; decl_type
; initialize= false } ]
; initialize= Uninit } ]
(* We should minimize the code that's having its variables
replaced to avoid conflict with the (two) new dummy
variables introduced by inlining *)
Expand Down Expand Up @@ -464,6 +464,10 @@
Block (List.map l ~f:(inline_function_statement propto adt fim))
| SList l ->
SList (List.map l ~f:(inline_function_statement propto adt fim))
| Decl {decl_adtype; decl_id; decl_type; initialize= Assign expr} ->
let d, s, e = inline_function_expression propto adt fim expr in
slist_concat_no_loc (d @ s)
(Decl {decl_adtype; decl_id; decl_type; initialize= Assign e})
| Decl r -> Decl r
| Skip -> Skip
| Break -> Break
Expand Down Expand Up @@ -748,6 +752,9 @@
(Middle.Stmt.Helpers.lhs_indices lhs)
then stmt
else Skip
| Decl {decl_id; initialize= Assign e; _} ->
if Set.mem live_variables_s decl_id || cannot_remove_expr e then stmt
else Skip

Check warning on line 757 in src/analysis_and_optimization/Optimize.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Optimize.ml#L756-L757

Added lines #L756 - L757 were not covered by tests
(* NOTE: we never get rid of declarations as we might not be able to
remove an assignment to a variable
due to side effects. *)
Expand Down Expand Up @@ -835,7 +842,8 @@
| Some ([] | [Index.All] | [Index.All; Index.All]) ->
{ stmt with
pattern=
Stmt.Fixed.Pattern.Decl {decl_pat with initialize= false} }
Stmt.Fixed.Pattern.Decl {decl_pat with initialize= Uninit}
}
| None | Some _ -> stmt)
| None -> stmt)
| Block block_lst ->
Expand Down Expand Up @@ -972,7 +980,7 @@
{ decl_adtype= Expr.Typed.adlevel_of key
; decl_id= data
; decl_type= Type.Unsized (Expr.Typed.type_of key)
; initialize= true }
; initialize= Default }
; meta= Location_span.empty }
:: accum) in
let lazy_code_motion_base i stmt =
Expand Down Expand Up @@ -1003,6 +1011,7 @@
let f stmt =
match stmt with
| Stmt.Fixed.Pattern.Assignment ((LVariable x, []), _, e')
|Decl {decl_id= x; initialize= Assign e'; _}
when Map.mem m e'
&& Expr.Typed.equal {e' with pattern= Var x}
(Map.find_exn m e') ->
Expand Down
9 changes: 5 additions & 4 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,8 @@ let create_decl_with_assign decl_id declc decl_type initial_value transform
() } in
let decl =
Stmt.
{ Fixed.pattern= Decl {decl_adtype; decl_id; decl_type; initialize= true}
{ Fixed.pattern=
Decl {decl_adtype; decl_id; decl_type; initialize= Default}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should use the logic on the next line which inspects the rhs_assignment to set Initialize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this code it's not modifying the params. Making changes in places that do no modify params is out of the scope of this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good -- I think we would really want to fix #1295 before doing this -- probably by writing a generic-ish check_size function in Stan which we can call after the assignment is completed.

; meta= smeta } in
let rhs_assignment =
Option.map
Expand Down Expand Up @@ -583,7 +584,7 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
{ decl_adtype= Expr.Typed.adlevel_of iteratee'
; decl_id= loopvar.name
; decl_type= Unsized decl_type
; initialize= true } } in
; initialize= Default } } in
let assignment var =
Stmt.Fixed.
{ pattern=
Expand Down Expand Up @@ -629,7 +630,7 @@ and trans_packed_assign loc trans_stmt lvals rhs assign_op =
{ decl_adtype= rhs.emeta.ad_level
; decl_id= sym
; decl_type= Unsized rhs_type
; initialize= false }
; initialize= Uninit }
; meta= rhs.emeta.loc } in
let assign =
{ temp with
Expand Down Expand Up @@ -743,7 +744,7 @@ let rec trans_sizedtype_decl declc tr name st =
{ decl_type= Sized SInt
; decl_id
; decl_adtype= DataOnly
; initialize= true }
; initialize= Default }
; meta= e.meta.loc } in
let assign =
{ Stmt.Fixed.pattern=
Expand Down
18 changes: 13 additions & 5 deletions src/middle/Stmt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
{ decl_adtype: UnsizedType.autodifftype
; decl_id: string
; decl_type: 'a Type.t
; initialize: bool }
; initialize: 'a decl_init }

Check warning on line 28 in src/middle/Stmt.ml

View check run for this annotation

Codecov / codecov/patch

src/middle/Stmt.ml#L28

Added line #L28 was not covered by tests
[@@deriving sexp, hash, map, fold, compare]

and 'e lvalue = 'e lbase * 'e Index.t list
Expand All @@ -34,6 +34,9 @@
and 'e lbase = LVariable of string | LTupleProjection of 'e lvalue * int
[@@deriving sexp, hash, map, compare, fold]

and 'a decl_init = Uninit | Default | Assign of 'a

Check warning on line 37 in src/middle/Stmt.ml

View check run for this annotation

Codecov / codecov/patch

src/middle/Stmt.ml#L37

Added line #L37 was not covered by tests
[@@deriving sexp, hash, map, fold, compare]

let rec pp_lvalue pp_e ppf (lbase, idcs) =
match lbase with
| LVariable v -> Fmt.pf ppf "%s%a" v (Index.pp_indices pp_e) idcs
Expand Down Expand Up @@ -70,9 +73,14 @@
| Block stmts ->
Fmt.pf ppf "{@;<1 2>@[<v>%a@]@;}" Fmt.(list pp_s ~sep:cut) stmts
| SList stmts -> Fmt.(list pp_s ~sep:cut |> vbox) ppf stmts
| Decl {decl_adtype; decl_id; decl_type; _} ->
Fmt.pf ppf "@[<hov 2>%a%a@ %s;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id
| Decl {decl_adtype; decl_id; decl_type; initialize} -> (
match initialize with
| Assign e ->

Check warning on line 78 in src/middle/Stmt.ml

View check run for this annotation

Codecov / codecov/patch

src/middle/Stmt.ml#L78

Added line #L78 was not covered by tests
Fmt.pf ppf "@[<hov 2>%a%a@ %s = %a;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id pp_e e

Check warning on line 80 in src/middle/Stmt.ml

View check run for this annotation

Codecov / codecov/patch

src/middle/Stmt.ml#L80

Added line #L80 was not covered by tests
| Uninit | Default ->
Fmt.pf ppf "@[<hov 2>%a%a@ %s;@]" UnsizedType.pp_autodifftype
decl_adtype (Type.pp pp_e) decl_type decl_id)

include Foldable.Make2 (struct
type nonrec ('a, 'b) t = ('a, 'b) t
Expand Down Expand Up @@ -143,7 +151,7 @@
{ decl_adtype= Expr.Typed.adlevel_of e
; decl_id= sym
; decl_type= Unsized (Expr.Typed.type_of e)
; initialize= true }
; initialize= Default }
; meta= e.meta.loc } in
let assign =
{ decl with
Expand Down
5 changes: 4 additions & 1 deletion src/middle/Stmt.mli
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ module Fixed : sig
{ decl_adtype: UnsizedType.autodifftype
; decl_id: string
; decl_type: 'a Type.t
; initialize: bool }
; initialize: 'a decl_init }
[@@deriving sexp, hash, compare]

and 'e lvalue = 'e lbase * 'e Index.t list
Expand All @@ -32,6 +32,9 @@ module Fixed : sig
and 'e lbase = LVariable of string | LTupleProjection of 'e lvalue * int
[@@deriving sexp, hash, map, compare, fold]

and 'a decl_init = Uninit | Default | Assign of 'a
[@@deriving sexp, hash, map, fold, compare]

include Pattern.S2 with type ('a, 'b) t := ('a, 'b) t
end

Expand Down
36 changes: 26 additions & 10 deletions src/stan_math_backend/Lower_stmt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,12 @@ let rec initialize_value st adtype =
(adtype : UnsizedType.autodifftype)]

(*Initialize an object of a given size.*)
let lower_assign_sized st adtype initialize =
if initialize then Some (initialize_value st adtype) else None
let lower_assign_sized st adtype (initialize : 'a Stmt.Fixed.Pattern.decl_init)
=
match initialize with
| Assign e -> Some (lower_expr e)
| Default -> Some (initialize_value st adtype)
| Uninit -> None

let lower_unsized_decl name ut adtype =
let type_ =
Expand All @@ -103,17 +107,24 @@ let lower_unsized_decl name ut adtype =
| true, _ -> TypeLiteral "matrix_cl<double>" in
make_variable_defn ~type_ ~name ()

let lower_possibly_opencl_decl name st adtype =
let lower_possibly_opencl_decl name st adtype
(initialize : 'a Stmt.Fixed.Pattern.decl_init) =
let ut = SizedType.to_unsized st in
let mem_pattern = SizedType.get_mem_pattern st in
match (Transform_Mir.is_opencl_var name, ut) with
| _, UnsizedType.(UInt | UReal) | false, _ ->
lower_possibly_var_decl adtype ut mem_pattern
| _, UnsizedType.(UInt | UReal) | false, _ -> (
match initialize with
| Assign
Expr.Fixed.
{ pattern= FunApp (CompilerInternal (Internal_fun.FnReadParam _), _)
; _ } ->
Auto
| _ -> lower_possibly_var_decl adtype ut mem_pattern)
| true, UArray UInt -> TypeLiteral "matrix_cl<int>"
| true, _ -> TypeLiteral "matrix_cl<double>"

let lower_sized_decl name st adtype initialize =
let type_ = lower_possibly_opencl_decl name st adtype in
let type_ = lower_possibly_opencl_decl name st adtype initialize in
let init =
lower_assign_sized st adtype initialize
|> Option.value_map ~default:Uninitialized ~f:(fun i -> Assignment i) in
Expand Down Expand Up @@ -320,7 +331,7 @@ let rec lower_statement Stmt.Fixed.{pattern; meta} : stmt list =
| Return e -> [Return (Option.map ~f:lower_expr e)]
| Block ls -> [Stmts.block (lower_statements ls)]
| SList ls -> lower_statements ls
| Decl {decl_adtype; decl_id; decl_type; initialize; _} ->
| Decl {decl_adtype; decl_id; decl_type; initialize} ->
[lower_decl decl_id decl_type decl_adtype initialize]
| Profile (name, ls) -> [lower_profile name (lower_statements ls)]

Expand All @@ -333,17 +344,22 @@ module Testing = struct
(Fmt.option Cpp.Printing.pp_expr)
(lower_assign_sized
(SArray (SArray (SMatrix (AoS, int 2, int 3), int 4), int 5))
DataOnly false)
DataOnly Stmt.Fixed.Pattern.Default)
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
|> print_endline;
[%expect {| |}]
[%expect
{|
std::vector<std::vector<Eigen::Matrix<double,-1,-1>>>(5,
std::vector<Eigen::Matrix<double,-1,-1>>(4,
Eigen::Matrix<double,-1,-1>::Constant(2, 3,
std::numeric_limits<double>::quiet_NaN()))) |}]

let%expect_test "set size mat array" =
let int = Expr.Helpers.int in
Fmt.str "@[<v>%a@]"
(Fmt.option Cpp.Printing.pp_expr)
(lower_assign_sized
(SArray (SArray (SMatrix (AoS, int 2, int 3), int 4), int 5))
DataOnly true)
DataOnly Stmt.Fixed.Pattern.Default)
|> print_endline;
[%expect
{|
Expand Down
Loading