diff --git a/src/yuv_rgb/color.rs b/src/yuv_rgb/color.rs index 0fba4c7..60fd53e 100644 --- a/src/yuv_rgb/color.rs +++ b/src/yuv_rgb/color.rs @@ -1,9 +1,12 @@ use av_data::pixel::{ColorPrimaries, MatrixCoefficients}; -use yuvxyb_math::{ColVector, Matrix, RowVector}; use super::{ycbcr_to_ypbpr, ypbpr_to_ycbcr}; use crate::{ConversionError, Pixel, Yuv, YuvConfig}; +type ColVector = yuvxyb_math::ColVector; +type Matrix = yuvxyb_math::Matrix; +type RowVector = yuvxyb_math::RowVector; + pub fn get_yuv_to_rgb_matrix(config: YuvConfig) -> Result { get_rgb_to_yuv_matrix(config).map(|m| m.invert()) } diff --git a/yuvxyb-math/src/matrix.rs b/yuvxyb-math/src/matrix.rs index c106a8f..59f764e 100644 --- a/yuvxyb-math/src/matrix.rs +++ b/yuvxyb-math/src/matrix.rs @@ -1,48 +1,60 @@ -use crate::multiply_add; +// These types are intended for floats, which are not `Eq` +#![allow(clippy::derive_partial_eq_without_eq)] -#[derive(Debug, Clone)] +use std::ops::{Div, Mul, Neg}; + +use crate::mul_add::FastMulAdd; + +#[derive(Debug, Clone, PartialEq)] #[must_use] -pub struct RowVector(f32, f32, f32); +pub struct RowVector(T, T, T); -impl RowVector { - pub const fn new(x: f32, y: f32, z: f32) -> Self { +impl RowVector { + pub const fn new(x: T, y: T, z: T) -> Self { Self(x, y, z) } #[must_use] - pub const fn x(&self) -> f32 { + pub const fn x(&self) -> T { self.0 } #[must_use] - pub const fn y(&self) -> f32 { + pub const fn y(&self) -> T { self.1 } #[must_use] - pub const fn z(&self) -> f32 { + pub const fn z(&self) -> T { self.2 } + pub const fn values(self) -> [T; 3] { + let Self(x, y, z) = self; + [x, y, z] + } +} + +impl RowVector +where + T: Copy + FastMulAdd + Mul + Div + Neg, +{ pub fn cross(&self, other: &Self) -> Self { let Self(sx, sy, sz) = *self; let Self(ox, oy, oz) = *other; Self::new( - multiply_add(sy, oz, -(sz * oy)), - multiply_add(sz, ox, -(sx * oz)), - multiply_add(sx, oy, -(sy * ox)), + sy.fast_mul_add(oz, -(sz * oy)), + sz.fast_mul_add(ox, -(sx * oz)), + sx.fast_mul_add(oy, -(sy * ox)), ) } #[must_use] - pub fn dot(&self, other: &Self) -> f32 { - multiply_add( - self.0, - other.0, - multiply_add(self.1, other.1, self.2 * other.2), - ) + pub fn dot(&self, other: &Self) -> T { + self.0 + .fast_mul_add(other.0, self.1.fast_mul_add(other.1, self.2 * other.2)) } - pub fn scalar_div(&self, x: f32) -> Self { + pub fn scalar_div(&self, x: T) -> Self { Self(self.0 / x, self.1 / x, self.2 / x) } @@ -51,64 +63,97 @@ impl RowVector { } } -impl From<[f32; 3]> for RowVector { - fn from(value: [f32; 3]) -> Self { - Self::new(value[0], value[1], value[2]) +impl From<[T; 3]> for RowVector { + fn from(value: [T; 3]) -> Self { + let [x, y, z] = value; + Self::new(x, y, z) } } #[derive(Debug, Clone, PartialEq)] #[must_use] -pub struct ColVector(f32, f32, f32); +pub struct ColVector(T, T, T); -impl ColVector { - pub const fn new(r: f32, g: f32, b: f32) -> Self { +impl ColVector { + pub const fn new(r: T, g: T, b: T) -> Self { Self(r, g, b) } #[must_use] - pub const fn r(&self) -> f32 { + pub const fn r(&self) -> T { self.0 } #[must_use] - pub const fn g(&self) -> f32 { + pub const fn g(&self) -> T { self.1 } #[must_use] - pub const fn b(&self) -> f32 { + pub const fn b(&self) -> T { self.2 } - pub const fn transpose(self) -> RowVector { + pub const fn transpose(self) -> RowVector { RowVector::new(self.0, self.1, self.2) } + + pub const fn values(self) -> [T; 3] { + let Self(r, g, b) = self; + [r, g, b] + } } -impl From<[f32; 3]> for ColVector { - fn from(value: [f32; 3]) -> Self { - Self::new(value[0], value[1], value[2]) +impl From<[T; 3]> for ColVector { + fn from(value: [T; 3]) -> Self { + let [r, g, b] = value; + Self::new(r, g, b) } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] #[must_use] -pub struct Matrix(RowVector, RowVector, RowVector); +pub struct Matrix(RowVector, RowVector, RowVector); -impl Matrix { - pub const fn new(r1: RowVector, r2: RowVector, r3: RowVector) -> Self { +impl Matrix { + pub const fn new(r1: RowVector, r2: RowVector, r3: RowVector) -> Self { Self(r1, r2, r3) } - pub const fn r1(&self) -> &RowVector { + pub const fn r1(&self) -> &RowVector { &self.0 } - pub const fn r2(&self) -> &RowVector { + pub const fn r2(&self) -> &RowVector { &self.1 } - pub const fn r3(&self) -> &RowVector { + pub const fn r3(&self) -> &RowVector { &self.2 } + pub const fn transpose(self) -> Self { + let Self(r1, r2, r3) = self; + + let RowVector(s11, s12, s13) = r1; + let RowVector(s21, s22, s23) = r2; + let RowVector(s31, s32, s33) = r3; + + Self::new( + RowVector::new(s11, s21, s31), + RowVector::new(s12, s22, s32), + RowVector::new(s13, s23, s33), + ) + } +} + +impl Matrix { + pub const fn identity() -> Self { + Self::new( + RowVector::new(1.0, 0.0, 0.0), + RowVector::new(0.0, 1.0, 0.0), + RowVector::new(0.0, 0.0, 1.0), + ) + } +} + +impl Matrix { pub const fn identity() -> Self { Self::new( RowVector::new(1.0, 0.0, 0.0), @@ -116,8 +161,13 @@ impl Matrix { RowVector::new(0.0, 0.0, 1.0), ) } +} - pub fn scalar_div(&self, x: f32) -> Self { +impl Matrix +where + T: Copy + FastMulAdd + Mul + Div + Neg, +{ + pub fn scalar_div(&self, x: T) -> Self { Self( self.0.scalar_div(x), self.1.scalar_div(x), @@ -125,20 +175,6 @@ impl Matrix { ) } - pub const fn transpose(self) -> Self { - let Self(r1, r2, r3) = self; - - let RowVector(s11, s12, s13) = r1; - let RowVector(s21, s22, s23) = r2; - let RowVector(s31, s32, s33) = r3; - - Self::new( - RowVector::new(s11, s21, s31), - RowVector::new(s12, s22, s32), - RowVector::new(s13, s23, s33), - ) - } - /// Will panic if the matrix is not invertible pub fn invert(&self) -> Self { // Cramer's rule @@ -148,23 +184,20 @@ impl Matrix { let RowVector(s21, s22, s23) = *r2; let RowVector(s31, s32, s33) = *r3; - let minor_11 = multiply_add(s22, s33, -(s32 * s23)); - let minor_12 = multiply_add(s21, s33, -(s31 * s23)); - let minor_13 = multiply_add(s21, s32, -(s31 * s22)); + let minor_11 = s22.fast_mul_add(s33, -(s32 * s23)); + let minor_12 = s21.fast_mul_add(s33, -(s31 * s23)); + let minor_13 = s21.fast_mul_add(s32, -(s31 * s22)); - let minor_21 = multiply_add(s12, s33, -(s32 * s13)); - let minor_22 = multiply_add(s11, s33, -(s31 * s13)); - let minor_23 = multiply_add(s11, s32, -(s31 * s12)); + let minor_21 = s12.fast_mul_add(s33, -(s32 * s13)); + let minor_22 = s11.fast_mul_add(s33, -(s31 * s13)); + let minor_23 = s11.fast_mul_add(s32, -(s31 * s12)); - let minor_31 = multiply_add(s12, s23, -(s22 * s13)); - let minor_32 = multiply_add(s11, s23, -(s21 * s13)); - let minor_33 = multiply_add(s11, s22, -(s21 * s12)); + let minor_31 = s12.fast_mul_add(s23, -(s22 * s13)); + let minor_32 = s11.fast_mul_add(s23, -(s21 * s13)); + let minor_33 = s11.fast_mul_add(s22, -(s21 * s12)); - let determinant = multiply_add( - s11, - minor_11, - -multiply_add(s12, minor_12, -(s13 * minor_13)), - ); + let determinant = + s11.fast_mul_add(minor_11, -s12.fast_mul_add(minor_12, -(s13 * minor_13))); Self::new( RowVector::new(minor_11, -minor_12, minor_13), @@ -175,13 +208,13 @@ impl Matrix { .scalar_div(determinant) } - pub fn mul_vec(&self, rhs: &ColVector) -> ColVector { + pub fn mul_vec(&self, rhs: &ColVector) -> ColVector { let Self(ref r1, ref r2, ref r3) = *self; ColVector::new( - multiply_add(r1.0, rhs.0, multiply_add(r1.1, rhs.1, r1.2 * rhs.2)), - multiply_add(r2.0, rhs.0, multiply_add(r2.1, rhs.1, r2.2 * rhs.2)), - multiply_add(r3.0, rhs.0, multiply_add(r3.1, rhs.1, r3.2 * rhs.2)), + r1.0.fast_mul_add(rhs.0, r1.1.fast_mul_add(rhs.1, r1.2 * rhs.2)), + r2.0.fast_mul_add(rhs.0, r2.1.fast_mul_add(rhs.1, r2.2 * rhs.2)), + r3.0.fast_mul_add(rhs.0, r3.1.fast_mul_add(rhs.1, r3.2 * rhs.2)), ) } @@ -191,31 +224,35 @@ impl Matrix { Self::new( RowVector::new( - multiply_add(r1.0, o1.0, multiply_add(r1.1, o2.0, r1.2 * o3.0)), - multiply_add(r1.0, o1.1, multiply_add(r1.1, o2.1, r1.2 * o3.1)), - multiply_add(r1.0, o1.2, multiply_add(r1.1, o2.2, r1.2 * o3.2)), + r1.0.fast_mul_add(o1.0, r1.1.fast_mul_add(o2.0, r1.2 * o3.0)), + r1.0.fast_mul_add(o1.1, r1.1.fast_mul_add(o2.1, r1.2 * o3.1)), + r1.0.fast_mul_add(o1.2, r1.1.fast_mul_add(o2.2, r1.2 * o3.2)), ), RowVector::new( - multiply_add(r2.0, o1.0, multiply_add(r2.1, o2.0, r2.2 * o3.0)), - multiply_add(r2.0, o1.1, multiply_add(r2.1, o2.1, r2.2 * o3.1)), - multiply_add(r2.0, o1.2, multiply_add(r2.1, o2.2, r2.2 * o3.2)), + r2.0.fast_mul_add(o1.0, r2.1.fast_mul_add(o2.0, r2.2 * o3.0)), + r2.0.fast_mul_add(o1.1, r2.1.fast_mul_add(o2.1, r2.2 * o3.1)), + r2.0.fast_mul_add(o1.2, r2.1.fast_mul_add(o2.2, r2.2 * o3.2)), ), RowVector::new( - multiply_add(r3.0, o1.0, multiply_add(r3.1, o2.0, r3.2 * o3.0)), - multiply_add(r3.0, o1.1, multiply_add(r3.1, o2.1, r3.2 * o3.1)), - multiply_add(r3.0, o1.2, multiply_add(r3.1, o2.2, r3.2 * o3.2)), + r3.0.fast_mul_add(o1.0, r3.1.fast_mul_add(o2.0, r3.2 * o3.0)), + r3.0.fast_mul_add(o1.1, r3.1.fast_mul_add(o2.1, r3.2 * o3.1)), + r3.0.fast_mul_add(o1.2, r3.1.fast_mul_add(o2.2, r3.2 * o3.2)), ), ) } #[must_use] - pub fn mul_arr(&self, rhs: [f32; 3]) -> [f32; 3] { + pub fn mul_arr(&self, rhs: [T; 3]) -> [T; 3] { let Self(ref r1, ref r2, ref r3) = *self; [ - multiply_add(r1.0, rhs[0], multiply_add(r1.1, rhs[1], r1.2 * rhs[2])), - multiply_add(r2.0, rhs[0], multiply_add(r2.1, rhs[1], r2.2 * rhs[2])), - multiply_add(r3.0, rhs[0], multiply_add(r3.1, rhs[1], r3.2 * rhs[2])), + r1.0.fast_mul_add(rhs[0], r1.1.fast_mul_add(rhs[1], r1.2 * rhs[2])), + r2.0.fast_mul_add(rhs[0], r2.1.fast_mul_add(rhs[1], r2.2 * rhs[2])), + r3.0.fast_mul_add(rhs[0], r3.1.fast_mul_add(rhs[1], r3.2 * rhs[2])), ] } + + pub const fn values(self) -> [[T; 3]; 3] { + [self.0.values(), self.1.values(), self.2.values()] + } } diff --git a/yuvxyb-math/src/mul_add.rs b/yuvxyb-math/src/mul_add.rs index 1543ba2..5b08f66 100644 --- a/yuvxyb-math/src/mul_add.rs +++ b/yuvxyb-math/src/mul_add.rs @@ -1,6 +1,10 @@ +// We want the "suboptimal" (less accurate) case here because it is faster. +#![allow(clippy::suboptimal_flops)] + +use std::ops::{Add, Mul}; + /// Computes (a * b) + c, leveraging FMA if available #[inline] -#[allow(clippy::suboptimal_flops)] #[must_use] pub fn multiply_add(a: f32, b: f32, c: f32) -> f32 { if cfg!(target_feature = "fma") { @@ -9,3 +13,32 @@ pub fn multiply_add(a: f32, b: f32, c: f32) -> f32 { a * b + c } } + +pub trait FastMulAdd: Sized + Mul + Add { + /// Computes (self * a) + b, leveraging FMA if available. + /// + /// If FMA is not available, the implementation should prefer computation + /// speed over accuracy (i.e. compute (self * a) + b without the benefits + /// of just one rounding error). + fn fast_mul_add(self, a: Self, b: Self) -> Self; +} + +impl FastMulAdd for f32 { + fn fast_mul_add(self, a: Self, b: Self) -> Self { + if cfg!(target_feature = "fma") { + self.mul_add(a, b) + } else { + self * a + b + } + } +} + +impl FastMulAdd for f64 { + fn fast_mul_add(self, a: Self, b: Self) -> Self { + if cfg!(target_feature = "fma") { + self.mul_add(a, b) + } else { + self * a + b + } + } +}