-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Replace Zygote adjoints with ChainRules' rrules #153
Conversation
src/chainrules.jl
Outdated
function ChainRulesCore.rrule(::typeof(ColVecs), X::AbstractMatrix) | ||
return ColVecs(X), vecs_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(RowVecs), X::AbstractMatrix) | ||
return RowVecs(X), vecs_pullback | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@willtebbutt How do we go about defining rrule
s for constructors?
For example, here typeof(ColVecs)
results in UnionAll
which is probably not what we want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to replace typeof(ColVecs)
with Type{ColVecs}
?
Maybe JuliaDiff/ChainRulesCore.jl#150 could simplify some definitions? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good. Just needs to use the tests in ChainRulesCore
-- you might need to define to_vec
for a few things, but that's about it.
throw(error("In slow method")) | ||
end | ||
|
||
function ChainRulesCore.rrule(::Type{ColVecs}, X::AbstractMatrix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the right thing to do here would be ::Type{<:ColVecs}
, would it not?
return ColVecs(X), vecs_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::Type{RowVecs}, X::AbstractMatrix) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here?
Replaced by #208 |
#116 (comment)
Currently a lot of AD tests fail. Still need to work on revamping the AD tests to use ChainRulesTestUtils.jl.
TODO
rrule_test
for each of therrules
. Plain usage of them results in errors withFiniteDifferences
. Figure of if this is due to the incorrect rrules.test/utils.jl
seems pass with minimal imports but fail other wise. Figure out why.