Skip to content

Commit

Permalink
Merge pull request #161 from devmotion/negative
Browse files Browse the repository at this point in the history
Ensure non-negativity of pairwise computations
  • Loading branch information
dkarrasch authored Apr 26, 2020
2 parents f21b517 + bf7757d commit 31bff8d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 28 deletions.
71 changes: 43 additions & 28 deletions src/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -616,15 +616,15 @@ function _pairwise!(r::AbstractMatrix, dist::SqEuclidean,
for j = 1:size(r, 2)
sb = sb2[j]
@simd for i = 1:size(r, 1)
@inbounds r[i, j] = sa2[i] + sb - 2 * r[i, j]
@inbounds r[i, j] = max(sa2[i] + sb - 2 * r[i, j], 0)
end
end
else
for j = 1:size(r, 2)
sb = sb2[j]
for i = 1:size(r, 1)
@inbounds selfterms = sa2[i] + sb
@inbounds v = selfterms - 2 * r[i, j]
@inbounds v = max(selfterms - 2 * r[i, j], 0)
if v < threshT * selfterms
# The distance is likely to be inaccurate, recalculate at higher prec.
# This reflects the following:
Expand Down Expand Up @@ -655,12 +655,12 @@ function _pairwise!(r::AbstractMatrix, dist::SqEuclidean, a::AbstractMatrix)
sa2j = sa2[j]
if threshT <= 0
@simd for i = (j + 1):n
r[i, j] = sa2[i] + sa2j - 2 * r[i, j]
r[i, j] = max(sa2[i] + sa2j - 2 * r[i, j], 0)
end
else
for i = (j + 1):n
selfterms = sa2[i] + sa2j
v = selfterms - 2 * r[i, j]
v = max(selfterms - 2 * r[i, j], 0)
if v < threshT * selfterms
v = zero(v)
for k = 1:size(a, 1)
Expand All @@ -685,7 +685,7 @@ function _pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean,
mul!(r, a', b .* w)
for j = 1:nb
@simd for i = 1:na
@inbounds r[i, j] = sa2[i] + sb2[j] - 2 * r[i, j]
@inbounds r[i, j] = max(sa2[i] + sb2[j] - 2 * r[i, j], 0)
end
end
r
Expand All @@ -704,7 +704,7 @@ function _pairwise!(r::AbstractMatrix, dist::WeightedSqEuclidean,
end
@inbounds r[j, j] = 0
@simd for i = (j + 1):n
@inbounds r[i, j] = sa2[i] + sa2[j] - 2 * r[i, j]
@inbounds r[i, j] = max(sa2[i] + sa2[j] - 2 * r[i, j], 0)
end
end
r
Expand All @@ -718,22 +718,31 @@ function _pairwise!(r::AbstractMatrix, dist::Euclidean,
sa2 = sumsq_percol(a)
sb2 = sumsq_percol(b)
threshT = convert(eltype(r), dist.thresh)
@inbounds for j = 1:nb
sb = sb2[j]
for i = 1:na
selfterms = sa2[i] + sb
v = selfterms - 2 * r[i, j]
if v < threshT * selfterms
# The distance is likely to be inaccurate, recalculate directly
# This reflects the following:
# while sqrt(x+ϵ) ≈ sqrt(x) + O(ϵ/sqrt(x)) when |x| >> ϵ,
# sqrt(x+ϵ) ≈ O(sqrt(ϵ)) otherwise.
v = zero(v)
for k = 1:m
v += (a[k, i] - b[k, j])^2
if threshT <= 0
for j = 1:nb
sb = sb2[j]
@simd for i = 1:na
@inbounds r[i, j] = sqrt(max(sa2[i] + sb - 2 * r[i, j], 0))
end
end
else
@inbounds for j = 1:nb
sb = sb2[j]
for i = 1:na
selfterms = sa2[i] + sb
v = max(selfterms - 2 * r[i, j], 0)
if v < threshT * selfterms
# The distance is likely to be inaccurate, recalculate directly
# This reflects the following:
# while sqrt(x+ϵ) ≈ sqrt(x) + O(ϵ/sqrt(x)) when |x| >> ϵ,
# sqrt(x+ϵ) ≈ O(sqrt(ϵ)) otherwise.
v = zero(v)
for k = 1:m
v += (a[k, i] - b[k, j])^2
end
end
r[i, j] = sqrt(v)
end
r[i, j] = sqrt(v)
end
end
r
Expand All @@ -750,16 +759,22 @@ function _pairwise!(r::AbstractMatrix, dist::Euclidean, a::AbstractMatrix)
end
r[j, j] = 0
sa2j = sa2[j]
for i = (j + 1):n
selfterms = sa2[i] + sa2j
v = selfterms - 2 * r[i, j]
if v < threshT * selfterms
v = zero(v)
for k = 1:m
v += (a[k, i] - a[k, j])^2
if threshT <= 0
@simd for i = (j + 1):n
r[i, j] = sqrt(max(sa2[i] + sa2j - 2 * r[i, j], 0))
end
else
for i = (j + 1):n
selfterms = sa2[i] + sa2j
v = max(selfterms - 2 * r[i, j], 0)
if v < threshT * selfterms
v = zero(v)
for k = 1:m
v += (a[k, i] - a[k, j])^2
end
end
r[i, j] = sqrt(v)
end
r[i, j] = sqrt(v)
end
end
r
Expand Down
11 changes: 11 additions & 0 deletions test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,17 @@ end
@test pd[2, 2] == 0
end

@testset "Euclidean non-negativity" begin
X = [0.3 0.3 + eps()]

@test all(x -> x >= 0, pairwise(SqEuclidean(), X; dims = 2))
@test all(x -> x >= 0, pairwise(SqEuclidean(), X, X; dims = 2))
@test all(x -> x >= 0, pairwise(Euclidean(), X; dims = 2))
@test all(x -> x >= 0, pairwise(Euclidean(), X, X; dims = 2))
@test all(x -> x >= 0, pairwise(WeightedSqEuclidean([1.0]), X; dims = 2))
@test all(x -> x >= 0, pairwise(WeightedSqEuclidean([1.0]), X, X; dims = 2))
end

@testset "Bregman Divergence" begin
# Some basic tests.
@test_throws ArgumentError bregman(x -> x, x -> 2*x, [1, 2, 3], [1, 2, 3])
Expand Down

0 comments on commit 31bff8d

Please sign in to comment.