Skip to content

Commit

Permalink
Use generic trace methods when possible (#16)
Browse files Browse the repository at this point in the history
* resolve misterious memory leaks
* add trace function
* fix bug in dimensions
* Move many functions to abstract_tensor_train
* simplify marginals()
* reorder product for more efficient calculation in marginals
  • Loading branch information
abraunst authored Oct 18, 2023
1 parent 444657b commit 4b04d3c
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 270 deletions.
157 changes: 142 additions & 15 deletions src/abstract_tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,27 @@ abstract type AbstractTensorTrain{F<:Number, N} end

Base.eltype(::AbstractTensorTrain{F,N}) where {N,F} = F


"""
bond_dims(A::AbstractTensorTrain)
Return a vector with the dimensions of the virtual bonds
"""
bond_dims(A::AbstractTensorTrain) = [size(A[t], 1) for t in 1:lastindex(A)]

function check_bond_dims(tensors::Vector{<:Array})
for t in 1:lastindex(tensors)
dᵗ = size(tensors[t],2)
dᵗ⁺¹ = size(tensors[mod1(t+1, length(tensors))],1)
if dᵗ != dᵗ⁺¹
println("Bond size for matrix t=$t. dᵗ=$dᵗ, dᵗ⁺¹=$dᵗ⁺¹")
return false
end
end
return true
end


"""
normalize_eachmatrix!(A::AbstractTensorTrain)
Expand All @@ -26,34 +47,140 @@ function normalize_eachmatrix!(A::AbstractTensorTrain)
c
end

function StatsBase.sample!(rng::AbstractRNG, x, A::AbstractTensorTrain{F,N};
r = accumulate_R(A)) where {F<:Real,N}
L = length(A)
@assert length(x) == L
@assert all(length(xᵗ) == N-2 for xᵗ in x)
d = first(bond_dims(A))

Q = Matrix(I, d, d) # stores product of the first `t` matrices, evaluated at the sampled `x¹,...,xᵗ`
for t in eachindex(A)
rᵗ⁺¹ = t == L ? Matrix(I, d, d) : r[t+1]
# collapse multivariate xᵗ into 1D vector, sample from it
Aᵗ = _reshape1(A[t])
@tullio QA[k,n,x] := Q[k,m] * Aᵗ[m,n,x]
@tullio p[x] := QA[k,n,x] * rᵗ⁺¹[n,k]
p ./= sum(p)
xᵗ = sample_noalloc(rng, p)
x[t] .= CartesianIndices(size(A[t])[3:end])[xᵗ] |> Tuple
# update prob
Q = Q * Aᵗ[:,:,xᵗ]
end
p = tr(Q) / tr(first(r))
return x, p
end


Base.:(==)(A::T, B::T) where {T<:AbstractTensorTrain} = isequal(A.tensors, B.tensors)
Base.isapprox(A::T, B::T; kw...) where {T<:AbstractTensorTrain} = isapprox(A.tensors, B.tensors; kw...)


function accumulate_M(A::AbstractTensorTrain)
L = length(A)
M = [zeros(0, 0) for _ in 1:L, _ in 1:L]

# initial condition
for t in 1:L-1
range_aᵗ⁺¹ = axes(A[t+1], 1)
Mᵗᵗ⁺¹ = [float((a == c)) for a in range_aᵗ⁺¹, c in range_aᵗ⁺¹]
M[t, t+1] = Mᵗᵗ⁺¹
end

for t in 1:L-1
Mᵗᵘ⁻¹ = M[t, t+1]
for u in t+2:L
Aᵘ⁻¹ = _reshape1(A[u-1])
@tullio Mᵗᵘ[aᵗ⁺¹, aᵘ] := Mᵗᵘ⁻¹[aᵗ⁺¹, aᵘ⁻¹] * Aᵘ⁻¹[aᵘ⁻¹, aᵘ, x]
M[t, u] = Mᵗᵘ
Mᵗᵘ⁻¹, Mᵗᵘ = Mᵗᵘ, Mᵗᵘ⁻¹
M = fill(zeros(0, 0), L, L)

for u in 2:L
Au = trace(A[u-1])
for t in 1:u-2
M[t, u] = M[t, u-1] * Au
end
# initial condition
M[u-1, u] = Matrix(I, size(A[u],1), size(A[u],1))
end

return M
end

trace(At) = @tullio _[aᵗ,aᵗ⁺¹] := _reshape1(At)[aᵗ,aᵗ⁺¹,x]

function accumulate_L(A::AbstractTensorTrain)
L = Matrix(I, size(A[begin],1), size(A[begin],1))
map(trace(Atx) for Atx in A) do At
L = L * At
end
end

function accumulate_R(A::AbstractTensorTrain)
R = Matrix(I, size(A[end],2), size(A[end],2))
map(trace(Atx) for Atx in Iterators.reverse(A)) do At
R = At * R
end |> reverse
end

"""
marginals(A::AbstractTensorTrain; l, r)
Compute the marginal distributions ``p(x^l)`` at each site
### Optional arguments
- `l = accumulate_L(A)`, `r = accumulate_R(A)` pre-computed partial normalizations
"""
function marginals(A::AbstractTensorTrain{F,N};
l = accumulate_L(A), r = accumulate_R(A)) where {F<:Real,N}

map(eachindex(A)) do t
Aᵗ = _reshape1(A[t])
R = t + 1 length(A) ? r[t+1] : Matrix(I, size(A[end],2), size(A[end],2))
L = t - 1 1 ? l[t-1] : Matrix(I, size(A[begin],1), size(A[begin],1))
@tullio lA[a¹,aᵗ⁺¹,x] := L[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x]
@tullio pᵗ[x] := lA[a¹,aᵗ⁺¹,x] * R[aᵗ⁺¹,a¹]
#@reduce pᵗ[x] := sum(a¹,aᵗ,aᵗ⁺¹) lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] * rᵗ⁺¹[aᵗ⁺¹,a¹]
pᵗ ./= sum(pᵗ)
reshape(pᵗ, size(A[t])[3:end])
end
end

"""
twovar_marginals(A::AbstractTensorTrain; l, r, M, Δlmax)
Compute the marginal distributions for each pair of sites ``p(x^l, x^m)``
### Optional arguments
- `l = accumulate_L(A)`, `r = accumulate_R(A)`, `M = accumulate_M(A)` pre-computed partial normalizations
- `maxdist = length(A)`: compute marginals only at distance `maxdist`: ``|l-m|\\le maxdist``
"""
function twovar_marginals(A::AbstractTensorTrain{F,N};
l = accumulate_L(A), r = accumulate_R(A), M = accumulate_M(A),
maxdist = length(A)-1) where {F<:Real,N}
qs = tuple(reduce(vcat, [x,x] for x in size(A[begin])[3:end])...)
b = Array{F,2*(N-2)}[zeros(zeros(Int, 2*(N-2))...)
for _ in eachindex(A), _ in eachindex(A)]
d = first(bond_dims(A))
for t in 1:length(A)-1
lᵗ⁻¹ = t == 1 ? Matrix(I, d, d) : l[t-1]
Aᵗ = _reshape1(A[t])
for u in t+1:min(length(A),t+maxdist)
rᵘ⁺¹ = u == length(A) ? Matrix(I, d, d) : r[u+1]
Aᵘ = _reshape1(A[u])
Mᵗᵘ = M[t, u]
rl = rᵘ⁺¹ * lᵗ⁻¹
@tullio rlAt[aᵘ⁺¹, aᵗ⁺¹, xᵗ] := rl[aᵘ⁺¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ]
@tullio rlAtMtu[aᵘ⁺¹,xᵗ,aᵘ] := rlAt[aᵘ⁺¹, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ]
@tullio bᵗᵘ[xᵗ, xᵘ] := rlAtMtu[aᵘ⁺¹,xᵗ,aᵘ] * Aᵘ[aᵘ, aᵘ⁺¹, xᵘ]

#@tullio bᵗᵘ[xᵗ, xᵘ] :=
#lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] *
#Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] * rᵘ⁺¹[aᵘ⁺¹,a¹]

bᵗᵘ ./= sum(bᵗᵘ)
b[t,u] = reshape(bᵗᵘ, qs)
end
end
b
end

"""
normalization(A::AbstractTensorTrain; l, r)
Compute the normalization ``Z=\\sum_{x^1,\\ldots,x^L} A^1(x^1)\\cdots A^L(x^L)``
"""
function normalization(A::AbstractTensorTrain; l = accumulate_L(A), r = accumulate_R(A))
z = tr(l[end])
@assert tr(r[begin]) z "z=$z, got $(tr(r[begin])), A=$A" # sanity check
z
end

"""
compress!(A::AbstractTensorTrain; svd_trunc::SVDTrunc)
Expand Down
108 changes: 0 additions & 108 deletions src/periodic_tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ struct PeriodicTensorTrain{F<:Number, N} <: AbstractTensorTrain{F,N}

function PeriodicTensorTrain{F,N}(tensors::Vector{Array{F,N}}) where {F<:Number, N}
N > 2 || throw(ArgumentError("Tensors shold have at least 3 indices: 2 virtual and 1 physical"))
size(tensors[1],1) == size(tensors[end],2) ||
throw(ArgumentError("Number of rows of the first matrix should coincide with the number of columns of the last matrix"))
check_bond_dims(tensors) ||
throw(ArgumentError("Matrix indices for matrix product non compatible"))
return new{F,N}(tensors)
Expand Down Expand Up @@ -55,94 +53,10 @@ function rand_periodic_tt(bondsizes::AbstractVector{<:Integer}, q...)
end
rand_periodic_tt(d::Integer, L::Integer, q...) = rand_periodic_tt(fill(d, L-1), q...)

bond_dims(A::PeriodicTensorTrain) = [size(A[t], 1) for t in 1:lastindex(A)]

evaluate(A::PeriodicTensorTrain, X...) = tr(prod(@view a[:, :, x...] for (a,x) in zip(A, X...)))

function accumulate_L(A::PeriodicTensorTrain)
l = [zeros(0,0) for _ in eachindex(A)]
A⁰ = _reshape1(first(A))
@reduce l⁰[a¹,a²] := sum(x) A⁰[a¹,a²,x]
l[1] = l⁰

lᵗ = l⁰
for t in 1:length(A)-1
Aᵗ = _reshape1(A[t+1])
@reduce lᵗ[a¹,aᵗ⁺¹] |= sum(x,aᵗ) lᵗ[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x]
l[t+1] = lᵗ
end
return l
end

function accumulate_R(A::PeriodicTensorTrain)
r = [zeros(0,0) for _ in eachindex(A)]
A⁰ = _reshape1(last(A))
@reduce rᴸ[aᴸ,a¹] := sum(x) A⁰[aᴸ,a¹,x]
r[end] = rᴸ

rᵗ = rᴸ
for t in length(A)-1:-1:1
Aᵗ = _reshape1(A[t])
@reduce rᵗ[aᵗ,a¹] |= sum(x,aᵗ⁺¹) Aᵗ[aᵗ,aᵗ⁺¹,x] * rᵗ[aᵗ⁺¹,a¹]
r[t] = rᵗ
end
return r
end

function marginals(A::PeriodicTensorTrain{F,N};
l = accumulate_L(A), r = accumulate_R(A)) where {F<:Real,N}

= _reshape1(A[begin]); r² = r[2]
@reduce p¹[x] := sum(a¹,a²) A¹[a¹,a²,x] * r²[a²,a¹]
./= sum(p¹)
= reshape(p¹, size(A[begin])[3:end])

Aᴸ = _reshape1(A[end]); lᴸ⁻¹ = l[end-1]
@reduce pᴸ[x] := sum(aᴸ,a¹) lᴸ⁻¹[a¹,aᴸ] * Aᴸ[aᴸ,a¹,x]
pᴸ ./= sum(pᴸ)
pᴸ = reshape(pᴸ, size(A[end])[3:end])

p = map(2:length(A)-1) do t
lᵗ⁻¹ = l[t-1]
Aᵗ = _reshape1(A[t])
rᵗ⁺¹ = r[t+1]
@reduce pᵗ[x] := sum(a¹,aᵗ,aᵗ⁺¹) lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ,aᵗ⁺¹,x] * rᵗ⁺¹[aᵗ⁺¹,a¹]
pᵗ ./= sum(pᵗ)
reshape(pᵗ, size(A[t])[3:end])
end

