Skip to content

Commit

Permalink
Update likelihood loss
Browse files Browse the repository at this point in the history
  • Loading branch information
jmeziere committed Aug 5, 2024
1 parent 618a6b8 commit 19ad794
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 23 deletions.
22 changes: 13 additions & 9 deletions src/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ function loss(state, getDeriv, getLoss, saveRecip)
if getLoss
state.plan.tempSpace .= sqrt(c) .* state.plan.recipSpace
return mapreduce(
(i,rsp,sup) -> sup ? abs2(rsp) - LogExpFunctions.xlogy(i, abs2(rsp)) - i + LogExpFunctionss.xlogx(i) : 0.0, +,
state.intens, state.plan.tempSpace, state.recSupport
(i,rsp,sup) -> sup ? abs2(rsp) - LogExpFunctions.xlogy(i, abs2(rsp)) - i + LogExpFunctions.xlogx(i) : 0.0, +,
state.intens, state.plan.tempSpace, state.recSupport, init = 0.0
)/length(state.recipSpace)
end
elseif state.losstype == 1
Expand Down Expand Up @@ -50,8 +50,8 @@ function emptyLoss(state)
end
state.plan.tempSpace .= sqrt(c) .* state.plan.recipSpace
return mapreduce(
(i,rsp,sup) -> sup ? abs2(rsp) - LogExpFunctions.xlogy(i, abs2(rsp)) - i + LogExpFunctionss.xlogx(i) : 0.0, +,
state.intens, state.plan.tempSpace, state.recSupport
(i,rsp,sup) -> sup ? abs2(rsp) - LogExpFunctions.xlogy(i, abs2(rsp)) - i + LogExpFunctions.xlogx(i) : 0.0, +,
state.intens, state.plan.tempSpace, state.recSupport, init = 0.0
)/length(state.recipSpace)
elseif state.losstype == 1
c = 1.0
Expand Down Expand Up @@ -82,12 +82,14 @@ function lossManyAtomic!(losses, losstype, x, y, z, adds, scalings, intens, reci
))

if losstype == 0
losses[i] += scalings[i]*abs2(rsp) - LogExpFunctions.xlogy(intens[j],scalings[i]*abs2(rsp))
losses[i] += (
scalings[i]*abs2(rsp) - LogExpFunctions.xlogy(intens[j],scalings[i]*abs2(rsp)) -
intens[j] + LogExpFunctions.xlogx(intens[j])
)/length(intens)
elseif losstype == 1
losses[i] += (scalings[i]*abs(rsp) - sqrt(intens[j]))^2
losses[i] += ((scalings[i]*abs(rsp) - sqrt(intens[j]))^2)/length(intens)
end
end
losses[i] /= length(intens)
end
end

Expand Down Expand Up @@ -119,8 +121,10 @@ function scalingManyAtomic!(scalings, losstype, x, y, z, adds, intens, recipSpac
end
end

function lossManyAtomic!(losses, state, x, y, z, adds)
losses .= 0
function lossManyAtomic!(losses, state, x, y, z, adds, addLoss)
if !addLoss
losses .= 0
end
scalings = CUDA.ones(Float64, length(x))

if state.scale
Expand Down
8 changes: 6 additions & 2 deletions src/State.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ end
function slowForwardProp(state::AtomicState, x, y, z, adds, saveRecip)
state.plan.recipSpace .= state.recipSpace
for i in 1:length(x)
if isnan(x[i])
continue
end
state.plan.recipSpace .+= (2 .* adds[i] .- 1) .* exp.(-1im .* (
x[i] .* (state.G[1] .+ state.h) .+
y[i] .* (state.G[2] .+ state.k) .+
Expand Down Expand Up @@ -383,7 +386,7 @@ function setpts!(state::MultiState, x, y, z, mx, my, mz, rho, ux, uy, uz, getDe
state.rholessRealSpace .= exp.(-1im .* (state.G[1] .* ux .+ state.G[2] .* uy .+ state.G[3] .* uz))
state.realSpace[1:length(mx)] .= rho .* state.rholessRealSpace
state.realSpace[length(mx)+1:end] .= exp.(-1im .* (state.G[1] .* x .+ state.G[2] .* y .+ state.G[3] .* z))
resize!(state.rhoPlan.realSpace, length(x))
resize!(state.rhoPlan.realSpace, length(mx))
resize!(state.plan.realSpace, length(x)+length(mx))
if getDeriv
resize!(state.xDeriv, length(x))
Expand All @@ -401,7 +404,8 @@ function setpts!(state::MultiState, x, y, z, mx, my, mz, rho, ux, uy, uz, getDe
mesoY = my .+ uy
mesoZ = mz .+ uz
FINUFFT.cufinufft_setpts!(state.plan.forPlan, fullX, fullY, fullZ)
FINUFFT.cufinufft_setpts!(state.plan.revPlan, mesoX, mesoY, mesoZ)
FINUFFT.cufinufft_setpts!(state.plan.revPlan, fullX, fullY, fullZ)
FINUFFT.cufinufft_setpts!(state.rhoPlan.revPlan, mesoX, mesoY, mesoZ)
end

function forwardProp(state::MultiState, saveRecip)
Expand Down
4 changes: 2 additions & 2 deletions test/Atomic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function atomicLikelihoodWithScaling(x, y, z, h, k, l, intens, recSupport)
recipSpace .*= recSupport
intens = intens .* recSupport
c = reduce(+, intens)/mapreduce(abs2, +, recipSpace)
return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)), +, intens, recipSpace)/length(intens)
return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens)
end

function atomicLikelihoodWithoutScaling(x, y, z, h, k, l, intens, recSupport)
Expand All @@ -36,7 +36,7 @@ function atomicLikelihoodWithoutScaling(x, y, z, h, k, l, intens, recSupport)
end
recipSpace .*= recSupport
intens = intens .* recSupport
return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)), +, intens, recipSpace)/length(intens)
return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens)
end

