Skip to content

Commit

Permalink
change test name folder
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Nov 14, 2023
1 parent 6acfca7 commit 8fbf8f7
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
File renamed without changes.
70 changes: 70 additions & 0 deletions tests/test_time_t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Tests for time Tensor t."""

# Author: Kilian Fatras <[email protected]>

import pytest
import torch

from torchcfm.conditional_flow_matching import (
ConditionalFlowMatcher,
ExactOptimalTransportConditionalFlowMatcher,
SchrodingerBridgeConditionalFlowMatcher,
TargetConditionalFlowMatcher,
VariancePreservingConditionalFlowMatcher,
)

seed = 1994
batch_size = 128


@pytest.mark.parametrize(
"FM",
[
ConditionalFlowMatcher(sigma=0.0),
ExactOptimalTransportConditionalFlowMatcher(sigma=0.0),
TargetConditionalFlowMatcher(sigma=0.0),
SchrodingerBridgeConditionalFlowMatcher(sigma=0.0),
VariancePreservingConditionalFlowMatcher(sigma=0.0),
],
)
def test_random_Tensor_t(FM):
# Test sample_location_and_conditional_flow functions
x0 = torch.randn(batch_size, 2)
x1 = torch.randn(batch_size, 2)

torch.manual_seed(seed)
t_given = torch.rand(batch_size)
t_given, xt, ut = FM.sample_location_and_conditional_flow(x0, x1, t=t_given)

torch.manual_seed(seed)
t_random, xt, ut = FM.sample_location_and_conditional_flow(x0, x1, t=None)

assert any(t_given == t_random)


@pytest.mark.parametrize(
"FM",
[
ExactOptimalTransportConditionalFlowMatcher(sigma=0.0),
SchrodingerBridgeConditionalFlowMatcher(sigma=0.0),
],
)
def test_guided_random_Tensor_t(FM):
# Test guided_sample_location_and_conditional_flow functions
x0 = torch.randn(batch_size, 2)
y0 = torch.randint(high=10, size=(batch_size, 1))
x1 = torch.randn(batch_size, 2)
y1 = torch.randint(high=10, size=(batch_size, 1))

torch.manual_seed(seed)
t_given = torch.rand(batch_size)
t_given, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=t_given
)

torch.manual_seed(seed)
t_random, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow(
x0, x1, y0=y0, y1=y1, t=None
)

assert any(t_given == t_random)

0 comments on commit 8fbf8f7

Please sign in to comment.