Skip to content

Commit

Permalink
Eedi3: Allow VideoNode to be passed to sclip_aa, separate logic (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
LightArrowsEXE authored Oct 7, 2024
1 parent 40dc819 commit 3f9b79c
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions vsaa/antialiasers/eedi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,17 @@ class EEDI3(_Antialiaser):
opencl: bool = dc_field(default=False, kw_only=True)

mclip: vs.VideoNode | None = None
sclip_aa: type[Antialiaser] | Antialiaser | Literal[True] | None = dc_field(
sclip_aa: type[Antialiaser] | Antialiaser | Literal[True] | vs.VideoNode | None = dc_field(
default_factory=lambda: nnedi3.Nnedi3
)

def __post_init__(self) -> None:
super().__post_init__()

self._sclip_aa: Antialiaser | Literal[True] | None
self._sclip_aa: Antialiaser | Literal[True] | vs.VideoNode | None

if self.sclip_aa and self.sclip_aa is not True and not isinstance(self.sclip_aa, Antialiaser):
self._sclip_aa = self.sclip_aa()
if self.sclip_aa and self.sclip_aa is not True and not isinstance(self.sclip_aa, (Antialiaser, vs.VideoNode)):
self._sclip_aa = self.sclip_aa() # type: ignore[operator]
else:
self._sclip_aa = self.sclip_aa # type: ignore[assignment]

Expand All @@ -67,23 +67,7 @@ def get_aa_args(self, clip: vs.VideoNode, **kwargs: Any) -> dict[str, Any]:

def interpolate(self, clip: vs.VideoNode, double_y: bool, **kwargs: Any) -> vs.VideoNode:
aa_kwargs = self.get_aa_args(clip, **kwargs)

if self._sclip_aa and ((('sclip' in kwargs) and not kwargs['sclip']) or 'sclip' not in kwargs):
if self._sclip_aa is True:
if double_y:
raise CustomValueError("You can't pass sclip_aa=True when supersampling!", self.__class__)
aa_kwargs.update(sclip=clip)
else:
sclip_args = self._sclip_aa.get_aa_args(clip, **(dict(mclip=kwargs.get('mclip', None))))

if double_y:
sclip_args |= self._sclip_aa.get_ss_args(clip)
else:
sclip_args |= self._sclip_aa.get_sr_args(clip)

aa_kwargs.update(
sclip=self._sclip_aa.interpolate(clip, double_y or not self.drop_fields, **sclip_args)
)
aa_kwargs = self._handle_sclip(clip, double_y, aa_kwargs, **kwargs)

function = core.eedi3m.EEDI3CL if self.opencl else core.eedi3m.EEDI3

Expand All @@ -93,6 +77,35 @@ def interpolate(self, clip: vs.VideoNode, double_y: bool, **kwargs: Any) -> vs.V

return self.shift_interpolate(clip, interpolated, double_y, **kwargs)

def _handle_sclip(
self, clip: vs.VideoNode, double_y: bool, aa_kwargs: dict[str, Any], **kwargs: Any
) -> dict[str, Any]:
if not self._sclip_aa or ('sclip' in kwargs and kwargs['sclip']):
return aa_kwargs

if self._sclip_aa is True or isinstance(self._sclip_aa, vs.VideoNode):
if double_y:
if self._sclip_aa is True:
raise CustomValueError("You can't pass sclip_aa=True when supersampling!", self.__class__)

if (clip.width, clip.height) != (self._sclip_aa.width, self._sclip_aa.height):
raise CustomValueError(
f'The dimensions of sclip_aa ({self._sclip_aa.width}x{self._sclip_aa.height}) '
f'don\'t match the expected dimensions ({clip.width}x{clip.height})!',
self.__class__
)

aa_kwargs.update(dict(sclip=clip))

return aa_kwargs

sclip_args = self._sclip_aa.get_aa_args(clip, mclip=kwargs.get('mclip'))
sclip_args.update(self._sclip_aa.get_ss_args(clip) if double_y else self._sclip_aa.get_sr_args(clip))

aa_kwargs.update(dict(sclip=self._sclip_aa.interpolate(clip, double_y or not self.drop_fields, **sclip_args)))

return aa_kwargs

_shift = 0.5

@inject_self.property
Expand Down

0 comments on commit 3f9b79c

Please sign in to comment.