diff --git a/candle-core/tests/matmul_tests.rs b/candle-core/tests/matmul_tests.rs index e3e181072b..c1c16401a8 100644 --- a/candle-core/tests/matmul_tests.rs +++ b/candle-core/tests/matmul_tests.rs @@ -49,6 +49,20 @@ fn matmul(device: &Device) -> Result<()> { Ok(()) } +fn matmul_bf16(device: &Device) -> Result<()> { + if !device.supports_bf16() { + return Ok(()); + } + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?.to_dtype(DType::BF16)?; + + let c = a.matmul(&b)?.to_dtype(DType::F32)?; + assert_eq!(c.to_vec2::()?, &[[7.0f32, 10.0], [15.0, 22.0]]); + Ok(()) +} + fn broadcast_matmul(device: &Device) -> Result<()> { let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?; let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?; @@ -96,6 +110,12 @@ fn mm_layout(device: &Device) -> Result<()> { } test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal); +test_device!( + matmul_bf16, + matmul_bf16_cpu, + matmul_bf16_gpu, + matmul_bf16_metal +); test_device!( broadcast_matmul, broadcast_matmul_cpu,