-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrote corkendall (issue 634) #647
Changes from 5 commits
47a4457
e58ebe7
2fff9f2
f4ff6da
c8672a3
8ed9553
d17295f
9e80325
dbb3298
c31d7ea
56cb219
23d5690
f9114c1
913843d
3f5132e
9628be3
7b74349
180ff30
0ccd3be
746eaf6
beb289a
704fcce
bd2cf5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -28,71 +28,76 @@ corspearman(X::RealMatrix) = (Z = mapslices(tiedrank, X, dims=1); cor(Z, Z)) | |||||||||||||
|
||||||||||||||
|
||||||||||||||
####################################### | ||||||||||||||
# | ||||||||||||||
# | ||||||||||||||
# Kendall correlation | ||||||||||||||
# | ||||||||||||||
# | ||||||||||||||
####################################### | ||||||||||||||
|
||||||||||||||
# Knight JASA (1966) | ||||||||||||||
|
||||||||||||||
function corkendall!(x::RealVector, y::RealVector) | ||||||||||||||
# Knight, William R. “A Computer Method for Calculating Kendall's Tau with Ungrouped Data.” | ||||||||||||||
# Journal of the American Statistical Association, vol. 61, no. 314, 1966, pp. 436–439. | ||||||||||||||
# JSTOR, www.jstor.org/stable/2282833. Accessed 15 Jan. 2021. | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
function corkendall!(x::RealVector, y::RealVector, permx=sortperm(x)) | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
if any(isnan, x) || any(isnan, y) return NaN end | ||||||||||||||
n = length(x) | ||||||||||||||
if n != length(y) error("Vectors must have same length") end | ||||||||||||||
|
||||||||||||||
# Initial sorting | ||||||||||||||
pm = sortperm(y) | ||||||||||||||
x[:] = x[pm] | ||||||||||||||
y[:] = y[pm] | ||||||||||||||
pm[:] = sortperm(x) | ||||||||||||||
x[:] = x[pm] | ||||||||||||||
|
||||||||||||||
# Counting ties in x and y | ||||||||||||||
iT = 1 | ||||||||||||||
nT = 0 | ||||||||||||||
iU = 1 | ||||||||||||||
nU = 0 | ||||||||||||||
for i = 2:n | ||||||||||||||
if x[i] == x[i-1] | ||||||||||||||
iT += 1 | ||||||||||||||
else | ||||||||||||||
nT += iT*(iT - 1) | ||||||||||||||
iT = 1 | ||||||||||||||
end | ||||||||||||||
if y[i] == y[i-1] | ||||||||||||||
iU += 1 | ||||||||||||||
else | ||||||||||||||
nU += iU*(iU - 1) | ||||||||||||||
iU = 1 | ||||||||||||||
x[:] = x[permx] | ||||||||||||||
y[:] = y[permx] | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
|
||||||||||||||
npairs = float(n) * (n - 1) / 2 | ||||||||||||||
# ntiesx, ntiesy, ndoubleties are floats to avoid overflows on 32bit | ||||||||||||||
ntiesx, ntiesy, ndoubleties, k, nswaps = 0.0, 0.0, 0.0, 0, 0 | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
|
||||||||||||||
@inbounds for i = 2:n | ||||||||||||||
if x[i - 1] == x[i] | ||||||||||||||
k += 1 | ||||||||||||||
elseif k > 0 | ||||||||||||||
# Sort the corresponding chunk of y, so the rows of hcat(x,y) are | ||||||||||||||
# sorted first on x, then (where x values are tied) on y. Hence | ||||||||||||||
# double ties can be counted by calling countties. | ||||||||||||||
sort!(view(y, (i - k - 1):(i - 1))) | ||||||||||||||
ntiesx += float(k) * (k + 1) / 2 | ||||||||||||||
ndoubleties += countties(y, i - k - 1, i - 1) | ||||||||||||||
k = 0 | ||||||||||||||
end | ||||||||||||||
end | ||||||||||||||
if iT > 1 nT += iT*(iT - 1) end | ||||||||||||||
nT = div(nT,2) | ||||||||||||||
if iU > 1 nU += iU*(iU - 1) end | ||||||||||||||
nU = div(nU,2) | ||||||||||||||
|
||||||||||||||
# Sort y after x | ||||||||||||||
y[:] = y[pm] | ||||||||||||||
|
||||||||||||||
# Calculate double ties | ||||||||||||||
iV = 1 | ||||||||||||||
nV = 0 | ||||||||||||||
jV = 1 | ||||||||||||||
for i = 2:n | ||||||||||||||
if x[i] == x[i-1] && y[i] == y[i-1] | ||||||||||||||
iV += 1 | ||||||||||||||
else | ||||||||||||||
nV += iV*(iV - 1) | ||||||||||||||
iV = 1 | ||||||||||||||
end | ||||||||||||||
if k > 0 | ||||||||||||||
sort!(view(y, ((n - k):n))) | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
ntiesx += float(k) * (k + 1) / 2 | ||||||||||||||
ndoubleties += countties(y, n - k, n) | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
end | ||||||||||||||
if iV > 1 nV += iV*(iV - 1) end | ||||||||||||||
nV = div(nV,2) | ||||||||||||||
|
||||||||||||||
nD = div(n*(n - 1),2) | ||||||||||||||
return (nD - nT - nU + nV - 2swaps!(y)) / (sqrt(nD - nT) * sqrt(nD - nU)) | ||||||||||||||
nswaps = msort!(y, 1, n) | ||||||||||||||
ntiesy = countties(y, 1, n) | ||||||||||||||
|
||||||||||||||
(npairs + ndoubleties - ntiesx - ntiesy - 2 * nswaps) / | ||||||||||||||
sqrt((npairs - ntiesx) * (npairs - ntiesy)) | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
""" | ||||||||||||||
countties(x::RealVector,lo::Int64,hi::Int64) | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
|
||||||||||||||
Assumes `x` is sorted. Returns the number of ties within `x[lo:hi]`. | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
""" | ||||||||||||||
function countties(x::AbstractVector, lo::Integer, hi::Integer) | ||||||||||||||
# avoid overflows on 32 bit by using floats | ||||||||||||||
thistiecount, result = 0.0, 0.0 | ||||||||||||||
(lo < 1 || hi > length(x)) && error("Bounds error") | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
@inbounds for i = (lo + 1):hi | ||||||||||||||
if x[i] == x[i - 1] | ||||||||||||||
thistiecount += 1.0 | ||||||||||||||
elseif thistiecount > 0 | ||||||||||||||
result += thistiecount * (thistiecount + 1) / 2 | ||||||||||||||
thistiecount = 0.0 | ||||||||||||||
end | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
if thistiecount > 0 | ||||||||||||||
result += thistiecount * (thistiecount + 1) / 2 | ||||||||||||||
end | ||||||||||||||
result | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
""" | ||||||||||||||
corkendall(x, y=x) | ||||||||||||||
|
@@ -102,52 +107,118 @@ matrices or vectors. | |||||||||||||
""" | ||||||||||||||
corkendall(x::RealVector, y::RealVector) = corkendall!(float(copy(x)), float(copy(y))) | ||||||||||||||
|
||||||||||||||
corkendall(X::RealMatrix, y::RealVector) = Float64[corkendall!(float(X[:,i]), float(copy(y))) for i in 1:size(X, 2)] | ||||||||||||||
|
||||||||||||||
corkendall(x::RealVector, Y::RealMatrix) = (n = size(Y,2); reshape(Float64[corkendall!(float(copy(x)), float(Y[:,i])) for i in 1:n], 1, n)) | ||||||||||||||
corkendall(X::RealMatrix, y::RealVector) = (permy = sortperm(y);Float64[corkendall!(float(copy(y)), float(X[:,i]), permy) for i in 1:size(X, 2)]) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume this is correct and faster?
Suggested change
BTW, do you think we really need to call Same suggestion for below. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you are correct that the |
||||||||||||||
|
||||||||||||||
corkendall(X::RealMatrix, Y::RealMatrix) = Float64[corkendall!(float(X[:,i]), float(Y[:,j])) for i in 1:size(X, 2), j in 1:size(Y, 2)] | ||||||||||||||
corkendall(x::RealVector, Y::RealMatrix) = (n = size(Y, 2); permx = sortperm(x); reshape(Float64[corkendall!(float(copy(x)), float(Y[:,i]), permx) for i in 1:n], 1, n)) | ||||||||||||||
|
||||||||||||||
function corkendall(X::RealMatrix) | ||||||||||||||
n = size(X, 2) | ||||||||||||||
C = Matrix{eltype(X)}(I, n, n) | ||||||||||||||
for j = 2:n | ||||||||||||||
for i = 1:j-1 | ||||||||||||||
C[i,j] = corkendall!(X[:,i],X[:,j]) | ||||||||||||||
C[j,i] = C[i,j] | ||||||||||||||
C = ones(float(eltype(X)), n, n)# avoids dependency on LinearAlgebra | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
@inbounds for j = 2:n | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
permx = sortperm(X[:,j]) | ||||||||||||||
for i = 1:j - 1 | ||||||||||||||
C[j,i] = corkendall!(X[:,j], X[:,i], permx) | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
C[i,j] = C[j,i] | ||||||||||||||
end | ||||||||||||||
end | ||||||||||||||
return C | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
function corkendall(X::RealMatrix, Y::RealMatrix) | ||||||||||||||
nr = size(X, 2) | ||||||||||||||
nc = size(Y, 2) | ||||||||||||||
C = zeros(float(eltype(X)), nr, nc) | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
@inbounds for j = 1:nr | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
permx = sortperm(X[:,j]) | ||||||||||||||
for i = 1:nc | ||||||||||||||
C[j,i] = corkendall!(X[:,j], Y[:,i], permx) | ||||||||||||||
end | ||||||||||||||
end | ||||||||||||||
return C | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
# Auxilliary functions for Kendall's rank correlation | ||||||||||||||
|
||||||||||||||
function swaps!(x::RealVector) | ||||||||||||||
n = length(x) | ||||||||||||||
if n == 1 return 0 end | ||||||||||||||
n2 = div(n, 2) | ||||||||||||||
xl = view(x, 1:n2) | ||||||||||||||
xr = view(x, n2+1:n) | ||||||||||||||
nsl = swaps!(xl) | ||||||||||||||
nsr = swaps!(xr) | ||||||||||||||
sort!(xl) | ||||||||||||||
sort!(xr) | ||||||||||||||
return nsl + nsr + mswaps(xl,xr) | ||||||||||||||
end | ||||||||||||||
# Tests appear to show that a value of 64 is optimal, | ||||||||||||||
# but note that the equivalent constant in base/sort.jl is 20. | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
const SMALL_THRESHOLD = 64 | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
|
||||||||||||||
function mswaps(x::RealVector, y::RealVector) | ||||||||||||||
i = 1 | ||||||||||||||
j = 1 | ||||||||||||||
nSwaps = 0 | ||||||||||||||
n = length(x) | ||||||||||||||
while i <= n && j <= length(y) | ||||||||||||||
if y[j] < x[i] | ||||||||||||||
nSwaps += n - i + 1 | ||||||||||||||
# Copy was from https://github.com/JuliaLang/julia/commit/28330a2fef4d9d149ba0fd3ffa06347b50067647 dated 20 Sep 2020 | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
""" | ||||||||||||||
msort!(v::AbstractVector, lo::Integer, hi::Integer, t=similar(v, 0)) | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
|
||||||||||||||
Mutates `v` by sorting elements `x[lo:hi]` using the merge sort algorithm. | ||||||||||||||
This method is a copy-paste-edit of sort! in base/sort.jl (the method specialised on MergeSortAlg), | ||||||||||||||
but amended to return the bubblesort distance. | ||||||||||||||
""" | ||||||||||||||
function msort!(v::AbstractVector, lo::Integer, hi::Integer, t=similar(v, 0)) | ||||||||||||||
# avoid overflow errors (if length(v)> 2^16) on 32 bit by using float | ||||||||||||||
nswaps = 0.0 | ||||||||||||||
@inbounds if lo < hi | ||||||||||||||
hi - lo <= SMALL_THRESHOLD && return isort!(v, lo, hi) | ||||||||||||||
|
||||||||||||||
m = midpoint(lo, hi) | ||||||||||||||
(length(t) < m - lo + 1) && resize!(t, m - lo + 1) | ||||||||||||||
|
||||||||||||||
nswaps = msort!(v, lo, m, t) | ||||||||||||||
nswaps += msort!(v, m + 1, hi, t) | ||||||||||||||
|
||||||||||||||
i, j = 1, lo | ||||||||||||||
while j <= m | ||||||||||||||
t[i] = v[j] | ||||||||||||||
i += 1 | ||||||||||||||
j += 1 | ||||||||||||||
else | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
i, k = 1, lo | ||||||||||||||
while k < j <= hi | ||||||||||||||
if v[j] < t[i] | ||||||||||||||
v[k] = v[j] | ||||||||||||||
j += 1 | ||||||||||||||
nswaps += m - lo + 1 - (i - 1) | ||||||||||||||
else | ||||||||||||||
v[k] = t[i] | ||||||||||||||
i += 1 | ||||||||||||||
end | ||||||||||||||
k += 1 | ||||||||||||||
end | ||||||||||||||
while k < j | ||||||||||||||
v[k] = t[i] | ||||||||||||||
k += 1 | ||||||||||||||
i += 1 | ||||||||||||||
end | ||||||||||||||
end | ||||||||||||||
return nSwaps | ||||||||||||||
return nswaps | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
# This function is also copied from base/sort.jl | ||||||||||||||
midpoint(lo::T, hi::T) where T <: Integer = lo + ((hi - lo) >>> 0x01) | ||||||||||||||
midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...) | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
|
||||||||||||||
# Copy was from https://github.com/JuliaLang/julia/commit/28330a2fef4d9d149ba0fd3ffa06347b50067647 dated 20 Sep 2020 | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
""" | ||||||||||||||
isort!(v::AbstractVector, lo::Integer, hi::Integer) | ||||||||||||||
|
||||||||||||||
Mutates `v` by sorting elements `x[lo:hi]` using the insertion sort algorithm. | ||||||||||||||
This method is a copy-paste-edit of sort! in base/sort.jl (the method specialised on InsertionSortAlg), | ||||||||||||||
amended to return the bubblesort distance. | ||||||||||||||
""" | ||||||||||||||
function isort!(v::AbstractVector, lo::Integer, hi::Integer) | ||||||||||||||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
if lo == hi return 0.0 end | ||||||||||||||
nswaps = 0.0 | ||||||||||||||
@inbounds for i = lo + 1:hi | ||||||||||||||
j = i | ||||||||||||||
x = v[i] | ||||||||||||||
while j > lo | ||||||||||||||
if x < v[j - 1] | ||||||||||||||
nswaps += 1.0 | ||||||||||||||
v[j] = v[j - 1] | ||||||||||||||
j -= 1 | ||||||||||||||
continue | ||||||||||||||
end | ||||||||||||||
break | ||||||||||||||
end | ||||||||||||||
v[j] = x | ||||||||||||||
end | ||||||||||||||
return nswaps | ||||||||||||||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,17 +26,63 @@ c22 = corspearman(x2, x2) | |
|
||
# corkendall | ||
|
||
@test corkendall(x1, y) ≈ -0.105409255338946 | ||
@test corkendall(x2, y) ≈ -0.117851130197758 | ||
@test_throws ErrorException("Vectors must have same length") corkendall([1,2,3,4], [1,2,3]) | ||
@test isnan(corkendall([1,2], [3,NaN])) | ||
@test isnan(corkendall([1,1,1], [1,2,3])) | ||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@test corkendall(X, y) ≈ [-0.105409255338946, -0.117851130197758] | ||
@test corkendall(y, X) ≈ [-0.105409255338946 -0.117851130197758] | ||
@test corkendall(x1, y) == -1 / sqrt(90) | ||
@test corkendall(x2, y) == -1 / sqrt(72) | ||
@test corkendall(X, y) == [-1 / sqrt(90), -1 / sqrt(72)] | ||
@test corkendall(y, X) == [-1 / sqrt(90) -1 / sqrt(72)] | ||
|
||
n = 100_000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use a lower value if possible to avoid slowing down CI. Please also add comments to shortly explain what's the purpose of each series of tests when it's not obvious (e.g. overflow, small threshold...). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've used a slightly lower value n = 78_000 then a bit later in the code redefined n to 100. I think test/rankcorr.jl executes in under a second (vs 7 seconds before). |
||
@test corkendall(repeat(x1, n), repeat(y, n)) ≈ -1 / sqrt(90) | ||
@test corkendall(repeat(x2, n), repeat(y, n)) ≈ -1 / sqrt(72) | ||
@test corkendall(repeat(X, n), repeat(y, n)) ≈ [-1 / sqrt(90), -1 / sqrt(72)] | ||
@test corkendall(repeat(y, n), repeat(X, n)) ≈ [-1 / sqrt(90) -1 / sqrt(72)] | ||
|
||
c11 = corkendall(x1, x1) | ||
c12 = corkendall(x1, x2) | ||
c22 = corkendall(x2, x2) | ||
|
||
@test c11 ≈ 1.0 | ||
@test c22 ≈ 1.0 | ||
@test c11 == 1.0 | ||
@test c22 == 1.0 | ||
@test c12 == 3 / sqrt(20) | ||
|
||
@test corkendall(X, X) ≈ [c11 c12; c12 c22] | ||
@test corkendall(X) ≈ [c11 c12; c12 c22] | ||
|
||
@test corkendall(repeat(X, n), repeat(X, n)) ≈ [c11 c12; c12 c22] | ||
@test corkendall(repeat(X, n)) ≈ [c11 c12; c12 c22] | ||
|
||
@test corkendall(collect(1:n), collect(1:n)) == 1.0 | ||
@test corkendall(collect(1:n), reverse(collect(1:n))) == -1.0 | ||
@test isnan(corkendall(repeat([1], n), collect(1:n))) | ||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@test corkendall(repeat([0,1,1,0], n), repeat([1,0,1,0], n)) == 0.0 | ||
|
||
z = [1 1 1; | ||
1 1 2; | ||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
1 2 2; | ||
1 2 2; | ||
1 2 1; | ||
2 1 2; | ||
1 1 2; | ||
2 2 2] | ||
|
||
@test corkendall(z) == [1 0 1 / 3; 0 1 0;1 / 3 0 1] | ||
PGS62 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@test corkendall(z, z) == [1 0 1 / 3; 0 1 0;1 / 3 0 1] | ||
@test corkendall(z[:,1], z) == [1 0 1 / 3] | ||
@test corkendall(z, z[:,1]) == [1;0;1 / 3] | ||
|
||
z = float(z) | ||
@test corkendall(z) == [1 0 1 / 3; 0 1 0;1 / 3 0 1] | ||
@test corkendall(z, z) == [1 0 1 / 3; 0 1 0;1 / 3 0 1] | ||
@test corkendall(z[:,1], z) == [1 0 1 / 3] | ||
@test corkendall(z, z[:,1]) == [1;0;1 / 3] | ||
|
||
w = repeat(z, n) | ||
@test corkendall(w) == [1 0 1 / 3; 0 1 0;1 / 3 0 1] | ||
@test corkendall(w, w) == [1 0 1 / 3; 0 1 0;1 / 3 0 1] | ||
@test corkendall(w[:,1], w) == [1 0 1 / 3] | ||
@test corkendall(w, w[:,1]) == [1;0;1 / 3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ugh, that was VSCode auto formatter. Fixed now.