Skip to content

Commit

Permalink
more detailed benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
stecrotti committed Oct 17, 2023
1 parent 5a74cde commit ae92ea3
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
using BenchmarkTools
using TensorTrains
import TensorTrains: accumulate_L, accumulate_R

SUITE = BenchmarkGroup()

L = 10
qs = (2, 5, 3)
d = 10
q = rand_tt(d, L, qs...)
p = rand_periodic_tt(d, L+1, qs...)

SUITE["accumulators"] = BenchmarkGroup()
SUITE["accumulators"]["accumL_tensortrain"] = @benchmarkable accumulate_L($q)
# SUITE["accumulators"]["accumR_tensortrain"] = @benchmarkable accumulate_R($q)
SUITE["accumulators"]["accumL_periodic"] = @benchmarkable accumulate_L($p)
# SUITE["accumulators"]["accumR_periodic"] = @benchmarkable accumulate_R($p)

SUITE["marginals"] = BenchmarkGroup()
SUITE["marginals"]["marginals_tensortrain"] = @benchmarkable marginals($q)
SUITE["marginals"]["marginals_periodic"] = @benchmarkable marginals($p)

SUITE["twovar_marginals"] = BenchmarkGroup()
SUITE["marginals"]["twovar_marginals_tensortrain"] = @benchmarkable twovar_marginals($q)
SUITE["marginals"]["twovar_marginals_periodic"] = @benchmarkable twovar_marginals($p)

SUITE["orthogonalize"] = BenchmarkGroup()
SUITE["orthogonalize"]["orth_left_tensortrain"] = @benchmarkable orthogonalize_left!($q)
svd_trunc = TruncThresh(0.0)
SUITE["orthogonalize"]["orth_left_tensortrain"] = @benchmarkable orthogonalize_left!($q; svd_trunc=$svd_trunc)
SUITE["orthogonalize"]["orth_left_periodic"] = @benchmarkable orthogonalize_left!($p; svd_trunc=$svd_trunc)

SUITE["sampling"] = BenchmarkGroup()
x = [[rand(1:q) for q in qs] for _ in 1:L]
Expand All @@ -25,7 +37,10 @@ function nsamples!(x, q, n)
end
end
SUITE["sampling"]["sample_tensortrain"] = @benchmarkable nsamples!($x, $q, 20)
SUITE["sampling"]["sample_periodic"] = @benchmarkable nsamples!($x, $p, 20)

SUITE["dot"] = BenchmarkGroup()
q2 = rand_tt(d, L, qs...)
SUITE["dot"]["dot_tensortrain"] = @benchmarkable dot($q, $q2)
p2 = rand_periodic_tt(d, L, qs...)
SUITE["dot"]["dot_periodic"] = @benchmarkable dot($p, $p2)

0 comments on commit ae92ea3

Please sign in to comment.