From 1b60fe6a3375d33785b99f32422e278c43e5cce8 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Fri, 29 Nov 2024 18:30:28 +0200 Subject: [PATCH 1/3] Fix typing for RandomErasing --- torchvision/transforms/v2/_augment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index b1dd5083408..4fce2e756c4 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -58,7 +58,7 @@ def __init__( p: float = 0.5, scale: Sequence[float] = (0.02, 0.33), ratio: Sequence[float] = (0.3, 3.3), - value: float = 0.0, + value: Union[float, Sequence[float]] = 0.0, inplace: bool = False, ): super().__init__(p=p) From f8d84ed84e9719f11b370bc22b3671a3972e086d Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Tue, 3 Dec 2024 12:26:21 +0200 Subject: [PATCH 2/3] Trying to fix mypy error --- torchvision/transforms/v2/_augment.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 4fce2e756c4..8e9f8ce13f4 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -77,13 +77,13 @@ def __init__( self.scale = scale self.ratio = ratio if isinstance(value, (int, float)): - self.value = [float(value)] + value = [float(value)] elif isinstance(value, str): - self.value = None + value = None elif isinstance(value, (list, tuple)): - self.value = [float(v) for v in value] - else: - self.value = value + value = [float(v) for v in value] + + self.value: Optional[Sequence[float]] = value self.inplace = inplace self._log_ratio = torch.log(torch.tensor(self.ratio)) From dcf5319c3722fdf4828d99c0505c882de52ee0a1 Mon Sep 17 00:00:00 2001 From: ancestor-mithril <58839912+ancestor-mithril@users.noreply.github.com> Date: Tue, 3 Dec 2024 13:03:46 +0200 Subject: [PATCH 3/3] Removed space --- torchvision/transforms/v2/_augment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 8e9f8ce13f4..d5fb8a894bc 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -82,7 +82,6 @@ def __init__( value = None elif isinstance(value, (list, tuple)): value = [float(v) for v in value] - self.value: Optional[Sequence[float]] = value self.inplace = inplace