Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create manifold diffusion #39

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions examples/manifolds.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Pkg;
Pkg.add(url="https://github.com/JuliaManifolds/ManifoldMeasures.jl")
Pkg.add(["MeasureBase", "Plots"])
using Manifolds, ManifoldMeasures, MeasureBase, Plots

function expectation(x_t, t; process = process, target_dist = target_dist, n_samples = 1000)
x0samples = [sampleforward(process, t, x_t) for i in 1:n_samples]
sampleweights = [mapslices(p -> target_pdf(target_dist, p), x0sample, dims = [1]) for x0sample in x0samples]
x_0_expectation = similar(x_t)
for pointindex in CartesianIndices(size(x_t)[2:end])
x_0_expectation[:, pointindex] = sum( (samp -> samp[:, pointindex]).(x0samples) .* (sampweights -> sampweights[pointindex]).(sampleweights) )
x_0_expectation[:, pointindex] = project(process.manifold, x_0_expectation[:, pointindex])
end
x_0_expectation
end

target_pdf(target_dist, point) = sum( MeasureBase.density_def(dist, point) * weight for (dist, weight) in zip(target_dist.dists, target_dist.weights) )

# N-dimensional sphere - 1 is a circle, 2 is a sphere, 3 is a quaternion, etc.
N = 2
manifold = Sphere(N)
# Distributions on the sphere
dists = [
ManifoldMeasures.VonMisesFisher(manifold, μ = project(manifold, [1.0, 1.0, 1.0]), κ = 70),
ManifoldMeasures.VonMisesFisher(manifold, μ = project(manifold, [-1.0, -1.0, -1.0]), κ = 70),
ManifoldMeasures.VonMisesFisher(manifold, μ = project(manifold, [0.99995, 9.99934e-5, 0.0]), κ = 1.5)
]
unnormalized_weights = [1, 1, 1]
target_dist = (dists = dists, weights = unnormalized_weights ./ sum(unnormalized_weights))

# Diffusion process
process = ManifoldBrownianDiffusion(manifold, 1.0)
d = (1, )
x_T = hcat(rand(uniform_distribution(manifold, zeros(N + 1)), d...)...)
timesteps = timeschedule(exp, log, 0.001, 20, 100)

@time diffusion_samples = samplebackward(expectation, process, timesteps, x_T)

function target_sample(target_dist)
r = rand()
for (dist, weight) in zip(target_dist.dists, target_dist.weights)
r -= weight
r < 0 && return rand(dist)
end
end

target_samples = hcat([target_sample(target_dist) for i in eachindex(diffusion_samples)]...)

coordvectors(samples) = [samples[i, :] for i in 1:size(samples)[1]]

pl_S1_diffusion_samples = plot(title = "Diffusion samples", coordvectors(diffusion_samples)..., size=(400, 400), st = :scatter, xlim = (-1.1, 1.1), ylim = (-1.1, 1.1),
alpha = 0.3, color = "blue")
pl_S1_target_samples = plot(title = "Target samples", coordvectors(target_samples)..., size=(400, 400), st = :scatter, xlim = (-1.1, 1.1), ylim = (-1.1, 1.1),
alpha = 0.3, color="red")
plot(pl_S1_diffusion_samples, pl_S1_target_samples, size = (800, 400))
76 changes: 76 additions & 0 deletions src/manifolds.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
If you want to use a manifold not in Manifolds.jl, you can define a new type that inherits from Manifold and implement the following functions:
- project(manifold, point)
- shortest_path_interpolation(rng, process, point_0, point_t, s, t)

x_t = Array(euclidean_dimension, batch_dims...)
"""

struct ManifoldBrownianDiffusion{M <: AbstractManifold, T <: Real} <: SamplingProcess
manifold::M
rate::T
getsteps::Function
end

ManifoldBrownianDiffusion(manifold::AbstractManifold, rate::T) where T <: Real = ManifoldBrownianDiffusion(manifold, rate, (t) -> range(min(t, 0.05), t, step=0.05))

pointindices(X) = CartesianIndices(size(X)[2:end])

function project!(x_to, x_from, manifold)
for pointindex in pointindices(x_from)
x_to[:, pointindex] = project(manifold, x_from[:, pointindex])
end
end

function sampleforward(rng::AbstractRNG, process::ManifoldBrownianDiffusion{M, T}, t::Real, x_0) where M where T
x_t = similar(x_0)
project!(x_t, x_0, process.manifold)
for step in process.getsteps(t * process.rate)
x_t .+= randn(rng, T, size(x_t)...) .* sqrt(step)
project!(x_t, x_t, process.manifold)
end
return x_t
end

shortestpath_interpolation(rng::AbstractRNG, process::ManifoldBrownianDiffusion, p_0::AbstractVector, p_t::AbstractVector, s, t) =
shortest_geodesic(process.manifold, p_0, p_t, s / t)

function shortestpath_interpolation!(x_s, rng::AbstractRNG, process::ManifoldBrownianDiffusion, x_0, x_t, s, t)
for pointindex in pointindices(x_s)
x_s[:, pointindex] = shortestpath_interpolation(rng, process, x_0[:, pointindex], x_t[:, pointindex], s, t)
end
end

# Empirically, this is good - should be the same as diffusing C from B as is done in the rotational/angular cases (nice symmetries) for spheres at least.
function endpoint_conditioned_sample(rng::AbstractRNG, process::ManifoldBrownianDiffusion, s::Real, t::Real, x_0, x_t)
B = sampleforward(rng, process, s, x_0)
C = sampleforward(rng, process, t-s, x_t)
shortestpath_interpolation!(C, rng, process, B, C, s, t)
return C
end

# This has not been tested yet. An optional way of a heuristic brownian bridge for general manifolds.
function endpoint_conditioned_sample_distance_proportional(rng::AbstractRNG, process::ManifoldBrownianDiffusion, s::Real, t::Real, x_0, x_t)
x_s = similar(x_0)
D = process.rate # Diffusion coefficient
dt = min(s, 0.05) # timestep
for pointindex in pointindices(x_from)
p_0 = x_0[:, pointindex]
p_t = x_t[:, pointindex]
d = distance(process.manifold, p_0, p_t)
t_remaining = t
p_cur = p_0
while (t - t_remaining) < s
dp = randn(rng, size(p_cur)...) .* (2 * D * dt)
p_new = p_cur .+ dp
d_remaining = distance(process.manifold, p_new, p_t)
factor = 1 - (d_remaining / d) * (t_remaining / T)
p_adjusted = shortest_geodesic(p_new, p_t, factor)

p_cur = project(process.manifold, p_adjusted)
t_remaining -= dt
end
x_s[:, pointindex] = p_cur
end
x_s
end