Skip to content

Commit

Permalink
Switch asserts to tests for PnP
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel committed May 28, 2024
1 parent 2497651 commit 119a463
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions test/testRegularization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
model(x) = x
# reduced constructor, checking defaults
pnp_reg = PnPRegularization(model, [2])
@assert pnp_reg.λ == 1.0
@assert pnp_reg.model == model
@assert pnp_reg.shape == [2]
@assert pnp_reg.input_transform == RegularizedLeastSquares.MinMaxTransform
@assert pnp_reg.ignoreIm == false
@test pnp_reg.λ == 1.0
@test pnp_reg.model == model
@test pnp_reg.shape == [2]
@test pnp_reg.input_transform == RegularizedLeastSquares.MinMaxTransform
@test pnp_reg.ignoreIm == false
# full constructor
pnp_reg = PnPRegularization(0.1; model=model, shape=[2], input_transform=x -> x, ignoreIm=true)
# full constructor defaults
pnp_reg = PnPRegularization(0.1; model=model, shape=[2])
@assert pnp_reg.input_transform == RegularizedLeastSquares.MinMaxTransform
@assert pnp_reg.ignoreIm == false
@test pnp_reg.input_transform == RegularizedLeastSquares.MinMaxTransform
@test pnp_reg.ignoreIm == false
# unnecessary kwargs are ignored
pnp_reg = PnPRegularization(0.1; model=model, shape=[2], input_transform=x -> x, ignoreIm=true, sMtHeLsE=1)
end
Expand All @@ -27,9 +27,14 @@ end
b = A * x

for solver in supported_solvers
S = createLinearSolver(solver, A, iterations=2; reg=[pnp_reg])
x_approx = solve!(S, b)
@info "PnP Regularization and $solver Compatibility"
@test try
S = createLinearSolver(solver, A, iterations=2; reg=[pnp_reg])
x_approx = solve!(S, b)
@info "PnP Regularization and $solver Compatibility"
true
catch ex
false
end
end
end

Expand All @@ -38,7 +43,7 @@ end
pnp_reg = PnPRegularization(0.1; model=x -> zeros(eltype(x), size(x)), shape=[2], input_transform=RegularizedLeastSquares.IdentityTransform)
out = prox!(pnp_reg, [1.0, 2.0], 0.1)
@info out
@assert out == [0.9, 1.8]
@test out == [0.9, 1.8]
end


Expand All @@ -49,24 +54,24 @@ end
input_transform=RegularizedLeastSquares.IdentityTransform
)
out = prox!(pnp_reg, [1.0 + 1.0im, 2.0 + 2.0im], 0.1)
@assert real(out) == [0.9, 1.8]
@assert imag(out) == [0.9, 1.8]
@test real(out) == [0.9, 1.8]
@test imag(out) == [0.9, 1.8]
# ignoreIm = true
pnp_reg = PnPRegularization(
0.1; model=x -> zeros(eltype(x), size(x)), shape=[2],
input_transform=RegularizedLeastSquares.IdentityTransform,
ignoreIm=true
)
out = prox!(pnp_reg, [1.0 + 1.0im, 2.0 + 2.0im], 0.1)
@assert real(out) == [0.9, 1.8]
@assert imag(out) == [1.0, 2.0]
@test real(out) == [0.9, 1.8]
@test imag(out) == [1.0, 2.0]
end


@testset "PnP Prox λ clipping" begin
pnp_reg = PnPRegularization(0.1; model=x -> zeros(eltype(x), size(x)), shape=[2], input_transform=RegularizedLeastSquares.IdentityTransform)
out = @test_warn "$(typeof(pnp_reg)) was given λ with value 1.5. Valid range is [0, 1]. λ changed to temp" prox!(pnp_reg, [1.0, 2.0], 1.5)
@assert out == [0.0, 0.0]
@test out == [0.0, 0.0]
out = @test_warn "$(typeof(pnp_reg)) was given λ with value -1.5. Valid range is [0, 1]. λ changed to temp" prox!(pnp_reg, [1.0, 2.0], -1.5)
@assert out == [1.0, 2.0]
@test out == [1.0, 2.0]
end

0 comments on commit 119a463

Please sign in to comment.