Skip to content

Commit

Permalink
Merge pull request #1281 from stan-dev/fix/copy-prop-profile-blocks
Browse files Browse the repository at this point in the history
Consistently treat profile blocks as blocks in optimizer
  • Loading branch information
WardBrian authored Jan 27, 2023
2 parents ca2a7bf + ff18775 commit bd516d5
Show file tree
Hide file tree
Showing 7 changed files with 1,107 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/analysis_and_optimization/Dataflow_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ let build_cf_graphs ?(flatten_loops = false) ?(blocks_after_body = true)
is visited after substatements *)
let substmt_preds =
match stmt with
| Block _ when blocks_after_body -> in_state.exits
| (Block _ | Profile _) when blocks_after_body -> in_state.exits
| _ -> Set.Poly.singleton label in
(* The accumulated state after traversing substatements *)
let substmt_state_unlooped, substmt_map =
Expand Down Expand Up @@ -174,7 +174,7 @@ let build_cf_graphs ?(flatten_loops = false) ?(blocks_after_body = true)
Set.Poly.diff substmt_state_unlooped.breaks in_state.breaks ]
in
({substmt_state_unlooped with exits= loop_exits}, loop_predecessors)
| Block _ when blocks_after_body ->
| (Block _ | Profile _) when blocks_after_body ->
(* Block statements are preceded by the natural exit points of the block
body *)
let block_predecessors = substmt_state_unlooped.exits in
Expand Down
3 changes: 2 additions & 1 deletion src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ let rec var_declarations Stmt.Fixed.{pattern; _} : string Set.Poly.t =
| IfElse (_, s, None) | While (_, s) | For {body= s; _} -> var_declarations s
| IfElse (_, s1, Some s2) ->
Set.Poly.union (var_declarations s1) (var_declarations s2)
| Block slist | SList slist ->
| Block slist | SList slist | Profile (_, slist) ->
Set.Poly.union_list (List.map ~f:var_declarations slist)
| _ -> Set.Poly.empty

Expand Down Expand Up @@ -431,6 +431,7 @@ let cleanup_empty_stmts stmts =
let is_decl = function {pattern= Decl _; _} -> true | _ -> false in
let flatten_block s =
match s.pattern with
(* NB: Does not include Profile since we don't want to remove those blocks *)
| SList ls | Block ls ->
if List.for_all ~f:(Fn.non is_decl) ls then ls else [s]
| _ -> [s] in
Expand Down
6 changes: 4 additions & 2 deletions src/middle/Stmt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ module Fixed = struct
| For {loopvar; lower; upper; body} ->
Fmt.pf ppf "for(%s in %a:%a) %a" loopvar pp_e lower pp_e upper pp_s
body
| Profile (_, stmts) ->
Fmt.pf ppf "{@;<1 2>@[<v>%a@]@;}" Fmt.(list pp_s ~sep:cut) stmts
| Profile (name, stmts) ->
Fmt.pf ppf "profile(%s){@;<1 2>@[<v>%a@]@;}" name
Fmt.(list pp_s ~sep:cut)
stmts
| Block stmts ->
Fmt.pf ppf "{@;<1 2>@[<v>%a@]@;}" Fmt.(list pp_s ~sep:cut) stmts
| SList stmts -> Fmt.(list pp_s ~sep:cut |> vbox) ppf stmts
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
data {
int<lower=0> N;
}
parameters {
matrix[N,N] X;
}

model {
row_vector[N] vec;

profile ("test") {
row_vector[N] vec2 = columns_dot_self(X);
vec = vec2;
}

target += sum(vec);

row_vector[N] vec3;
{
row_vector[N] vec4 = columns_dot_self(X);
vec3 = vec4;
}

target += sum(vec3);

}
Loading

0 comments on commit bd516d5

Please sign in to comment.