diff --git a/aloscene/frame.py b/aloscene/frame.py index 304fc6fd..ba71711a 100644 --- a/aloscene/frame.py +++ b/aloscene/frame.py @@ -465,6 +465,9 @@ def norm_minmax_sym(self): tensor = 2 * tensor - 1.0 elif tensor.normalization == "255": tensor = 2 * (tensor / 255.0) - 1.0 + elif tensor.mean_std is not None: + tensor = tensor.norm01() + tensor = 2 * tensor - 1.0 else: raise Exception(f"Can't convert from {tensor.normalization} to norm255") tensor.mean_std = None @@ -484,7 +487,7 @@ def mean_std_norm(self, mean, std, name) -> Frame: """ tensor = self mean_tensor, std_tensor = self._get_mean_std_tensor( - tensor.shape, tensor.names, tensor._resnet_mean_std, device=tensor.device + tensor.shape, tensor.names, (mean, std), device=tensor.device ) if tensor.normalization == "01": tensor = tensor - mean_tensor @@ -507,10 +510,27 @@ def mean_std_norm(self, mean, std, name) -> Frame: return tensor + def norm_meanstd(self, mean_std=None, name=None) -> Frame: + """Returns z-norm of the current frame. + This method will simply call `frame.mean_std_norm()` with the mean/std property of the frame and the selected name. + Instead of a custom mean/std, you can use the resnet norm based on the normalized use of resnet on pytorch. + Examples + -------- + >>> frame_resnet = frame.norm_resnet(name="resnet") + >>> frame_custom_norm = frame.norm_resnet(name="custom") + """ + if name == "resnet": + return self.norm_resnet() + elif mean_std is not None and len(mean_std) == 2: + return self.mean_std_norm(mean=mean_std[0], std=mean_std[1], name=name) + elif self.mean_std is not None: + return self.mean_std_norm(mean=self.mean_std[0], std=self.mean_std[1], name=name) + else: + raise Exception("Please pass a mean_std tuple or use the resnet norm") + def norm_resnet(self) -> Frame: """Normalized the current frame based on the normalized use on resnet on pytorch. This method will simply call `frame.mean_std_norm()` with the resnet mean/std and the name `resnet`. - Examples -------- >>> frame_resnet = frame.norm_resnet()