Skip to content

Commit

Permalink
Merge pull request #80 from atong01/add_tests
Browse files Browse the repository at this point in the history
Add tests on lambda_t, low sigma values and guided functions
  • Loading branch information
kilianFatras authored Nov 30, 2023
2 parents e96c52c + 3f08702 commit 3c0cfdd
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ jobs:
pip install -e .
- name: Run tests and collect coverage
run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/
run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/ --cov-fail-under=30

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
name: codecov-torchcfm
verbose: true
fail_ci_if_error: true
3 changes: 1 addition & 2 deletions .github/workflows/test_runner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ jobs:
pip install -e .
- name: Run tests and collect coverage
run: pytest runner --cov runner # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
run: pytest runner --cov runner --cov-fail-under=30 # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
name: codecov-runner
verbose: true
fail_ci_if_error: true
3 changes: 2 additions & 1 deletion tests/test_conditional_flow_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def sample_plan(method, x0, x1, sigma):

@pytest.mark.parametrize("method", ["vp_cfm", "t_cfm", "sb_cfm", "exact_ot_cfm", "i_cfm"])
# Test both integer and floating sigma
@pytest.mark.parametrize("sigma", [0.0, 0.5, 1.5, 0, 1])
@pytest.mark.parametrize("sigma", [0.0, 5e-4, 0.5, 1.5, 0, 1])
@pytest.mark.parametrize("shape", [[1], [2], [1, 2], [3, 4, 5]])
def test_fm(method, sigma, shape):
batch_size = TEST_BATCH_SIZE
Expand All @@ -107,6 +107,7 @@ def test_fm(method, sigma, shape):
torch.manual_seed(TEST_SEED)
np.random.seed(TEST_SEED)
t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True)
_ = FM.compute_lambda(t)

if method in ["sb_cfm", "exact_ot_cfm"]:
torch.manual_seed(TEST_SEED)
Expand Down
15 changes: 8 additions & 7 deletions tests/test_time_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def test_random_Tensor_t(FM):
SchrodingerBridgeConditionalFlowMatcher(sigma=0.1),
],
)
def test_guided_random_Tensor_t(FM):
@pytest.mark.parametrize("return_noise", [True, False])
def test_guided_random_Tensor_t(FM, return_noise):
# Test guided_sample_location_and_conditional_flow functions
x0 = torch.randn(batch_size, 2)
y0 = torch.randint(high=10, size=(batch_size, 1))
Expand All @@ -58,13 +59,13 @@ def test_guided_random_Tensor_t(FM):

torch.manual_seed(seed)
t_given = torch.rand(batch_size)
t_given, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=t_given
)
t_given = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=t_given, return_noise=return_noise
)[0]

torch.manual_seed(seed)
t_random, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=None
)
t_random = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=None, return_noise=return_noise
)[0]

assert any(t_given == t_random)

0 comments on commit 3c0cfdd

Please sign in to comment.