From 19ad794bee203786ba016762b43287819846788c Mon Sep 17 00:00:00 2001 From: jmeziere Date: Sun, 4 Aug 2024 22:55:52 -0600 Subject: [PATCH] Update likelihood loss --- src/Losses.jl | 22 +++++++++++++--------- src/State.jl | 8 ++++++-- test/Atomic.jl | 4 ++-- test/Meso.jl | 4 ++-- test/Multi.jl | 4 ++-- test/Traditional.jl | 4 ++-- test/runtests.jl | 8 ++++---- 7 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/Losses.jl b/src/Losses.jl index 8e656ff..b85790f 100644 --- a/src/Losses.jl +++ b/src/Losses.jl @@ -14,8 +14,8 @@ function loss(state, getDeriv, getLoss, saveRecip) if getLoss state.plan.tempSpace .= sqrt(c) .* state.plan.recipSpace return mapreduce( - (i,rsp,sup) -> sup ? abs2(rsp) - LogExpFunctions.xlogy(i, abs2(rsp)) - i + LogExpFunctionss.xlogx(i) : 0.0, +, - state.intens, state.plan.tempSpace, state.recSupport + (i,rsp,sup) -> sup ? abs2(rsp) - LogExpFunctions.xlogy(i, abs2(rsp)) - i + LogExpFunctions.xlogx(i) : 0.0, +, + state.intens, state.plan.tempSpace, state.recSupport, init = 0.0 )/length(state.recipSpace) end elseif state.losstype == 1 @@ -50,8 +50,8 @@ function emptyLoss(state) end state.plan.tempSpace .= sqrt(c) .* state.plan.recipSpace return mapreduce( - (i,rsp,sup) -> sup ? abs2(rsp) - LogExpFunctions.xlogy(i, abs2(rsp)) - i + LogExpFunctionss.xlogx(i) : 0.0, +, - state.intens, state.plan.tempSpace, state.recSupport + (i,rsp,sup) -> sup ? abs2(rsp) - LogExpFunctions.xlogy(i, abs2(rsp)) - i + LogExpFunctions.xlogx(i) : 0.0, +, + state.intens, state.plan.tempSpace, state.recSupport, init = 0.0 )/length(state.recipSpace) elseif state.losstype == 1 c = 1.0 @@ -82,12 +82,14 @@ function lossManyAtomic!(losses, losstype, x, y, z, adds, scalings, intens, reci )) if losstype == 0 - losses[i] += scalings[i]*abs2(rsp) - LogExpFunctions.xlogy(intens[j],scalings[i]*abs2(rsp)) + losses[i] += ( + scalings[i]*abs2(rsp) - LogExpFunctions.xlogy(intens[j],scalings[i]*abs2(rsp)) - + intens[j] + LogExpFunctions.xlogx(intens[j]) + )/length(intens) elseif losstype == 1 - losses[i] += (scalings[i]*abs(rsp) - sqrt(intens[j]))^2 + losses[i] += ((scalings[i]*abs(rsp) - sqrt(intens[j]))^2)/length(intens) end end - losses[i] /= length(intens) end end @@ -119,8 +121,10 @@ function scalingManyAtomic!(scalings, losstype, x, y, z, adds, intens, recipSpac end end -function lossManyAtomic!(losses, state, x, y, z, adds) - losses .= 0 +function lossManyAtomic!(losses, state, x, y, z, adds, addLoss) + if !addLoss + losses .= 0 + end scalings = CUDA.ones(Float64, length(x)) if state.scale diff --git a/src/State.jl b/src/State.jl index 1550099..ce472b1 100644 --- a/src/State.jl +++ b/src/State.jl @@ -108,6 +108,9 @@ end function slowForwardProp(state::AtomicState, x, y, z, adds, saveRecip) state.plan.recipSpace .= state.recipSpace for i in 1:length(x) + if isnan(x[i]) + continue + end state.plan.recipSpace .+= (2 .* adds[i] .- 1) .* exp.(-1im .* ( x[i] .* (state.G[1] .+ state.h) .+ y[i] .* (state.G[2] .+ state.k) .+ @@ -383,7 +386,7 @@ function setpts!(state::MultiState, x, y, z, mx, my, mz, rho, ux, uy, uz, getDe state.rholessRealSpace .= exp.(-1im .* (state.G[1] .* ux .+ state.G[2] .* uy .+ state.G[3] .* uz)) state.realSpace[1:length(mx)] .= rho .* state.rholessRealSpace state.realSpace[length(mx)+1:end] .= exp.(-1im .* (state.G[1] .* x .+ state.G[2] .* y .+ state.G[3] .* z)) - resize!(state.rhoPlan.realSpace, length(x)) + resize!(state.rhoPlan.realSpace, length(mx)) resize!(state.plan.realSpace, length(x)+length(mx)) if getDeriv resize!(state.xDeriv, length(x)) @@ -401,7 +404,8 @@ function setpts!(state::MultiState, x, y, z, mx, my, mz, rho, ux, uy, uz, getDe mesoY = my .+ uy mesoZ = mz .+ uz FINUFFT.cufinufft_setpts!(state.plan.forPlan, fullX, fullY, fullZ) - FINUFFT.cufinufft_setpts!(state.plan.revPlan, mesoX, mesoY, mesoZ) + FINUFFT.cufinufft_setpts!(state.plan.revPlan, fullX, fullY, fullZ) + FINUFFT.cufinufft_setpts!(state.rhoPlan.revPlan, mesoX, mesoY, mesoZ) end function forwardProp(state::MultiState, saveRecip) diff --git a/test/Atomic.jl b/test/Atomic.jl index 246adae..3ff4691 100644 --- a/test/Atomic.jl +++ b/test/Atomic.jl @@ -16,7 +16,7 @@ function atomicLikelihoodWithScaling(x, y, z, h, k, l, intens, recSupport) recipSpace .*= recSupport intens = intens .* recSupport c = reduce(+, intens)/mapreduce(abs2, +, recipSpace) - return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)), +, intens, recipSpace)/length(intens) + return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens) end function atomicLikelihoodWithoutScaling(x, y, z, h, k, l, intens, recSupport) @@ -36,7 +36,7 @@ function atomicLikelihoodWithoutScaling(x, y, z, h, k, l, intens, recSupport) end recipSpace .*= recSupport intens = intens .* recSupport - return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)), +, intens, recipSpace)/length(intens) + return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens) end function atomicL2WithScaling(x, y, z, h, k, l, intens, recSupport) diff --git a/test/Meso.jl b/test/Meso.jl index ba07584..1b0b782 100644 --- a/test/Meso.jl +++ b/test/Meso.jl @@ -18,7 +18,7 @@ function mesoLikelihoodWithScaling(x, y, z, rho, ux, uy, uz, h, k, l, hp, kp, lp recipSpace .*= recSupport intens = intens .* recSupport c = reduce(+, intens)/mapreduce(abs2, +, recipSpace) - return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)), +, intens, recipSpace)/length(intens) + return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens) end function mesoLikelihoodWithoutScaling(x, y, z, rho, ux, uy, uz, h, k, l, hp, kp, lp, intens, recSupport) @@ -40,7 +40,7 @@ function mesoLikelihoodWithoutScaling(x, y, z, rho, ux, uy, uz, h, k, l, hp, kp, end recipSpace .*= recSupport intens = intens .* recSupport - return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)), +, intens, recipSpace)/length(intens) + return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens) end function mesoL2WithScaling(x, y, z, rho, ux, uy, uz, h, k, l, hp, kp, lp, intens, recSupport) diff --git a/test/Multi.jl b/test/Multi.jl index 35054ea..b7ded07 100644 --- a/test/Multi.jl +++ b/test/Multi.jl @@ -27,7 +27,7 @@ function multiLikelihoodWithScaling(x, y, z, mx, my, mz, rho, ux, uy, uz, h, k, recipSpace .*= recSupport intens = intens .* recSupport c = reduce(+, intens)/mapreduce(abs2, +, recipSpace) - return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)), +, intens, recipSpace)/length(intens) + return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens) end function multiLikelihoodWithoutScaling(x, y, z, mx, my, mz, rho, ux, uy, uz, h, k, l, hp, kp, lp, intens, recSupport) @@ -58,7 +58,7 @@ function multiLikelihoodWithoutScaling(x, y, z, mx, my, mz, rho, ux, uy, uz, h, end recipSpace .*= recSupport intens = intens .* recSupport - return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)), +, intens, recipSpace)/length(intens) + return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens) end function multiL2WithScaling(x, y, z, mx, my, mz, rho, ux, uy, uz, h, k, l, hp, kp, lp, intens, recSupport) diff --git a/test/Traditional.jl b/test/Traditional.jl index dc11a09..71f59fe 100644 --- a/test/Traditional.jl +++ b/test/Traditional.jl @@ -19,7 +19,7 @@ function tradLikelihoodWithScaling(realSpace, intens, recSupport) recipSpace .*= recSupport intens = intens .* recSupport c = reduce(+, intens)/mapreduce(abs2, +, recipSpace) - return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)), +, intens, recipSpace)/length(intens) + return mapreduce((i,r) -> c*abs2(r) - LogExpFunctions.xlogy(i,c*abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens) end function tradLikelihoodWithoutScaling(realSpace, intens, recSupport) @@ -42,7 +42,7 @@ function tradLikelihoodWithoutScaling(realSpace, intens, recSupport) end recipSpace .*= recSupport intens = intens .* recSupport - return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)), +, intens, recipSpace)/length(intens) + return mapreduce((i,r) -> abs2(r) - LogExpFunctions.xlogy(i,abs2(r)) - i + LogExpFunctions.xlogx(i), +, intens, recipSpace)/length(intens) end function tradL2WithScaling(realSpace, intens, recSupport) diff --git a/test/runtests.jl b/test/runtests.jl index ee47fd9..211e869 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,7 +70,7 @@ include("Multi.jl") for i in 1:length(manyX) losses[i] = atomicLikelihoodWithScaling(vcat(x,[manyX[i]]), vcat(y,[manyY[i]]), vcat(z,[manyZ[i]]), h.+G[1], k.+G[2], l.+G[3], intens, recSupport) end - BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds) + BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds, false) @test isapprox(testee, tester, rtol=1e-6) @test all(isapprox.(Array(state.xDeriv), xDeriv, rtol=1e-6)) @@ -97,7 +97,7 @@ include("Multi.jl") for i in 1:length(manyX) losses[i] = atomicLikelihoodWithoutScaling(vcat(x,[manyX[i]]), vcat(y,[manyY[i]]), vcat(z,[manyZ[i]]), h.+G[1], k.+G[2], l.+G[3], intens, recSupport) end - BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds) + BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds, false) @test isapprox(testee, tester, rtol=1e-6) @test all(isapprox.(Array(state.xDeriv), xDeriv, rtol=1e-6)) @@ -124,7 +124,7 @@ include("Multi.jl") for i in 1:length(manyX) losses[i] = atomicL2WithScaling(vcat(x,[manyX[i]]), vcat(y,[manyY[i]]), vcat(z,[manyZ[i]]), h.+G[1], k.+G[2], l.+G[3], intens, recSupport) end - BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds) + BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds, false) @test isapprox(testee, tester, rtol=1e-6) @test all(isapprox.(Array(state.xDeriv), xDeriv, rtol=1e-6)) @@ -151,7 +151,7 @@ include("Multi.jl") for i in 1:length(manyX) losses[i] = atomicL2WithoutScaling(vcat(x,[manyX[i]]), vcat(y,[manyY[i]]), vcat(z,[manyZ[i]]), h.+G[1], k.+G[2], l.+G[3], intens, recSupport) end - BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds) + BcdiCore.lossManyAtomic!(cuLosses, state, cuManyX, cuManyY, cuManyZ, cuAdds, false) @test isapprox(testee, tester, rtol=1e-6) @test all(isapprox.(Array(state.xDeriv), xDeriv, rtol=1e-6))