From a74086d0e5e221de8df919a54c4d434436bd372e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Thu, 2 Jan 2020 16:12:03 -0300 Subject: [PATCH] Make sure gaussian_gramian is generic --- src/utils.jl | 5 +++-- test/basic.jl | 29 ++++++++++++++++++----------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 9893a3c..6aa2e4f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,7 +2,8 @@ # Licensed under the MIT License. See LICENSE in the project root. # ------------------------------------------------------------------ +euclidsq(x, y) = sum((x[i] - y[i])^2 for i in eachindex(x)) + function gaussian_gramian(xs, ys; σ=1) - euclidsq(x,y) = sum((x .- y).^2) - [exp(-euclidsq(x,y)/2σ^2) for x in xs, y in ys] + [exp(-euclidsq(x, y) / 2σ^2) for x in xs, y in ys] end diff --git a/test/basic.jl b/test/basic.jl index 7415313..fe09b24 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1,18 +1,25 @@ @testset "Basic" begin - for (d_nu, d_de) in [pair₁, pair₂] - Random.seed!(123) - x_nu, x_de = rand(d_nu, 100), rand(d_de, 200) + @testset "Gramian" begin + x_nu, x_de = [rand(2) for i=1:100], [rand(2) for i=1:200] + G = DensityRatioEstimation.gaussian_gramian(x_nu, x_de, σ=1.0) + @test size(G) == (length(x_nu), length(x_de)) + @test all(G .> 0) - @testset "Gramian" begin - G = DensityRatioEstimation.gaussian_gramian(x_nu, x_de, σ=1.0) - @test size(G) == (length(x_nu), length(x_de)) - @test all(G .> 0) + G = DensityRatioEstimation.gaussian_gramian(x_nu, x_nu, σ=2.0) + @test issymmetric(G) + @test all(G .> 0) - G = DensityRatioEstimation.gaussian_gramian(x_nu, x_nu, σ=2.0) - @test issymmetric(G) - @test all(G .> 0) - end + # features can be any indexable + x_nu = [(a=1.,b=2.),(a=3.,b=4.)] + x_de = [(a=1.,b=2.),(a=3.,b=4.),(a=5.,b=6.)] + G = DensityRatioEstimation.gaussian_gramian(x_nu, x_de) + @test size(G) == (2, 3) + @test all(G .> 0) + end + for (d_nu, d_de) in [pair₁, pair₂] + Random.seed!(123) + x_nu, x_de = rand(d_nu, 100), rand(d_de, 200) @testset "$dre -- $optlib" for (dre, optlib) in [(KMM(), JuMPLib), (KLIEP(), OptimLib), (KLIEP(), ConvexLib)]