From 92e3e40688da1763f98fdcc111496e551271413d Mon Sep 17 00:00:00 2001 From: Alfredo Braunstein Date: Fri, 13 Oct 2023 16:35:04 +0200 Subject: [PATCH] reorder product for more efficient calculation in marginals --- src/abstract_tensor_train.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/abstract_tensor_train.jl b/src/abstract_tensor_train.jl index 2597a76..67a716d 100644 --- a/src/abstract_tensor_train.jl +++ b/src/abstract_tensor_train.jl @@ -59,7 +59,8 @@ for t in eachindex(A) rᵗ⁺¹ = t == L ? Matrix(I, d, d) : r[t+1] # collapse multivariate xᵗ into 1D vector, sample from it Aᵗ = _reshape1(A[t]) - @tullio p[x] := Q[k,m] * Aᵗ[m,n,x] * rᵗ⁺¹[n,k] + @tullio QA[k,n,x] := Q[k,m] * Aᵗ[m,n,x] + @tullio p[x] := QA[k,n,x] * rᵗ⁺¹[n,k] p ./= sum(p) xᵗ = sample_noalloc(rng, p) x[t] .= CartesianIndices(size(A[t])[3:end])[xᵗ] |> Tuple @@ -121,10 +122,11 @@ function marginals(A::AbstractTensorTrain{F,N}; map(eachindex(A)) do t Aᵗ = _reshape1(A[t]) - RL = t == length(A) ? l[end-1] : - t == 1 ? r[begin+1] : - r[t+1]*l[t-1] - @tullio pᵗ[x] := RL[aᵗ⁺¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] + R = t + 1 ≤ length(A) ? r[t+1] : Matrix(I, size(A[end],2), size(A[end],2)) + L = t - 1 ≥ 1 ? l[t-1] : Matrix(I, size(A[begin],1), size(A[begin],1)) + @tullio lA[a¹,aᵗ⁺¹,x] := L[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] + @tullio pᵗ[x] := lA[a¹,aᵗ⁺¹,x] * R[aᵗ⁺¹,a¹] + #@reduce pᵗ[x] := sum(a¹,aᵗ,aᵗ⁺¹) lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] * rᵗ⁺¹[aᵗ⁺¹,a¹] pᵗ ./= sum(pᵗ) reshape(pᵗ, size(A[t])[3:end]) end