-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Widening the scope of the package and dropping support for batching #214
Conversation
This is a fairly big one @devmotion , but I would greatly appreciate it if you had a super-quick look. The main part is just removing the dimensionality completely from the definition of the bijectors, in addition to a couple of small things:
I'm also going to make a separate PR to remove the stuff related to This will be a huge breaking release, as I think it's time we just rip the band-aid off. |
handling batches
…tors.jl into tor/write-without-batch
@yebai you might also want to take a look at this |
InvertibleBatchNorm | ||
Coupling, | ||
InvertibleBatchNorm, | ||
elementwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... I wonder if we could reuse some existing functionality in the ecosystem here. And/or if there is a shorter name. Regarding the first point, e.g., Transducers.Map
seems similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to take suggestions, but I'm also okay with elementwise
, so IMO this shouldn't hold this PR back.
Not too big of a fan to depend use Transducers.Map
though; seems like unnecessary complexity just to make Base.Fix1(broadcast, f)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I don't want to depend on Transducers either. Unfortunately, it seems we can't just define a curried version
Base.map(t::Transform) = ...
or
Base.broadcast(t::Transform) = ...
since we would like to use elementwise
also for functions such as exp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we would like to use elementwise also for functions such as exp?
Exactly 😕
src/Bijectors.jl
Outdated
Base.@deprecate NamedBijector(bs) NamedTransform(bs) | ||
|
||
@noinline function Base.inv(b::AbstractBijector) | ||
Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv) | ||
inverse(b) | ||
end | ||
Base.@deprecate Exp() elementwise(exp) false | ||
Base.@deprecate Log() elementwise(log) false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are deprecations needed if it is a breaking release? Or would it be sufficient to add them to some changelog/announcement/NEWS.md?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed!
src/interface.jl
Outdated
|
||
Transform `x` using `b`, treating `x` as a single input. | ||
""" | ||
transform(f::Function, x) = f(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if it matters, but Julia won't specialize on f
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good catch!
""" | ||
isclosedform(b::Bijector)::bool | ||
isclosedform(b⁻¹::Inverse{<:Bijector})::bool | ||
logabsdetjac(b, x) = last(with_logabsdet_jacobian(b, x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed? Seems like something that - if desired - should maybe go to ChangesOfVariables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally find this useful, and the last time we spoke about this (JuliaMath/ChangesOfVariables.jl#3), there didn't seem to be a desire to add it 😕
src/interface.jl
Outdated
# Useful for checking if compositions, etc. are invertible or not. | ||
Base.:+(::NotInvertible, ::Invertible) = NotInvertible() | ||
Base.:+(::Invertible, ::NotInvertible) = NotInvertible() | ||
Base.:+(::NotInvertible, ::NotInvertible) = NotInvertible() | ||
Base.:+(::Invertible, ::Invertible) = Invertible() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend removing these definitions. Seems like a misuse of +
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
""" | ||
with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x)) | ||
inverse(t::Transform) = Inverse(t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just remove this definition, and it should be sufficient to operate with InverseFunctions.NoInverse
instead of Invertible
/NoInvertible
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Co-authored-by: David Widmann <[email protected]>
I think I've replied/addressed your comments now @devmotion Some I'm of the opinion that we leave until later PRs, given how long this PR has been in the pipeline + a lot of improvements we want to do (e.g. adding a |
Yeah, let's just make another breaking release if it becomes necessary. I think I checked yesterday and Bijectors only has ~9 direct dependents, so it's not too bad if we iterate and release multiple breaking versions if needed (of course, it would be better to have the optimal design right away but that's completely unrealistic 😄). |
Bueno! You "happy" with the current version of the PR then? |
Thanks, @torfjelde @devmotion -- it looks good to me. I agree that we can keep improving the design in new PRs. |
This PR makes DPPL compatible with the changes to come in TuringLang/Bijectors.jl#214. Tests are passing locally. Closes #455 Closes #456
…eep existing compat) (#469) * Fixed a typo in tutorial (#451) * CompatHelper: bump compat for Turing to 0.24 for package turing, (keep existing compat) (#450) This pull request changes the compat entry for the `Turing` package from `0.21` to `0.21, 0.24` for package turing. This keeps the compat entries for earlier versions. Note: I have not tested your package with this new compat entry. It is your responsibility to make sure that your package tests pass before you merge this pull request. Co-authored-by: Hong Ge <[email protected]> * Some minor utility improvements (#452) This PR does the following: - Moves the `varname_leaves` from `TestUtils` to main module. - It can be very useful in Turing.jl for constructing `Chains` and the like, so I think it's a good idea to make it part of the main module rather than keeping it "hidden" there. - Makes the default `varinfo` in the constructor of `LogDensityFunction` be `model.context` rather than a new `DynamicPPL.DefaultContext`. - The `context` pass to `evaluate!!` will override the leaf-context in `model.context`, and so the current default constructor always uses `DefaultContext` as the leaf-context, even if the `Model` has been `contextualize`d with some other leaf-context, e.g. `PriorContext`. This PR fixes this issue. * Always run CI (#453) I find the current `bors` workflow a bit tedious. Most of the time, I summon `bors` to see the CI results (see e.g. #438). Given that most `CI` tests are quick (< 10mins), we can always run them by default. The most time-consuming `IntegrationTests` is still run by `bors` to avoid excessive CI runs. * Compat with new Bijectors.jl (#454) This PR makes DPPL compatible with the changes to come in TuringLang/Bijectors.jl#214. Tests are passing locally. Closes #455 Closes #456 * Another Bijectors.jl compat bound bump (#457) * CompatHelper: bump compat for MCMCChains to 6 for package test, (keep existing compat) (#467) This pull request changes the compat entry for the `MCMCChains` package from `4.0.4, 5` to `4.0.4, 5, 6` for package test. This keeps the compat entries for earlier versions. Note: I have not tested your package with this new compat entry. It is your responsibility to make sure that your package tests pass before you merge this pull request. Co-authored-by: Hong Ge <[email protected]> * CompatHelper: bump compat for AbstractPPL to 0.6 for package test, (keep existing compat) --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]>
This PR is an attempt at a couple of things:
Bijector
.TODOs:
logpdf_with_jac
,logpdf_forward
,forward(d::Distribution, ...)
?forward
no longer works for multiple samples because we no longer have support for batched inputs.