Skip to content

Commit

Permalink
normalize last matrix as well
Browse files Browse the repository at this point in the history
stecrotti committed May 27, 2024
1 parent 49d690a commit 820a7ad
Showing 3 changed files with 21 additions and 22 deletions.
21 changes: 10 additions & 11 deletions src/periodic_tensor_train.jl
Original file line number Diff line number Diff line change
@@ -91,16 +91,16 @@ function orthogonalize_right!(C::PeriodicTensorTrain{F}; svd_trunc=TruncThresh(1
c = Logarithmic(one(F))

for t in length(C):-1:2
mt = maximum(abs, M)
if !isnan(mt) && !isinf(mt) && !iszero(mt)
M ./= mt
c *= mt
end
U, λ, V = svd_trunc(M)
@cast Aᵗ[m, n, x] := V'[m, (n, x)] x 1:q
C[t] = _reshapeas(Aᵗ, C[t])
Cᵗ⁻¹ = _reshape1(C[t-1])
@tullio D[m, n, x] := Cᵗ⁻¹[m, k, x] * U[k, n] * λ[n]
m = maximum(abs, D)
if !isnan(m) && !isinf(m) && !iszero(m)
D ./= m
c *= m
end
@cast M[m, (n, x)] := D[m, n, x]
end
C[begin] = _reshapeas(D, C[begin])
@@ -115,18 +115,17 @@ function orthogonalize_left!(A::PeriodicTensorTrain{F}; svd_trunc=TruncThresh(1e
D = fill(1.0,1,1,1) # initialize
c = Logarithmic(one(F))


for t in 1:length(A)-1
mt = maximum(abs, M)
if !isnan(mt) && !isinf(mt) && !iszero(mt)
M ./= mt
c *= mt
end
U, λ, V = svd_trunc(M)
@cast Aᵗ[m, n, x] := U[(m, x), n] x 1:q
A[t] = _reshapeas(Aᵗ, A[t])
Aᵗ⁺¹ = _reshape1(A[t+1])
@tullio D[m, n, x] := λ[m] * V'[m, l] * Aᵗ⁺¹[l, n, x]
m = maximum(abs, D)
if !isnan(m) && !isinf(m) && !iszero(m)
D ./= m
c *= m
end
@cast M[(m, x), n] |= D[m, n, x]
end
U, λ, V = svd_trunc(M)
20 changes: 10 additions & 10 deletions src/tensor_train.jl
Original file line number Diff line number Diff line change
@@ -76,16 +76,16 @@ function orthogonalize_right!(C::TensorTrain{F}; svd_trunc=TruncThresh(1e-6)) wh
c = Logarithmic(one(F))

for t in length(C):-1:2
mt = maximum(abs, M)
if !isnan(mt) && !isinf(mt) && !iszero(mt)
M ./= mt
c *= mt
end
U, λ, V = svd_trunc(M)
@cast Aᵗ[m, n, x] := V'[m, (n, x)] x 1:q
C[t] = _reshapeas(Aᵗ, C[t])
Cᵗ⁻¹ = _reshape1(C[t-1])
@tullio D[m, n, x] := Cᵗ⁻¹[m, k, x] * U[k, n] * λ[n]
m = maximum(abs, D)
if !isnan(m) && !isinf(m) && !iszero(m)
D ./= m
c *= m
end
@cast M[m, (n, x)] := D[m, n, x]
end
C[begin] = _reshapeas(D, C[begin])
@@ -108,16 +108,16 @@ function orthogonalize_left!(C::TensorTrain{F}; svd_trunc=TruncThresh(1e-6)) whe
c = Logarithmic(one(F))

for t in 1:length(C)-1
mt = maximum(abs, M)
if !isnan(mt) && !isinf(mt) && !iszero(mt)
M ./= mt
c *= mt
end
U, λ, V = svd_trunc(M)
@cast Aᵗ[m, n, x] := U[(m, x), n] x 1:q
C[t] = _reshapeas(Aᵗ, C[t])
Cᵗ⁺¹ = _reshape1(C[t+1])
@tullio D[m, n, x] := λ[m] * V'[m, l] * Cᵗ⁺¹[l, n, x]
m = maximum(abs, D)
if !isnan(m) && !isinf(m) && !iszero(m)
D ./= m
c *= m
end
@cast M[(m, x), n] |= D[m, n, x]
end
C[end] = _reshapeas(D, C[end])
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -14,6 +14,6 @@ function sample_noalloc(rng::AbstractRNG, w)
i += 1
cw > t && return i
end
@assert false
@assert false "$w"

Check warning on line 17 in src/utils.jl

Codecov / codecov/patch

src/utils.jl#L17

Added line #L17 was not covered by tests
end
sample_noalloc(w) = sample_noalloc(GLOBAL_RNG, w)

0 comments on commit 820a7ad

Please sign in to comment.