Skip to content

Commit

Permalink
use butterfly function for in-memory fft!
Browse files Browse the repository at this point in the history
  • Loading branch information
KlausC committed Feb 3, 2024
1 parent 764c1c9 commit 22b4424
Showing 1 changed file with 66 additions and 23 deletions.
89 changes: 66 additions & 23 deletions src/fourier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ using principal `2dd`th root of unity `ω = X^z, ω^dd == -1`.
Result overwrites input `F`. `W` is workspace of same size as `F`. Assume `n * dd == length(F)`.
"""
function fft!(n::Integer, F::V, dd::Int, z::Int, W::V) where {R,V<:AbstractVector{R}}
function fft!(n::Integer, F::V, dd::Int, z::Int) where {R,V<:AbstractVector{R}}
n > 0 && count_ones(n) == 1 || throw(ArgumentError("n must be power of 2"))
0 < dd && count_ones(dd) == 1 || throw(ArgumentError("dd must be power of 2 < n"))
length(F) == n * dd || throw(ArgumentError("length(F) must be n * dd"))
mod(z * n ÷ 2, 2dd) == dd || throw(ArgumentError("w^($(n)/2) must be -1"))
#mod(z * n ÷ 2, 2dd) == dd || throw(ArgumentError("w^($(n)/2) must be -1"))

F0 = F
m = n
Expand All @@ -95,24 +95,8 @@ function fft!(n::Integer, F::V, dd::Int, z::Int, W::V) where {R,V<:AbstractVecto
m = k
d += d
end
m = 2
while m <= n
k = m ÷ 2
for j = 0:m:n-1
for i = 0:k-1
for l = 1:dd
W[(2i+j)*dd+l] = F[(i+j)*dd+l]
W[(2i+j+1)*dd+l] = F[(i+k+j)*dd+l]
end
end
end
F, W = W, F
m += m
end
if F !== F0
F0 .= F
end
F0
butterfly!(F, dd, n)
F
end

"""
Expand Down Expand Up @@ -205,10 +189,10 @@ function _schoenhage_strassen!(k::Int, FF::Q, GG::Q, W::Q) where Q<:AbstractVect

scale!(FF, 2d, z)
scale!(GG, 2d, z)
fft!(δ, FF, 2d, w, W)
fft!(δ, GG, 2d, w, W)
fft!(δ, FF, 2d, w)
fft!(δ, GG, 2d, w)
convolute_all!(FF, FF, GG, 2d, W)
fft!(δ, FF, 2d, -w, W)
fft!(δ, FF, 2d, -w)
scale!(FF, 2d, -z)
FF .= sdiv.(FF, δ)
shrink!(FF, d, δ)
Expand Down Expand Up @@ -505,3 +489,62 @@ end
function effort_sch(N)
(N * ilog2(N) + 1) * e_factor
end

"""
revert(a::Int, b::Int)
Return `a` with bits `0:n-1` in reverse order.
"""
function revert(a, n)
accu = (a >> UInt8(n)) << UInt8(n)
for i = 1:n
accu += accu
accu += isodd(a)
a >>= 0x1
end
accu
end

"""
butterfly!(a, d, k)
Permute in-memory so `a_new[revert(i)*d+l] == a[i*d] for i ∈ 0:2^k-1 for l ∈ 1:d`.
"""
function butterfly!(F::AbstractVector, dd::Int, n::Int)
k = ilog2(n)
@assert n * dd == length(F)
id = 0
for i = 0:n-1
j = revert(i, k)
if i < j
jd = j * dd
for l = 1:dd
F[id+l], F[jd+l] = F[jd+l], F[id+l]
end
end
id += dd
end
F
end

function butterfly2!(F::AbstractVector, dd::Int, n::Int, W::AbstractVector)
m = 2
F0 = F
while m <= n
k = m ÷ 2
for j = 0:m:n-1
for i = 0:k-1
for l = 1:dd
W[(2i+j)*dd+l] = F[(i+j)*dd+l]
W[(2i+j+1)*dd+l] = F[(i+k+j)*dd+l]
end
end
end
F, W = W, F
m += m
end
if F !== F0
F0 .= F
end
F0
end

0 comments on commit 22b4424

Please sign in to comment.