Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrote corkendall (issue 634) #647

Merged
merged 23 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 153 additions & 82 deletions src/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,71 +28,76 @@ corspearman(X::RealMatrix) = (Z = mapslices(tiedrank, X, dims=1); cor(Z, Z))


#######################################
#
#
# Kendall correlation
#
#
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ugh, that was VSCode auto formatter. Fixed now.

#######################################

# Knight JASA (1966)

function corkendall!(x::RealVector, y::RealVector)
# Knight, William R. “A Computer Method for Calculating Kendall's Tau with Ungrouped Data.”
# Journal of the American Statistical Association, vol. 61, no. 314, 1966, pp. 436–439.
# JSTOR, www.jstor.org/stable/2282833. Accessed 15 Jan. 2021.
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
function corkendall!(x::RealVector, y::RealVector, permx=sortperm(x))
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
if any(isnan, x) || any(isnan, y) return NaN end
n = length(x)
if n != length(y) error("Vectors must have same length") end

# Initial sorting
pm = sortperm(y)
x[:] = x[pm]
y[:] = y[pm]
pm[:] = sortperm(x)
x[:] = x[pm]

# Counting ties in x and y
iT = 1
nT = 0
iU = 1
nU = 0
for i = 2:n
if x[i] == x[i-1]
iT += 1
else
nT += iT*(iT - 1)
iT = 1
end
if y[i] == y[i-1]
iU += 1
else
nU += iU*(iU - 1)
iU = 1
x[:] = x[permx]
y[:] = y[permx]
PGS62 marked this conversation as resolved.
Show resolved Hide resolved

npairs = float(n) * (n - 1) / 2
# ntiesx, ntiesy, ndoubleties are floats to avoid overflows on 32bit
ntiesx, ntiesy, ndoubleties, k, nswaps = 0.0, 0.0, 0.0, 0, 0
PGS62 marked this conversation as resolved.
Show resolved Hide resolved

@inbounds for i = 2:n
if x[i - 1] == x[i]
k += 1
elseif k > 0
# Sort the corresponding chunk of y, so the rows of hcat(x,y) are
# sorted first on x, then (where x values are tied) on y. Hence
# double ties can be counted by calling countties.
sort!(view(y, (i - k - 1):(i - 1)))
ntiesx += float(k) * (k + 1) / 2
ndoubleties += countties(y, i - k - 1, i - 1)
k = 0
end
end
if iT > 1 nT += iT*(iT - 1) end
nT = div(nT,2)
if iU > 1 nU += iU*(iU - 1) end
nU = div(nU,2)

# Sort y after x
y[:] = y[pm]

# Calculate double ties
iV = 1
nV = 0
jV = 1
for i = 2:n
if x[i] == x[i-1] && y[i] == y[i-1]
iV += 1
else
nV += iV*(iV - 1)
iV = 1
end
if k > 0
sort!(view(y, ((n - k):n)))
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
ntiesx += float(k) * (k + 1) / 2
ndoubleties += countties(y, n - k, n)
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
end
if iV > 1 nV += iV*(iV - 1) end
nV = div(nV,2)

nD = div(n*(n - 1),2)
return (nD - nT - nU + nV - 2swaps!(y)) / (sqrt(nD - nT) * sqrt(nD - nU))
nswaps = msort!(y, 1, n)
ntiesy = countties(y, 1, n)

(npairs + ndoubleties - ntiesx - ntiesy - 2 * nswaps) /
sqrt((npairs - ntiesx) * (npairs - ntiesy))
end

"""
countties(x::RealVector,lo::Int64,hi::Int64)
PGS62 marked this conversation as resolved.
Show resolved Hide resolved

Assumes `x` is sorted. Returns the number of ties within `x[lo:hi]`.
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
"""
function countties(x::AbstractVector, lo::Integer, hi::Integer)
# avoid overflows on 32 bit by using floats
thistiecount, result = 0.0, 0.0
(lo < 1 || hi > length(x)) && error("Bounds error")
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
@inbounds for i = (lo + 1):hi
if x[i] == x[i - 1]
thistiecount += 1.0
elseif thistiecount > 0
result += thistiecount * (thistiecount + 1) / 2
thistiecount = 0.0
end
end

if thistiecount > 0
result += thistiecount * (thistiecount + 1) / 2
end
result
end

"""
corkendall(x, y=x)
Expand All @@ -102,52 +107,118 @@ matrices or vectors.
"""
corkendall(x::RealVector, y::RealVector) = corkendall!(float(copy(x)), float(copy(y)))

corkendall(X::RealMatrix, y::RealVector) = Float64[corkendall!(float(X[:,i]), float(copy(y))) for i in 1:size(X, 2)]

corkendall(x::RealVector, Y::RealMatrix) = (n = size(Y,2); reshape(Float64[corkendall!(float(copy(x)), float(Y[:,i])) for i in 1:n], 1, n))
corkendall(X::RealMatrix, y::RealVector) = (permy = sortperm(y);Float64[corkendall!(float(copy(y)), float(X[:,i]), permy) for i in 1:size(X, 2)])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is correct and faster?

