diff --git a/vsaa/funcs.py b/vsaa/funcs.py index 1e29f49..a411709 100644 --- a/vsaa/funcs.py +++ b/vsaa/funcs.py @@ -7,7 +7,7 @@ from vsexprtools import complexpr_available, norm_expr from vskernels import Bilinear, Box, Catrom, NoScale, Scaler, ScalerT from vsmasktools import EdgeDetect, EdgeDetectT, Prewitt, ScharrTCanny -from vsrgtools import MeanMode, RepairMode, box_blur, contrasharpening_median, repair, unsharp_masked +from vsrgtools import MeanMode, RepairMode, bilateral, box_blur, contrasharpening_median, repair, unsharp_masked from vstools import ( MISSING, CustomRuntimeError, CustomValueError, FormatsMismatchError, FunctionUtil, KwargsT, MissingT, PlanesT, VSFunction, get_h, get_peak_value, get_w, get_y, join, normalize_planes, plane, scale_8bit, scale_value, split, vs @@ -308,7 +308,7 @@ def based_aa( downscaler: ScalerT | None = None, supersampler: ScalerT | ShaderFile | Path | Literal[False] = ArtCNN.C16F64, eedi3_kwargs: KwargsT | None = dict(alpha=0.125, beta=0.25, vthresh0=12, vthresh1=24, field=1), - prefilter: vs.VideoNode | VSFunction | None = None, postfilter: VSFunction | None = None, + prefilter: vs.VideoNode | VSFunction | None = None, postfilter: VSFunction | None | Literal[False] = None, show_mask: bool = False, planes: PlanesT = 0, **kwargs: Any ) -> vs.VideoNode: @@ -346,8 +346,13 @@ def based_aa( introducing too much ringing. Default: ArtCNN.C16F64. :param eedi3_kwargs: Keyword arguments to pass on to EEDI3. - :param prefilter: Prefilter to apply before anti-aliasing. Default: None. - :param postfilter: Postfilter to apply after anti-aliasing. Default: None. + :param prefilter: Prefilter to apply before anti-aliasing. + Must be a VideoNode, a function that takes a VideoNode and returns a VideoNode, + or None. Default: None. + :param postfilter: Postfilter to apply after anti-aliasing. + Must be a function that takes a VideoNode and returns a VideoNode, or None. + If None, applies a median-filtered bilateral smoother to clean halos + created during antialiasing. Default: None. :param show_mask: If True, returns the edge detection mask instead of the processed clip. Default: False :param planes: Planes to process. Default: Luma only. @@ -355,7 +360,7 @@ def based_aa( :return: Anti-aliased clip or edge detection mask if show_mask is True. :raises CustomRuntimeError: If required packages are missing. - :raises CustomValueError: If rfactor is not above 0.0. + :raises CustomValueError: If rfactor is not above 0.0, or invalid prefilter/postfilter is passed. """ ... @@ -366,7 +371,7 @@ def based_aa( downscaler: ScalerT | None = None, supersampler: ScalerT | ShaderFile | Path | Literal[False] | MissingT = MISSING, eedi3_kwargs: KwargsT | None = dict(alpha=0.125, beta=0.25, vthresh0=12, vthresh1=24, field=1), - prefilter: vs.VideoNode | VSFunction | None = None, postfilter: VSFunction | None = None, + prefilter: vs.VideoNode | VSFunction | None = None, postfilter: VSFunction | None | Literal[False] = None, show_mask: bool = False, planes: PlanesT = 0, **kwargs: Any ) -> vs.VideoNode: @@ -395,14 +400,8 @@ def based_aa( supersampler = PlaceboShader(supersampler) - if prefilter: - if isinstance(prefilter, vs.VideoNode): - FormatsMismatchError.check(based_aa, func.work_clip, prefilter) - ss_clip = prefilter - else: - ss_clip = prefilter(func.work_clip) - else: - ss_clip = func.work_clip + if rfactor <= 0.0: + raise CustomValueError('rfactor must be greater than 0!', based_aa, rfactor) aaw, aah = [(round(r * rfactor) + 1) & ~1 for r in (func.work_clip.width, func.work_clip.height)] @@ -412,9 +411,6 @@ def based_aa( and max(aah, func.work_clip.height) % min(aah, func.work_clip.height) == 0 ) else Catrom - if rfactor <= 0.0: - raise CustomValueError('rfactor must be greater than 0!', based_aa, rfactor) - if rfactor < 1.0: downscaler, supersampler = supersampler, downscaler @@ -427,6 +423,17 @@ def based_aa( if show_mask: return mask + if prefilter: + if isinstance(prefilter, vs.VideoNode): + FormatsMismatchError.check(based_aa, func.work_clip, prefilter) + ss_clip = prefilter + elif callable(prefilter): + ss_clip = prefilter(func.work_clip) + else: + raise CustomValueError('Invalid prefilter!', based_aa, prefilter) + else: + ss_clip = func.work_clip + supersampler = Scaler.ensure_obj(supersampler, based_aa) downscaler = Scaler.ensure_obj(downscaler, based_aa) @@ -436,7 +443,14 @@ def based_aa( aa = Eedi3(mclip=mclip, sclip_aa=True).aa(ss, **eedi3_kwargs | kwargs) aa = downscaler.scale(aa, func.work_clip.width, func.work_clip.height) - aa_out = postfilter(aa) if postfilter else aa + if postfilter is None: + aa_out = MeanMode.MEDIAN(aa, func.work_clip, bilateral(aa)) + elif callable(postfilter): + aa_out = postfilter(aa) + elif postfilter is False: + aa_out = aa + else: + raise CustomValueError('Invalid postfilter!', based_aa, postfilter) if pskip: no_aa = downscaler.scale(ss, func.work_clip.width, func.work_clip.height)