diff --git a/tests/test_graph.py b/tests/test_graph.py index 14fe15b..40c9e2b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,3 +1,4 @@ +import shutil import torch import pytest import numpy as np @@ -434,7 +435,9 @@ def test_GraphSeqStats(): seq_stats = GraphStats(classes=classes, bin_size=1) seq_stats.add_batch(pred, ref) - # seq_stats.plot('./seq_stats') - # seq_stats.report('./seq_stats') + outdir = './test_stats' + seq_stats.plot(outdir) + seq_stats.report(outdir) + shutil.rmtree(outdir) # fmt:on