From b12f8fa15ae7d3b38eba569d026175805c84b495 Mon Sep 17 00:00:00 2001 From: guillaumehu Date: Mon, 13 Nov 2023 14:12:05 +0100 Subject: [PATCH] property for sigma int/float --- torchcfm/conditional_flow_matching.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torchcfm/conditional_flow_matching.py b/torchcfm/conditional_flow_matching.py index 3e518ef..99a9742 100644 --- a/torchcfm/conditional_flow_matching.py +++ b/torchcfm/conditional_flow_matching.py @@ -54,7 +54,16 @@ def __init__(self, sigma: float = 0.0): ---------- sigma : float """ - self.sigma = sigma + self._sigma = sigma + + @property + def sigma(self): + if isinstance(self._sigma, float): + return self._sigma + elif isinstance(self._sigma, int): + return float(self._sigma) + else: + raise ValueError("Sigma must be a float or int.") def compute_mu_t(self, x0, x1, t): """