Skip to content

Commit

Permalink
reverse derivative of ArrayType.get
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 19, 2024
1 parent b95691a commit cab712a
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions SciLean/Data/ArrayType/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,17 @@ theorem ArrayType.get.arg_cont.adjoint_rule (i : Idx) :
=
fun e : Elem => oneHot (i,()) e := by sorry_proof

@[fun_trans]
theorem ArrayType.get.arg_cont.revFDeriv_rule (i : Idx)
(cont : W → Cont) (hf : Differentiable K cont) :
revFDeriv K (fun w => ArrayType.get (cont w) i)
=
fun w : W =>
let xi := revFDeriv K cont w
(ArrayType.get xi.1 i, fun (de : Elem) =>
xi.2 (oneHot (i,()) de)) := by
unfold revFDeriv; fun_trans

@[fun_trans]
theorem ArrayType.set.arg_cont.adjoint_rule (i : Idx) :
adjoint K (fun c : Cont => ArrayType.set c i 0)
Expand Down

0 comments on commit cab712a

Please sign in to comment.