diff --git a/src/Matrix-Strassen.jl b/src/Matrix-Strassen.jl index a07c8b1834..caa5d62e23 100644 --- a/src/Matrix-Strassen.jl +++ b/src/Matrix-Strassen.jl @@ -1,9 +1,9 @@ """ Provides generic asymptotically fast matrix methods: - - mul and mul! using the Strassen scheme - - _solve_tril! - - lu! - - _solve_triu + - `mul` and `mul!` using the Strassen scheme + - `_solve_tril!` + - `lu!` + - `_solve_triu` Just prefix the function by "Strassen." all 4 functions support a keyword argument "cutoff" to indicate when the base case should be used. @@ -40,6 +40,12 @@ function mul(A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T} end #scheduling copied from the nmod_mat_mul in Flint +""" +Fast, recursive, generic matrix multiplication using the Strassen +trick. + +`cutoff` indicates when the recursion stops and the base case is called. +""" function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T} sA = size(A) sB = size(B) @@ -274,17 +280,82 @@ function lu!(P::Perm{Int}, A; cutoff::Int = 300) return r1 + r2 end -function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff) - #b*inv(T), thus solves Tx = b for T upper triangular +function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff, side::Symbol = :left) + #inv(T)*b, thus solves Tx = b for T upper triangular n = ncols(T) if n <= cutoff - R = AbstractAlgebra._solve_triu(T, b) + R = AbstractAlgebra._solve_triu(T, b; side) return R end + if side == :left + return _solve_triu_left(T, b; cutoff) + end + @assert side == :right + @assert n == nrows(T) == nrows(b) + + n2 = div(n, 2) + n % 2 + m = ncols(b) + m2 = div(m, 2) + m % 2 + #= + b = [U X; V Y] + T = [A B; 0 C] + x = [SS RR; S R] + + [0 C] [SS; S] = CS = V + [0 C] [RR; R] = CR = Y + + [A B] [SS; S] = A SS + B S = U => A SS = U - BS + [A B] [RR; R] = A RR + B R = U => A RR = X - BR + + =# + + U = view(b, 1:n2, 1:m2) + X = view(b, 1:n2, m2+1:m) + V = view(b, n2+1:n, 1:m2) + Y = view(b, n2+1:n, m2+1:m) + + A = view(T, 1:n2, 1:n2) + B = view(T, 1:n2, 1+n2:n) + C = view(T, 1+n2:n, 1+n2:n) + + S = _solve_triu(C, V; cutoff, side) + R = _solve_triu(C, Y; cutoff, side) + + SS = mul(B, S; cutoff) + SS = sub!(SS, U, SS) + SS = _solve_triu(A, SS; cutoff, side) + + RR = mul(B, R; cutoff) + RR = sub!(RR, X, RR) + RR = _solve_triu(A, RR; cutoff, side) + + return [SS RR; S R] +end + +function _solve_triu_left(T::MatElem, b::MatElem; cutoff::Int = cutoff) + #b*inv(T), thus solves xT = b for T upper triangular + n = ncols(T) + if n <= cutoff + R = AbstractAlgebra._solve_triu_left(T, b) + return R + end + + @assert ncols(b) == nrows(T) == n n2 = div(n, 2) + n % 2 m = nrows(b) m2 = div(m, 2) + m % 2 + #= + b = [U X; V Y] + T = [A B; 0 C] + x = [S SS; R RR] + + [S SS] [A; 0] = SA = U + [R RR] [A; 0] = RA = V + [S SS] [B; C] = SB + SS C = X => SS C = Y - SB + [R RR] [B; C] = RB + RR C = Y => RR C = Y - RB + + =# U = view(b, 1:m2, 1:n2) V = view(b, 1:m2, n2+1:n) @@ -295,18 +366,21 @@ function _solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff) B = view(T, 1:n2, 1+n2:n) C = view(T, 1+n2:n, 1+n2:n) - S = _solve_triu(A, U; cutoff) - R = _solve_triu(A, X; cutoff) + S = _solve_triu_left(A, U; cutoff) + R = _solve_triu_left(A, X; cutoff) SS = mul(S, B; cutoff) SS = sub!(SS, V, SS) - SS = _solve_triu(C, SS; cutoff) + SS = _solve_triu_left(C, SS; cutoff) RR = mul(R, B; cutoff) RR = sub!(RR, Y, RR) - RR = _solve_triu(C, RR; cutoff) + RR = _solve_triu_left(C, RR; cutoff) + #THINK: both pairs of solving could be combined: + # solve [U; X], A to get S and R... return [S SS; R RR] end + end # module diff --git a/src/Matrix.jl b/src/Matrix.jl index dc2d363d63..493678e45e 100644 --- a/src/Matrix.jl +++ b/src/Matrix.jl @@ -2135,7 +2135,7 @@ function rref!(A::MatrixElem{T}) where {T <: FieldElement} V[j, i] = A[j, pivots[np + i]] end end - V = _solve_triu(U, V, false) + V = _solve_triu_right(U, V; unipotent = false) for i = 1:rnk for j = 1:i A[j, pivots[i]] = i == j ? one(R) : R() @@ -3411,14 +3411,14 @@ $n\times m$ matrix $x$ such that $Ux = b$. If $U$ is singular an exception is raised. If unit is true then $U$ is assumed to have ones on its diagonal, and the diagonal will not be read. """ -function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T <: FieldElement} +function _solve_triu_right(U::MatElem{T}, b::MatElem{T}; unipotent::Bool = false) where {T <: FieldElement} n = nrows(U) m = ncols(b) R = base_ring(U) X = zero(b) Tinv = Vector{elem_type(R)}(undef, n) tmp = Vector{elem_type(R)}(undef, n) - if unit == false + if unipotent == false for i = 1:n Tinv[i] = inv(U[i, i]) end @@ -3435,7 +3435,7 @@ function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T end s = reduce!(s) s = b[j, i] - s - if unit == false + if unipotent == false s = mul!(s, s, Tinv[j]) end tmp[j] = s @@ -3447,6 +3447,89 @@ function _solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T return X end +@doc raw""" + _solve_triu(U::MatElem{T}, b::MatElem{T}; side::Symbol = :left) where {T <: RingElement} + +Let $U$ be a non-singular $n\times n$ upper triangular matrix $U$ over a field. If +`side = :right`, let $b$ +be an $n\times m$ matrix $b$ over the same field, return an +$n\times m$ matrix $x$ such that $Ux = b$. If this is not possible, an error +will be raised. + +If `side = :left`, the default, $b$ has to be $m \times n$. In this case +$xU = b$ is solved - or an error raised. + +See also [`AbstractAlgebra._solve_triu_left`](@ref) and [`Strassen`](@ref) for + asymptotically fast versions. +""" +function _solve_triu(U::MatElem{T}, b::MatElem{T}; side::Symbol = :left) where {T <: RingElement} + if side == :left + return _solve_triu_left(U, b) + end + @assert side == :right + n = nrows(U) + m = ncols(b) + R = base_ring(U) + X = zero(b) + tmp = Vector{elem_type(R)}(undef, n) + t = R() + for i = 1:m + for j = 1:n + tmp[j] = X[j, i] + end + for j = n:-1:1 + s = R(0) + for k = j + 1:n + s = addmul!(s, U[j, k], tmp[k], t) +# s = s + U[j, k] * tmp[k] + end + s = b[j, i] - s + tmp[j] = divexact(s, U[j,j]) + end + for j = 1:n + X[j, i] = tmp[j] + end + end + return X +end + +@doc raw""" + _solve_triu_left(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement} + +Given a non-singular $n\times n$ matrix $U$ over a field which is upper +triangular, and an $m\times n$ matrix $b$ over the same ring, return an +$m\times n$ matrix $x$ such that $xU = b$. If this is not possible, an error +will be raised. + +See also [`_solve_triu`](@ref) and [`Strassen`](@ref) for asymptotically fast + versions. +""" +function _solve_triu_left(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement} + n = ncols(U) + m = nrows(b) + R = base_ring(U) + X = zero(b) + tmp = Vector{elem_type(R)}(undef, n) + t = R() + for i = 1:m + for j = 1:n + tmp[j] = X[i, j] + end + for j = 1:n + s = R() + for k = 1:j-1 + s = addmul!(s, U[k, j], tmp[k], t) + end + s = b[i, j] - s + tmp[j] = divexact(s, U[j,j]) + end + for j = 1:n + X[i, j] = tmp[j] + end + end + return X +end + #solves A x = B for A intended to be lower triangular #only the lower part is used. if f is true, then the diagonal is assumed to be 1 #used to use lu! diff --git a/test/Solve-test.jl b/test/Solve-test.jl index ee19dbb564..849d3d37f7 100644 --- a/test/Solve-test.jl +++ b/test/Solve-test.jl @@ -259,3 +259,15 @@ end @test ncols(S) == 3 @test base_ring(S) == QQ end + +@testset "solve_triu" begin + A = matrix(ZZ, 10, 10, [i<=j ? i+j-1 : 0 for i=1:10 for j=1:10]) + x = matrix(ZZ, rand(-10:10, 10, 10)) + @test AbstractAlgebra._solve_triu(A, A*x; side = :right) == x + @test AbstractAlgebra._solve_triu(A, x*A; side = :left) == x + + A = matrix(ZZ, 20, 20, [i<=j ? i+j-1 : 0 for i=1:20 for j=1:20]) + x = matrix(ZZ, rand(-10:10, 20, 20)) + @test AbstractAlgebra.Strassen._solve_triu(A, A*x; cutoff = 10, side = :right) == x + @test AbstractAlgebra.Strassen._solve_triu(A, x*A; cutoff = 10, side = :left) == x +end diff --git a/test/generic/Matrix-test.jl b/test/generic/Matrix-test.jl index f57f9f3f38..559a3d7584 100644 --- a/test/generic/Matrix-test.jl +++ b/test/generic/Matrix-test.jl @@ -2658,7 +2658,7 @@ end M = randmat_triu(S, -100:100) b = rand(U, -100:100) - x = AbstractAlgebra._solve_triu(M, b, false) + x = AbstractAlgebra._solve_triu_right(M, b; unipotent = false) @test M*x == b end