Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite: removing dimensionality and allow non-bijective transformations #183

Closed
wants to merge 48 commits into from

Conversation

torfjelde
Copy link
Member

This is a draft for a proper overhaul of Bijectors.jl

Goals with this PR are:

  1. Remove the annoying dimensionality by being explict about when we're working with a collection of inputs vs. single input.
    • Also removes the restriction that input and output of bijectors should be the same size, type, etc.
  2. Allow more than just bijectors (differentiable bijections with differentiable inverses).
  3. Resolve some simpler but outstanding issues:
  4. We have a lot of hacks in this package to make AD work across the board. Some of these vanish due to (1), but I'm also taking this opportunity to try to check what we actually need and what we can get rid of due to improvements upstream.
  5. Add support for mutating methods, e.g. transform!, logabsdetjac!.
  6. Add a bunch of sane default implementations, e.g. mutating versions, batched versions, etc.

Also, this is likely to result end up including a lot of changes, so we might end up splitting this into multiple PRs once it becomes more than a draft. But for now it's all here.

torfjelde added 30 commits June 5, 2021 06:46
@yebai
Copy link
Member

yebai commented Feb 3, 2022

Let's perhaps make a push to finish the work here. @torfjelde

@@ -0,0 +1,144 @@
# Bijectors.jl
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion I'm curious what you think of this design:) You can ignore the code for now, the main thing is just whether or not you're happy with the idea of doing

forward(b, x) = forward_single(b, x)

# Non-batched version
forward_single(b, x) = ...

# Batched version
forward_multiple(b, x) = ...

We can then introduce a Batch type which generalizes ColVecs etc., and do

forward(b, x::AbstractBatch) = forward_multiple(b, x)

My plan is to split this into two PRs:

  1. Remove all "official" support for batched computation, thus always assuming that the input given represents a single input (some bijectors might still support it, but there's not going to be a "official" support for it). In this PR there is no forward_single, etc., just forward.
  2. Add support for batching and overload forward_single instead of forward.

But ideally we'd fully adopt the ChangesOfVariables.jl interface, i.e. replace forward with with_logabsdet_jacobian. The issue here is that we of course can no longer do

with_logabsdet_jacobian(b, x) = with_logabsdet_jacobian_single(b, x)

etc. which makes me want to keep forward or ChangesOfVariables decides to take on a similar interface and encourage people to instead implement with_logabsdet_jacobian_single.

Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be a huge improvement if batching is declared more explicitly. Is there a specific reason why one needs forward_single and forward_multiple instead of just forward(::MyBijector, x) and forward(::MyBijector, x::AbstractBatch)? More functions means more entry points and hence more possible confusion for developers. If there's no dedicated _single function anymore, it would also work better with the CoV API, I assume?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is method ambiguity 😕 If only forward or with_logabsdet_jacobian is the entry-point, then we cannot provide sane defaults, e.g. forward(b, x::AbstractBatch) without every other implementation of forward being very explicit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is annoying - but there are multiple different options I think. One would be e.g. to implement (possibly different) default implementations but not as forward(...) but some forward_batch_style1(..) (whatever) and then to "activate" it by defining forward(b::MyBijector, x::AbstractBatch) = forward_batch_style1(b, x) etc. I.e., batch support would be enabled manually but without much effort. And one could reduce the amount of code even more with some helper macro.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that works of course, but I think this will be easy to forget vs. always just implemented *_single and batching is guaranteed to at least work.

Both approaches feel sub-optimal 😕

@yebai
Copy link
Member

yebai commented Feb 2, 2023

@torfjelde do we still need this after #214?

@torfjelde
Copy link
Member Author

No this can be closed:)

@yebai yebai closed this Feb 2, 2023
@yebai yebai deleted the tor/rewrite branch February 2, 2023 18:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants