Skip to content

Commit

Permalink
Solve triu (#1920)
Browse files Browse the repository at this point in the history
  • Loading branch information
fieker authored Dec 12, 2024
1 parent 56ca039 commit c18a4d2
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 16 deletions.
96 changes: 85 additions & 11 deletions src/Matrix-Strassen.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
Provides generic asymptotically fast matrix methods:
- mul and mul! using the Strassen scheme
- _solve_tril!
- lu!
- _solve_triu
- `mul` and `mul!` using the Strassen scheme
- `_solve_tril!`
- `lu!`
- `_solve_triu`
Just prefix the function by "Strassen." all 4 functions support a keyword
argument "cutoff" to indicate when the base case should be used.
Expand Down Expand Up @@ -40,6 +40,12 @@ function mul(A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T}
end

#scheduling copied from the nmod_mat_mul in Flint
"""
Fast, recursive, generic matrix multiplication using the Strassen
trick.
`cutoff` indicates when the recursion stops and the base case is called.
"""
function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T}
sA = size(A)
sB = size(B)
Expand Down Expand Up @@ -274,17 +280,82 @@ function lu!(P::Perm{Int}, A; cutoff::Int = 300)
return r1 + r2
end

function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff)
#b*inv(T), thus solves Tx = b for T upper triangular
function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff, side::Symbol = :left)
#inv(T)*b, thus solves Tx = b for T upper triangular
n = ncols(T)
if n <= cutoff
R = AbstractAlgebra._solve_triu(T, b)
R = AbstractAlgebra._solve_triu(T, b; side)
return R
end
if side == :left
return _solve_triu_left(T, b; cutoff)
end
@assert side == :right
@assert n == nrows(T) == nrows(b)

n2 = div(n, 2) + n % 2
m = ncols(b)
m2 = div(m, 2) + m % 2
#=
b = [U X; V Y]
T = [A B; 0 C]
x = [SS RR; S R]
[0 C] [SS; S] = CS = V
[0 C] [RR; R] = CR = Y
[A B] [SS; S] = A SS + B S = U => A SS = U - BS
[A B] [RR; R] = A RR + B R = U => A RR = X - BR
=#

U = view(b, 1:n2, 1:m2)
X = view(b, 1:n2, m2+1:m)
V = view(b, n2+1:n, 1:m2)
Y = view(b, n2+1:n, m2+1:m)

A = view(T, 1:n2, 1:n2)
B = view(T, 1:n2, 1+n2:n)
C = view(T, 1+n2:n, 1+n2:n)

S = _solve_triu(C, V; cutoff, side)
R = _solve_triu(C, Y; cutoff, side)

SS = mul(B, S; cutoff)
SS = sub!(SS, U, SS)
SS = _solve_triu(A, SS; cutoff, side)

RR = mul(B, R; cutoff)
RR = sub!(RR, X, RR)
RR = _solve_triu(A, RR; cutoff, side)

return [SS RR; S R]
end

function _solve_triu_left(T::MatElem, b::MatElem; cutoff::Int = cutoff)
#b*inv(T), thus solves xT = b for T upper triangular
n = ncols(T)
if n <= cutoff
R = AbstractAlgebra._solve_triu_left(T, b)
return R
end

@assert ncols(b) == nrows(T) == n

n2 = div(n, 2) + n % 2
m = nrows(b)
m2 = div(m, 2) + m % 2
#=
b = [U X; V Y]
T = [A B; 0 C]
x = [S SS; R RR]
[S SS] [A; 0] = SA = U
[R RR] [A; 0] = RA = V
[S SS] [B; C] = SB + SS C = X => SS C = Y - SB
[R RR] [B; C] = RB + RR C = Y => RR C = Y - RB
=#

U = view(b, 1:m2, 1:n2)
V = view(b, 1:m2, n2+1:n)
Expand All @@ -295,18 +366,21 @@ function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff)
B = view(T, 1:n2, 1+n2:n)
C = view(T, 1+n2:n, 1+n2:n)

S = _solve_triu(A, U; cutoff)
R = _solve_triu(A, X; cutoff)
S = _solve_triu_left(A, U; cutoff)
R = _solve_triu_left(A, X; cutoff)

SS = mul(S, B; cutoff)
SS = sub!(SS, V, SS)
SS = _solve_triu(C, SS; cutoff)
SS = _solve_triu_left(C, SS; cutoff)

RR = mul(R, B; cutoff)
RR = sub!(RR, Y, RR)
RR = _solve_triu(C, RR; cutoff)
RR = _solve_triu_left(C, RR; cutoff)
#THINK: both pairs of solving could be combined:
# solve [U; X], A to get S and R...

return [S SS; R RR]
end


