diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index 473b545..603d15d 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -7,6 +7,7 @@ import math import warnings +from typing import Union import torch @@ -48,12 +49,12 @@ class ConditionalFlowMatcher: - score function $\nabla log p_t(x|x0, x1)$ """ - def __init__(self, sigma: float = 0.0): + def __init__(self, sigma: Union[float, int] = 0.0): r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. Parameters ---------- - sigma : float + sigma : Union[float, int] """ self.sigma = sigma @@ -216,15 +217,15 @@ class ExactOptimalTransportConditionalFlowMatcher(ConditionalFlowMatcher): It overrides the sample_location_and_conditional_flow. """ - def __init__(self, sigma: float = 0.0): + def __init__(self, sigma: Union[float, int] = 0.0): r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. Parameters ---------- - sigma : float + sigma : Union[float, int] ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]). """ - self.sigma = sigma + super().__init__(sigma) self.ot_sampler = OTPlanSampler(method="exact") def sample_location_and_conditional_flow(self, x0, x1, return_noise=False): @@ -383,13 +384,13 @@ class SchrodingerBridgeConditionalFlowMatcher(ConditionalFlowMatcher): sample_location_and_conditional_flow functions. """ - def __init__(self, sigma: float = 1.0, ot_method="exact"): + def __init__(self, sigma: Union[float, int] = 1.0, ot_method="exact"): r"""Initialize the SchrodingerBridgeConditionalFlowMatcher class. It requires the hyper- parameter $\sigma$ and the entropic OT map. Parameters ---------- - sigma : float + sigma : Union[float, int] ot_sampler: exact OT method to draw couplings (x0, x1) (see Eq.(17) [1]). we use exact as the default as we found this to perform better (more accurate and faster) in practice for reasonable batch sizes. @@ -400,7 +401,7 @@ def __init__(self, sigma: float = 1.0, ot_method="exact"): raise ValueError(f"Sigma must be strictly positive, got {sigma}.") elif sigma < 1e-3: warnings.warn("Small sigma values may lead to numerical instability.") - self.sigma = sigma + super().__init__(sigma) self.ot_method = ot_method self.ot_sampler = OTPlanSampler(method=ot_method, reg=2 * self.sigma**2)