Skip to content

Commit

Permalink
Merge branch 'main' into add_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
atong01 authored Nov 23, 2023
2 parents c0a8dd3 + 7cb209d commit 77bdc91
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions torchcfm/conditional_flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import math
import warnings
from typing import Union

import torch

Expand Down Expand Up @@ -48,12 +49,12 @@ class ConditionalFlowMatcher:
- score function $\nabla log p_t(x|x0, x1)$
"""

def __init__(self, sigma: float = 0.0):
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
Parameters
----------
sigma : float
sigma : Union[float, int]
"""
self.sigma = sigma

Expand Down Expand Up @@ -216,15 +217,15 @@ class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher):
It overrides the sample_location_and_conditional_flow.
"""

def __init__(self, sigma: float = 0.0):
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
Parameters
----------
sigma : float
sigma : Union[float, int]
ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]).
"""
self.sigma = sigma
super().__init__(sigma)
self.ot_sampler = OTPlanSampler(method="exact")

def sample_location_and_conditional_flow(self, x0, x1, return_noise=False):
Expand Down Expand Up @@ -383,13 +384,13 @@ class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher):
sample_location_and_conditional_flow functions.
"""

def __init__(self, sigma: float = 1.0, ot_method="exact"):
def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"):
r"""Initialize the SchrodingerBridgeConditionalFlowMatcher class. It requires the hyper-
parameter $\sigma$ and the entropic OT map.
Parameters
----------
sigma : float
sigma : Union[float, int]
ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]).
we use exact as the default as we found this to perform better
(more accurate and faster) in practice for reasonable batch sizes.
Expand All @@ -400,7 +401,7 @@ def __init__(self, sigma: float = 1.0, ot_method="exact"):
raise ValueError(f"Sigma must be strictly positive, got {sigma}.")
elif sigma < 1e-3:
warnings.warn("Small sigma values may lead to numerical instability.")

Check warning on line 403 in torchcfm/conditional_flow_matching.py

View check run for this annotation

Codecov / codecov/patch

torchcfm/conditional_flow_matching.py#L403

Added line #L403 was not covered by tests
self.sigma = sigma
super().__init__(sigma)
self.ot_method = ot_method
self.ot_sampler = OTPlanSampler(method=ot_method, reg=2 * self.sigma**2)

Expand Down

0 comments on commit 77bdc91

Please sign in to comment.