Skip to content

Commit

Permalink
Handle SparseArrays.jl dep as a Pkg extension (#251)
Browse files Browse the repository at this point in the history
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
dkarrasch and devmotion authored Jul 22, 2023
1 parent 11f744e commit b8cddf2
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 79 deletions.
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
name = "Distances"
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.10.8"
version = "0.10.9"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[weakdeps]
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
DistancesSparseArraysExt = "SparseArrays"

[compat]
StatsAPI = "1"
julia = "1"

[extras]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["OffsetArrays", "Random", "Test", "Unitful"]
test = ["OffsetArrays", "Random", "SparseArrays", "Test", "Unitful"]
83 changes: 83 additions & 0 deletions ext/DistancesSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
module DistancesSparseArraysExt

using Distances
import Distances: _evaluate
using Distances: UnionMetrics, result_type, eval_start, eval_op, eval_reduce, eval_end
using SparseArrays: SparseVectorUnion, nonzeroinds, nonzeros, nnz
using SparseArrays: SparseVectorUnion

eval_op_a(d, ai, b) = eval_op(d, ai, zero(eltype(b)))
eval_op_b(d, bi, a) = eval_op(d, zero(eltype(a)), bi)

# It is assumed that eval_reduce(d, s, eval_op(d, zero(eltype(a)), zero(eltype(b)))) == s
# This justifies ignoring all terms where both inputs are zero.
Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::SparseVectorUnion, b::SparseVectorUnion, ::Nothing)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
if length(a) == 0
return zero(result_type(d, a, b))
end
anzind = nonzeroinds(a)
bnzind = nonzeroinds(b)
anzval = nonzeros(a)
bnzval = nonzeros(b)
ma = nnz(a)
mb = nnz(b)
ia = 1; ib = 1
s = eval_start(d, a, b)
@inbounds while ia <= ma && ib <= mb
ja = anzind[ia]
jb = bnzind[ib]
if ja == jb
v = eval_op(d, anzval[ia], bnzval[ib])
ia += 1; ib += 1
elseif ja < jb
v = eval_op_a(d, anzval[ia], b)
ia += 1
else
v = eval_op_b(d, bnzval[ib], a)
ib += 1
end
s = eval_reduce(d, s, v)
end
@inbounds while ia <= ma
v = eval_op_a(d, anzval[ia], b)
s = eval_reduce(d, s, v)
ia += 1
end
@inbounds while ib <= mb
v = eval_op_b(d, bnzval[ib], a)
s = eval_reduce(d, s, v)
ib += 1
end
return eval_end(d, s)
end

@inline function _bhattacharyya_coeff(a::SparseVectorUnion, b::SparseVectorUnion)
anzind = nonzeroinds(a)
bnzind = nonzeroinds(b)
anzval = nonzeros(a)
bnzval = nonzeros(b)
ma = nnz(a)
mb = nnz(b)

ia = 1; ib = 1
s = zero(typeof(sqrt(oneunit(eltype(a))*oneunit(eltype(b)))))
@inbounds while ia <= ma && ib <= mb
ja = anzind[ia]
jb = bnzind[ib]
if ja == jb
s += sqrt(anzval[ia] * bnzval[ib])
ia += 1; ib += 1
elseif ja < jb
ia += 1
else
ib += 1
end
end
# efficient method for sum for SparseVectorView is missing
return s, sum(anzval), sum(bnzval)
end

end # module DistancesSparseArraysExt
7 changes: 5 additions & 2 deletions src/Distances.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module Distances

using LinearAlgebra
using Statistics
using SparseArrays: SparseVectorUnion, nonzeroinds, nonzeros, nnz
using Statistics: mean
import StatsAPI: pairwise, pairwise!

export
Expand Down Expand Up @@ -120,4 +119,8 @@ include("bregman.jl")

include("deprecated.jl")

@static if !isdefined(Base, :get_extension)
include("../ext/DistancesSparseArraysExt.jl")
end

end # module end
26 changes: 0 additions & 26 deletions src/bhattacharyya.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,6 @@ end
return sqab, asum, bsum
end

@inline function _bhattacharyya_coeff(a::SparseVectorUnion, b::SparseVectorUnion)
anzind = nonzeroinds(a)
bnzind = nonzeroinds(b)
anzval = nonzeros(a)
bnzval = nonzeros(b)
ma = nnz(a)
mb = nnz(b)

ia = 1; ib = 1
s = zero(typeof(sqrt(oneunit(eltype(a))*oneunit(eltype(b)))))
@inbounds while ia <= ma && ib <= mb
ja = anzind[ia]
jb = bnzind[ib]
if ja == jb
s += sqrt(anzval[ia] * bnzval[ib])
ia += 1; ib += 1
elseif ja < jb
ia += 1
else
ib += 1
end
end
# efficient method for sum for SparseVectorView is missing
return s, sum(anzval), sum(bnzval)
end

# Faster pair- and column-wise versions TBD...


Expand Down
49 changes: 0 additions & 49 deletions src/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,55 +308,6 @@ Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::AbstractArray, b
end
end

eval_op_a(d, ai, b) = eval_op(d, ai, zero(eltype(b)))
eval_op_b(d, bi, a) = eval_op(d, zero(eltype(a)), bi)

# It is assumed that eval_reduce(d, s, eval_op(d, zero(eltype(a)), zero(eltype(b)))) == s
# This justifies ignoring all terms where both inputs are zero.
Base.@propagate_inbounds function _evaluate(d::UnionMetrics, a::SparseVectorUnion, b::SparseVectorUnion, ::Nothing)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
if length(a) == 0
return zero(result_type(d, a, b))
end
anzind = nonzeroinds(a)
bnzind = nonzeroinds(b)
anzval = nonzeros(a)
bnzval = nonzeros(b)
ma = nnz(a)
mb = nnz(b)
ia = 1; ib = 1
s = eval_start(d, a, b)
@inbounds while ia <= ma && ib <= mb
ja = anzind[ia]
jb = bnzind[ib]
if ja == jb
v = eval_op(d, anzval[ia], bnzval[ib])
ia += 1; ib += 1
elseif ja < jb
v = eval_op_a(d, anzval[ia], b)
ia += 1
else
v = eval_op_b(d, bnzval[ib], a)
ib += 1
end
s = eval_reduce(d, s, v)
end
@inbounds while ia <= ma
v = eval_op_a(d, anzval[ia], b)
s = eval_reduce(d, s, v)
ia += 1
end
@inbounds while ib <= mb
v = eval_op_b(d, bnzval[ib], a)
s = eval_reduce(d, s, v)
ib += 1
end
return eval_end(d, s)
end


_evaluate(dist::UnionMetrics, a::Number, b::Number, ::Nothing) = eval_end(dist, eval_op(dist, a, b))
function _evaluate(dist::UnionMetrics, a::Number, b::Number, p)
length(p) != 1 && throw(DimensionMismatch("inputs are scalars but parameters have length $(length(p))."))
Expand Down

2 comments on commit b8cddf2

@dkarrasch
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88066

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.9 -m "<description of version>" b8cddf2a2a62458a8b5180a228649b874beb0800
git push origin v0.10.9

Please sign in to comment.