Skip to content

Commit

Permalink
isanitize the solve_triu stuff a little bit
Browse files Browse the repository at this point in the history
  • Loading branch information
fieker committed Dec 3, 2024
1 parent 7bbbc7f commit 2a5d632
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 10 deletions.
92 changes: 85 additions & 7 deletions src/Matrix-Strassen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Provides generic asymptotically fast matrix methods:
- _solve_tril!
- lu!
- _solve_triu
- _solve_triu_left
Just prefix the function by "Strassen." all 4 functions support a keyword
argument "cutoff" to indicate when the base case should be used.
Expand Down Expand Up @@ -40,6 +41,12 @@ function mul(A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T}
end

#scheduling copied from the nmod_mat_mul in Flint
"""
Fast, recursive, generic matrix multiplication using the Strassen
trick.
`cutoff` indicates when the recursion stops and the base case is called.
"""
function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T}
sA = size(A)
sB = size(B)
Expand Down Expand Up @@ -274,17 +281,85 @@ function lu!(P::Perm{Int}, A; cutoff::Int = 300)
return r1 + r2
end

function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff)
#b*inv(T), thus solves Tx = b for T upper triangular
function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff, side::Symbol = :left)
#inv(T)*b, thus solves Tx = b for T upper triangular
n = ncols(T)
if n <= cutoff
R = AbstractAlgebra._solve_triu(T, b; side)
return R
end
if side == :left
return _solve_triu_left(T, b; side)
end
@assert side == :right
@assert n == nrows(T) == nrows(b)

n2 = div(n, 2) + n % 2
m = ncols(b)
m2 = div(m, 2) + m % 2
#=
b = [U X; V Y]
T = [A B; 0 C]
x = [SS RR; S R]
[0 C] [SS; S] = CS = V
[0 C] [RR; R] = CR = Y
[A B] [SS; S] = A SS + B S = U => A SS = U - BS
[A B] [RR; R] = A RR + B R = U => A RR = X - BR
=#

U = view(b, 1:n2, 1:m2)
X = view(b, 1:n2, m2+1:m)
V = view(b, n2+1:n, 1:m2)
Y = view(b, n2+1:n, m2+1:m)

A = view(T, 1:n2, 1:n2)
B = view(T, 1:n2, 1+n2:n)
C = view(T, 1+n2:n, 1+n2:n)

S = _solve_triu(C, V; cutoff, side)
R = _solve_triu(C, Y; cutoff, side)

SS = mul(B, S; cutoff)
SS = sub!(SS, U, SS)
SS = _solve_triu(A, SS; cutoff, side)

RR = mul(B, R; cutoff)
RR = sub!(RR, X, RR)
RR = _solve_triu(A, RR; cutoff, side)

return [SS RR; S R]
end

function _solve_triu_left(T::MatElem, b::MatElem; cutoff::Int = cutoff, side::Symbol = :left)
#b*inv(T), thus solves xT = b for T upper triangular
n = ncols(T)
if n <= cutoff
R = AbstractAlgebra._solve_triu(T, b)
R = AbstractAlgebra._solve_triu(T, b; side)
return R
end
if side == :right
return _solve_triu(T, b, side)

Check warning on line 344 in src/Matrix-Strassen.jl

View check run for this annotation

Codecov / codecov/patch

src/Matrix-Strassen.jl#L343-L344

Added lines #L343 - L344 were not covered by tests
end
@assert side == :left
@assert ncols(b) == nrows(T) == n

Check warning on line 347 in src/Matrix-Strassen.jl

View check run for this annotation

Codecov / codecov/patch

src/Matrix-Strassen.jl#L346-L347

Added lines #L346 - L347 were not covered by tests

n2 = div(n, 2) + n % 2
m = nrows(b)
m2 = div(m, 2) + m % 2
#=
b = [U X; V Y]
T = [A B; 0 C]
x = [S SS; R RR]
[S SS] [A; 0] = SA = U
[R RR] [A; 0] = RA = V
[S SS] [B; C] = SB + SS C = X => SS C = Y - SB
[R RR] [B; C] = RB + RR C = Y => RR C = Y - RB
=#

