Skip to content

Commit

Permalink
Merge pull request #1442 from stan-dev/feature/stochastic-row-col
Browse files Browse the repository at this point in the history
Add stochastic row and column matrix types
  • Loading branch information
WardBrian authored Aug 5, 2024
2 parents 4f0d26c + 4247ad4 commit 0ea83bb
Show file tree
Hide file tree
Showing 17 changed files with 1,441 additions and 307 deletions.
3 changes: 2 additions & 1 deletion src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ let trans_bounds_values (trans : Expr.Typed.t Transformation.t) : bound_values =
| Upper upper -> {lower= `None; upper= bound_value upper}
| LowerUpper (lower, upper) ->
{lower= bound_value lower; upper= bound_value upper}
| Simplex -> {lower= `Lit 0.; upper= `Lit 1.}
| Simplex | StochasticColumn | StochasticRow ->
{lower= `Lit 0.; upper= `Lit 1.}
| PositiveOrdered -> {lower= `Lit 0.; upper= `None}
| UnitVector -> {lower= `Lit (-1.); upper= `Lit 1.}
| CholeskyCorr | CholeskyCov | Correlation | Covariance | Ordered | Offset _
Expand Down
19 changes: 17 additions & 2 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ let check_transform_shape decl_id decl_var meta = function
same_shape decl_id decl_var "lower" e1 meta
@ same_shape decl_id decl_var "upper" e2 meta
| Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered
|PositiveOrdered | Simplex | UnitVector | Identity | TupleTransformation _ ->
|PositiveOrdered | Simplex | UnitVector | Identity | TupleTransformation _
|StochasticRow | StochasticColumn ->
[]

let copy_indices indexed (var : Expr.Typed.t) =
Expand All @@ -294,7 +295,8 @@ let extract_transform_args var = function
| LowerUpper (a1, a2) | OffsetMultiplier (a1, a2) ->
[copy_indices var a1; copy_indices var a2]
| Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered
|PositiveOrdered | Simplex | UnitVector | Identity | TupleTransformation _ ->
|PositiveOrdered | Simplex | UnitVector | Identity | TupleTransformation _
|StochasticRow | StochasticColumn ->
[]

let rec param_size transform sizedtype =
Expand All @@ -321,6 +323,17 @@ let rec param_size transform sizedtype =
let k_choose_2 k =
Expr.Helpers.(binop (binop k Times (binop k Minus (int 1))) Divide (int 2))
in
let rec stoch_size f1 f2 st =
match st with
| SizedType.SMatrix (mem_pattern, d1, d2) ->
SizedType.SMatrix (mem_pattern, f1 d1, f2 d2)
| SArray (t, d) -> SizedType.SArray (stoch_size f1 f2 t, d)
| SInt | SReal | SComplex | SRowVector _ | SVector _ | STuple _
|SComplexRowVector _ | SComplexVector _ | SComplexMatrix _ ->
Common.ICE.internal_compiler_error
[%message "Expecting SMatrix, got " (st : Expr.Typed.t SizedType.t)]
in
let min_one d = Expr.Helpers.(binop d Minus (int 1)) in
match transform with
| Transformation.Identity | Lower _ | Upper _
|LowerUpper (_, _)
Expand All @@ -339,6 +352,8 @@ let rec param_size transform sizedtype =
| Simplex ->
shrink_eigen (fun d -> Expr.Helpers.(binop d Minus (int 1))) sizedtype
| CholeskyCorr | Correlation -> shrink_eigen k_choose_2 sizedtype
| StochasticRow -> stoch_size Fn.id min_one sizedtype
| StochasticColumn -> stoch_size min_one Fn.id sizedtype
| CholeskyCov ->
(* (N * (N + 1)) / 2 + (M - N) * N *)
shrink_eigen_mat
Expand Down
6 changes: 5 additions & 1 deletion src/frontend/Pretty_printing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ let pp_bracketed_transform ppf = function
pf ppf "<@[offset=%a,@ multiplier=%a@]>" pp_expression e1 pp_expression e2
| Identity | Ordered | PositiveOrdered | Simplex | UnitVector | CholeskyCorr
|CholeskyCov | Correlation | Covariance | TupleTransformation _
(* tuple transformations are handled in pp_transformed_type *) ->
|StochasticColumn
|StochasticRow (* tuple transformations are handled in pp_transformed_type *)
->
()

let rec pp_transformed_type ppf (st, trans) =
Expand Down Expand Up @@ -362,6 +364,8 @@ let rec pp_transformed_type ppf (st, trans) =
| CholeskyCov -> pf ppf "cholesky_factor_cov%a" cov_sizes_fmt ()
| Correlation -> pf ppf "corr_matrix%a" cov_sizes_fmt ()
| Covariance -> pf ppf "cov_matrix%a" cov_sizes_fmt ()
| StochasticColumn -> pf ppf "column_stochastic_matrix%a" sizes_fmt ()
| StochasticRow -> pf ppf "row_stochastic_matrix%a" sizes_fmt ()
| TupleTransformation transforms ->
(* NB this calls the top-level function to handle internal arrays etc *)
let transTypes = Middle.Utils.zip_stuple_trans_exn st transforms in
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,8 @@ and check_transformation cf tenv ut trans =
| CholeskyCov -> CholeskyCov
| Correlation -> Correlation
| Covariance -> Covariance
| StochasticColumn -> StochasticColumn
| StochasticRow -> StochasticRow
| TupleTransformation tms ->
let typesTrans = Utils.zip_utuple_trans_exn ut tms in
let tes =
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/lexer.mll
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ rule token = parse
Parser.CHOLESKYFACTORCOV }
| "corr_matrix" { lexer_logger "corr_matrix" ; Parser.CORRMATRIX }
| "cov_matrix" { lexer_logger "cov_matrix" ; Parser.COVMATRIX }
| "column_stochastic_matrix"{ lexer_logger "column_stochastic_matrix" ; Parser.STOCHASTICCOLUMNMATRIX }
| "row_stochastic_matrix" { lexer_logger "row_stochastic_matrix" ; Parser.STOCHASTICROWMATRIX }
(* Transformation keywords *)
| "lower" { lexer_logger "lower" ; Parser.LOWER }
| "upper" { lexer_logger "upper" ; Parser.UPPER }
Expand Down
Loading

0 comments on commit 0ea83bb

Please sign in to comment.