Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mutating arithmetic for SRows #1659

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 165 additions & 19 deletions src/Sparse/Row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@
return A
end

function Base.empty(A::SRow)
return sparse_row(base_ring(A))
end

function zero(A::SRow)
return empty(A)
end

function swap!(A::SRow, B::SRow)
A.pos, B.pos = B.pos, A.pos
A.values, B.values = B.values, A.values
Expand Down Expand Up @@ -447,15 +455,17 @@
# Inplace scaling
#
################################################################################

@doc raw"""
scale_row!(a::SRow, b::NCRingElem) -> SRow

Returns the (left) product of $b \times a$ and reassigns the value of $a$ to this product.
For rows, the standard multiplication is from the left.
"""
function scale_row!(a::SRow{T}, b::T) where T
@assert !iszero(b)
if isone(b)
if iszero(b)
return empty!(a)

Check warning on line 467 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L467

Added line #L467 was not covered by tests
elseif isone(b)
return a
end
i = 1
Expand All @@ -465,20 +475,23 @@
deleteat!(a.values, i)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Orthogonal to this PR, but: Since you now handle b==0 at the start, the iszero(a.values[i]) check above the fold can only return true if the coefficient ring is not a domain. So it could be strengthened to something like !is_domain_type(T) && iszero(a.values[i]).

Since is_domain_type is a trait depending only on T, the compiler can eliminate the is_domain_type(T) check -- if it returns true it can elide the if block, and if it is false we the same code we have currently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this can easily wait for a follow-up PR. (Also scale_row! and scale_row_right! could be merged)

deleteat!(a.pos, i)
else
i += 1
i += 1
end
end
return a
end

scale_row!(a::SRow, b) = scale_row!(a, base_ring(a)(b))

@doc raw"""
scale_row_right!(a::SRow, b::NCRingElem) -> SRow

Returns the (right) product of $a \times b$ and modifies $a$ to this product.
"""
function scale_row_right!(a::SRow{T}, b::T) where T
@assert !iszero(b)
if isone(b)
if iszero(b)
return empty!(a)

Check warning on line 493 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L493

Added line #L493 was not covered by tests
elseif isone(b)
return a
end
i = 1
Expand All @@ -488,16 +501,20 @@
deleteat!(a.values, i)
deleteat!(a.pos, i)
else
i += 1
i += 1
end
end
return a
end

scale_row_right!(a::SRow, b) = scale_row_right!(a, base_ring(a)(b))

Check warning on line 510 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L510

Added line #L510 was not covered by tests

function scale_row_left!(a::SRow{T}, b::T) where T
return scale_row!(a,b)
end

scale_row_left!(a::SRow, b) = scale_row_left!(a, base_ring(a)(b))

Check warning on line 516 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L516

Added line #L516 was not covered by tests

################################################################################
#
# Addition
Expand All @@ -506,22 +523,22 @@

function +(A::SRow{T}, B::SRow{T}) where T
if length(A.values) == 0
return B
return deepcopy(B)
elseif length(B.values) == 0
return A
return deepcopy(A)
end
return add_scaled_row(A, B, one(base_ring(A)))
end

function -(A::SRow{T}, B::SRow{T}) where T
if length(A) == 0
if length(B) == 0
return A
return deepcopy(A)

Check warning on line 536 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L536

Added line #L536 was not covered by tests
else
return add_scaled_row(B, A, base_ring(B)(-1))
return add_scaled_row(B, A, -1)
end
end
return add_scaled_row(B, A, base_ring(A)(-1))
return add_scaled_row(B, A, -1)
end

function -(A::SRow{T}) where {T}
Expand Down Expand Up @@ -683,10 +700,10 @@

Returns the row $c A + B$.
"""
add_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_scaled_row!(a, deepcopy(b), c)
add_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_scaled_row!(a, deepcopy(b), c)

add_left_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_left_scaled_row!(a, deepcopy(b), c)
add_right_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_right_scaled_row!(a, deepcopy(b), c)
add_left_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_left_scaled_row!(a, deepcopy(b), c)
add_right_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_right_scaled_row!(a, deepcopy(b), c)

Check warning on line 706 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L705-L706

Added lines #L705 - L706 were not covered by tests



Expand All @@ -696,7 +713,9 @@
Adds the left scaled row $c A$ to $B$.
"""
function add_scaled_row!(a::SRow{T}, b::SRow{T}, c::T, ::Val{left_side} = Val(true)) where {T, left_side}
@assert a !== b
if a === b
a = deepcopy(a)

Check warning on line 717 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L717