function atomicL2WithScaling(x, y, z, h, k, l, intens, recSupport)
Expand Down
4 changes: 2 additions & 2 deletions test/Meso.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function mesoLikelihoodWithScaling(x, y, z, rho, ux, uy, uz, h, k, l, hp, kp, lp
recipSpace .*= recSupport
intens = intens .* recSupport
c = reduce(+, intens)/mapreduce(abs2, +, recipSpace)
return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)), +, intens, recipSpace)/length(intens)
return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens)
end

function mesoLikelihoodWithoutScaling(x, y, z, rho, ux, uy, uz, h, k, l, hp, kp, lp, intens, recSupport)
Expand All @@ -40,7 +40,7 @@ function mesoLikelihoodWithoutScaling(x, y, z, rho, ux, uy, uz, h, k, l, hp, kp,
end
recipSpace .*= recSupport
intens = intens .* recSupport
return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)), +, intens, recipSpace)/length(intens)
return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens)
end

function mesoL2WithScaling(x, y, z, rho, ux, uy, uz, h, k, l, hp, kp, lp, intens, recSupport)
Expand Down
4 changes: 2 additions & 2 deletions test/Multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function multiLikelihoodWithScaling(x, y, z, mx, my, mz, rho, ux, uy, uz, h, k,
recipSpace .*= recSupport
intens = intens .* recSupport
c = reduce(+, intens)/mapreduce(abs2, +, recipSpace)
return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)), +, intens, recipSpace)/length(intens)
return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens)
end

function multiLikelihoodWithoutScaling(x, y, z, mx, my, mz, rho, ux, uy, uz, h, k, l, hp, kp, lp, intens, recSupport)
Expand Down Expand Up @@ -58,7 +58,7 @@ function multiLikelihoodWithoutScaling(x, y, z, mx, my, mz, rho, ux, uy, uz, h,
end
recipSpace .*= recSupport
intens = intens .* recSupport
return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)), +, intens, recipSpace)/length(intens)
return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens)
end

function multiL2WithScaling(x, y, z, mx, my, mz, rho, ux, uy, uz, h, k, l, hp, kp, lp, intens, recSupport)
Expand Down
4 changes: 2 additions & 2 deletions test/Traditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function tradLikelihoodWithScaling(realSpace, intens, recSupport)
recipSpace .*= recSupport
intens = intens .* recSupport
c = reduce(+, intens)/mapreduce(abs2, +, recipSpace)
return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)), +, intens, recipSpace)/length(intens)
return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens)
end

function tradLikelihoodWithoutScaling(realSpace, intens, recSupport)
Expand All @@ -42,7 +42,7 @@ function tradLikelihoodWithoutScaling(realSpace, intens, recSupport)
end
recipSpace .*= recSupport
intens = intens .* recSupport
return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)), +, intens, recipSpace)/length(intens)
return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens)
end

function tradL2WithScaling(realSpace, intens, recSupport)
Expand Down
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ include("Multi.jl")
for i in 1:length(manyX)
losses[i] = atomicLikelihoodWithScaling(vcat(x,[manyX[i]]), vcat(y,[manyY[i]]), vcat(z,[manyZ[i]]), h.+G[1], k.+G[2], l.+G[3], intens, recSupport)
end
BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds)
BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds, false)

@test isapprox(testee, tester, rtol=1e-6)
@test all(isapprox.(Array(state.xDeriv), xDeriv, rtol=1e-6))
Expand All @@ -97,7 +97,7 @@ include("Multi.jl")
for i in 1:length(manyX)
losses[i] = atomicLikelihoodWithoutScaling(vcat(x,[manyX[i]]), vcat(y,[manyY[i]]), vcat(z,[manyZ[i]]), h.+G[1], k.+G[2], l.+G[3], intens, recSupport)
end
BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds)
BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds, false)

@test isapprox(testee, tester, rtol=1e-6)
@test all(isapprox.(Array(state.xDeriv), xDeriv, rtol=1e-6))
Expand All @@ -124,7 +124,7 @@ include("Multi.jl")
for i in 1:length(manyX)
losses[i] = atomicL2WithScaling(vcat(x,[manyX[i]]), vcat(y,[manyY[i]]), vcat(z,[manyZ[i]]), h.+G[1], k.+G[2], l.+G[3], intens, recSupport)
end
BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds)
BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds, false)

@test isapprox(testee, tester, rtol=1e-6)
@test all(isapprox.(Array(state.xDeriv), xDeriv, rtol=1e-6))
Expand All @@ -151,7 +151,7 @@ include("Multi.jl")
for i in 1:length(manyX)
losses[i] = atomicL2WithoutScaling(vcat(x,[manyX[i]]), vcat(y,[manyY[i]]), vcat(z,[manyZ[i]]), h.+G[1], k.+G[2], l.+G[3], intens, recSupport)
end
BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds)
BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds, false)

@test isapprox(testee, tester, rtol=1e-6)
@test all(isapprox.(Array(state.xDeriv), xDeriv, rtol=1e-6))
Expand Down

0 comments on commit 19ad794

Please sign in to comment.