From b319b480535c8ac40699984589a7ad53ca0dff85 Mon Sep 17 00:00:00 2001 From: Kilian Date: Thu, 30 Nov 2023 10:27:28 -0500 Subject: [PATCH 1/5] add return_noise test for guided functions --- tests/test_time_t.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/tests/test_time_t.py b/tests/test_time_t.py index 89d9499..e4c5e46 100644 --- a/tests/test_time_t.py +++ b/tests/test_time_t.py @@ -49,22 +49,35 @@ def test_random_Tensor_t(FM): SchrodingerBridgeConditionalFlowMatcher(sigma=0.1), ], ) -def test_guided_random_Tensor_t(FM): +@pytest.mark.parametrize("return_noise", [True, False]) +def test_guided_random_Tensor_t(FM, return_noise): # Test guided_sample_location_and_conditional_flow functions x0 = torch.randn(batch_size, 2) y0 = torch.randint(high=10, size=(batch_size, 1)) x1 = torch.randn(batch_size, 2) y1 = torch.randint(high=10, size=(batch_size, 1)) - torch.manual_seed(seed) - t_given = torch.rand(batch_size) - t_given, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow( - x0, x1, y0=y0, y1=y1, t=t_given - ) + if return_noise: + torch.manual_seed(seed) + t_given = torch.rand(batch_size) + t_given, xt, ut, y0, y1, eps = FM.guided_sample_location_and_conditional_flow( + x0, x1, y0=y0, y1=y1, t=t_given, return_noise=return_noise + ) - torch.manual_seed(seed) - t_random, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow( - x0, x1, y0=y0, y1=y1, t=None - ) + torch.manual_seed(seed) + t_random, xt, ut, y0, y1, eps = FM.guided_sample_location_and_conditional_flow( + x0, x1, y0=y0, y1=y1, t=None, return_noise=return_noise + ) + else: + torch.manual_seed(seed) + t_given = torch.rand(batch_size) + t_given, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow( + x0, x1, y0=y0, y1=y1, t=t_given + ) + + torch.manual_seed(seed) + t_random, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow( + x0, x1, y0=y0, y1=y1, t=None + ) assert any(t_given == t_random) From 2ee5088e51d3fc04e207ad0e2b267ee572942c63 Mon Sep 17 00:00:00 2001 From: Kilian Date: Thu, 30 Nov 2023 10:27:52 -0500 Subject: [PATCH 2/5] add low sigma value to get warning --- tests/test_conditional_flow_matcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_conditional_flow_matcher.py b/tests/test_conditional_flow_matcher.py index 6a12d4e..2c00a1e 100644 --- a/tests/test_conditional_flow_matcher.py +++ b/tests/test_conditional_flow_matcher.py @@ -92,7 +92,7 @@ def sample_plan(method, x0, x1, sigma): @pytest.mark.parametrize("method", ["vp_cfm", "t_cfm", "sb_cfm", "exact_ot_cfm", "i_cfm"]) # Test both integer and floating sigma -@pytest.mark.parametrize("sigma", [0.0, 0.5, 1.5, 0, 1]) +@pytest.mark.parametrize("sigma", [0.0, 5e-4, 0.5, 1.5, 0, 1]) @pytest.mark.parametrize("shape", [[1], [2], [1, 2], [3, 4, 5]]) def test_fm(method, sigma, shape): batch_size = TEST_BATCH_SIZE From ea6e5d01f419628291e918af9a73a422dd853a85 Mon Sep 17 00:00:00 2001 From: Kilian Date: Thu, 30 Nov 2023 10:31:24 -0500 Subject: [PATCH 3/5] compute lambda_t in test --- tests/test_conditional_flow_matcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_conditional_flow_matcher.py b/tests/test_conditional_flow_matcher.py index 2c00a1e..b080470 100644 --- a/tests/test_conditional_flow_matcher.py +++ b/tests/test_conditional_flow_matcher.py @@ -107,6 +107,7 @@ def test_fm(method, sigma, shape): torch.manual_seed(TEST_SEED) np.random.seed(TEST_SEED) t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True) + _ = FM.compute_lambda(t) if method in ["sb_cfm", "exact_ot_cfm"]: torch.manual_seed(TEST_SEED) From d0202a8bb018a0b366f1368a8704d3890d6042b8 Mon Sep 17 00:00:00 2001 From: Alex Tong Date: Thu, 30 Nov 2023 11:05:06 -0500 Subject: [PATCH 4/5] Cleanup test --- tests/test_time_t.py | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/tests/test_time_t.py b/tests/test_time_t.py index e4c5e46..d79ef7e 100644 --- a/tests/test_time_t.py +++ b/tests/test_time_t.py @@ -57,27 +57,15 @@ def test_guided_random_Tensor_t(FM, return_noise): x1 = torch.randn(batch_size, 2) y1 = torch.randint(high=10, size=(batch_size, 1)) - if return_noise: - torch.manual_seed(seed) - t_given = torch.rand(batch_size) - t_given, xt, ut, y0, y1, eps = FM.guided_sample_location_and_conditional_flow( - x0, x1, y0=y0, y1=y1, t=t_given, return_noise=return_noise - ) - - torch.manual_seed(seed) - t_random, xt, ut, y0, y1, eps = FM.guided_sample_location_and_conditional_flow( - x0, x1, y0=y0, y1=y1, t=None, return_noise=return_noise - ) - else: - torch.manual_seed(seed) - t_given = torch.rand(batch_size) - t_given, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow( - x0, x1, y0=y0, y1=y1, t=t_given - ) + torch.manual_seed(seed) + t_given = torch.rand(batch_size) + t_given = FM.guided_sample_location_and_conditional_flow( + x0, x1, y0=y0, y1=y1, t=t_given, return_noise=return_noise + )[0] - torch.manual_seed(seed) - t_random, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow( - x0, x1, y0=y0, y1=y1, t=None - ) + torch.manual_seed(seed) + t_random = FM.guided_sample_location_and_conditional_flow( + x0, x1, y0=y0, y1=y1, t=None, return_noise=return_noise + )[0] assert any(t_given == t_random) From 3f08702f44ac1e9f71f758ddb881a470d054393d Mon Sep 17 00:00:00 2001 From: Alex Tong Date: Thu, 30 Nov 2023 11:13:23 -0500 Subject: [PATCH 5/5] Cleanup coverage workflow configuration --- .github/workflows/test.yaml | 3 +-- .github/workflows/test_runner.yaml | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index c68097d..23e57cb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -62,11 +62,10 @@ jobs: pip install -e . - name: Run tests and collect coverage - run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/ + run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/ --cov-fail-under=30 - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: name: codecov-torchcfm verbose: true - fail_ci_if_error: true diff --git a/.github/workflows/test_runner.yaml b/.github/workflows/test_runner.yaml index f840b04..e901d93 100644 --- a/.github/workflows/test_runner.yaml +++ b/.github/workflows/test_runner.yaml @@ -64,11 +64,10 @@ jobs: pip install -e . - name: Run tests and collect coverage - run: pytest runner --cov runner # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER + run: pytest runner --cov runner --cov-fail-under=30 # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: name: codecov-runner verbose: true - fail_ci_if_error: true