How to handle a log prob for reverse mode? #903
Replies: 2 comments 3 replies
-
Thinking about this more, for an example say we were going to write matrix multiply with two autodiffable matrices. parameters {
matrix[N, N] A;
matrix[N, N] B;
}
transformed parameters {
matrix[N, N] C = A * B;
} The C++ we want to generate is something of the form // place the matrices on the autodiff stack
auto arena_A = stan::math::to_arena(A);
auto arena_B = stan::math::to_arena(B);
// Do the forward pass, promoting the value matrix multiply to var
auto C = to_arena(promote_to_var(multiply(value_of(arena_A), value_of(arena_B))));
// place a reverse pass callback on the callback stack
// passing in the objects we need for adjoint calcs
reverse_pass_callback(
[arena_A, arena_B, C]() mutable {
// Do the adjoint accumulation for A and B
adjoint_of(arena_A) += multiply(adjoint_of(C), transpose(value_of(arena_B)));
adjoint_of(arena_B) += multiply(transpose(value_of(arena_A)), adjoint_of(C));
}
}); In a much more abstract form this can be written like // place the matrices on the autodiff stack
A_Arena = StackAllocStmt A;
B_Arena = StackAllocStmt B;
// Do the forward pass, promoting the value matrix multiply to var
Return = ForwardPass (A_Arena, B_Arena);
// place a reverse pass callback on the callback stack
// passing in the objects we need for adjoint calcs
ReverseCallBack(
[A_Arena, B_Arena, Return],
Input1Adjoint,
Input2Adjoint
) I think for the functions themselves we would have a type (**
* Expresses base operations we will need later
* We could probably do something clever to generate these from the
* the stan math signatures hash table
*)
type ('lhs , 'rhs) Multiply = UFun (StanLib ("multiply", ...), ['lhs, 'rhs])
type ('t) Transpose = UFun (StanLib ("transpose", ...), ['t'])
(**
* For tagging whether we need the value or adjoint of an input
*)
type ReversePair = Value | Adjoint
(**
* Tags we can use later to deduce which input is needed in the adjoint calculation
*)
type AdjointArgs =
ReturnAdj of ReversePair
| FirstArg of ReversePair
| SecondArg of ReversePair
| ThirdArg of ReversePair
(* The actual adjoint function, I think it can just be like this where it's a tag essentially*)
type ('expr) adjoint_fun = 'expr
(**
* Function comprising the forward and reverse pass
* Type is comprised of
* Name of the function
* Return type
* List of input argument types
* List of functions for each adjoint calculation
*)
type reverse_mode_function =
string * UnsizedType.returntype * UnsizedType.t list * adjoint_fun list The One thing to note in the above, if the focus is just on functions implemented With something like multiply we would have type ('lhs_expr, 'rhs_expr) ReverseMultiply = UFun (reverse_mode_function ("multiply",
UnsizedType.UMatrix,
[UMatrix, UMatrix],
[adjoint_fun (Multiply (ReturnArg Adjoint, Transpose (SecondArg Value))),
adjoint_fun (Multiply (Transpose (FirstArg Value), ReturnArg Adjoint))]), ['lhs_expr, 'rhs_expr]) And with that I think we can do each of the steps for generating the C++ we want.
auto lhs_arena = to_arena(lhs);
auto rhs_arena = to_arena(rhs);
auto C = to_arena(promote_to_var(multiply(value_of(lhs_arena), value_of(rhs_arena))));
reverse_pass_callback(
[arena_A, arena_B, C]() mutable {
// Do the adjoint accumulation for A and B
adjoint_of(arena_A) += multiply(adjoint_of(C), transpose(value_of(arena_B)));
adjoint_of(arena_B) += multiply(transpose(value_of(arena_A)), adjoint_of(C));
}
}); We can also look at the One thing that's nice about this is that it can be done in pieces. So for instance if we had
and if we have multiple reverse mode functions next to each other we can put their forward and reverse passes together so we only call one reverse pass callback and calculate the adjoints for multiple functions at once. One glaring hole I've thought of so far is how to handle temporaries. Aka for cases with temporaries in the function what to name the thing assigned to the arena for code like the below. matrix[N, M] SomeObj = multiply(add(X, Y), Z) I haven't totally thought of that yet. We could just make some hashes up to name temporaries and pull out matrix[N, M] hash_tmp = add(X, Y);
matrix[N, M] SomeObj = multiply(hash_tmp, Z); Then do the stuff above to make the reverse mode passes for each |
Beta Was this translation helpful? Give feedback.
-
Is there a way to make something like this efficient?
It's similar to what I did in my C++ example code on the forums and in the AD Handbook. I never figured out a way to get a type for the callback that made it efficient to store or access. |
Beta Was this translation helpful? Give feedback.
-
I'd like to make a separate log prob for reverse mode autodiff as I think several optimizations we've discussed for the future (the new var matrix type, reverse mode in the compiler, etc.) would be a lot easier to work with if we had one.
Working on the new matrix type for reverse mode autodiff, I find that I'm doing a lot of weird things to be able to handle that
log_prob
can take in types with scalars ofdouble
,var
,fvar<double>
,fvar<var>
etc. it would be nice to have a separatelog_prob
in the program, or somewhere, where we modify theUnsizedType.autodifftype
definition toand then for the new reverse mode log prob we just read the old one and promote any
Autodiffable
to aRevAutoDiff
.But I'm not sure where we should put this and when we should make it? My first thought is just putting it after the optimization pass of the transformed mir. I want to make the new reverse mode log prob after the mir optims / transforms so that in the future if we ever do something where we move statements across blocks (like moving data from
log_prob
orgenerated_quantities
toprepare_data
) we won't have to deal with any weirds of having twolog_probs
. Then if we want to do any additional optimization passes after generating the new log prob with var types we can do them there.Any thoughts on this? If this seems like an okay idea with folks I might go for it today
Beta Was this translation helpful? Give feedback.
All reactions