Suggested change
corkendall(X::RealMatrix, y::RealVector) = (permy = sortperm(y);Float64[corkendall!(float(copy(y)), float(X[:,i]), permy) for i in 1:size(X, 2)])
function corkendall(X::RealMatrix, y::RealVector)
permy = sortperm(y)
y′ = float(copy(y))
return Float64[corkendall!(y′, float(X[:,i]), permy) for i in 1:size(X, 2)])
end

BTW, do you think we really need to call float here? This will make an additional copy for integer vectors

Same suggestion for below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are correct that the float calls are not necessary. Also I don't think that call to Float64 is necessary either. Both were inherited from the existing code. I will get rid of them, assuming that doesn't cause test failures.


corkendall(X::RealMatrix, Y::RealMatrix) = Float64[corkendall!(float(X[:,i]), float(Y[:,j])) for i in 1:size(X, 2), j in 1:size(Y, 2)]
corkendall(x::RealVector, Y::RealMatrix) = (n = size(Y, 2); permx = sortperm(x); reshape(Float64[corkendall!(float(copy(x)), float(Y[:,i]), permx) for i in 1:n], 1, n))

function corkendall(X::RealMatrix)
n = size(X, 2)
C = Matrix{eltype(X)}(I, n, n)
for j = 2:n
for i = 1:j-1
C[i,j] = corkendall!(X[:,i],X[:,j])
C[j,i] = C[i,j]
C = ones(float(eltype(X)), n, n)# avoids dependency on LinearAlgebra
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
@inbounds for j = 2:n
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
permx = sortperm(X[:,j])
for i = 1:j - 1
C[j,i] = corkendall!(X[:,j], X[:,i], permx)
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
C[i,j] = C[j,i]
end
end
return C
end

function corkendall(X::RealMatrix, Y::RealMatrix)
nr = size(X, 2)
nc = size(Y, 2)
C = zeros(float(eltype(X)), nr, nc)
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
@inbounds for j = 1:nr
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
permx = sortperm(X[:,j])
for i = 1:nc
C[j,i] = corkendall!(X[:,j], Y[:,i], permx)
end
end
return C
end

# Auxilliary functions for Kendall's rank correlation

function swaps!(x::RealVector)
n = length(x)
if n == 1 return 0 end
n2 = div(n, 2)
xl = view(x, 1:n2)
xr = view(x, n2+1:n)
nsl = swaps!(xl)
nsr = swaps!(xr)
sort!(xl)
sort!(xr)
return nsl + nsr + mswaps(xl,xr)
end
# Tests appear to show that a value of 64 is optimal,
# but note that the equivalent constant in base/sort.jl is 20.
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
const SMALL_THRESHOLD = 64
PGS62 marked this conversation as resolved.
Show resolved Hide resolved

function mswaps(x::RealVector, y::RealVector)
i = 1
j = 1
nSwaps = 0
n = length(x)
while i <= n && j <= length(y)
if y[j] < x[i]
nSwaps += n - i + 1
# Copy was from https://github.com/JuliaLang/julia/commit/28330a2fef4d9d149ba0fd3ffa06347b50067647 dated 20 Sep 2020
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copy was from https://github.com/JuliaLang/julia/commit/28330a2fef4d9d149ba0fd3ffa06347b50067647 dated 20 Sep 2020
# Implementation copied from Julia Base
# (commit 28330a2fef4d9d149ba0fd3ffa06347b50067647, dated 20 Sep 2020)

"""
msort!(v::AbstractVector, lo::Integer, hi::Integer, t=similar(v, 0))
PGS62 marked this conversation as resolved.
Show resolved Hide resolved

Mutates `v` by sorting elements `x[lo:hi]` using the merge sort algorithm.
This method is a copy-paste-edit of sort! in base/sort.jl (the method specialised on MergeSortAlg),
but amended to return the bubblesort distance.
"""
function msort!(v::AbstractVector, lo::Integer, hi::Integer, t=similar(v, 0))
# avoid overflow errors (if length(v)> 2^16) on 32 bit by using float
nswaps = 0.0
@inbounds if lo < hi
hi - lo <= SMALL_THRESHOLD && return isort!(v, lo, hi)

m = midpoint(lo, hi)
(length(t) < m - lo + 1) && resize!(t, m - lo + 1)

nswaps = msort!(v, lo, m, t)
nswaps += msort!(v, m + 1, hi, t)

i, j = 1, lo
while j <= m
t[i] = v[j]
i += 1
j += 1
else
end

i, k = 1, lo
while k < j <= hi
if v[j] < t[i]
v[k] = v[j]
j += 1
nswaps += m - lo + 1 - (i - 1)
else
v[k] = t[i]
i += 1
end
k += 1
end
while k < j
v[k] = t[i]
k += 1
i += 1
end
end
return nSwaps
return nswaps
end

# This function is also copied from base/sort.jl
midpoint(lo::T, hi::T) where T <: Integer = lo + ((hi - lo) >>> 0x01)
midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)
PGS62 marked this conversation as resolved.
Show resolved Hide resolved

