Skip to content

Commit

Permalink
unthunk input to Tridiagonal_pullback
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed May 20, 2024
1 parent dfbd363 commit 938623b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,9 @@ end

function rrule(::Type{Tridiagonal}, dl, d, du)
y = Tridiagonal(dl, d, du)
@views function ∇Tridiagonal(∂y)
@views function Tridiagonal_pullback(ȳ)
∂y = unthunk(ȳ)
return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1))
end
return y, ∇Tridiagonal
return y, Tridiagonal_pullback
end

0 comments on commit 938623b

Please sign in to comment.