return append!([p¹], p, [pᴸ])
end

function twovar_marginals(A::PeriodicTensorTrain{F,N};
l = accumulate_L(A), r = accumulate_R(A), M = accumulate_M(A),
maxdist = length(A)-1) where {F<:Real,N}
qs = tuple(reduce(vcat, [x,x] for x in size(A[begin])[3:end])...)
b = Array{F,2*(N-2)}[zeros(zeros(Int, 2*(N-2))...)
for _ in eachindex(A), _ in eachindex(A)]
d = first(bond_dims(A))
for t in 1:length(A)-1
lᵗ⁻¹ = t == 1 ? Matrix(I, d, d) : l[t-1]
Aᵗ = _reshape1(A[t])
for u in t+1:min(length(A),t+maxdist)
rᵘ⁺¹ = u == length(A) ? Matrix(I, d, d) : r[u+1]
Aᵘ = _reshape1(A[u])
Mᵗᵘ = M[t, u]
@tullio bᵗᵘ[xᵗ, xᵘ] :=
lᵗ⁻¹[a¹,aᵗ] * Aᵗ[aᵗ, aᵗ⁺¹, xᵗ] * Mᵗᵘ[aᵗ⁺¹, aᵘ] *
Aᵘ[aᵘ, aᵘ⁺¹, xᵘ] * rᵘ⁺¹[aᵘ⁺¹,a¹]
bᵗᵘ ./= sum(bᵗᵘ)
b[t,u] = reshape(bᵗᵘ, qs)
end
end
b
end