# Copy was from https://github.com/JuliaLang/julia/commit/28330a2fef4d9d149ba0fd3ffa06347b50067647 dated 20 Sep 2020
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
"""
isort!(v::AbstractVector, lo::Integer, hi::Integer)

Mutates `v` by sorting elements `x[lo:hi]` using the insertion sort algorithm.
This method is a copy-paste-edit of sort! in base/sort.jl (the method specialised on InsertionSortAlg),
amended to return the bubblesort distance.
"""
function isort!(v::AbstractVector, lo::Integer, hi::Integer)
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
if lo == hi return 0.0 end
nswaps = 0.0
@inbounds for i = lo + 1:hi
j = i
x = v[i]
while j > lo
if x < v[j - 1]
nswaps += 1.0
v[j] = v[j - 1]
j -= 1
continue
end
break
end
v[j] = x
end
return nswaps
end
58 changes: 52 additions & 6 deletions test/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,63 @@ c22 = corspearman(x2, x2)

# corkendall

@test corkendall(x1, y) ≈ -0.105409255338946
@test corkendall(x2, y) ≈ -0.117851130197758
@test_throws ErrorException("Vectors must have same length") corkendall([1,2,3,4], [1,2,3])
@test isnan(corkendall([1,2], [3,NaN]))
@test isnan(corkendall([1,1,1], [1,2,3]))
PGS62 marked this conversation as resolved.
Show resolved Hide resolved

@test corkendall(X, y) ≈ [-0.105409255338946, -0.117851130197758]
@test corkendall(y, X) ≈ [-0.105409255338946 -0.117851130197758]
@test corkendall(x1, y) == -1 / sqrt(90)
@test corkendall(x2, y) == -1 / sqrt(72)
@test corkendall(X, y) == [-1 / sqrt(90), -1 / sqrt(72)]
@test corkendall(y, X) == [-1 / sqrt(90) -1 / sqrt(72)]

n = 100_000
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a lower value if possible to avoid slowing down CI. Please also add comments to shortly explain what's the purpose of each series of tests when it's not obvious (e.g. overflow, small threshold...).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've used a slightly lower value n = 78_000 then a bit later in the code redefined n to 100. I think test/rankcorr.jl executes in under a second (vs 7 seconds before).

@test corkendall(repeat(x1, n), repeat(y, n)) ≈ -1 / sqrt(90)
@test corkendall(repeat(x2, n), repeat(y, n)) ≈ -1 / sqrt(72)
@test corkendall(repeat(X, n), repeat(y, n)) ≈ [-1 / sqrt(90), -1 / sqrt(72)]
@test corkendall(repeat(y, n), repeat(X, n)) ≈ [-1 / sqrt(90) -1 / sqrt(72)]

c11 = corkendall(x1, x1)
c12 = corkendall(x1, x2)
c22 = corkendall(x2, x2)

@test c11 ≈ 1.0
@test c22 ≈ 1.0
@test c11 == 1.0
@test c22 == 1.0
@test c12 == 3 / sqrt(20)

@test corkendall(X, X) ≈ [c11 c12; c12 c22]
@test corkendall(X) ≈ [c11 c12; c12 c22]

@test corkendall(repeat(X, n), repeat(X, n)) ≈ [c11 c12; c12 c22]
@test corkendall(repeat(X, n)) ≈ [c11 c12; c12 c22]

@test corkendall(collect(1:n), collect(1:n)) == 1.0
@test corkendall(collect(1:n), reverse(collect(1:n))) == -1.0
@test isnan(corkendall(repeat([1], n), collect(1:n)))
PGS62 marked this conversation as resolved.
Show resolved Hide resolved

@test corkendall(repeat([0,1,1,0], n), repeat([1,0,1,0], n)) == 0.0

z = [1 1 1;
1 1 2;
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
1 2 2;
1 2 2;
1 2 1;
2 1 2;
1 1 2;
2 2 2]

@test corkendall(z) == [1 0 1 / 3; 0 1 0;1 / 3 0 1]
PGS62 marked this conversation as resolved.
Show resolved Hide resolved
@test corkendall(z, z) == [1 0 1 / 3; 0 1 0;1 / 3 0 1]
@test corkendall(z[:,1], z) == [1 0 1 / 3]
@test corkendall(z, z[:,1]) == [1;0;1 / 3]

z = float(z)
@test corkendall(z) == [1 0 1 / 3; 0 1 0;1 / 3 0 1]
@test corkendall(z, z) == [1 0 1 / 3; 0 1 0;1 / 3 0 1]
@test corkendall(z[:,1], z) == [1 0 1 / 3]
@test corkendall(z, z[:,1]) == [1;0;1 / 3]

w = repeat(z, n)
@test corkendall(w) == [1 0 1 / 3; 0 1 0;1 / 3 0 1]
@test corkendall(w, w) == [1 0 1 / 3; 0 1 0;1 / 3 0 1]
@test corkendall(w[:,1], w) == [1 0 1 / 3]
@test corkendall(w, w[:,1]) == [1;0;1 / 3]