Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Operators for higher dimensional Tensorspaces #390

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
6 changes: 6 additions & 0 deletions src/Caching/blockbanded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@

diagblockshift(a,b) = error("Developer: Not implemented for blocklengths $a, $b")

function diagblockshift(a::BroadcastArray{<:Int, 1}, b::BroadcastArray{<:Int, 1})
if a.f == b.f
return 0
end
error("Broadcastvectors of blocklengths $a, $b are not implemented.")

Check warning on line 29 in src/Caching/blockbanded.jl

View check run for this annotation

Codecov / codecov/patch

src/Caching/blockbanded.jl#L29

Added line #L29 was not covered by tests
end
function diagblockshift(a::AbstractRange, b::AbstractRange)
@assert step(a) == step(b)
first(b)-first(a)
Expand Down
131 changes: 122 additions & 9 deletions src/Multivariate/TensorSpace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@

const InfOnes = Ones{Int,1,Tuple{OneToInf{Int}}}
const Tensorizer2D{AA, BB} = Tensorizer{Tuple{AA, BB}}
const MixedTrivConstTensorizer{d} = Tensorizer{<:Tuple{Vararg{Union{InfOnes, SVector{1, <:Int}},d}}} # const or trivial
# TrivialTensorizer and ConstantTensorizer are special cases of MixedTrivConstTensorizer
const TrivialTensorizer{d} = Tensorizer{NTuple{d,InfOnes}}
const ConstantTensorizer{d} = Tensorizer{<:NTuple{d,SVector{1, <:Int}}} # for all dimensions constant

eltype(::Type{<:Tensorizer{<:Tuple{Vararg{Any,N}}}}) where {N} = NTuple{N,Int}
dimensions(a::Tensorizer) = map(sum,a.blocks)
Expand All @@ -40,6 +43,24 @@

Base.keys(a::Tensorizer) = oneto(length(a))


function start(a::ConstantTensorizer{d}) where {d}
@assert length(a) == 1
block = ntuple(one, d)
return (block, (0,1))

Check warning on line 50 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L47-L50

Added lines #L47 - L50 were not covered by tests
end

function next(a::ConstantTensorizer{d}, iterator_tuple) where {d}
(block, (i,tot)) = iterator_tuple
ret = block
ret, (block, (i+1,tot))

Check warning on line 56 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L53-L56

Added lines #L53 - L56 were not covered by tests
end

function done(a::ConstantTensorizer, iterator_tuple)::Bool
i, tot = last(iterator_tuple)
return i ≥ tot

Check warning on line 61 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L59-L61

Added lines #L59 - L61 were not covered by tests
end

function start(a::TrivialTensorizer{d}) where {d}
# ((block_dim_1, block_dim_2,...), (itaration_number, iterator, iterator_state)), (itemssofar, length)
block = ntuple(one, d)
Expand Down Expand Up @@ -78,21 +99,71 @@
ret, ((block, (j, iterator, iter_state)), (i,tot))
end


function done(a::TrivialTensorizer, iterator_tuple)::Bool
i, tot = last(iterator_tuple)
return i ≥ tot
end


function start(a::MixedTrivConstTensorizer{d}) where {d}

Check warning on line 108 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L108

Added line #L108 was not covered by tests
# const indices are always left to be one
relevant_ind = filter!(i->i≠0, map(i->a.blocks[i] isa SVector{1, <:Int} ? 0 : i,1:length(a.blocks)))
real_d = length(relevant_ind)

Check warning on line 111 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L110-L111

Added lines #L110 - L111 were not covered by tests
# ((block_dim_1, block_dim_2,...), (itaration_number, iterator, iterator_state)), (itemssofar, length)
block = ones(Int, real_d)
return (block, (relevant_ind, real_d, 0, nothing, nothing)), (0,length(a))

Check warning on line 114 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L113-L114

Added lines #L113 - L114 were not covered by tests
end

function next(a::MixedTrivConstTensorizer{d}, iterator_tuple) where {d}
(block, (relevant_ind, real_d, j, iterator, iter_state)), (i,tot) = iterator_tuple

Check warning on line 118 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L117-L118

Added lines #L117 - L118 were not covered by tests

@inline function check_block_finished(j, iterator, block)
if iterator === nothing
return true

Check warning on line 122 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L120-L122

Added lines #L120 - L122 were not covered by tests
end
# there are N-1 over d-1 combinations in a block
amount_combinations_block = binomial(sum(block)-1, real_d-1)

Check warning on line 125 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L125

Added line #L125 was not covered by tests
# check if all combinations have been iterated over
amount_combinations_block <= j

Check warning on line 127 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L127

Added line #L127 was not covered by tests
end

ret_vec = ones(Int, d)
ret_vec[relevant_ind] = reverse(block)
ret = Tuple(SVector{d}(ret_vec))

Check warning on line 132 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L130-L132

Added lines #L130 - L132 were not covered by tests

if check_block_finished(j, iterator, block) # end of new block

