Skip to content

Commit

Permalink
generalize twovar_marginals to k-nearest niegs
Browse files Browse the repository at this point in the history
  • Loading branch information
stecrotti committed Oct 23, 2024
1 parent a3b92ae commit 23c8768
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/UniformTensorTrains/UniformTensorTrains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module UniformTensorTrains
using ..TensorTrains
using ..TensorTrains: _reshape1

using LinearAlgebra: LinearAlgebra, dot, tr
using LinearAlgebra: LinearAlgebra, dot, tr, I
using KrylovKit: eigsolve
using TensorCast: TensorCast, @cast
using TensorTrains: TensorTrains, AbstractPeriodicTensorTrain,
Expand Down
26 changes: 20 additions & 6 deletions src/UniformTensorTrains/uniform_tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,28 @@ function TensorTrains.marginals(A::InfiniteUniformTensorTrain; B = one_normaliza
return [m]
end

function TensorTrains.twovar_marginals(A::InfiniteUniformTensorTrain{F}; B = one_normalization(A)) where F
# to be consistent with the finite-T version, this returns a `maxdist+1`x`maxdist+1` matrix `m` where `m[t,t+Δt]` is the marginal at distance `Δt` for all `t`
function TensorTrains.twovar_marginals(A::InfiniteUniformTensorTrain{F,N};
maxdist::Integer=1, B = one_normalization(A)) where {F,N}
maxdist > -1 || throw(DomainError("maxdist must be non-negative, got $maxdist"))
_, l, r = _eigen(A; B)
iter = Iterators.product(axes(A.tensor)[3:end]...)
m = map(Iterators.product(iter, iter)) do (x1, x2)
l' * (@views A.tensor[:,:,x1...] * A.tensor[:,:,x2...]) * r
m = Array{F,2*(N-2)}[zeros(F, zeros(Int, 2*(N-2))...)
for _ in 1:maxdist+1, _ in 1:maxdist+1]
M = Matrix(1.0I, size(A.tensor, 1), size(A.tensor, 1))
Aᵗ = _reshape1(A.tensor)
for Δt in 1:maxdist
@tullio lAt[aᵗ, xᵗ] := l[bᵗ] * Aᵗ[bᵗ,aᵗ,xᵗ]
@tullio lAtM[bᵗ,xᵗ] := lAt[aᵗ, xᵗ] * M[aᵗ,bᵗ]
@tullio lAtMAu[cᵗ,xᵗ,xᵘ] := lAtM[bᵗ,xᵗ] * Aᵗ[bᵗ,cᵗ,xᵘ]
@tullio b[xᵗ, xᵘ] := lAtMAu[cᵗ,xᵗ,xᵘ] * r[cᵗ]
b ./= sum(b)
bᵗᵘ = reshape(real(b), (size(A.tensor)[3:end]..., size(A.tensor)[3:end]...)...)
for t in 1:(maxdist + 1 - Δt)
m[t,t+Δt] = bᵗᵘ
end
M = M * B
end
m ./= sum(m)
return [m]
return m
end

function TensorTrains.normalize_eachmatrix!(A::InfiniteUniformTensorTrain{F}) where {F}
Expand Down
6 changes: 5 additions & 1 deletion test/uniform_tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ end

@testset "Marginals" begin
marg = only(marginals(A))
two_marg = only(twovar_marginals(A))
@test_throws DomainError twovar_marginals(A; maxdist=-2)
tv = twovar_marginals(A; maxdist=3)
@test tv[1,2] == tv[2,3] == tv[3,4]
@test tv[1,3] == tv[2,4]
two_marg = tv[1,2]
N = ndims(two_marg)
N2 = N ÷ 2
@test sum(two_marg, dims=N2+1:N)[:,:,1,1] sum(two_marg, dims=1:N2)[1,1,:,:] marg
Expand Down

0 comments on commit 23c8768

Please sign in to comment.