Added line #L717 was not covered by tests
end
i = 1
j = 1
t = base_ring(a)()
Expand Down Expand Up @@ -735,17 +754,144 @@
return b
end

add_scaled_row!(a::SRow{T}, b::SRow{T}, c) where {T} = add_scaled_row!(a, b, base_ring(a)(c))

add_scaled_row!(a::SRow{T}, b::SRow{T}, c, side::Val) where {T} = add_scaled_row!(a, b, base_ring(a)(c), side)

Check warning on line 759 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L759

Added line #L759 was not covered by tests

# ignore tmp argument
add_scaled_row!(a::SRow{T}, b::SRow{T}, c::T, tmp::SRow{T}) where T = add_scaled_row!(a, b, c)
add_scaled_row!(a::SRow{T}, b::SRow{T}, c, tmp::SRow{T}) where T = add_scaled_row!(a, b, c)

Check warning on line 762 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L762

Added line #L762 was not covered by tests

add_left_scaled_row!(a::SRow{T}, b::SRow{T}, c::T) where T = add_scaled_row!(a, b, c)
add_left_scaled_row!(a::SRow{T}, b::SRow{T}, c) where T = add_scaled_row!(a, b, c)

Check warning on line 764 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L764

Added line #L764 was not covered by tests

@doc raw"""
add_right_scaled_row!(A::SRow{T}, B::SRow{T}, c::T) -> SRow{T}

Return the right scaled row $c A$ to $B$ by changing $B$ in place.
Return the right scaled row $A c$ to $B$ by changing $B$ in place.
"""
add_right_scaled_row!(a::SRow{T}, b::SRow{T}, c::T) where T = add_scaled_row!(a, b, c, Val(false))
add_right_scaled_row!(a::SRow{T}, b::SRow{T}, c) where T = add_scaled_row!(a, b, c, Val(false))


################################################################################
#
# Mutating arithmetics
#
################################################################################

function zero!(z::SRow)
return empty!(z)
end

function neg!(z::SRow{T}, x::SRow{T}) where T
if z === x
return neg!(x)

Check warning on line 786 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L786

Added line #L786 was not covered by tests
end
swap!(z, -x)
return z
end

function neg!(z::SRow)
for i in 1:length(z)
z.values[i] = neg!(z.values[i])
end
return z
end

function add!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
if z === x
return add!(x, y)
elseif z === y
return add!(y, x)
end
swap!(z, x + y)
return z
end

function add!(z::SRow{T}, x::SRow{T}) where T
if z === x
return scale_row!(z, 2)
end
return add_scaled_row!(x, z, one(base_ring(x)))
end

function sub!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
if z === x
return sub!(x, y)
elseif z === y
return neg!(sub!(y, x))
end
swap!(z, x - y)
return z
end

function sub!(z::SRow{T}, x::SRow{T}) where T
if z === x
return empty!(z)
end
return add_scaled_row!(x, z, -1)
end

function mul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
error("Not implemented")

Check warning on line 834 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L833-L834

Added lines #L833 - L834 were not covered by tests
end

function mul!(z::SRow{T}, x::SRow{T}, c) where T
if z === x
return scale_row_right!(x, c)

Check warning on line 839 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L837-L839

Added lines #L837 - L839 were not covered by tests
end
swap!(z, x * c)
return z

Check warning on line 842 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L841-L842

Added lines #L841 - L842 were not covered by tests
end

function mul!(z::SRow{T}, c, y::SRow{T}) where T
if z === y
return scale_row_left!(y, c)

Check warning on line 847 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L845-L847

Added lines #L845 - L847 were not covered by tests
end
swap!(z, c * y)
return z

Check warning on line 850 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L849-L850

Added lines #L849 - L850 were not covered by tests
end

function addmul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
error("Not implemented")

Check warning on line 854 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L853-L854

Added lines #L853 - L854 were not covered by tests
end

function addmul!(z::SRow{T}, x::SRow{T}, y) where T
if z === x
return scale_row_right!(x, y+1)

Check warning on line 859 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L857-L859

Added lines #L857 - L859 were not covered by tests
end
return add_right_scaled_row!(x, z, y)

Check warning on line 861 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L861

Added line #L861 was not covered by tests
end

function addmul!(z::SRow{T}, x, y::SRow{T}) where T
if z === x
return scale_row_left!(y, x+1)

Check warning on line 866 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L864-L866

Added lines #L864 - L866 were not covered by tests
end
return add_left_scaled_row!(y, z, x)

Check warning on line 868 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L868

Added line #L868 was not covered by tests
end

function submul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
error("Not implemented")