Check warning on line 134 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L134

Added line #L134 was not covered by tests
# set up iterator for new block
current_sum = sum(block)
iterator = multiexponents(real_d, current_sum+1-real_d)
iter_state = nothing
j = 0

Check warning on line 139 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L136-L139

Added lines #L136 - L139 were not covered by tests
end

# increase block, or initialize new block
_res, iter_state = iterate(iterator, iter_state)

Check warning on line 143 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L143

Added line #L143 was not covered by tests
# res = Tuple(SVector{real_d}(_res))
block = _res.+1
j = j+1

Check warning on line 146 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L145-L146

Added lines #L145 - L146 were not covered by tests

ret, ((block, (relevant_ind, real_d, j, iterator, iter_state)), (i,tot))

Check warning on line 148 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L148

Added line #L148 was not covered by tests
end

function done(a::MixedTrivConstTensorizer, iterator_tuple)::Bool
i, tot = last(iterator_tuple)
return i ≥ tot

Check warning on line 153 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L151-L153

Added lines #L151 - L153 were not covered by tests
end


# (blockrow,blockcol), (subrow,subcol), (rowshift,colshift), (numblockrows,numblockcols), (itemssofar, length)
start(a::Tensorizer2D) = _start(a)
start(a::TrivialTensorizer{2}) = _start(a)
start(a::MixedTrivConstTensorizer{2}) = _start(a)

Check warning on line 160 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L160

Added line #L160 was not covered by tests

_start(a) = (1,1, 1,1, 0,0, a.blocks[1][1],a.blocks[2][1]), (0,length(a))

next(a::Tensorizer2D, state) = _next(a, state::typeof(_start(a)))
next(a::TrivialTensorizer{2}, state) = _next(a, state::typeof(_start(a)))
next(a::MixedTrivConstTensorizer{2}, state) = _next(a, state::typeof(_start(a)))

Check warning on line 166 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L166

Added line #L166 was not covered by tests

function _next(a, st)
(K,J, k,j, rsh,csh, n,m), (i,tot) = st
Expand Down Expand Up @@ -121,6 +192,7 @@

done(a::Tensorizer2D, state) = _done(a, state::typeof(_start(a)))
done(a::TrivialTensorizer{2}, state) = _done(a, state::typeof(_start(a)))
done(a::MixedTrivConstTensorizer{2}, state) = _done(a, state::typeof(_start(a)))

Check warning on line 195 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L195

Added line #L195 was not covered by tests

function _done(a, st)::Bool
i, tot = last(st)
Expand Down Expand Up @@ -216,7 +288,18 @@

blocklengths(::TrivialTensorizer{2}) = 1:∞


## anonymous function needed in order to compare if two blocklenghts are equal
_blocklengths_trivialTensorizer(d) = let d=d
x->binomial(x+(d-2), d-1)
end
blocklengths(::TrivialTensorizer{d}) where {d} = _blocklengths_trivialTensorizer(d).(1:∞)
blocklengths(::ConstantTensorizer) = SVector(1)
function blocklengths(a::MixedTrivConstTensorizer{d}) where {d}
real_d = mapreduce(bl->bl isa SVector{1, <:Int} ? 0 : 1, +, a.blocks)
real_d == 0 && return SVector(1)
real_d == 1 && return 1:∞
return _blocklengths_trivialTensorizer(real_d).(1:∞)

Check warning on line 301 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L296-L301

Added lines #L296 - L301 were not covered by tests
end

blocklengths(it::Tensorizer) = tensorblocklengths(it.blocks...)
blocklengths(it::CachedIterator) = blocklengths(it.iterator)
Expand Down Expand Up @@ -304,7 +387,15 @@
const TensorSpaceND{d, D, R} = TensorSpace{<:NTuple{d, <:UnivariateSpace}, D, R}

tensorizer(sp::TensorSpace) = Tensorizer(map(blocklengths,sp.spaces))
blocklengths(S::TensorSpace) = tensorblocklengths(map(blocklengths,S.spaces)...)
function blocklengths(S::TensorSpace)
list_blocks = map(blocklengths,S.spaces)
if all(x->x == Ones{Int}(ℵ₀), list_blocks)
d = length(S.spaces)
return _blocklengths_trivialTensorizer(d).(1:∞)
else
return tensorblocklengths(list_blocks...)
end
end


# the evaluation is *, so the type will be the same as *
Expand All @@ -327,11 +418,11 @@
==(A::TensorSpace{<:NTuple{N,Space}}, B::TensorSpace{<:NTuple{N,Space}}) where {N} =
factors(A) == factors(B)

conversion_rule(a::TensorSpace{<:NTuple{2,Space}}, b::TensorSpace{<:NTuple{2,Space}}) =
conversion_type(a.spaces[1],b.spaces[1]) ⊗ conversion_type(a.spaces[2],b.spaces[2])
conversion_rule(a::TensorSpace{<:NTuple{N,Space}}, b::TensorSpace{<:NTuple{N,Space}}) where {N} =
mapreduce((a,b)->conversion_type(a,b),⊗,a.spaces,b.spaces)