U = view(b, 1:m2, 1:n2)
V = view(b, 1:m2, n2+1:n)
Expand All @@ -295,18 +370,21 @@ function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff)
B = view(T, 1:n2, 1+n2:n)
C = view(T, 1+n2:n, 1+n2:n)

S = _solve_triu(A, U; cutoff)
R = _solve_triu(A, X; cutoff)
S = _solve_triu_left(U, A; cutoff, side)
R = _solve_triu_left(X, A; cutoff, side)

Check warning on line 374 in src/Matrix-Strassen.jl

View check run for this annotation

Codecov / codecov/patch

src/Matrix-Strassen.jl#L373-L374

Added lines #L373 - L374 were not covered by tests

SS = mul(S, B; cutoff)
SS = sub!(SS, V, SS)
SS = _solve_triu(C, SS; cutoff)
SS = _solve_triu_left(SS, C; cutoff, side)

Check warning on line 378 in src/Matrix-Strassen.jl

View check run for this annotation

Codecov / codecov/patch

src/Matrix-Strassen.jl#L378

Added line #L378 was not covered by tests

RR = mul(R, B; cutoff)
RR = sub!(RR, Y, RR)
RR = _solve_triu(C, RR; cutoff)
RR = _solve_triu_left(RR, C; cutoff, side)

Check warning on line 382 in src/Matrix-Strassen.jl

View check run for this annotation

Codecov / codecov/patch

src/Matrix-Strassen.jl#L382

Added line #L382 was not covered by tests
#THINK: both pairs of solving could be combined:
# solve [U; X], A to get S and R...

return [S SS; R RR]
end


end # module
14 changes: 11 additions & 3 deletions src/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3497,7 +3497,11 @@ will be raised.
See also [`AbstractAlgebra.__solve_triu_left`](@ref)
"""
function _solve_triu(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement}
function _solve_triu(U::MatElem{T}, b::MatElem{T}; side::Symbol = :left) where {T <: RingElement}
if side == :left
return _solve_triu_left(U, b; side)
end
@assert side == :right
n = nrows(U)
m = ncols(b)
R = base_ring(U)
Expand Down Expand Up @@ -3525,7 +3529,7 @@ function _solve_triu(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement}
end

@doc raw"""
__solve_triu_left(b::MatElem{T}, U::MatElem{T}) where {T <: RingElement}
_solve_triu_left(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement}
Given a non-singular $n\times n$ matrix $U$ over a field which is upper
triangular, and an $m\times n$ matrix $b$ over the same ring, return an
Expand All @@ -3535,7 +3539,11 @@ will be raised.
See also [`_solve_triu`](@ref) or [`can__solve_left_reduced_triu`](@ref) when
$U$ is not square or not of full rank.
"""
function __solve_triu_left(b::MatElem{T}, U::MatElem{T}) where {T <: RingElement}
function _solve_triu_left(U::MatElem{T}, b::MatElem{T}; side::Symbol = :left) where {T <: RingElement}
if side == :right
return _solve_triu(U, b; side)

Check warning on line 3544 in src/Matrix.jl

View check run for this annotation

Codecov / codecov/patch

src/Matrix.jl#L3544

Added line #L3544 was not covered by tests
end
@assert side == :left
n = ncols(U)
m = nrows(b)
R = base_ring(U)
Expand Down
12 changes: 12 additions & 0 deletions test/Solve-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,15 @@ end
@test ncols(S) == 3
@test base_ring(S) == QQ
end

@testset "solve_triu" begin
A = matrix(ZZ, 10, 10, [i<=j ? i+j-1 : 0 for i=1:10 for j=1:10])
x = matrix(ZZ, rand(-10:10, 10, 10))
@test AbstractAlgebra._solve_triu(A, A*x; side = :right) == x
@test AbstractAlgebra._solve_triu(A, x*A; side = :left) == x

A = matrix(ZZ, 20, 20, [i<=j ? i+j-1 : 0 for i=1:20 for j=1:20])
x = matrix(ZZ, rand(-10:10, 20, 20))
@test AbstractAlgebra.Strassen._solve_triu(A, A*x; cutoff = 10, side = :right) == x
@test AbstractAlgebra.Strassen._solve_triu(A, x*A; cutoff = 10, side = :left) == x
end

0 comments on commit 2a5d632

Please sign in to comment.