From d5138519a4ab65a54be2b39c43bedc33d6af43df Mon Sep 17 00:00:00 2001 From: Alfredo Braunstein Date: Fri, 13 Oct 2023 16:14:44 +0200 Subject: [PATCH] simplify marginals() --- src/abstract_tensor_train.jl | 55 +++++++++++++++--------------------- 1 file changed, 22 insertions(+), 33 deletions(-) diff --git a/src/abstract_tensor_train.jl b/src/abstract_tensor_train.jl index 9f474fb..2597a76 100644 --- a/src/abstract_tensor_train.jl +++ b/src/abstract_tensor_train.jl @@ -77,23 +77,16 @@ Base.isapprox(A::T, B::T; kw...) where {T<:AbstractTensorTrain} = isapprox(A.ten function accumulate_M(A::AbstractTensorTrain) L = length(A) - M = [zeros(0, 0) for _ in 1:L, _ in 1:L] - - # initial condition - for t in 1:L-1 - range_aᵗ⁺¹ = axes(A[t+1], 1) - Mᵗᵗ⁺¹ = [float((a == c)) for a in range_aᵗ⁺¹, c in range_aᵗ⁺¹] - M[t, t+1] = Mᵗᵗ⁺¹ - end - for t in 1:L-1 - Mᵗᵘ⁻¹ = M[t, t+1] - for u in t+2:L - Aᵘ⁻¹ = _reshape1(A[u-1]) - @tullio Mᵗᵘ[aᵗ⁺¹, aᵘ] := Mᵗᵘ⁻¹[aᵗ⁺¹, aᵘ⁻¹] * Aᵘ⁻¹[aᵘ⁻¹, aᵘ, x] - M[t, u] = Mᵗᵘ - Mᵗᵘ⁻¹, Mᵗᵘ = Mᵗᵘ, Mᵗᵘ⁻¹ + M = fill(zeros(0, 0), L, L) + + for u in 2:L + Au = trace(A[u-1]) + for t in 1:u-2 + M[t, u] = M[t, u-1] * Au end + # initial condition + M[u-1, u] = Matrix(I, size(A[u],1), size(A[u],1)) end return M @@ -126,25 +119,15 @@ Compute the marginal distributions ``p(x^l)`` at each site function marginals(A::AbstractTensorTrain{F,N}; l = accumulate_L(A), r = accumulate_R(A)) where {F<:Real,N} - A¹ = _reshape1(A[begin]); r² = r[2] - @tullio p¹[x] := A¹[a¹,a²,x] * r²[a²,a¹] - p¹ ./= sum(p¹) - p¹ = reshape(p¹, size(A[begin])[3:end]) - - Aᴸ = _reshape1(A[end]); lᴸ⁻¹ = l[end-1] - @tullio pᴸ[x] := lᴸ⁻¹[a¹,aᴸ] * Aᴸ[aᴸ,a¹,x] - pᴸ ./= sum(pᴸ) - pᴸ = reshape(pᴸ, size(A[end])[3:end]) - - p = map(2:length(A)-1) do t + map(eachindex(A)) do t Aᵗ = _reshape1(A[t]) - rl = r[t+1] * l[t-1] - @tullio pᵗ[x] := rl[aᵗ⁺¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] + RL = t == length(A) ? l[end-1] : + t == 1 ? r[begin+1] : + r[t+1]*l[t-1] + @tullio pᵗ[x] := RL[aᵗ⁺¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] pᵗ ./= sum(pᵗ) reshape(pᵗ, size(A[t])[3:end]) end - - return append!([p¹], p, [pᴸ]) end """ @@ -170,9 +153,15 @@ function twovar_marginals(A::AbstractTensorTrain{F,N}; rᵘ⁺¹ = u == length(A) ? Matrix(I, d, d) : r[u+1] Aᵘ = _reshape1(A[u]) Mᵗᵘ = M[t, u] - @tullio bᵗᵘ[xᵗ, xᵘ] := - lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] * - Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] * rᵘ⁺¹[aᵘ⁺¹,a¹] + rl = rᵘ⁺¹ * lᵗ⁻¹ + @tullio rlAt[aᵘ⁺¹, aᵗ⁺¹, xᵗ] := rl[aᵘ⁺¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] + @tullio rlAtMtu[aᵘ⁺¹,xᵗ,aᵘ] := rlAt[aᵘ⁺¹, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] + @tullio bᵗᵘ[xᵗ, xᵘ] := rlAtMtu[aᵘ⁺¹,xᵗ,aᵘ] * Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] + + #@tullio bᵗᵘ[xᵗ, xᵘ] := + #lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] * + #Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] * rᵘ⁺¹[aᵘ⁺¹,a¹] + bᵗᵘ ./= sum(bᵗᵘ) b[t,u] = reshape(bᵗᵘ, qs) end