Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solve triu #1920

Merged
merged 8 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 85 additions & 11 deletions src/Matrix-Strassen.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
Provides generic asymptotically fast matrix methods:
- mul and mul! using the Strassen scheme
- _solve_tril!
- lu!
- _solve_triu
- `mul` and `mul!` using the Strassen scheme
- `_solve_tril!`
- `lu!`
- `_solve_triu`

Just prefix the function by "Strassen." all 4 functions support a keyword
argument "cutoff" to indicate when the base case should be used.
Expand Down Expand Up @@ -40,6 +40,12 @@ function mul(A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T}
end

#scheduling copied from the nmod_mat_mul in Flint
"""
Fast, recursive, generic matrix multiplication using the Strassen
trick.

`cutoff` indicates when the recursion stops and the base case is called.
"""
function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T}
sA = size(A)
sB = size(B)
Expand Down Expand Up @@ -274,17 +280,82 @@ function lu!(P::Perm{Int}, A; cutoff::Int = 300)
return r1 + r2
end

function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff)
#b*inv(T), thus solves Tx = b for T upper triangular
function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff, side::Symbol = :left)
#inv(T)*b, thus solves Tx = b for T upper triangular
n = ncols(T)
if n <= cutoff
R = AbstractAlgebra._solve_triu(T, b)
R = AbstractAlgebra._solve_triu(T, b; side)
return R
end
if side == :left
return _solve_triu_left(T, b; cutoff)
end
@assert side == :right
@assert n == nrows(T) == nrows(b)

n2 = div(n, 2) + n % 2
m = ncols(b)
m2 = div(m, 2) + m % 2
#=
b = [U X; V Y]
T = [A B; 0 C]
x = [SS RR; S R]

[0 C] [SS; S] = CS = V
[0 C] [RR; R] = CR = Y

[A B] [SS; S] = A SS + B S = U => A SS = U - BS
[A B] [RR; R] = A RR + B R = U => A RR = X - BR

=#

U = view(b, 1:n2, 1:m2)
X = view(b, 1:n2, m2+1:m)
V = view(b, n2+1:n, 1:m2)
Y = view(b, n2+1:n, m2+1:m)

A = view(T, 1:n2, 1:n2)
B = view(T, 1:n2, 1+n2:n)
C = view(T, 1+n2:n, 1+n2:n)

S = _solve_triu(C, V; cutoff, side)
R = _solve_triu(C, Y; cutoff, side)

SS = mul(B, S; cutoff)
SS = sub!(SS, U, SS)
SS = _solve_triu(A, SS; cutoff, side)

RR = mul(B, R; cutoff)
RR = sub!(RR, X, RR)
RR = _solve_triu(A, RR; cutoff, side)

return [SS RR; S R]
end

function _solve_triu_left(T::MatElem, b::MatElem; cutoff::Int = cutoff)
#b*inv(T), thus solves xT = b for T upper triangular
n = ncols(T)
if n <= cutoff
R = AbstractAlgebra._solve_triu_left(T, b)
return R
end

@assert ncols(b) == nrows(T) == n

n2 = div(n, 2) + n % 2
m = nrows(b)
m2 = div(m, 2) + m % 2
#=
b = [U X; V Y]
T = [A B; 0 C]
x = [S SS; R RR]

[S SS] [A; 0] = SA = U
[R RR] [A; 0] = RA = V
[S SS] [B; C] = SB + SS C = X => SS C = Y - SB
[R RR] [B; C] = RB + RR C = Y => RR C = Y - RB

=#

U = view(b, 1:m2, 1:n2)
V = view(b, 1:m2, n2+1:n)
Expand All @@ -295,18 +366,21 @@ function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff)
B = view(T, 1:n2, 1+n2:n)
C = view(T, 1+n2:n, 1+n2:n)

S = _solve_triu(A, U; cutoff)
R = _solve_triu(A, X; cutoff)
S = _solve_triu_left(A, U; cutoff)
R = _solve_triu_left(A, X; cutoff)

SS = mul(S, B; cutoff)
SS = sub!(SS, V, SS)
SS = _solve_triu(C, SS; cutoff)
SS = _solve_triu_left(C, SS; cutoff)

RR = mul(R, B; cutoff)
RR = sub!(RR, Y, RR)
RR = _solve_triu(C, RR; cutoff)
RR = _solve_triu_left(C, RR; cutoff)
#THINK: both pairs of solving could be combined:
# solve [U; X], A to get S and R...

return [S SS; R RR]
end


