From 4d5df69ca06851c1922dac5ede06f5eed20c2a18 Mon Sep 17 00:00:00 2001 From: nilfm Date: Tue, 26 Mar 2024 15:23:15 -0400 Subject: [PATCH] cambi.c: validate input value ranges --- libvmaf/src/feature/cambi.c | 55 +++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/libvmaf/src/feature/cambi.c b/libvmaf/src/feature/cambi.c index 094e24c32..d7512d39c 100644 --- a/libvmaf/src/feature/cambi.c +++ b/libvmaf/src/feature/cambi.c @@ -608,7 +608,62 @@ static void anti_dithering_filter(VmafPicture *pic, unsigned width, unsigned hei } } +static int validate_image_lbd(const VmafPicture *pic) { + int bpc = pic->bpc; + if (bpc == 8) return 0; + uint8_t max_val = (1 << bpc) - 1; + int channel = 0; + uint8_t *data = (uint8_t *)pic->data[channel]; + size_t stride = pic->stride[channel]; + for (unsigned i = 0; i < pic->h[channel]; i++) { + for (unsigned j = 0; j < pic->w[channel]; j++) { + if (data[i * stride + j] > max_val) { + vmaf_log( + VMAF_LOG_LEVEL_ERROR, + "Invalid input. The input contains values greater than %d, which exceeds the maximum value for a %d-bit depth format.", + max_val, bpc + ); + return -EINVAL; + } + } + } + return 0; +} + +static int validate_image_hbd(const VmafPicture *pic) { + int bpc = pic->bpc; + if (bpc == 16) return 0; + uint16_t max_val = (1 << bpc) - 1; + int channel = 0; + uint16_t *data = (uint16_t *)pic->data[channel]; + size_t stride = pic->stride[channel] / 2; + for (unsigned i = 0; i < pic->h[channel]; i++) { + for (unsigned j = 0; j < pic->w[channel]; j++) { + if (data[i * stride + j] > max_val) { + vmaf_log( + VMAF_LOG_LEVEL_ERROR, + "Invalid input. The input contains values greater than %d, which exceeds the maximum value for a %d-bit depth format.", + max_val, bpc + ); + return -EINVAL; + } + } + } + return 0; +} + +static int validate_image(const VmafPicture *pic) { + if (pic->bpc <= 8) { + return validate_image_lbd(pic); + } else { + return validate_image_hbd(pic); + } +} + static int cambi_preprocessing(const VmafPicture *image, VmafPicture *preprocessed, int width, int height, int enc_bitdepth) { + if (validate_image(image)) { + return -EINVAL; + } if (image->bpc >= 10) { decimate_generic_uint16_and_convert_to_10b(image, preprocessed, width, height); }