diff --git a/vsaa/antialiasers/eedi3.py b/vsaa/antialiasers/eedi3.py index 2120eab..2d6efee 100644 --- a/vsaa/antialiasers/eedi3.py +++ b/vsaa/antialiasers/eedi3.py @@ -34,17 +34,17 @@ class EEDI3(_Antialiaser): opencl: bool = dc_field(default=False, kw_only=True) mclip: vs.VideoNode | None = None - sclip_aa: type[Antialiaser] | Antialiaser | Literal[True] | None = dc_field( + sclip_aa: type[Antialiaser] | Antialiaser | Literal[True] | vs.VideoNode | None = dc_field( default_factory=lambda: nnedi3.Nnedi3 ) def __post_init__(self) -> None: super().__post_init__() - self._sclip_aa: Antialiaser | Literal[True] | None + self._sclip_aa: Antialiaser | Literal[True] | vs.VideoNode | None - if self.sclip_aa and self.sclip_aa is not True and not isinstance(self.sclip_aa, Antialiaser): - self._sclip_aa = self.sclip_aa() + if self.sclip_aa and self.sclip_aa is not True and not isinstance(self.sclip_aa, (Antialiaser, vs.VideoNode)): + self._sclip_aa = self.sclip_aa() # type: ignore[operator] else: self._sclip_aa = self.sclip_aa # type: ignore[assignment] @@ -67,23 +67,7 @@ def get_aa_args(self, clip: vs.VideoNode, **kwargs: Any) -> dict[str, Any]: def interpolate(self, clip: vs.VideoNode, double_y: bool, **kwargs: Any) -> vs.VideoNode: aa_kwargs = self.get_aa_args(clip, **kwargs) - - if self._sclip_aa and ((('sclip' in kwargs) and not kwargs['sclip']) or 'sclip' not in kwargs): - if self._sclip_aa is True: - if double_y: - raise CustomValueError("You can't pass sclip_aa=True when supersampling!", self.__class__) - aa_kwargs.update(sclip=clip) - else: - sclip_args = self._sclip_aa.get_aa_args(clip, **(dict(mclip=kwargs.get('mclip', None)))) - - if double_y: - sclip_args |= self._sclip_aa.get_ss_args(clip) - else: - sclip_args |= self._sclip_aa.get_sr_args(clip) - - aa_kwargs.update( - sclip=self._sclip_aa.interpolate(clip, double_y or not self.drop_fields, **sclip_args) - ) + aa_kwargs = self._handle_sclip(clip, double_y, aa_kwargs, **kwargs) function = core.eedi3m.EEDI3CL if self.opencl else core.eedi3m.EEDI3 @@ -93,6 +77,35 @@ def interpolate(self, clip: vs.VideoNode, double_y: bool, **kwargs: Any) -> vs.V return self.shift_interpolate(clip, interpolated, double_y, **kwargs) + def _handle_sclip( + self, clip: vs.VideoNode, double_y: bool, aa_kwargs: dict[str, Any], **kwargs: Any + ) -> dict[str, Any]: + if not self._sclip_aa or ('sclip' in kwargs and kwargs['sclip']): + return aa_kwargs + + if self._sclip_aa is True or isinstance(self._sclip_aa, vs.VideoNode): + if double_y: + if self._sclip_aa is True: + raise CustomValueError("You can't pass sclip_aa=True when supersampling!", self.__class__) + + if (clip.width, clip.height) != (self._sclip_aa.width, self._sclip_aa.height): + raise CustomValueError( + f'The dimensions of sclip_aa ({self._sclip_aa.width}x{self._sclip_aa.height}) ' + f'don\'t match the expected dimensions ({clip.width}x{clip.height})!', + self.__class__ + ) + + aa_kwargs.update(dict(sclip=clip)) + + return aa_kwargs + + sclip_args = self._sclip_aa.get_aa_args(clip, mclip=kwargs.get('mclip')) + sclip_args.update(self._sclip_aa.get_ss_args(clip) if double_y else self._sclip_aa.get_sr_args(clip)) + + aa_kwargs.update(dict(sclip=self._sclip_aa.interpolate(clip, double_y or not self.drop_fields, **sclip_args))) + + return aa_kwargs + _shift = 0.5 @inject_self.property