Check warning on line 872 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L871-L872

Added lines #L871 - L872 were not covered by tests
end

function submul!(z::SRow{T}, x::SRow{T}, y) where T
if z === x
return scale_row_right!(x, -y+1)

Check warning on line 877 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L875-L877

Added lines #L875 - L877 were not covered by tests
end
return add_right_scaled_row!(x, z, -y)

Check warning on line 879 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L879

Added line #L879 was not covered by tests
end

function submul!(z::SRow{T}, x, y::SRow{T}) where T
if z === x
return scale_row_left!(y, -x+1)

Check warning on line 884 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L882-L884

Added lines #L882 - L884 were not covered by tests
end
return add_left_scaled_row!(y, z, -x)

Check warning on line 886 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L886

Added line #L886 was not covered by tests
end


# ignore temp variable
addmul!(z::SRow{T}, x::SRow{T}, y, t) where T = addmul!(z, x, y)
addmul!(z::SRow{T}, x, y::SRow{T}, t) where T = addmul!(z, x, y)
submul!(z::SRow{T}, x::SRow{T}, y, t) where T = submul!(z, x, y)
submul!(z::SRow{T}, x, y::SRow{T}, t) where T = submul!(z, x, y)

Check warning on line 894 in src/Sparse/Row.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/Row.jl#L891-L894

Added lines #L891 - L894 were not covered by tests


################################################################################
Expand Down
4 changes: 3 additions & 1 deletion src/Sparse/ZZRow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@

function add_scaled_row(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingElem, sr::SRow{ZZRingElem} = sparse_row(ZZ))
empty!(sr)
@assert c != 0
n = ZZRingElem()
pi = 1
pj = 1
Expand Down Expand Up @@ -323,6 +322,9 @@
end

function add_scaled_row!(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingElem, sr::SRow{ZZRingElem} = sparse_row(ZZ))
if iszero(c)
return Aj

Check warning on line 326 in src/Sparse/ZZRow.jl

View check run for this annotation

Codecov / codecov/patch

src/Sparse/ZZRow.jl#L326

Added line #L326 was not covered by tests
end
_t = sr
sr = add_scaled_row(Ai, Aj, c, sr)
@assert _t === sr
Expand Down
36 changes: 36 additions & 0 deletions test/Sparse/Row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,40 @@
B = sparse_row(F,[1],[y])
C = add_scaled_row(A,B,F(1))
@test C == A+B

# mutating arithmetic
randcoeff() = begin
n = rand((1,1,1,2,5,7,15))
return rand(-2^n:2^n)
end
Main.equality(A::SRow, B::SRow) = A == B
@testset "mutating arithmetic; R = $R" for R in (ZZ, QQ)
for _ in 1:10
maxind_A = rand(0:10)
inds_A = Hecke.Random.randsubseq(1:maxind_A, rand())
vals_A = elem_type(R)[R(rand((-1, 1)) * rand(1:10)) for _ in 1:length(inds_A)]
A = sparse_row(R, inds_A, vals_A)

maxind_B = rand(0:10)
inds_B = Hecke.Random.randsubseq(1:maxind_B, rand())
vals_B = elem_type(R)[R(rand((-1, 1)) * rand(1:10)) for _ in 1:length(inds_B)]
B = sparse_row(R, inds_B, vals_B)

test_mutating_op_like_zero(zero, zero!, A)

test_mutating_op_like_neg(-, neg!, A)

test_mutating_op_like_add(+, add!, A, B)
test_mutating_op_like_add(-, sub!, A, B)
# test_mutating_op_like_mul(*, mul!, A, randcoeff(); right_factor_is_scalar=true)
# test_mutating_op_like_mul(*, mul!, randcoeff(), A; left_factor_is_scalar=true)
# test_mutating_op_like_mul(*, mul!, A, ZZ(randcoeff()); right_factor_is_scalar=true)
# test_mutating_op_like_mul(*, mul!, ZZ(randcoeff()), A; left_factor_is_scalar=true)

# test_mutating_op_like_addmul((a, b, c) -> a + b*c, addmul!, A, B, randcoeff(); right_factor_is_scalar=true)
# test_mutating_op_like_addmul((a, b, c) -> a + b*c, addmul!, A, randcoeff(), B; left_factor_is_scalar=true)
# test_mutating_op_like_addmul((a, b, c) -> a - b*c, submul!, A, B, randcoeff(); right_factor_is_scalar=true)
# test_mutating_op_like_addmul((a, b, c) -> a - b*c, submul!, A, randcoeff(), B; left_factor_is_scalar=true)
end
end
end
Loading