end # module
91 changes: 87 additions & 4 deletions src/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2135,7 +2135,7 @@ function rref!(A::MatrixElem{T}) where {T <: FieldElement}
V[j, i] = A[j, pivots[np + i]]
end
end
V = _solve_triu(U, V, false)
V = _solve_triu_right(U, V; unipotent = false)
for i = 1:rnk
for j = 1:i
A[j, pivots[i]] = i == j ? one(R) : R()
Expand Down Expand Up @@ -3411,14 +3411,14 @@ $n\times m$ matrix $x$ such that $Ux = b$. If $U$ is singular an exception
is raised. If unit is true then $U$ is assumed to have ones on its
diagonal, and the diagonal will not be read.
"""
function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T <: FieldElement}
function _solve_triu_right(U::MatElem{T}, b::MatElem{T}; unipotent::Bool = false) where {T <: FieldElement}
n = nrows(U)
m = ncols(b)
R = base_ring(U)
X = zero(b)
Tinv = Vector{elem_type(R)}(undef, n)
tmp = Vector{elem_type(R)}(undef, n)
if unit == false
if unipotent == false
for i = 1:n
Tinv[i] = inv(U[i, i])
end
Expand All @@ -3435,7 +3435,7 @@ function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T
end
s = reduce!(s)
s = b[j, i] - s
if unit == false
if unipotent == false
s = mul!(s, s, Tinv[j])
end
tmp[j] = s
Expand All @@ -3447,6 +3447,89 @@ function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T
return X
end

@doc raw"""
_solve_triu(U::MatElem{T}, b::MatElem{T}; side::Symbol = :left) where {T <: RingElement}
Let $U$ be a non-singular $n\times n$ upper triangular matrix $U$ over a field. If
`side = :right`, let $b$
be an $n\times m$ matrix $b$ over the same field, return an
$n\times m$ matrix $x$ such that $Ux = b$. If this is not possible, an error
will be raised.
If `side = :left`, the default, $b$ has to be $m \times n$. In this case
$xU = b$ is solved - or an error raised.
See also [`AbstractAlgebra._solve_triu_left`](@ref) and [`Strassen`](@ref) for
asymptotically fast versions.
"""
function _solve_triu(U::MatElem{T}, b::MatElem{T}; side::Symbol = :left) where {T <: RingElement}
if side == :left
return _solve_triu_left(U, b)
end
@assert side == :right
n = nrows(U)
m = ncols(b)
R = base_ring(U)
X = zero(b)
tmp = Vector{elem_type(R)}(undef, n)
t = R()
for i = 1:m
for j = 1:n
tmp[j] = X[j, i]
end
for j = n:-1:1
s = R(0)
for k = j + 1:n
s = addmul!(s, U[j, k], tmp[k], t)
# s = s + U[j, k] * tmp[k]
end
s = b[j, i] - s
tmp[j] = divexact(s, U[j,j])
end
for j = 1:n
X[j, i] = tmp[j]
end
end
return X
end

@doc raw"""
_solve_triu_left(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement}
Given a non-singular $n\times n$ matrix $U$ over a field which is upper
triangular, and an $m\times n$ matrix $b$ over the same ring, return an
$m\times n$ matrix $x$ such that $xU = b$. If this is not possible, an error
will be raised.
See also [`_solve_triu`](@ref) and [`Strassen`](@ref) for asymptotically fast
versions.
"""
function _solve_triu_left(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement}
n = ncols(U)
m = nrows(b)
R = base_ring(U)
X = zero(b)
tmp = Vector{elem_type(R)}(undef, n)
t = R()
for i = 1:m
for j = 1:n
tmp[j] = X[i, j]
end
for j = 1:n
s = R()
for k = 1:j-1
s = addmul!(s, U[k, j], tmp[k], t)
end
s = b[i, j] - s
tmp[j] = divexact(s, U[j,j])
end
for j = 1:n
X[i, j] = tmp[j]
end
end
return X
end

#solves A x = B for A intended to be lower triangular
#only the lower part is used. if f is true, then the diagonal is assumed to be 1
#used to use lu!
Expand Down
12 changes: 12 additions & 0 deletions test/Solve-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,15 @@ end
@test ncols(S) == 3
@test base_ring(S) == QQ
end

@testset "solve_triu" begin
A = matrix(ZZ, 10, 10, [i<=j ? i+j-1 : 0 for i=1:10 for j=1:10])
x = matrix(ZZ, rand(-10:10, 10, 10))
@test AbstractAlgebra._solve_triu(A, A*x; side = :right) == x
@test AbstractAlgebra._solve_triu(A, x*A; side = :left) == x

A = matrix(ZZ, 20, 20, [i<=j ? i+j-1 : 0 for i=1:20 for j=1:20])
x = matrix(ZZ, rand(-10:10, 20, 20))
@test AbstractAlgebra.Strassen._solve_triu(A, A*x; cutoff = 10, side = :right) == x
@test AbstractAlgebra.Strassen._solve_triu(A, x*A; cutoff = 10, side = :left) == x
end
2 changes: 1 addition & 1 deletion test/generic/Matrix-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2658,7 +2658,7 @@ end
M = randmat_triu(S, -100:100)
b = rand(U, -100:100)

x = AbstractAlgebra._solve_triu(M, b, false)
x = AbstractAlgebra._solve_triu_right(M, b; unipotent = false)

@test M*x == b
end
Expand Down

0 comments on commit c18a4d2

Please sign in to comment.