Check warning on line 422 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L421-L422

Added lines #L421 - L422 were not covered by tests

maxspace_rule(a::TensorSpace{<:NTuple{2,Space}}, b::TensorSpace{<:NTuple{2,Space}}) =
maxspace(a.spaces[1],b.spaces[1]) ⊗ maxspace(a.spaces[2],b.spaces[2])
maxspace_rule(a::TensorSpace{<:NTuple{N,Space}}, b::TensorSpace{<:NTuple{N,Space}}) where {N} =
mapreduce((a,b)->maxspace(a,b),⊗,a.spaces,b.spaces)

function spacescompatible(A::TensorSpace{<:NTuple{N,Space}}, B::TensorSpace{<:NTuple{N,Space}}) where {N}
_spacescompatible(factors(A), factors(B))
Expand Down Expand Up @@ -621,7 +712,7 @@
ret
end

@inline function totensoriterator(it::TrivialTensorizer{d},M::AbstractVector) where {d}
@inline function totensoriterator(it::MixedTrivConstTensorizer{d} ,M::AbstractVector) where {d}

Check warning on line 715 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L715

Added line #L715 was not covered by tests
B=block(it,length(M))
return it, M, B
end
Expand Down Expand Up @@ -664,7 +755,29 @@
evaluate(f::AbstractVector,S::TensorSpace2D,x,y) = ProductFun(totensor(S,f),S)(x,y)

# ND evaluation functions of Trivial Spaces
evaluate(f::AbstractVector,S::TensorSpaceND,x) = TrivialTensorFun(totensor(S, f)..., S)(x...)
not_const_spaces_indices(S) = filter!(i->i≠0, map(i->S.spaces[i] isa ConstantSpace ? 0 : i,1:length(S.spaces)))
function evaluate(f::AbstractVector,S::TensorSpaceND,x)
if !any(s->s isa ConstantSpace, S.spaces)
return TrivialTensorFun(totensor(S, f)..., S)(x...)

Check warning on line 761 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L758-L761

Added lines #L758 - L761 were not covered by tests
end
not_cons_indices = not_const_spaces_indices(S)
xmod = if length(x) == length(not_cons_indices)
x

Check warning on line 765 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L763-L765

Added lines #L763 - L765 were not covered by tests
else
x[not_cons_indices]

Check warning on line 767 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L767

Added line #L767 was not covered by tests
end
(length(not_cons_indices) == 0) && return f[1]
S_new = reduce(⊗, S.spaces[not_cons_indices])
if length(not_cons_indices) > 2
return TrivialTensorFun(totensor(S_new, f)..., S_new)(x...)
elseif length(S_new) == 2
return ProductFun(totensor(S_new, f), S_new)(x...)
elseif length(S_new) == 1
return Fun(S_new[1], f)(x)

Check warning on line 776 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L769-L776

Added lines #L769 - L776 were not covered by tests
else
error("This should not happen")

Check warning on line 778 in src/Multivariate/TensorSpace.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TensorSpace.jl#L778

Added line #L778 was not covered by tests
end
end

coefficientmatrix(f::Fun{<:AbstractProductSpace}) = totensor(space(f),f.coefficients)

Expand Down
10 changes: 4 additions & 6 deletions src/Multivariate/TrivialTensorFun.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@

struct TrivialTensorFun{d, SS<:TensorSpaceND{d}, T<:Number} <: MultivariateFun{T, d}
space::SS
coefficients::Vector{T}
coefficients::AbstractVector{T}
iterator::TrivialTensorizer{d}
orders::Block{1, Int}
end


function TrivialTensorFun(iter::TrivialTensorizer{d},cfs::Vector{T},blk::Block, sp::TensorSpaceND{d}) where {T<:Number,d}
if any(map(dimension, sp.spaces).!=ℵ₀)
error("This Space is not a Trivial Tensor space!")
end
function TrivialTensorFun(iter::TrivialTensorizer{d},

Check warning on line 12 in src/Multivariate/TrivialTensorFun.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TrivialTensorFun.jl#L12

Added line #L12 was not covered by tests
cfs::AbstractVector{T}, blk::Block, sp::TensorSpaceND{d}) where {T<:Number,d}
TrivialTensorFun(sp, cfs, iter, blk)
end

(f::TrivialTensorFun)(x...) = evaluate(f, x...)

# TensorSpace evaluation
function evaluate(f::TrivialTensorFun{d, SS, T},x...) where {d, SS, T}
highest_order = f.orders.n[1]
highest_order = f.orders.n[1]-1

Check warning on line 21 in src/Multivariate/TrivialTensorFun.jl

View check run for this annotation

Codecov / codecov/patch

src/Multivariate/TrivialTensorFun.jl#L21

Added line #L21 was not covered by tests
n = length(f.coefficients)

# this could be lazy evaluated for the sparse case
Expand Down
Loading