function normalization(A::PeriodicTensorTrain; l = accumulate_L(A), r = accumulate_R(A))
z = tr(l[end])
@assert tr(r[begin]) z "z=$z, got $(tr(r[begin])), A=$A" # sanity check
z
end

function _compose(f, A::PeriodicTensorTrain{F,NA}, B::PeriodicTensorTrain{F,NB}) where {F,NA,NB}
@assert NA == NB
Expand All @@ -164,28 +78,6 @@ end

PeriodicTensorTrain(A::TensorTrain) = PeriodicTensorTrain(A.tensors)

function StatsBase.sample!(rng::AbstractRNG, x, A::PeriodicTensorTrain{F,N};
r = accumulate_R(A)) where {F<:Real,N}
L = length(A)
@assert length(x) == L
@assert all(length(xᵗ) == N-2 for xᵗ in x)
d = first(bond_dims(A))

Q = Matrix(I, d, d) # stores product of the first `t` matrices, evaluated at the sampled `x¹,...,xᵗ`
for t in eachindex(A)
rᵗ⁺¹ = t == L ? Matrix(I, d, d) : r[t+1]
# collapse multivariate xᵗ into 1D vector, sample from it
Aᵗ = _reshape1(A[t])
@tullio p[x] := Q[k,m] * Aᵗ[m,n,x] * rᵗ⁺¹[n,k]
p ./= sum(p)
xᵗ = sample_noalloc(rng, p)
x[t] .= CartesianIndices(size(A[t])[3:end])[xᵗ] |> Tuple
# update prob
Q = Q * Aᵗ[:,:,xᵗ]
end
p = tr(Q) / tr(first(r))
return x, p
end

function orthogonalize_right!(C::PeriodicTensorTrain; svd_trunc=TruncThresh(1e-6))
C⁰ = _reshape1(C[begin])
Expand Down
Loading

0 comments on commit 4b04d3c

Please sign in to comment.