end # module
91 changes: 87 additions & 4 deletions src/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2175,7 +2175,7 @@ function rref!(A::MatrixElem{T}) where {T <: FieldElement}
V[j, i] = A[j, pivots[np + i]]
end
end
V = _solve_triu(U, V, false)
V = _solve_triu_right(U, V; unipotent = false)
for i = 1:rnk
for j = 1:i
A[j, pivots[i]] = i == j ? one(R) : R()
Expand Down Expand Up @@ -3451,14 +3451,14 @@ $n\times m$ matrix $x$ such that $Ux = b$. If $U$ is singular an exception
is raised. If unit is true then $U$ is assumed to have ones on its
diagonal, and the diagonal will not be read.
"""
function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T <: FieldElement}
function _solve_triu_right(U::MatElem{T}, b::MatElem{T}; unipotent::Bool = false) where {T <: FieldElement}
n = nrows(U)
m = ncols(b)
R = base_ring(U)
X = zero(b)
Tinv = Vector{elem_type(R)}(undef, n)
tmp = Vector{elem_type(R)}(undef, n)
if unit == false
if unipotent == false
for i = 1:n
Tinv[i] = inv(U[i, i])
end
Expand All @@ -3475,7 +3475,7 @@ function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T
end
s = reduce!(s)
s = b[j, i] - s
if unit == false
if unipotent == false
s = mul!(s, s, Tinv[j])
end
tmp[j] = s
Expand All @@ -3487,6 +3487,89 @@ function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T
return X
end

@doc raw"""
_solve_triu(U::MatElem{T}, b::MatElem{T}; side::Symbol = :left) where {T <: RingElement}

Let $U$ be a non-singular $n\times n$ upper triangular matrix $U$ over a field. If
`side = :right`, let $b$
be an $n\times m$ matrix $b$ over the same field, return an
$n\times m$ matrix $x$ such that $Ux = b$. If this is not possible, an error
will be raised.

If `side = :left`, the default, $b$ has to be $m \times n$. In this case
$xU = b$ is solved - or an error raised.

See also [`AbstractAlgebra._solve_triu_left`](@ref) and [`Strassen`](@ref) for
asymptotically fast versions.
"""
function _solve_triu(U::MatElem{T}, b::MatElem{T}; side::Symbol = :left) where {T <: RingElement}
if side == :left
return _solve_triu_left(U, b)
end
@assert side == :right
n = nrows(U)
m = ncols(b)
R = base_ring(U)
X = zero(b)
tmp = Vector{elem_type(R)}(undef, n)
t = R()
for i = 1:m
for j = 1:n
tmp[j] = X[j, i]
end
for j = n:-1:1
s = R(0)
for k = j + 1:n
s = addmul!(s, U[j, k], tmp[k], t)
# s = s + U[j, k] * tmp[k]
end
s = b[j, i] - s
tmp[j] = divexact(s, U[j,j])
end
for j = 1:n
X[j, i] = tmp[j]
end
end
return X
end

@doc raw"""
_solve_triu_left(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement}

Given a non-singular $n\times n$ matrix $U$ over a field which is upper
triangular, and an $m\times n$ matrix $b$ over the same ring, return an
$m\times n$ matrix $x$ such that $xU = b$. If this is not possible, an error
will be raised.

See also [`_solve_triu`](@ref) and [`Strassen`](@ref) for asymptotically fast
versions.
"""
function _solve_triu_left(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement}
n = ncols(U)
m = nrows(b)
R = base_ring(U)
X = zero(b)
tmp = Vector{elem_type(R)}(undef, n)
t = R()
for i = 1:m
for j = 1:n
tmp[j] = X[i, j]
end
for j = 1:n
s = R()
for k = 1:j-1
s = addmul!(s, U[k, j], tmp[k], t)
end
s = b[i, j] - s
tmp[j] = divexact(s, U[j,j])
end
for j = 1:n
X[i, j] = tmp[j]
end
end
return X
end

#solves A x = B for A intended to be lower triangular
#only the lower part is used. if f is true, then the diagonal is assumed to be 1
#used to use lu!
Expand Down
12 changes: 12 additions & 0 deletions test/Solve-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,15 @@ end
@test ncols(S) == 3
@test base_ring(S) == QQ
end

@testset "solve_triu" begin
A = matrix(ZZ, 10, 10, [i<=j ? i+j-1 : 0 for i=1:10 for j=1:10])
x = matrix(ZZ, rand(-10:10, 10, 10))
@test AbstractAlgebra._solve_triu(A, A*x; side = :right) == x
@test AbstractAlgebra._solve_triu(A, x*A; side = :left) == x

A = matrix(ZZ, 20, 20, [i<=j ? i+j-1 : 0 for i=1:20 for j=1:20])
x = matrix(ZZ, rand(-10:10, 20, 20))
@test AbstractAlgebra.Strassen._solve_triu(A, A*x; cutoff = 10, side = :right) == x
@test AbstractAlgebra.Strassen._solve_triu(A, x*A; cutoff = 10, side = :left) == x
end
2 changes: 1 addition & 1 deletion test/generic/Matrix-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2656,7 +2656,7 @@ end
M = randmat_triu(S, -100:100)
b = rand(U, -100:100)

x = AbstractAlgebra._solve_triu(M, b, false)
x = AbstractAlgebra._solve_triu_right(M, b; unipotent = false)

@test M*x == b
end
Expand Down
Loading