diff --git a/Project.toml b/Project.toml index 48d3ec1..e8670c8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Distances" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.1" +version = "0.10.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/generic.jl b/src/generic.jl index e3460fc..66f01c8 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -38,6 +38,7 @@ result_type(f, a::Type, b::Type) = typeof(f(oneunit(a), oneunit(b))) # don't req _eltype(a) = __eltype(Base.IteratorEltype(a), a) _eltype(::Type{T}) where {T} = eltype(T) === T ? T : _eltype(eltype(T)) +_eltype(::Type{Union{Missing, T}}) where {T} = Union{Missing, T} __eltype(::Base.HasEltype, a) = _eltype(eltype(a)) __eltype(::Base.EltypeUnknown, a) = _eltype(typeof(first(a))) diff --git a/src/haversine.jl b/src/haversine.jl index 306f61e..31f9562 100644 --- a/src/haversine.jl +++ b/src/haversine.jl @@ -60,3 +60,5 @@ function (dist::SphericalAngle)(x, y) end spherical_angle(x, y) = SphericalAngle()(x, y) + +result_type(::Union{Haversine, SphericalAngle}, ::Type, ::Type) = Float64 diff --git a/src/metrics.jl b/src/metrics.jl index 190d619..693c76f 100644 --- a/src/metrics.jl +++ b/src/metrics.jl @@ -315,8 +315,8 @@ function _evaluate(dist::UnionMetrics, a::Number, b::Number, p) end eval_start(d::UnionMetrics, a, b) = _eval_start(d, _eltype(a), _eltype(b)) -_eval_start(d, ::Type{Ta}, ::Type{Tb}) where {Ta,Tb} = - _eval_start(d, _eltype(Ta), _eltype(Tb), parameters(d)) +_eval_start(d::UnionMetrics, ::Type{Ta}, ::Type{Tb}) where {Ta,Tb} = + _eval_start(d, Ta, Tb, parameters(d)) _eval_start(d::UnionMetrics, ::Type{Ta}, ::Type{Tb}, ::Nothing) where {Ta,Tb} = zero(typeof(eval_op(d, oneunit(Ta), oneunit(Tb)))) _eval_start(d::UnionMetrics, ::Type{Ta}, ::Type{Tb}, p) where {Ta,Tb} = diff --git a/test/test_dists.jl b/test/test_dists.jl index 12a7e21..a8a692a 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -182,6 +182,7 @@ end ([4.0, 5.0, 6.0, 7.0], [3. 8.; 9. 1.0])) x, y = T.(_x), T.(_y) for (x, y) in ((x, y), + (convert(Array{Union{Missing, T}}, x), convert(Array{Union{Missing, T}}, y)), ((Iterators.take(x, 4), Iterators.take(y, 4))), # iterator (((x[i] for i in 1:length(x)), (y[i] for i in 1:length(y)))), # generator ) @@ -640,6 +641,13 @@ end test_pairwise(SqMahalanobis(Q), X, Y, T) test_pairwise(Mahalanobis(Q), X, Y, T) + + m, nx, ny = 2, 8, 6 + + X = rand(T, m, nx) + Y = rand(T, m, ny) + test_pairwise(Haversine(), X, Y, T) + test_pairwise(SphericalAngle(), X, Y, T) end function test_scalar_pairwise(dist, x, y, T)