-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
normalizing frames with custom mean_std instead of defaulting to resnet #316
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -465,6 +465,9 @@ def norm_minmax_sym(self): | |
tensor = 2 * tensor - 1.0 | ||
elif tensor.normalization == "255": | ||
tensor = 2 * (tensor / 255.0) - 1.0 | ||
elif tensor.mean_std is not None: | ||
tensor = tensor.norm01() | ||
tensor = 2 * tensor - 1.0 | ||
else: | ||
raise Exception(f"Can't convert from {tensor.normalization} to norm255") | ||
tensor.mean_std = None | ||
|
@@ -484,7 +487,7 @@ def mean_std_norm(self, mean, std, name) -> Frame: | |
""" | ||
tensor = self | ||
mean_tensor, std_tensor = self._get_mean_std_tensor( | ||
tensor.shape, tensor.names, tensor._resnet_mean_std, device=tensor.device | ||
tensor.shape, tensor.names, (mean, std), device=tensor.device | ||
) | ||
if tensor.normalization == "01": | ||
tensor = tensor - mean_tensor | ||
|
@@ -507,10 +510,27 @@ def mean_std_norm(self, mean, std, name) -> Frame: | |
|
||
return tensor | ||
|
||
def norm_meanstd(self, mean_std=None, name=None) -> Frame: | ||
"""Returns z-norm of the current frame. | ||
This method will simply call `frame.mean_std_norm()` with the mean/std property of the frame and the selected name. | ||
Instead of a custom mean/std, you can use the resnet norm based on the normalized use of resnet on pytorch. | ||
Examples | ||
-------- | ||
>>> frame_resnet = frame.norm_resnet(name="resnet") | ||
>>> frame_custom_norm = frame.norm_resnet(name="custom") | ||
""" | ||
if name == "resnet": | ||
return self.norm_resnet() | ||
elif mean_std is not None and len(mean_std) == 2: | ||
return self.mean_std_norm(mean=mean_std[0], std=mean_std[1], name=name) | ||
elif self.mean_std is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's no difference between this condition and the one above. |
||
return self.mean_std_norm(mean=self.mean_std[0], std=self.mean_std[1], name=name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The attribute |
||
else: | ||
raise Exception("Please pass a mean_std tuple or use the resnet norm") | ||
|
||
def norm_resnet(self) -> Frame: | ||
"""Normalized the current frame based on the normalized use on resnet on pytorch. This method will | ||
simply call `frame.mean_std_norm()` with the resnet mean/std and the name `resnet`. | ||
|
||
Examples | ||
-------- | ||
>>> frame_resnet = frame.norm_resnet() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having two methodes with nearly the same name is not a good practice. try to fix
mean_std_norm
instead