Skip to content

Commit

Permalink
simplify marginals()
Browse files Browse the repository at this point in the history
  • Loading branch information
abraunst committed Oct 13, 2023
1 parent 1c8d8f0 commit d513851
Showing 1 changed file with 22 additions and 33 deletions.
55 changes: 22 additions & 33 deletions src/abstract_tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,16 @@ Base.isapprox(A::T, B::T; kw...) where {T<:AbstractTensorTrain} = isapprox(A.ten

function accumulate_M(A::AbstractTensorTrain)
L = length(A)
M = [zeros(0, 0) for _ in 1:L, _ in 1:L]

# initial condition
for t in 1:L-1
range_aᵗ⁺¹ = axes(A[t+1], 1)
Mᵗᵗ⁺¹ = [float((a == c)) for a in range_aᵗ⁺¹, c in range_aᵗ⁺¹]
M[t, t+1] = Mᵗᵗ⁺¹
end

for t in 1:L-1
Mᵗᵘ⁻¹ = M[t, t+1]
for u in t+2:L
Aᵘ⁻¹ = _reshape1(A[u-1])
@tullio Mᵗᵘ[aᵗ⁺¹, aᵘ] := Mᵗᵘ⁻¹[aᵗ⁺¹, aᵘ⁻¹] * Aᵘ⁻¹[aᵘ⁻¹, aᵘ, x]
M[t, u] = Mᵗᵘ
Mᵗᵘ⁻¹, Mᵗᵘ = Mᵗᵘ, Mᵗᵘ⁻¹
M = fill(zeros(0, 0), L, L)

for u in 2:L
Au = trace(A[u-1])
for t in 1:u-2
M[t, u] = M[t, u-1] * Au
end
# initial condition
M[u-1, u] = Matrix(I, size(A[u],1), size(A[u],1))
end

return M
Expand Down Expand Up @@ -126,25 +119,15 @@ Compute the marginal distributions ``p(x^l)`` at each site
function marginals(A::AbstractTensorTrain{F,N};
l = accumulate_L(A), r = accumulate_R(A)) where {F<:Real,N}

= _reshape1(A[begin]); r² = r[2]
@tullio p¹[x] := A¹[a¹,a²,x] * r²[a²,a¹]
./= sum(p¹)
= reshape(p¹, size(A[begin])[3:end])

Aᴸ = _reshape1(A[end]); lᴸ⁻¹ = l[end-1]
@tullio pᴸ[x] := lᴸ⁻¹[a¹,aᴸ] * Aᴸ[aᴸ,a¹,x]
pᴸ ./= sum(pᴸ)
pᴸ = reshape(pᴸ, size(A[end])[3:end])

p = map(2:length(A)-1) do t
map(eachindex(A)) do t
Aᵗ = _reshape1(A[t])
rl = r[t+1] * l[t-1]
@tullio pᵗ[x] := rl[aᵗ⁺¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x]
RL = t == length(A) ? l[end-1] :
t == 1 ? r[begin+1] :
r[t+1]*l[t-1]
@tullio pᵗ[x] := RL[aᵗ⁺¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x]
pᵗ ./= sum(pᵗ)
reshape(pᵗ, size(A[t])[3:end])
end

return append!([p¹], p, [pᴸ])
end

"""
Expand All @@ -170,9 +153,15 @@ function twovar_marginals(A::AbstractTensorTrain{F,N};
rᵘ⁺¹ = u == length(A) ? Matrix(I, d, d) : r[u+1]
Aᵘ = _reshape1(A[u])
Mᵗᵘ = M[t, u]
@tullio bᵗᵘ[xᵗ, xᵘ] :=
lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] *
Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] * rᵘ⁺¹[aᵘ⁺¹,a¹]
rl = rᵘ⁺¹ * lᵗ⁻¹
@tullio rlAt[aᵘ⁺¹, aᵗ⁺¹, xᵗ] := rl[aᵘ⁺¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ]
@tullio rlAtMtu[aᵘ⁺¹,xᵗ,aᵘ] := rlAt[aᵘ⁺¹, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ]
@tullio bᵗᵘ[xᵗ, xᵘ] := rlAtMtu[aᵘ⁺¹,xᵗ,aᵘ] * Aᵘ[aᵘ, aᵘ⁺¹, xᵘ]

#@tullio bᵗᵘ[xᵗ, xᵘ] :=
#lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] *
#Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] * rᵘ⁺¹[aᵘ⁺¹,a¹]

bᵗᵘ ./= sum(bᵗᵘ)
b[t,u] = reshape(bᵗᵘ, qs)
end
Expand Down

0 comments on commit d513851

Please sign in to comment.