From 4b04d3ca9e7bf0b21be82e637655c62067e37b0c Mon Sep 17 00:00:00 2001 From: Alfredo Braunstein Date: Wed, 18 Oct 2023 09:24:57 -0300 Subject: [PATCH] Use generic trace methods when possible (#16) * resolve misterious memory leaks * add trace function * fix bug in dimensions * Move many functions to abstract_tensor_train * simplify marginals() * reorder product for more efficient calculation in marginals --- src/abstract_tensor_train.jl | 157 +++++++++++++++++++++++++++++++---- src/periodic_tensor_train.jl | 108 ------------------------ src/tensor_train.jl | 145 -------------------------------- test/tensor_train.jl | 4 +- 4 files changed, 144 insertions(+), 270 deletions(-) diff --git a/src/abstract_tensor_train.jl b/src/abstract_tensor_train.jl index 3be200b..67a716d 100644 --- a/src/abstract_tensor_train.jl +++ b/src/abstract_tensor_train.jl @@ -8,6 +8,27 @@ abstract type AbstractTensorTrain{F<:Number, N} end Base.eltype(::AbstractTensorTrain{F,N}) where {N,F} = F + +""" + bond_dims(A::AbstractTensorTrain) + +Return a vector with the dimensions of the virtual bonds +""" +bond_dims(A::AbstractTensorTrain) = [size(A[t], 1) for t in 1:lastindex(A)] + +function check_bond_dims(tensors::Vector{<:Array}) + for t in 1:lastindex(tensors) + dᵗ = size(tensors[t],2) + dᵗ⁺¹ = size(tensors[mod1(t+1, length(tensors))],1) + if dᵗ != dᵗ⁺¹ + println("Bond size for matrix t=$t. dᵗ=$dᵗ, dᵗ⁺¹=$dᵗ⁺¹") + return false + end + end + return true +end + + """ normalize_eachmatrix!(A::AbstractTensorTrain) @@ -26,34 +47,140 @@ function normalize_eachmatrix!(A::AbstractTensorTrain) c end +function StatsBase.sample!(rng::AbstractRNG, x, A::AbstractTensorTrain{F,N}; + r = accumulate_R(A)) where {F<:Real,N} +L = length(A) +@assert length(x) == L +@assert all(length(xᵗ) == N-2 for xᵗ in x) +d = first(bond_dims(A)) + +Q = Matrix(I, d, d) # stores product of the first `t` matrices, evaluated at the sampled `x¹,...,xᵗ` +for t in eachindex(A) + rᵗ⁺¹ = t == L ? Matrix(I, d, d) : r[t+1] + # collapse multivariate xᵗ into 1D vector, sample from it + Aᵗ = _reshape1(A[t]) + @tullio QA[k,n,x] := Q[k,m] * Aᵗ[m,n,x] + @tullio p[x] := QA[k,n,x] * rᵗ⁺¹[n,k] + p ./= sum(p) + xᵗ = sample_noalloc(rng, p) + x[t] .= CartesianIndices(size(A[t])[3:end])[xᵗ] |> Tuple + # update prob + Q = Q * Aᵗ[:,:,xᵗ] +end +p = tr(Q) / tr(first(r)) +return x, p +end + + Base.:(==)(A::T, B::T) where {T<:AbstractTensorTrain} = isequal(A.tensors, B.tensors) Base.isapprox(A::T, B::T; kw...) where {T<:AbstractTensorTrain} = isapprox(A.tensors, B.tensors; kw...) function accumulate_M(A::AbstractTensorTrain) L = length(A) - M = [zeros(0, 0) for _ in 1:L, _ in 1:L] - - # initial condition - for t in 1:L-1 - range_aᵗ⁺¹ = axes(A[t+1], 1) - Mᵗᵗ⁺¹ = [float((a == c)) for a in range_aᵗ⁺¹, c in range_aᵗ⁺¹] - M[t, t+1] = Mᵗᵗ⁺¹ - end - for t in 1:L-1 - Mᵗᵘ⁻¹ = M[t, t+1] - for u in t+2:L - Aᵘ⁻¹ = _reshape1(A[u-1]) - @tullio Mᵗᵘ[aᵗ⁺¹, aᵘ] := Mᵗᵘ⁻¹[aᵗ⁺¹, aᵘ⁻¹] * Aᵘ⁻¹[aᵘ⁻¹, aᵘ, x] - M[t, u] = Mᵗᵘ - Mᵗᵘ⁻¹, Mᵗᵘ = Mᵗᵘ, Mᵗᵘ⁻¹ + M = fill(zeros(0, 0), L, L) + + for u in 2:L + Au = trace(A[u-1]) + for t in 1:u-2 + M[t, u] = M[t, u-1] * Au end + # initial condition + M[u-1, u] = Matrix(I, size(A[u],1), size(A[u],1)) end return M end +trace(At) = @tullio _[aᵗ,aᵗ⁺¹] := _reshape1(At)[aᵗ,aᵗ⁺¹,x] + +function accumulate_L(A::AbstractTensorTrain) + L = Matrix(I, size(A[begin],1), size(A[begin],1)) + map(trace(Atx) for Atx in A) do At + L = L * At + end +end + +function accumulate_R(A::AbstractTensorTrain) + R = Matrix(I, size(A[end],2), size(A[end],2)) + map(trace(Atx) for Atx in Iterators.reverse(A)) do At + R = At * R + end |> reverse +end + +""" + marginals(A::AbstractTensorTrain; l, r) + +Compute the marginal distributions ``p(x^l)`` at each site + +### Optional arguments +- `l = accumulate_L(A)`, `r = accumulate_R(A)` pre-computed partial normalizations +""" +function marginals(A::AbstractTensorTrain{F,N}; + l = accumulate_L(A), r = accumulate_R(A)) where {F<:Real,N} + + map(eachindex(A)) do t + Aᵗ = _reshape1(A[t]) + R = t + 1 ≤ length(A) ? r[t+1] : Matrix(I, size(A[end],2), size(A[end],2)) + L = t - 1 ≥ 1 ? l[t-1] : Matrix(I, size(A[begin],1), size(A[begin],1)) + @tullio lA[a¹,aᵗ⁺¹,x] := L[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] + @tullio pᵗ[x] := lA[a¹,aᵗ⁺¹,x] * R[aᵗ⁺¹,a¹] + #@reduce pᵗ[x] := sum(a¹,aᵗ,aᵗ⁺¹) lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] * rᵗ⁺¹[aᵗ⁺¹,a¹] + pᵗ ./= sum(pᵗ) + reshape(pᵗ, size(A[t])[3:end]) + end +end + +""" + twovar_marginals(A::AbstractTensorTrain; l, r, M, Δlmax) + +Compute the marginal distributions for each pair of sites ``p(x^l, x^m)`` + +### Optional arguments +- `l = accumulate_L(A)`, `r = accumulate_R(A)`, `M = accumulate_M(A)` pre-computed partial normalizations +- `maxdist = length(A)`: compute marginals only at distance `maxdist`: ``|l-m|\\le maxdist`` +""" +function twovar_marginals(A::AbstractTensorTrain{F,N}; + l = accumulate_L(A), r = accumulate_R(A), M = accumulate_M(A), + maxdist = length(A)-1) where {F<:Real,N} + qs = tuple(reduce(vcat, [x,x] for x in size(A[begin])[3:end])...) + b = Array{F,2*(N-2)}[zeros(zeros(Int, 2*(N-2))...) + for _ in eachindex(A), _ in eachindex(A)] + d = first(bond_dims(A)) + for t in 1:length(A)-1 + lᵗ⁻¹ = t == 1 ? Matrix(I, d, d) : l[t-1] + Aᵗ = _reshape1(A[t]) + for u in t+1:min(length(A),t+maxdist) + rᵘ⁺¹ = u == length(A) ? Matrix(I, d, d) : r[u+1] + Aᵘ = _reshape1(A[u]) + Mᵗᵘ = M[t, u] + rl = rᵘ⁺¹ * lᵗ⁻¹ + @tullio rlAt[aᵘ⁺¹, aᵗ⁺¹, xᵗ] := rl[aᵘ⁺¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] + @tullio rlAtMtu[aᵘ⁺¹,xᵗ,aᵘ] := rlAt[aᵘ⁺¹, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] + @tullio bᵗᵘ[xᵗ, xᵘ] := rlAtMtu[aᵘ⁺¹,xᵗ,aᵘ] * Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] + + #@tullio bᵗᵘ[xᵗ, xᵘ] := + #lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] * + #Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] * rᵘ⁺¹[aᵘ⁺¹,a¹] + + bᵗᵘ ./= sum(bᵗᵘ) + b[t,u] = reshape(bᵗᵘ, qs) + end + end + b +end + +""" + normalization(A::AbstractTensorTrain; l, r) + +Compute the normalization ``Z=\\sum_{x^1,\\ldots,x^L} A^1(x^1)\\cdots A^L(x^L)`` +""" +function normalization(A::AbstractTensorTrain; l = accumulate_L(A), r = accumulate_R(A)) + z = tr(l[end]) + @assert tr(r[begin]) ≈ z "z=$z, got $(tr(r[begin])), A=$A" # sanity check + z +end """ compress!(A::AbstractTensorTrain; svd_trunc::SVDTrunc) diff --git a/src/periodic_tensor_train.jl b/src/periodic_tensor_train.jl index 66c9aec..b7f4586 100644 --- a/src/periodic_tensor_train.jl +++ b/src/periodic_tensor_train.jl @@ -10,8 +10,6 @@ struct PeriodicTensorTrain{F<:Number, N} <: AbstractTensorTrain{F,N} function PeriodicTensorTrain{F,N}(tensors::Vector{Array{F,N}}) where {F<:Number, N} N > 2 || throw(ArgumentError("Tensors shold have at least 3 indices: 2 virtual and 1 physical")) - size(tensors[1],1) == size(tensors[end],2) || - throw(ArgumentError("Number of rows of the first matrix should coincide with the number of columns of the last matrix")) check_bond_dims(tensors) || throw(ArgumentError("Matrix indices for matrix product non compatible")) return new{F,N}(tensors) @@ -55,94 +53,10 @@ function rand_periodic_tt(bondsizes::AbstractVector{<:Integer}, q...) end rand_periodic_tt(d::Integer, L::Integer, q...) = rand_periodic_tt(fill(d, L-1), q...) -bond_dims(A::PeriodicTensorTrain) = [size(A[t], 1) for t in 1:lastindex(A)] - evaluate(A::PeriodicTensorTrain, X...) = tr(prod(@view a[:, :, x...] for (a,x) in zip(A, X...))) -function accumulate_L(A::PeriodicTensorTrain) - l = [zeros(0,0) for _ in eachindex(A)] - A⁰ = _reshape1(first(A)) - @reduce l⁰[a¹,a²] := sum(x) A⁰[a¹,a²,x] - l[1] = l⁰ - - lᵗ = l⁰ - for t in 1:length(A)-1 - Aᵗ = _reshape1(A[t+1]) - @reduce lᵗ[a¹,aᵗ⁺¹] |= sum(x,aᵗ) lᵗ[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] - l[t+1] = lᵗ - end - return l -end - -function accumulate_R(A::PeriodicTensorTrain) - r = [zeros(0,0) for _ in eachindex(A)] - A⁰ = _reshape1(last(A)) - @reduce rᴸ[aᴸ,a¹] := sum(x) A⁰[aᴸ,a¹,x] - r[end] = rᴸ - - rᵗ = rᴸ - for t in length(A)-1:-1:1 - Aᵗ = _reshape1(A[t]) - @reduce rᵗ[aᵗ,a¹] |= sum(x,aᵗ⁺¹) Aᵗ[aᵗ,aᵗ⁺¹,x] * rᵗ[aᵗ⁺¹,a¹] - r[t] = rᵗ - end - return r -end -function marginals(A::PeriodicTensorTrain{F,N}; - l = accumulate_L(A), r = accumulate_R(A)) where {F<:Real,N} - - A¹ = _reshape1(A[begin]); r² = r[2] - @reduce p¹[x] := sum(a¹,a²) A¹[a¹,a²,x] * r²[a²,a¹] - p¹ ./= sum(p¹) - p¹ = reshape(p¹, size(A[begin])[3:end]) - - Aᴸ = _reshape1(A[end]); lᴸ⁻¹ = l[end-1] - @reduce pᴸ[x] := sum(aᴸ,a¹) lᴸ⁻¹[a¹,aᴸ] * Aᴸ[aᴸ,a¹,x] - pᴸ ./= sum(pᴸ) - pᴸ = reshape(pᴸ, size(A[end])[3:end]) - - p = map(2:length(A)-1) do t - lᵗ⁻¹ = l[t-1] - Aᵗ = _reshape1(A[t]) - rᵗ⁺¹ = r[t+1] - @reduce pᵗ[x] := sum(a¹,aᵗ,aᵗ⁺¹) lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] * rᵗ⁺¹[aᵗ⁺¹,a¹] - pᵗ ./= sum(pᵗ) - reshape(pᵗ, size(A[t])[3:end]) - end - return append!([p¹], p, [pᴸ]) -end - -function twovar_marginals(A::PeriodicTensorTrain{F,N}; - l = accumulate_L(A), r = accumulate_R(A), M = accumulate_M(A), - maxdist = length(A)-1) where {F<:Real,N} - qs = tuple(reduce(vcat, [x,x] for x in size(A[begin])[3:end])...) - b = Array{F,2*(N-2)}[zeros(zeros(Int, 2*(N-2))...) - for _ in eachindex(A), _ in eachindex(A)] - d = first(bond_dims(A)) - for t in 1:length(A)-1 - lᵗ⁻¹ = t == 1 ? Matrix(I, d, d) : l[t-1] - Aᵗ = _reshape1(A[t]) - for u in t+1:min(length(A),t+maxdist) - rᵘ⁺¹ = u == length(A) ? Matrix(I, d, d) : r[u+1] - Aᵘ = _reshape1(A[u]) - Mᵗᵘ = M[t, u] - @tullio bᵗᵘ[xᵗ, xᵘ] := - lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] * - Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] * rᵘ⁺¹[aᵘ⁺¹,a¹] - bᵗᵘ ./= sum(bᵗᵘ) - b[t,u] = reshape(bᵗᵘ, qs) - end - end - b -end - -function normalization(A::PeriodicTensorTrain; l = accumulate_L(A), r = accumulate_R(A)) - z = tr(l[end]) - @assert tr(r[begin]) ≈ z "z=$z, got $(tr(r[begin])), A=$A" # sanity check - z -end function _compose(f, A::PeriodicTensorTrain{F,NA}, B::PeriodicTensorTrain{F,NB}) where {F,NA,NB} @assert NA == NB @@ -164,28 +78,6 @@ end PeriodicTensorTrain(A::TensorTrain) = PeriodicTensorTrain(A.tensors) -function StatsBase.sample!(rng::AbstractRNG, x, A::PeriodicTensorTrain{F,N}; - r = accumulate_R(A)) where {F<:Real,N} - L = length(A) - @assert length(x) == L - @assert all(length(xᵗ) == N-2 for xᵗ in x) - d = first(bond_dims(A)) - - Q = Matrix(I, d, d) # stores product of the first `t` matrices, evaluated at the sampled `x¹,...,xᵗ` - for t in eachindex(A) - rᵗ⁺¹ = t == L ? Matrix(I, d, d) : r[t+1] - # collapse multivariate xᵗ into 1D vector, sample from it - Aᵗ = _reshape1(A[t]) - @tullio p[x] := Q[k,m] * Aᵗ[m,n,x] * rᵗ⁺¹[n,k] - p ./= sum(p) - xᵗ = sample_noalloc(rng, p) - x[t] .= CartesianIndices(size(A[t])[3:end])[xᵗ] |> Tuple - # update prob - Q = Q * Aᵗ[:,:,xᵗ] - end - p = tr(Q) / tr(first(r)) - return x, p -end function orthogonalize_right!(C::PeriodicTensorTrain; svd_trunc=TruncThresh(1e-6)) C⁰ = _reshape1(C[begin]) diff --git a/src/tensor_train.jl b/src/tensor_train.jl index 19ca53a..3271970 100644 --- a/src/tensor_train.jl +++ b/src/tensor_train.jl @@ -26,17 +26,6 @@ end check_bond_dims, length, eachindex -function check_bond_dims(tensors::Vector{<:Array}) - for t in 1:lastindex(tensors)-1 - dᵗ = size(tensors[t],2) - dᵗ⁺¹ = size(tensors[t+1],1) - if dᵗ != dᵗ⁺¹ - println("Bond size for matrix t=$t. dᵗ=$dᵗ, dᵗ⁺¹=$dᵗ⁺¹") - return false - end - end - return true -end """ uniform_tt(bondsizes::AbstractVector{<:Integer}, q...) @@ -68,13 +57,6 @@ function rand_tt(bondsizes::AbstractVector{<:Integer}, q...) end rand_tt(d::Integer, L::Integer, q...) = rand_tt([1; fill(d, L-1); 1], q...) -""" - bond_dims(A::AbstractTensorTrain) - -Return a vector with the dimensions of the virtual bonds -""" -bond_dims(A::TensorTrain) = [size(A[t], 2) for t in 1:lastindex(A)-1] - """ evaluate(A::AbstractTensorTrain, X...) @@ -144,112 +126,6 @@ function orthogonalize_left!(C::TensorTrain; svd_trunc=TruncThresh(1e-6)) end -function accumulate_L(A::TensorTrain) - l = [zeros(0) for _ in eachindex(A)] - A⁰ = _reshape1(A[begin]) - @reduce l⁰[a¹] := sum(x) A⁰[1,a¹,x] - l[1] = l⁰ - - lᵗ = l⁰ - for t in 1:length(A)-1 - Aᵗ = _reshape1(A[t+1]) - @reduce lᵗ[aᵗ⁺¹] |= sum(x,aᵗ) lᵗ[aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] - l[t+1] = lᵗ - end - return l -end - -function accumulate_R(A::TensorTrain) - r = [zeros(0) for _ in eachindex(A)] - Aᵀ = _reshape1(A[end]) - @reduce rᵀ[aᵀ] := sum(x) Aᵀ[aᵀ,1,x] - r[end] = rᵀ - - rᵗ = rᵀ - for t in length(A)-1:-1:1 - Aᵗ = _reshape1(A[t]) - @reduce rᵗ[aᵗ] |= sum(x,aᵗ⁺¹) Aᵗ[aᵗ,aᵗ⁺¹,x] * rᵗ[aᵗ⁺¹] - r[t] = rᵗ - end - return r -end - -""" - marginals(A::AbstractTensorTrain; l, r) - -Compute the marginal distributions ``p(x^l)`` at each site - -### Optional arguments -- `l = accumulate_L(A)`, `r = accumulate_R(A)` pre-computed partial normalizations -""" -function marginals(A::TensorTrain{F,N}; - l = accumulate_L(A), r = accumulate_R(A)) where {F<:Real,N} - - A⁰ = _reshape1(A[begin]); r¹ = r[2] - @reduce p⁰[x] := sum(a¹) A⁰[1,a¹,x] * r¹[a¹] - p⁰ ./= sum(p⁰) - p⁰ = reshape(p⁰, size(A[begin])[3:end]) - - Aᵀ = _reshape1(A[end]); lᵀ⁻¹ = l[end-1] - @reduce pᵀ[x] := sum(aᵀ) lᵀ⁻¹[aᵀ] * Aᵀ[aᵀ,1,x] - pᵀ ./= sum(pᵀ) - pᵀ = reshape(pᵀ, size(A[end])[3:end]) - - p = map(2:length(A)-1) do t - lᵗ⁻¹ = l[t-1] - Aᵗ = _reshape1(A[t]) - rᵗ⁺¹ = r[t+1] - @reduce pᵗ[x] := sum(aᵗ,aᵗ⁺¹) lᵗ⁻¹[aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] * rᵗ⁺¹[aᵗ⁺¹] - pᵗ ./= sum(pᵗ) - reshape(pᵗ, size(A[t])[3:end]) - end - - return append!([p⁰], p, [pᵀ]) -end - -""" - twovar_marginals(A::AbstractTensorTrain; l, r, M, Δlmax) - -Compute the marginal distributions for each pair of sites ``p(x^l, x^m)`` - -### Optional arguments -- `l = accumulate_L(A)`, `r = accumulate_R(A)`, `M = accumulate_M(A)` pre-computed partial normalizations -- `maxdist = length(A)`: compute marginals only at distance `maxdist`: ``|l-m|\\le maxdist`` -""" -function twovar_marginals(A::TensorTrain{F,N}; - l = accumulate_L(A), r = accumulate_R(A), M = accumulate_M(A), - maxdist = length(A)-1) where {F<:Real,N} - qs = tuple(reduce(vcat, [x,x] for x in size(A[begin])[3:end])...) - b = Array{F,2*(N-2)}[zeros(zeros(Int, 2*(N-2))...) - for _ in eachindex(A), _ in eachindex(A)] - for t in 1:length(A)-1 - lᵗ⁻¹ = t == 1 ? [1.0;] : l[t-1] - Aᵗ = _reshape1(A[t]) - for u in t+1:min(length(A),t+maxdist) - rᵘ⁺¹ = u == length(A) ? [1.0;] : r[u+1] - Aᵘ = _reshape1(A[u]) - Mᵗᵘ = M[t, u] - @tullio bᵗᵘ[xᵗ, xᵘ] := - lᵗ⁻¹[aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] * - Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] * rᵘ⁺¹[aᵘ⁺¹] - bᵗᵘ ./= sum(bᵗᵘ) - b[t,u] = reshape(bᵗᵘ, qs) - end - end - b -end - -""" - normalization(A::AbstractTensorTrain; l, r) - -Compute the normalization ``Z=\\sum_{x^1,\\ldots,x^L} A^1(x^1)\\cdots A^L(x^L)`` -""" -function normalization(A::TensorTrain; l = accumulate_L(A), r = accumulate_R(A)) - z = only(l[end]) - @assert only(r[begin]) ≈ z "z=$z, got $(only(r[begin])), A=$A" # sanity check - z -end - # used to do stuff like `A+B` with `A,B` tensor trains function _compose(f, A::TensorTrain{F,NA}, B::TensorTrain{F,NB}) where {F,NA,NB} @assert NA == NB @@ -273,24 +149,3 @@ function _compose(f, A::TensorTrain{F,NA}, B::TensorTrain{F,NB}) where {F,NA,NB} TensorTrain(tensors) end -function StatsBase.sample!(rng::AbstractRNG, x, A::TensorTrain{F,N}; - r = accumulate_R(A)) where {F<:Real,N} - L = length(A) - @assert length(x) == L - @assert all(length(xᵗ) == N-2 for xᵗ in x) - - Q = ones(F, 1, 1) # stores product of the first `t` matrices, evaluated at the sampled `x¹,...,xᵗ` - for t in eachindex(A) - rᵗ⁺¹ = t == L ? ones(F,1) : r[t+1] - # collapse multivariate xᵗ into 1D vector, sample from it - Aᵗ = _reshape1(A[t]) - @tullio p[x] := Q[m] * Aᵗ[m,n,x] * rᵗ⁺¹[n] - p ./= sum(p) - xᵗ = sample_noalloc(rng, p) - x[t] .= CartesianIndices(size(A[t])[3:end])[xᵗ] |> Tuple - # update prob - Q = Q * Aᵗ[:,:,xᵗ] - end - p = only(Q) / only(first(r)) - return x, p -end \ No newline at end of file diff --git a/test/tensor_train.jl b/test/tensor_train.jl index e9bb0df..fcf6cf8 100644 --- a/test/tensor_train.jl +++ b/test/tensor_train.jl @@ -22,7 +22,7 @@ end tensors = [rand(1,3,2), rand(3,4,2), rand(4,10,2), rand(10,1,2)] C = TensorTrain(tensors) - @test bond_dims(C) == [3,4,10] + @test bond_dims(C) == [1,3,4,10] @test eltype(C) == eltype(1.0) x = [rand(1:2,1) for _ in C] @@ -107,7 +107,7 @@ end m = TensorTrains.accumulate_M(A) Z = only(l[end]) @test only(r[begin]) ≈ Z - @test l[begin]' * m[1,end] * r[end] ≈ Z + @test only(l[begin] * m[1,end] * r[end]) ≈ Z end @testset "Sum of TTs" begin