Skip to content

Commit

Permalink
reorder product for more efficient calculation in marginals
Browse files Browse the repository at this point in the history
  • Loading branch information
abraunst authored and stecrotti committed Oct 17, 2023
1 parent 068713e commit 92e3e40
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/abstract_tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ for t in eachindex(A)
rᵗ⁺¹ = t == L ? Matrix(I, d, d) : r[t+1]
# collapse multivariate xᵗ into 1D vector, sample from it
Aᵗ = _reshape1(A[t])
@tullio p[x] := Q[k,m] * Aᵗ[m,n,x] * rᵗ⁺¹[n,k]
@tullio QA[k,n,x] := Q[k,m] * Aᵗ[m,n,x]
@tullio p[x] := QA[k,n,x] * rᵗ⁺¹[n,k]
p ./= sum(p)
xᵗ = sample_noalloc(rng, p)
x[t] .= CartesianIndices(size(A[t])[3:end])[xᵗ] |> Tuple
Expand Down Expand Up @@ -121,10 +122,11 @@ function marginals(A::AbstractTensorTrain{F,N};

map(eachindex(A)) do t
Aᵗ = _reshape1(A[t])
RL = t == length(A) ? l[end-1] :
t == 1 ? r[begin+1] :
r[t+1]*l[t-1]
@tullio pᵗ[x] := RL[aᵗ⁺¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x]
R = t + 1 length(A) ? r[t+1] : Matrix(I, size(A[end],2), size(A[end],2))
L = t - 1 1 ? l[t-1] : Matrix(I, size(A[begin],1), size(A[begin],1))
@tullio lA[a¹,aᵗ⁺¹,x] := L[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x]
@tullio pᵗ[x] := lA[a¹,aᵗ⁺¹,x] * R[aᵗ⁺¹,a¹]
#@reduce pᵗ[x] := sum(a¹,aᵗ,aᵗ⁺¹) lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] * rᵗ⁺¹[aᵗ⁺¹,a¹]
pᵗ ./= sum(pᵗ)
reshape(pᵗ, size(A[t])[3:end])
end
Expand Down

0 comments on commit 92e3e40

Please sign in to comment.