Skip to content

Commit

Permalink
Expand yuvxyb-math API (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
FreezyLemon authored Nov 24, 2024
1 parent b156513 commit 9f3a6e6
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 86 deletions.
5 changes: 4 additions & 1 deletion src/yuv_rgb/color.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use av_data::pixel::{ColorPrimaries, MatrixCoefficients};
use yuvxyb_math::{ColVector, Matrix, RowVector};

use super::{ycbcr_to_ypbpr, ypbpr_to_ycbcr};
use crate::{ConversionError, Pixel, Yuv, YuvConfig};

type ColVector = yuvxyb_math::ColVector<f32>;
type Matrix = yuvxyb_math::Matrix<f32>;
type RowVector = yuvxyb_math::RowVector<f32>;

pub fn get_yuv_to_rgb_matrix(config: YuvConfig) -> Result<Matrix, ConversionError> {
get_rgb_to_yuv_matrix(config).map(|m| m.invert())
}
Expand Down
205 changes: 121 additions & 84 deletions yuvxyb-math/src/matrix.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,60 @@
use crate::multiply_add;
// These types are intended for floats, which are not `Eq`
#![allow(clippy::derive_partial_eq_without_eq)]

#[derive(Debug, Clone)]
use std::ops::{Div, Mul, Neg};

use crate::mul_add::FastMulAdd;

#[derive(Debug, Clone, PartialEq)]
#[must_use]
pub struct RowVector(f32, f32, f32);
pub struct RowVector<T>(T, T, T);

impl RowVector {
pub const fn new(x: f32, y: f32, z: f32) -> Self {
impl<T: Copy> RowVector<T> {
pub const fn new(x: T, y: T, z: T) -> Self {
Self(x, y, z)
}

#[must_use]
pub const fn x(&self) -> f32 {
pub const fn x(&self) -> T {
self.0
}
#[must_use]
pub const fn y(&self) -> f32 {
pub const fn y(&self) -> T {
self.1
}
#[must_use]
pub const fn z(&self) -> f32 {
pub const fn z(&self) -> T {
self.2
}

pub const fn values(self) -> [T; 3] {
let Self(x, y, z) = self;
[x, y, z]
}
}

impl<T> RowVector<T>
where
T: Copy + FastMulAdd + Mul<T, Output = T> + Div<T, Output = T> + Neg<Output = T>,
{
pub fn cross(&self, other: &Self) -> Self {
let Self(sx, sy, sz) = *self;
let Self(ox, oy, oz) = *other;

Self::new(
multiply_add(sy, oz, -(sz * oy)),
multiply_add(sz, ox, -(sx * oz)),
multiply_add(sx, oy, -(sy * ox)),
sy.fast_mul_add(oz, -(sz * oy)),
sz.fast_mul_add(ox, -(sx * oz)),
sx.fast_mul_add(oy, -(sy * ox)),
)
}

#[must_use]
pub fn dot(&self, other: &Self) -> f32 {
multiply_add(
self.0,
other.0,
multiply_add(self.1, other.1, self.2 * other.2),
)
pub fn dot(&self, other: &Self) -> T {
self.0
.fast_mul_add(other.0, self.1.fast_mul_add(other.1, self.2 * other.2))
}

pub fn scalar_div(&self, x: f32) -> Self {
pub fn scalar_div(&self, x: T) -> Self {
Self(self.0 / x, self.1 / x, self.2 / x)
}

Expand All @@ -51,94 +63,118 @@ impl RowVector {
}
}

impl From<[f32; 3]> for RowVector {
fn from(value: [f32; 3]) -> Self {
Self::new(value[0], value[1], value[2])
impl<T: Copy> From<[T; 3]> for RowVector<T> {
fn from(value: [T; 3]) -> Self {
let [x, y, z] = value;
Self::new(x, y, z)
}
}

#[derive(Debug, Clone, PartialEq)]
#[must_use]
pub struct ColVector(f32, f32, f32);
pub struct ColVector<T>(T, T, T);

impl ColVector {
pub const fn new(r: f32, g: f32, b: f32) -> Self {
impl<T: Copy> ColVector<T> {
pub const fn new(r: T, g: T, b: T) -> Self {
Self(r, g, b)
}

#[must_use]
pub const fn r(&self) -> f32 {
pub const fn r(&self) -> T {
self.0
}
#[must_use]
pub const fn g(&self) -> f32 {
pub const fn g(&self) -> T {
self.1
}
#[must_use]
pub const fn b(&self) -> f32 {
pub const fn b(&self) -> T {
self.2
}

pub const fn transpose(self) -> RowVector {
pub const fn transpose(self) -> RowVector<T> {
RowVector::new(self.0, self.1, self.2)
}

pub const fn values(self) -> [T; 3] {
let Self(r, g, b) = self;
[r, g, b]
}
}

impl From<[f32; 3]> for ColVector {
fn from(value: [f32; 3]) -> Self {
Self::new(value[0], value[1], value[2])
impl<T: Copy> From<[T; 3]> for ColVector<T> {
fn from(value: [T; 3]) -> Self {
let [r, g, b] = value;
Self::new(r, g, b)
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
#[must_use]
pub struct Matrix(RowVector, RowVector, RowVector);
pub struct Matrix<T>(RowVector<T>, RowVector<T>, RowVector<T>);

impl Matrix {
pub const fn new(r1: RowVector, r2: RowVector, r3: RowVector) -> Self {
impl<T: Copy> Matrix<T> {
pub const fn new(r1: RowVector<T>, r2: RowVector<T>, r3: RowVector<T>) -> Self {
Self(r1, r2, r3)
}

pub const fn r1(&self) -> &RowVector {
pub const fn r1(&self) -> &RowVector<T> {
&self.0
}
pub const fn r2(&self) -> &RowVector {
pub const fn r2(&self) -> &RowVector<T> {
&self.1
}
pub const fn r3(&self) -> &RowVector {
pub const fn r3(&self) -> &RowVector<T> {
&self.2
}

pub const fn transpose(self) -> Self {
let Self(r1, r2, r3) = self;

let RowVector(s11, s12, s13) = r1;
let RowVector(s21, s22, s23) = r2;
let RowVector(s31, s32, s33) = r3;

Self::new(
RowVector::new(s11, s21, s31),
RowVector::new(s12, s22, s32),
RowVector::new(s13, s23, s33),
)
}
}

impl Matrix<f32> {
pub const fn identity() -> Self {
Self::new(
RowVector::new(1.0, 0.0, 0.0),
RowVector::new(0.0, 1.0, 0.0),
RowVector::new(0.0, 0.0, 1.0),
)
}
}

impl Matrix<f64> {
pub const fn identity() -> Self {
Self::new(
RowVector::new(1.0, 0.0, 0.0),
RowVector::new(0.0, 1.0, 0.0),
RowVector::new(0.0, 0.0, 1.0),
)
}
}

pub fn scalar_div(&self, x: f32) -> Self {
impl<T> Matrix<T>
where
T: Copy + FastMulAdd + Mul<T, Output = T> + Div<T, Output = T> + Neg<Output = T>,
{
pub fn scalar_div(&self, x: T) -> Self {
Self(
self.0.scalar_div(x),
self.1.scalar_div(x),
self.2.scalar_div(x),
)
}

pub const fn transpose(self) -> Self {
let Self(r1, r2, r3) = self;

let RowVector(s11, s12, s13) = r1;
let RowVector(s21, s22, s23) = r2;
let RowVector(s31, s32, s33) = r3;

Self::new(
RowVector::new(s11, s21, s31),
RowVector::new(s12, s22, s32),
RowVector::new(s13, s23, s33),
)
}

/// Will panic if the matrix is not invertible
pub fn invert(&self) -> Self {
// Cramer's rule
Expand All @@ -148,23 +184,20 @@ impl Matrix {
let RowVector(s21, s22, s23) = *r2;
let RowVector(s31, s32, s33) = *r3;

let minor_11 = multiply_add(s22, s33, -(s32 * s23));
let minor_12 = multiply_add(s21, s33, -(s31 * s23));
let minor_13 = multiply_add(s21, s32, -(s31 * s22));
let minor_11 = s22.fast_mul_add(s33, -(s32 * s23));
let minor_12 = s21.fast_mul_add(s33, -(s31 * s23));
let minor_13 = s21.fast_mul_add(s32, -(s31 * s22));

let minor_21 = multiply_add(s12, s33, -(s32 * s13));
let minor_22 = multiply_add(s11, s33, -(s31 * s13));
let minor_23 = multiply_add(s11, s32, -(s31 * s12));
let minor_21 = s12.fast_mul_add(s33, -(s32 * s13));
let minor_22 = s11.fast_mul_add(s33, -(s31 * s13));
let minor_23 = s11.fast_mul_add(s32, -(s31 * s12));

let minor_31 = multiply_add(s12, s23, -(s22 * s13));
let minor_32 = multiply_add(s11, s23, -(s21 * s13));
let minor_33 = multiply_add(s11, s22, -(s21 * s12));
let minor_31 = s12.fast_mul_add(s23, -(s22 * s13));
let minor_32 = s11.fast_mul_add(s23, -(s21 * s13));
let minor_33 = s11.fast_mul_add(s22, -(s21 * s12));

let determinant = multiply_add(
s11,
minor_11,
-multiply_add(s12, minor_12, -(s13 * minor_13)),
);
let determinant =
s11.fast_mul_add(minor_11, -s12.fast_mul_add(minor_12, -(s13 * minor_13)));

Self::new(
RowVector::new(minor_11, -minor_12, minor_13),
Expand All @@ -175,13 +208,13 @@ impl Matrix {
.scalar_div(determinant)
}

pub fn mul_vec(&self, rhs: &ColVector) -> ColVector {
pub fn mul_vec(&self, rhs: &ColVector<T>) -> ColVector<T> {
let Self(ref r1, ref r2, ref r3) = *self;

ColVector::new(
multiply_add(r1.0, rhs.0, multiply_add(r1.1, rhs.1, r1.2 * rhs.2)),
multiply_add(r2.0, rhs.0, multiply_add(r2.1, rhs.1, r2.2 * rhs.2)),
multiply_add(r3.0, rhs.0, multiply_add(r3.1, rhs.1, r3.2 * rhs.2)),
r1.0.fast_mul_add(rhs.0, r1.1.fast_mul_add(rhs.1, r1.2 * rhs.2)),
r2.0.fast_mul_add(rhs.0, r2.1.fast_mul_add(rhs.1, r2.2 * rhs.2)),
r3.0.fast_mul_add(rhs.0, r3.1.fast_mul_add(rhs.1, r3.2 * rhs.2)),
)
}

Expand All @@ -191,31 +224,35 @@ impl Matrix {

Self::new(
RowVector::new(
multiply_add(r1.0, o1.0, multiply_add(r1.1, o2.0, r1.2 * o3.0)),
multiply_add(r1.0, o1.1, multiply_add(r1.1, o2.1, r1.2 * o3.1)),
multiply_add(r1.0, o1.2, multiply_add(r1.1, o2.2, r1.2 * o3.2)),
r1.0.fast_mul_add(o1.0, r1.1.fast_mul_add(o2.0, r1.2 * o3.0)),
r1.0.fast_mul_add(o1.1, r1.1.fast_mul_add(o2.1, r1.2 * o3.1)),
r1.0.fast_mul_add(o1.2, r1.1.fast_mul_add(o2.2, r1.2 * o3.2)),
),
RowVector::new(
multiply_add(r2.0, o1.0, multiply_add(r2.1, o2.0, r2.2 * o3.0)),
multiply_add(r2.0, o1.1, multiply_add(r2.1, o2.1, r2.2 * o3.1)),
multiply_add(r2.0, o1.2, multiply_add(r2.1, o2.2, r2.2 * o3.2)),
r2.0.fast_mul_add(o1.0, r2.1.fast_mul_add(o2.0, r2.2 * o3.0)),
r2.0.fast_mul_add(o1.1, r2.1.fast_mul_add(o2.1, r2.2 * o3.1)),
r2.0.fast_mul_add(o1.2, r2.1.fast_mul_add(o2.2, r2.2 * o3.2)),
),
RowVector::new(
multiply_add(r3.0, o1.0, multiply_add(r3.1, o2.0, r3.2 * o3.0)),
multiply_add(r3.0, o1.1, multiply_add(r3.1, o2.1, r3.2 * o3.1)),
multiply_add(r3.0, o1.2, multiply_add(r3.1, o2.2, r3.2 * o3.2)),
r3.0.fast_mul_add(o1.0, r3.1.fast_mul_add(o2.0, r3.2 * o3.0)),
r3.0.fast_mul_add(o1.1, r3.1.fast_mul_add(o2.1, r3.2 * o3.1)),
r3.0.fast_mul_add(o1.2, r3.1.fast_mul_add(o2.2, r3.2 * o3.2)),
),
)
}

#[must_use]
pub fn mul_arr(&self, rhs: [f32; 3]) -> [f32; 3] {
pub fn mul_arr(&self, rhs: [T; 3]) -> [T; 3] {
let Self(ref r1, ref r2, ref r3) = *self;

[
multiply_add(r1.0, rhs[0], multiply_add(r1.1, rhs[1], r1.2 * rhs[2])),
multiply_add(r2.0, rhs[0], multiply_add(r2.1, rhs[1], r2.2 * rhs[2])),
multiply_add(r3.0, rhs[0], multiply_add(r3.1, rhs[1], r3.2 * rhs[2])),
r1.0.fast_mul_add(rhs[0], r1.1.fast_mul_add(rhs[1], r1.2 * rhs[2])),
r2.0.fast_mul_add(rhs[0], r2.1.fast_mul_add(rhs[1], r2.2 * rhs[2])),
r3.0.fast_mul_add(rhs[0], r3.1.fast_mul_add(rhs[1], r3.2 * rhs[2])),
]
}

pub const fn values(self) -> [[T; 3]; 3] {
[self.0.values(), self.1.values(), self.2.values()]
}
}
35 changes: 34 additions & 1 deletion yuvxyb-math/src/mul_add.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// We want the "suboptimal" (less accurate) case here because it is faster.
#![allow(clippy::suboptimal_flops)]

use std::ops::{Add, Mul};

/// Computes (a * b) + c, leveraging FMA if available
#[inline]
#[allow(clippy::suboptimal_flops)]
#[must_use]
pub fn multiply_add(a: f32, b: f32, c: f32) -> f32 {
if cfg!(target_feature = "fma") {
Expand All @@ -9,3 +13,32 @@ pub fn multiply_add(a: f32, b: f32, c: f32) -> f32 {
a * b + c
}
}

pub trait FastMulAdd: Sized + Mul<Self, Output = Self> + Add<Self, Output = Self> {
/// Computes (self * a) + b, leveraging FMA if available.
///
/// If FMA is not available, the implementation should prefer computation
/// speed over accuracy (i.e. compute (self * a) + b without the benefits
/// of just one rounding error).
fn fast_mul_add(self, a: Self, b: Self) -> Self;
}

impl FastMulAdd for f32 {
fn fast_mul_add(self, a: Self, b: Self) -> Self {
if cfg!(target_feature = "fma") {
self.mul_add(a, b)
} else {
self * a + b
}
}
}

impl FastMulAdd for f64 {
fn fast_mul_add(self, a: Self, b: Self) -> Self {
if cfg!(target_feature = "fma") {
self.mul_add(a, b)
} else {
self * a + b
}
}
}

0 comments on commit 9f3a6e6

Please sign in to comment.