diff --git a/Cargo.lock b/Cargo.lock index 74910d9290..68107eda47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1444,11 +1444,14 @@ dependencies = [ [[package]] name = "half" version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +source = "git+https://github.com/FL33TW00D/half-rs.git?branch=feature/arbitrary#6bc4bea632269b53ccb6666a8508edc25fba9f3e" dependencies = [ + "arbitrary", + "bytemuck", "cfg-if", "crunchy", + "num-traits", + "serde", ] [[package]] @@ -1697,6 +1700,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + [[package]] name = "libredox" version = "0.0.2" @@ -1879,17 +1888,20 @@ dependencies = [ "codespan-reporting", "diff", "env_logger", + "half", "hexf-parse", "hlsl-snapshots", "indexmap", "itertools", "log", + "num-traits", "petgraph", "pp-rs", "ron", "rspirv", "rustc-hash", "serde", + "smallvec", "spirv 0.3.0+sdk-1.3.268.0", "termcolor", "thiserror", @@ -2019,6 +2031,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -3654,6 +3667,7 @@ dependencies = [ "flume", "getrandom", "glam", + "half", "ktx2", "log", "nanorand", diff --git a/Cargo.toml b/Cargo.toml index 68c29b671b..8e820ad7ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -193,6 +193,7 @@ ndk-sys = "0.5.0" #gpu-alloc = { path = "../gpu-alloc/gpu-alloc" } [patch.crates-io] +half = { git = "https://github.com/FL33TW00D/half-rs.git", branch = "feature/arbitrary" } #glow = { path = "../glow" } #web-sys = { path = "../wasm-bindgen/crates/web-sys" } #js-sys = { path = "../wasm-bindgen/crates/js-sys" } diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 1f4d4951f5..01767ed81d 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -35,6 +35,7 @@ encase = { workspace = true, features = ["glam"] } flume.workspace = true getrandom.workspace = true glam.workspace = true +half = { version = "2.1.0", features = ["bytemuck"] } ktx2.workspace = true log.workspace = true nanorand.workspace = true diff --git a/examples/src/lib.rs b/examples/src/lib.rs index d212fd404a..5fe61feff6 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -17,6 +17,7 @@ pub mod mipmap; pub mod msaa_line; pub mod render_to_texture; pub mod repeated_compute; +pub mod shader_f16; pub mod shadow; pub mod skybox; pub mod srgb_blend; diff --git a/examples/src/main.rs b/examples/src/main.rs index 5d29d484b6..3cf5bdc267 100644 --- a/examples/src/main.rs +++ b/examples/src/main.rs @@ -146,6 +146,12 @@ const EXAMPLES: &[ExampleDesc] = &[ webgl: false, // No RODS webgpu: true, }, + ExampleDesc { + name: "shader-f16", + function: wgpu_examples::shader_f16::main, + webgl: false, // No RODS + webgpu: true, + }, ]; fn get_example_name() -> Option { diff --git a/examples/src/shader_f16/README.md b/examples/src/shader_f16/README.md new file mode 100644 index 0000000000..9cbcd9921c --- /dev/null +++ b/examples/src/shader_f16/README.md @@ -0,0 +1,9 @@ +# shader-f16 + +Demonstrate the ability to perform compute in F16 using wgpu. + +## To Run + +``` +RUST_LOG=hello_compute cargo run --bin wgpu-examples shader_f16 +``` diff --git a/examples/src/shader_f16/mod.rs b/examples/src/shader_f16/mod.rs new file mode 100644 index 0000000000..400158e008 --- /dev/null +++ b/examples/src/shader_f16/mod.rs @@ -0,0 +1,189 @@ +use half::f16; +use std::{borrow::Cow, str::FromStr}; +use wgpu::util::DeviceExt; + +#[cfg_attr(test, allow(dead_code))] +async fn run() { + let numbers = if std::env::args().len() <= 2 { + let default = vec![ + f16::from_f32(27.), + f16::from_f32(7.), + f16::from_f32(5.), + f16::from_f32(3.), + ]; + println!("No numbers were provided, defaulting to {default:?}"); + default + } else { + std::env::args() + .skip(2) + .map(|s| f16::from_str(&s).expect("You must pass a list of positive integers!")) + .collect() + }; + + let steps = execute_gpu(&numbers).await.unwrap(); + println!("Steps: [{:?}]", steps); + #[cfg(target_arch = "wasm32")] + log::info!("Steps: [{:?}]", steps); +} + +#[cfg_attr(test, allow(dead_code))] +async fn execute_gpu(numbers: &[f16]) -> Option> { + // Instantiates instance of WebGPU + let instance = wgpu::Instance::default(); + + // `request_adapter` instantiates the general connection to the GPU + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions::default()) + .await?; + + // `request_device` instantiates the feature specific connection to the GPU, defining some parameters, + // `features` being the available features. + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: None, + required_features: wgpu::Features::SHADER_F16, + required_limits: wgpu::Limits::downlevel_defaults(), + memory_hints: Default::default(), + }, + None, + ) + .await + .unwrap(); + + execute_gpu_inner(&device, &queue, numbers).await +} + +async fn execute_gpu_inner( + device: &wgpu::Device, + queue: &wgpu::Queue, + numbers: &[f16], +) -> Option> { + // Loads the shader from WGSL + let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))), + }); + + // Gets the size in bytes of the buffer. + let size = std::mem::size_of_val(numbers) as wgpu::BufferAddress; + + // Instantiates buffer without data. + // `usage` of buffer specifies how it can be used: + // `BufferUsages::MAP_READ` allows it to be read (outside the shader). + // `BufferUsages::COPY_DST` allows it to be the destination of the copy. + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + // Instantiates buffer with data (`numbers`). + // Usage allowing the buffer to be: + // A storage buffer (can be bound within a bind group and thus available to a shader). + // The destination of a copy. + // The source of a copy. + let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Storage Buffer"), + contents: bytemuck::cast_slice(numbers), + usage: wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::COPY_SRC, + }); + + // A bind group defines how buffers are accessed by shaders. + // It is to WebGPU what a descriptor set is to Vulkan. + // `binding` here refers to the `binding` of a buffer in the shader (`layout(set = 0, binding = 0) buffer`). + + // A pipeline specifies the operation of a shader + + // Instantiates the pipeline. + let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &cs_module, + entry_point: None, + compilation_options: Default::default(), + cache: None, + }); + + // Instantiates the bind group, once again specifying the binding of buffers. + let bind_group_layout = compute_pipeline.get_bind_group_layout(0); + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: storage_buffer.as_entire_binding(), + }], + }); + + // A command encoder executes one or many pipelines. + // It is to WebGPU what a command buffer is to Vulkan. + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&compute_pipeline); + cpass.set_bind_group(0, Some(&bind_group), &[]); + cpass.insert_debug_marker("compute collatz iterations"); + cpass.dispatch_workgroups(numbers.len() as u32, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed + } + // Sets adds copy operation to command encoder. + // Will copy data from storage buffer on GPU to staging buffer on CPU. + encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size); + + // Submits command encoder for processing + queue.submit(Some(encoder.finish())); + + // Note that we're not calling `.await` here. + let buffer_slice = staging_buffer.slice(..); + // Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished. + let (sender, receiver) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); + + // Poll the device in a blocking manner so that our future resolves. + // In an actual application, `device.poll(...)` should + // be called in an event loop or on another thread. + device.poll(wgpu::Maintain::wait()).panic_on_timeout(); + + // Awaits until `buffer_future` can be read from + if let Ok(Ok(())) = receiver.recv_async().await { + // Gets contents of buffer + let data = buffer_slice.get_mapped_range(); + // Since contents are got in bytes, this converts these bytes back to u32 + let result = bytemuck::cast_slice(&data).to_vec(); + + // With the current interface, we have to make sure all mapped views are + // dropped before we unmap the buffer. + drop(data); + staging_buffer.unmap(); // Unmaps buffer from memory + // If you are familiar with C++ these 2 lines can be thought of similarly to: + // delete myPointer; + // myPointer = NULL; + // It effectively frees the memory + + // Returns data from buffer + Some(result) + } else { + panic!("failed to run compute on gpu!") + } +} + +pub fn main() { + #[cfg(not(target_arch = "wasm32"))] + { + env_logger::init(); + pollster::block_on(run()); + } + #[cfg(target_arch = "wasm32")] + { + std::panic::set_hook(Box::new(console_error_panic_hook::hook)); + console_log::init().expect("could not initialize logger"); + wasm_bindgen_futures::spawn_local(run()); + } +} diff --git a/examples/src/shader_f16/shader.wgsl b/examples/src/shader_f16/shader.wgsl new file mode 100644 index 0000000000..85ccc8b79a --- /dev/null +++ b/examples/src/shader_f16/shader.wgsl @@ -0,0 +1,9 @@ +enable f16; + +@group(0) @binding(0) +var values: array>; // this is used as both values and output for convenience + +@compute @workgroup_size(1) +fn main(@builtin(global_invocation_id) global_id: vec3) { + values[global_id.x] = fma(values[0], values[0], values[0]); +} diff --git a/naga/Cargo.toml b/naga/Cargo.toml index 19912a36a8..46075e2558 100644 --- a/naga/Cargo.toml +++ b/naga/Cargo.toml @@ -41,8 +41,8 @@ msl-out = [] ## If you want to enable MSL output it regardless of the target platform, use `naga/msl-out`. msl-out-if-target-apple = [] -serialize = ["dep:serde", "bitflags/serde", "indexmap/serde"] -deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde"] +serialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"] +deserialize = ["dep:serde", "bitflags/serde", "indexmap/serde", "half/serde"] arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"] spv-in = ["dep:petgraph", "dep:spirv"] spv-out = ["dep:spirv"] @@ -82,6 +82,10 @@ petgraph = { version = "0.6", optional = true } pp-rs = { version = "0.2.1", optional = true } hexf-parse = { version = "0.2.1", optional = true } unicode-xid = { version = "0.2.6", optional = true } +# TODO: remove `[patch]` entry in workspace `Cargo.toml` for `half` after we upstream `arbitrary` support +half = { version = "2.4.1", features = ["arbitrary", "num-traits"] } +num-traits = "0.2" +smallvec = { workspace = true, features = ["const_new"] } [build-dependencies] cfg_aliases.workspace = true diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 2ce9f22f27..70894ce405 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2647,6 +2647,9 @@ impl<'a, W: Write> Writer<'a, W> { // decimal part even it's zero which is needed for a valid glsl float constant crate::Literal::F64(value) => write!(self.out, "{value:?}LF")?, crate::Literal::F32(value) => write!(self.out, "{value:?}")?, + crate::Literal::F16(_) => { + return Err(Error::Custom("GLSL has no 16-bit float type".into())); + } // Unsigned integers need a `u` at the end // // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 0eb18f0e16..5efdb76234 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2383,6 +2383,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // decimal part even it's zero crate::Literal::F64(value) => write!(self.out, "{value:?}L")?, crate::Literal::F32(value) => write!(self.out, "{value:?}")?, + crate::Literal::F16(value) => write!(self.out, "{value:?}h")?, crate::Literal::U32(value) => write!(self.out, "{value}u")?, crate::Literal::I32(value) => write!(self.out, "{value}")?, crate::Literal::U64(value) => write!(self.out, "{value}uL")?, diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 7ab97f491c..404ecc5863 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -6,6 +6,8 @@ use crate::{ proc::{self, NameKey, TypeResolution}, valid, FastHashMap, FastHashSet, }; +use half::f16; +use num_traits::real::Real; #[cfg(test)] use std::ptr; use std::{ @@ -389,8 +391,12 @@ impl crate::Scalar { match self { Self { kind: Sk::Float, - width: _, + width: 4, } => "float", + Self { + kind: Sk::Float, + width: 2, + } => "half", Self { kind: Sk::Sint, width: 4, @@ -1414,6 +1420,21 @@ impl Writer { crate::Literal::F64(_) => { return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) } + crate::Literal::F16(value) => { + if value.is_infinite() { + let sign = if value.is_sign_negative() { "-" } else { "" }; + write!(self.out, "{sign}INFINITY")?; + } else if value.is_nan() { + write!(self.out, "NAN")?; + } else { + let suffix = if value.fract() == f16::from_f32(0.0) { + ".0h" + } else { + "h" + }; + write!(self.out, "{value}{suffix}")?; + } + } crate::Literal::F32(value) => { if value.is_infinite() { let sign = if value.is_sign_negative() { "-" } else { "" }; diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 9bd58508a1..fea852b293 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -406,6 +406,10 @@ impl super::Instruction { instruction } + pub(super) fn constant_16bit(result_type_id: Word, id: Word, low: Word) -> Self { + Self::constant(result_type_id, id, &[low]) + } + pub(super) fn constant_32bit(result_type_id: Word, id: Word, value: Word) -> Self { Self::constant(result_type_id, id, &[value]) } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 14f1fc0027..759e7ebc4d 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -870,6 +870,15 @@ impl Writer { if bits == 64 { self.capabilities_used.insert(spirv::Capability::Float64); } + if bits == 16 { + self.capabilities_used.insert(spirv::Capability::Float16); + self.capabilities_used + .insert(spirv::Capability::StorageBuffer16BitAccess); + self.capabilities_used + .insert(spirv::Capability::UniformAndStorageBuffer16BitAccess); + self.capabilities_used + .insert(spirv::Capability::StorageInputOutput16); + } Instruction::type_float(id, bits) } Sk::Bool => Instruction::type_bool(id), @@ -1233,6 +1242,10 @@ impl Writer { Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32) } crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()), + crate::Literal::F16(value) => { + let low = value.to_bits(); + Instruction::constant_16bit(type_id, id, low as u32) + } crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value), crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32), crate::Literal::U64(value) => { diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index e8b942a62c..611d50d776 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1221,6 +1221,7 @@ impl Writer { match expressions[expr] { Expression::Literal(literal) => match literal { + crate::Literal::F16(value) => write!(self.out, "{value}h")?, crate::Literal::F32(value) => write!(self.out, "{value}f")?, crate::Literal::U32(value) => write!(self.out, "{value}u")?, crate::Literal::I32(value) => { @@ -1971,6 +1972,10 @@ const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str { kind: Sk::Float, width: 4, } => "f32", + Scalar { + kind: Sk::Float, + width: 2, + } => "f16", Scalar { kind: Sk::Sint, width: 4, diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 5ad063a6b6..6b20ae3238 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -36,6 +36,7 @@ mod null; use convert::*; pub use error::Error; use function::*; +use half::f16; use indexmap::IndexSet; use crate::{ @@ -5484,6 +5485,9 @@ impl> Frontend { }) => { let low = self.next()?; match width { + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Literal + // If a numeric type’s bit width is less than 32-bits, the value appears in the low-order bits of the word. + 2 => crate::Literal::F16(f16::from_bits(low as u16)), 4 => crate::Literal::F32(f32::from_bits(low)), 8 => { inst.expect(5)?; diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 7c65d93de3..5e5a96c0b8 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -1,7 +1,7 @@ use crate::front::wgsl::parse::lexer::Token; use crate::front::wgsl::Scalar; use crate::proc::{Alignment, ConstantEvaluatorError, ResolveError}; -use crate::{SourceLocation, Span}; +use crate::{Extension, SourceLocation, Span}; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFile; use codespan_reporting::term; @@ -135,8 +135,6 @@ pub enum NumberError { Invalid, #[error("numeric literal not representable by target type")] NotRepresentable, - #[error("unimplemented f16 type")] - UnimplementedF16, } #[derive(Copy, Clone, Debug, PartialEq)] @@ -183,6 +181,7 @@ pub(crate) enum Error<'a> { UnknownType(Span), UnknownStorageFormat(Span), UnknownConservativeDepth(Span), + UnknownExtension(Span, &'a str), SizeAttributeTooLow(Span, u32), AlignAttributeTooLow(Span, Alignment), NonPowerOfTwoAlignAttribute(Span), @@ -265,6 +264,7 @@ pub(crate) enum Error<'a> { PipelineConstantIDValue(Span), NotBool(Span), ConstAssertFailed(Span), + ExtensionNotEnabled(Span, Extension), } #[derive(Clone, Debug)] @@ -861,6 +861,16 @@ impl<'a> Error<'a> { labels: vec![(span, "evaluates to false".into())], notes: vec![], }, + Error::UnknownExtension(span, word) => ParseError { + message: format!("Unknown extension: {}. See available extensions at: https://www.w3.org/TR/WGSL/#enable-extension", word), + labels: vec![(span, "unknown extension".into())], + notes: vec![], + }, + Error::ExtensionNotEnabled(span, ref extension) => ParseError { + message: format!("Extension `{:?}` is not enabled", extension), + labels: vec![(span, "extension not enabled".into())], + notes: vec![], + }, } } } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 78e81350b4..0ffec0b83c 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1831,6 +1831,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expr: Typed = match *expr { ast::Expression::Literal(literal) => { let literal = match literal { + ast::Literal::Number(Number::F16(f)) => crate::Literal::F16(f), ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f), ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i), ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u), diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 4307ca3d9f..d30044a488 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -1,10 +1,11 @@ use crate::front::wgsl::parse::number::Number; use crate::front::wgsl::Scalar; -use crate::{Arena, FastIndexSet, Handle, Span}; +use crate::{Arena, Extension, FastIndexSet, Handle, Span}; use std::hash::Hash; #[derive(Debug, Default)] pub struct TranslationUnit<'a> { + pub directives: Arena, pub decls: Arena>, /// The common expressions arena for the entire translation unit. /// @@ -67,6 +68,12 @@ impl PartialEq for Dependency<'_> { impl Eq for Dependency<'_> {} +//A directive modifies how a WGSL program is processed by a WebGPU implementation. +#[derive(Debug)] +pub enum GlobalDirective { + Enable(EnableDirective), +} + /// A module-scope declaration. #[derive(Debug)] pub struct GlobalDecl<'a> { @@ -138,6 +145,11 @@ pub struct ResourceBinding<'a> { pub binding: Handle>, } +#[derive(Debug)] +pub struct EnableDirective { + pub enable_extension_list: Vec, +} + #[derive(Debug)] pub struct GlobalVariable<'a> { pub name: Ident<'a>, diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index 3ba71b07cc..3f04aad7f0 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -114,7 +114,10 @@ pub fn map_storage_format(word: &str, span: Span) -> Result Option { use crate::ScalarKind as Sk; match word { - // "f16" => Some(Scalar { kind: Sk::Float, width: 2 }), + "f16" => Some(Scalar { + kind: Sk::Float, + width: 2, + }), "f32" => Some(Scalar { kind: Sk::Float, width: 4, @@ -294,3 +297,10 @@ pub fn map_subgroup_operation( _ => return None, }) } + +pub fn map_extension(word: &str, span: Span) -> Result> { + match word { + "f16" => Ok(crate::Extension::F16), + _ => Err(Error::UnknownExtension(span, word)), + } +} diff --git a/naga/src/front/wgsl/parse/lexer.rs b/naga/src/front/wgsl/parse/lexer.rs index d03a448561..a97340c7c2 100644 --- a/naga/src/front/wgsl/parse/lexer.rs +++ b/naga/src/front/wgsl/parse/lexer.rs @@ -1,8 +1,10 @@ +use smallvec::SmallVec; + use super::{number::consume_number, Error, ExpectedToken}; use crate::front::wgsl::error::NumberError; use crate::front::wgsl::parse::{conv, Number}; use crate::front::wgsl::Scalar; -use crate::Span; +use crate::{Extension, Span}; type TokenSpan<'a> = (Token<'a>, Span); @@ -204,6 +206,7 @@ pub(in crate::front::wgsl) struct Lexer<'a> { pub(in crate::front::wgsl) source: &'a str, // The byte offset of the end of the last non-trivia token. last_end_offset: usize, + extensions: SmallVec<[Extension; 4]>, } impl<'a> Lexer<'a> { @@ -212,9 +215,15 @@ impl<'a> Lexer<'a> { input, source: input, last_end_offset: 0, + extensions: SmallVec::new_const(), } } + /// Add a new extension to the lexer. + pub(in crate::front::wgsl) fn add_extension(&mut self, extension: Extension) { + self.extensions.push(extension); + } + /// Calls the function with a lexer and returns the result of the function as well as the span for everything the function parsed /// /// # Examples @@ -346,6 +355,15 @@ impl<'a> Lexer<'a> { } } + pub(in crate::front::wgsl) fn next_extension_with_span( + &mut self, + ) -> Result<(&'a str, Span), Error<'a>> { + match self.next() { + (Token::Word(word), span) => Ok((word, span)), + other => Err(Error::Unexpected(other.1, ExpectedToken::Identifier)), + } + } + pub(in crate::front::wgsl) fn next_ident_with_span( &mut self, ) -> Result<(&'a str, Span), Error<'a>> { @@ -376,14 +394,19 @@ impl<'a> Lexer<'a> { /// Parses a generic scalar type, for example ``. pub(in crate::front::wgsl) fn next_scalar_generic(&mut self) -> Result> { self.expect_generic_paren('<')?; - let pair = match self.next() { - (Token::Word(word), span) => { - conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span)) - } + let (scalar, span) = match self.next() { + (Token::Word(word), span) => conv::get_scalar_type(word) + .map(|scalar| (scalar, span)) + .ok_or(Error::UnknownScalarType(span)), (_, span) => Err(Error::UnknownScalarType(span)), }?; + + if matches!(scalar, Scalar::F16) && !self.extensions.contains(&Extension::F16) { + return Err(Error::ExtensionNotEnabled(span, Extension::F16)); + } + self.expect_generic_paren('>')?; - Ok(pair) + Ok(scalar) } /// Parses a generic scalar type, for example ``. @@ -393,14 +416,20 @@ impl<'a> Lexer<'a> { &mut self, ) -> Result<(Scalar, Span), Error<'a>> { self.expect_generic_paren('<')?; - let pair = match self.next() { + + let (scalar, span) = match self.next() { (Token::Word(word), span) => conv::get_scalar_type(word) .map(|scalar| (scalar, span)) .ok_or(Error::UnknownScalarType(span)), (_, span) => Err(Error::UnknownScalarType(span)), }?; + + if matches!(scalar, Scalar::F16) && !self.extensions.contains(&Extension::F16) { + return Err(Error::ExtensionNotEnabled(span, Extension::F16)); + } + self.expect_generic_paren('>')?; - Ok(pair) + Ok((scalar, span)) } pub(in crate::front::wgsl) fn next_storage_access( @@ -458,6 +487,7 @@ fn sub_test(source: &str, expected_tokens: &[Token]) { #[test] fn test_numbers() { + use half::f16; // WGSL spec examples // // decimal integer @@ -482,14 +512,14 @@ fn test_numbers() { Token::Number(Ok(Number::AbstractFloat(0.01))), Token::Number(Ok(Number::AbstractFloat(12.34))), Token::Number(Ok(Number::F32(0.))), - Token::Number(Err(NumberError::UnimplementedF16)), + Token::Number(Ok(Number::F16(f16::from_f32(0.)))), Token::Number(Ok(Number::AbstractFloat(0.001))), Token::Number(Ok(Number::AbstractFloat(43.75))), Token::Number(Ok(Number::F32(16.))), Token::Number(Ok(Number::AbstractFloat(0.1875))), - Token::Number(Err(NumberError::UnimplementedF16)), + Token::Number(Ok(Number::F16(f16::from_f32(0.75)))), Token::Number(Ok(Number::AbstractFloat(0.12109375))), - Token::Number(Err(NumberError::UnimplementedF16)), + Token::Number(Ok(Number::F16(f16::from_f32(12.5)))), ], ); diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 3b1d60620b..58a6812ff5 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -2255,6 +2255,54 @@ impl Parser { Ok(fun) } + fn enable_extension<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ) -> Result> { + let (ext, ext_span) = lexer.next_extension_with_span()?; + let extension = conv::map_extension(ext, ext_span)?; + lexer.add_extension(extension.clone()); + Ok(extension) + } + + fn global_directive<'a>( + &mut self, + lexer: &mut Lexer<'a>, + out: &mut ast::TranslationUnit<'a>, + ) -> Result<(), Error<'a>> { + while let Token::Word("enable") = lexer.peek().0 { + let (_, enable_span) = lexer.next_ident_with_span()?; + + let mut enable_extension_list = Vec::with_capacity(4); + + // Parse the first extension + let extension = self.enable_extension(lexer)?; + enable_extension_list.push(extension); + + // Parse additional extensions separated by commas + while lexer.skip(Token::Separator(',')) { + let extension = self.enable_extension(lexer)?; + enable_extension_list.push(extension); + } + + // Require a semicolon at the end + if !lexer.skip(Token::Separator(';')) { + return Err(Error::Unexpected( + lexer.next().1, + ExpectedToken::Token(Token::Separator(';')), + )); + } + + out.directives.append( + ast::GlobalDirective::Enable(ast::EnableDirective { + enable_extension_list, + }), + enable_span, + ); + } + Ok(()) + } + fn global_decl<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -2474,6 +2522,7 @@ impl Parser { let mut lexer = Lexer::new(source); let mut tu = ast::TranslationUnit::default(); + self.global_directive(&mut lexer, &mut tu)?; loop { match self.global_decl(&mut lexer, &mut tu) { Err(error) => return Err(error), diff --git a/naga/src/front/wgsl/parse/number.rs b/naga/src/front/wgsl/parse/number.rs index ceb2cb336c..bdee00e8f3 100644 --- a/naga/src/front/wgsl/parse/number.rs +++ b/naga/src/front/wgsl/parse/number.rs @@ -1,5 +1,6 @@ use crate::front::wgsl::error::NumberError; use crate::front::wgsl::parse::lexer::Token; +use half::f16; /// When using this type assume no Abstract Int/Float for now #[derive(Copy, Clone, Debug, PartialEq)] @@ -16,6 +17,8 @@ pub enum Number { I64(i64), /// Concrete u64 U64(u64), + /// Concrete f16 + F16(f16), /// Concrete f32 F32(f32), /// Concrete f64 @@ -362,7 +365,8 @@ fn parse_hex_float(input: &str, kind: Option) -> Result Err(NumberError::NotRepresentable), }, - Some(FloatKind::F16) => Err(NumberError::UnimplementedF16), + // TODO: f16 is not supported by hexf_parse + Some(FloatKind::F16) => Err(NumberError::NotRepresentable), Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) { Ok(num) => Ok(Number::F32(num)), // can only be ParseHexfErrorKind::Inexact but we can't check since it's private @@ -398,7 +402,12 @@ fn parse_dec_float(input: &str, kind: Option) -> Result Err(NumberError::UnimplementedF16), + Some(FloatKind::F16) => { + let num = input.parse::().unwrap(); // will never fail + num.is_finite() + .then_some(Number::F16(num)) + .ok_or(NumberError::NotRepresentable) + } } } diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 038e215a6a..6b6686979a 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -268,6 +268,7 @@ pub use crate::arena::{Arena, Handle, Range, UniqueArena}; pub use crate::span::{SourceLocation, Span, SpanContext, WithSpan}; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; +use half::f16; #[cfg(feature = "deserialize")] use serde::Deserialize; #[cfg(feature = "serialize")] @@ -870,6 +871,7 @@ pub enum Literal { F64(f64), /// May not be NaN or infinity. F32(f32), + F16(f16), U32(u32), I32(i32), U64(u64), @@ -959,6 +961,15 @@ pub struct ResourceBinding { pub binding: u32, } +/// Enable directive +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Extension { + F16, +} + /// Variable defined at module level. #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 1b7f5cf910..ba51ade9ea 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -1,6 +1,8 @@ use std::iter; use arrayvec::ArrayVec; +use half::f16; +use num_traits::{real::Real, FromPrimitive, One, ToPrimitive, Zero}; use crate::{ arena::{Arena, Handle, HandleVec, UniqueArena}, @@ -199,6 +201,7 @@ gen_component_wise_extractor! { literals: [ AbstractFloat => AbstractFloat: f64, F32 => F32: f32, + F16 => F16: f16, AbstractInt => AbstractInt: i64, U32 => U32: u32, I32 => I32: i32, @@ -219,6 +222,7 @@ gen_component_wise_extractor! { literals: [ AbstractFloat => Abstract: f64, F32 => F32: f32, + F16 => F16: f16, ], scalar_kinds: [ Float, @@ -244,6 +248,7 @@ gen_component_wise_extractor! { AbstractFloat => AbstractFloat: f64, AbstractInt => AbstractInt: i64, F32 => F32: f32, + F16 => F16: f16, I32 => I32: i32, ], scalar_kinds: [ @@ -1088,6 +1093,7 @@ impl<'a> ConstantEvaluator<'a> { component_wise_scalar(self, span, [arg], |args| match args { Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])), Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])), + Scalar::F16([e]) => Ok(Scalar::F16([e.abs()])), Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])), Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])), Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz @@ -1119,9 +1125,13 @@ impl<'a> ConstantEvaluator<'a> { } ) } - crate::MathFunction::Saturate => { - component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) }) - } + crate::MathFunction::Saturate => component_wise_float(self, span, [arg], |e| match e { + Float::F16([e]) => Ok(Float::F16( + [e.clamp(f16::from_f32(0.0), f16::from_f32(1.0))], + )), + Float::F32([e]) => Ok(Float::F32([e.clamp(0., 1.)])), + Float::Abstract([e]) => Ok(Float::Abstract([e.clamp(0., 1.)])), + }), // trigonometry crate::MathFunction::Cos => { @@ -1175,8 +1185,8 @@ impl<'a> ConstantEvaluator<'a> { component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) }) } crate::MathFunction::Round => { - // TODO: Use `f{32,64}.round_ties_even()` when available on stable. This polyfill - // is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source], + // TODO: this hit stable on 1.77, but MSRV of naga is 1.74.0 + // This polyfill is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source], // which has licensing compatible with ours. See also // . // @@ -1198,6 +1208,9 @@ impl<'a> ConstantEvaluator<'a> { component_wise_float(self, span, [arg], |e| match e { Float::Abstract([e]) => Ok(Float::Abstract([round_ties_even(e)])), Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])), + Float::F16([e]) => { + Ok(Float::F16([(f16::from_f64(round_ties_even(f64::from(e))))])) + } }) } crate::MathFunction::Fract => { @@ -1243,15 +1256,27 @@ impl<'a> ConstantEvaluator<'a> { ) } crate::MathFunction::Step => { - component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| { - Ok([if edge <= x { 1.0 } else { 0.0 }]) + component_wise_float(self, span, [arg, arg1.unwrap()], |x| match x { + Float::Abstract([edge, x]) => { + Ok(Float::Abstract([if edge <= x { 1.0 } else { 0.0 }])) + } + Float::F32([edge, x]) => Ok(Float::F32([if edge <= x { 1.0 } else { 0.0 }])), + Float::F16([edge, x]) => Ok(Float::F16([if edge <= x { + f16::one() + } else { + f16::zero() + }])), }) } crate::MathFunction::Sqrt => { component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) }) } crate::MathFunction::InverseSqrt => { - component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) }) + component_wise_float(self, span, [arg], |e| match e { + Float::Abstract([e]) => Ok(Float::Abstract([1. / e.sqrt()])), + Float::F32([e]) => Ok(Float::F32([1. / e.sqrt()])), + Float::F16([e]) => Ok(Float::F16([f16::from_f32(1. / f32::from(e).sqrt())])), + }) } // bits @@ -1529,6 +1554,7 @@ impl<'a> ConstantEvaluator<'a> { Literal::I32(v) => v, Literal::U32(v) => v as i32, Literal::F32(v) => v as i32, + Literal::F16(v) => f16::to_i32(&v).unwrap(), //Only None on NaN or Inf Literal::Bool(v) => v as i32, Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => { return make_error(); @@ -1540,6 +1566,7 @@ impl<'a> ConstantEvaluator<'a> { Literal::I32(v) => v as u32, Literal::U32(v) => v, Literal::F32(v) => v as u32, + Literal::F16(v) => f16::to_u32(&v).unwrap(), //Only None on NaN or Inf Literal::Bool(v) => v as u32, Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => { return make_error(); @@ -1555,6 +1582,7 @@ impl<'a> ConstantEvaluator<'a> { Literal::F64(v) => v as i64, Literal::I64(v) => v, Literal::U64(v) => v as i64, + Literal::F16(v) => f16::to_i64(&v).unwrap(), //Only None on NaN or Inf Literal::AbstractInt(v) => i64::try_from_abstract(v)?, Literal::AbstractFloat(v) => i64::try_from_abstract(v)?, }), @@ -1566,9 +1594,22 @@ impl<'a> ConstantEvaluator<'a> { Literal::F64(v) => v as u64, Literal::I64(v) => v as u64, Literal::U64(v) => v, + Literal::F16(v) => f16::to_u64(&v).unwrap(), //Only None on NaN or Inf Literal::AbstractInt(v) => u64::try_from_abstract(v)?, Literal::AbstractFloat(v) => u64::try_from_abstract(v)?, }), + Sc::F16 => Literal::F16(match literal { + Literal::F16(v) => v, + Literal::F32(v) => f16::from_f32(v), + Literal::F64(v) => f16::from_f64(v), + Literal::Bool(v) => f16::from_u32(v as u32).unwrap(), + Literal::I64(v) => f16::from_i64(v).unwrap(), + Literal::U64(v) => f16::from_u64(v).unwrap(), + Literal::I32(v) => f16::from_i32(v).unwrap(), + Literal::U32(v) => f16::from_u32(v).unwrap(), + Literal::AbstractFloat(v) => f16::try_from_abstract(v)?, + Literal::AbstractInt(v) => f16::try_from_abstract(v)?, + }), Sc::F32 => Literal::F32(match literal { Literal::I32(v) => v as f32, Literal::U32(v) => v as f32, @@ -1577,12 +1618,14 @@ impl<'a> ConstantEvaluator<'a> { Literal::F64(_) | Literal::I64(_) | Literal::U64(_) => { return make_error(); } + Literal::F16(v) => f16::to_f32(v), Literal::AbstractInt(v) => f32::try_from_abstract(v)?, Literal::AbstractFloat(v) => f32::try_from_abstract(v)?, }), Sc::F64 => Literal::F64(match literal { Literal::I32(v) => v as f64, Literal::U32(v) => v as f64, + Literal::F16(v) => f16::to_f64(v), Literal::F32(v) => v as f64, Literal::F64(v) => v, Literal::Bool(v) => v as u32 as f64, @@ -1594,6 +1637,7 @@ impl<'a> ConstantEvaluator<'a> { Literal::I32(v) => v != 0, Literal::U32(v) => v != 0, Literal::F32(v) => v != 0.0, + Literal::F16(v) => v != f16::zero(), Literal::Bool(v) => v, Literal::F64(_) | Literal::I64(_) @@ -1743,6 +1787,7 @@ impl<'a> ConstantEvaluator<'a> { UnaryOperator::Negate => match value { Literal::I32(v) => Literal::I32(v.wrapping_neg()), Literal::F32(v) => Literal::F32(-v), + Literal::F16(v) => Literal::F16(-v), Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()), Literal::AbstractFloat(v) => Literal::AbstractFloat(-v), _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), @@ -1881,6 +1926,14 @@ impl<'a> ConstantEvaluator<'a> { BinaryOperator::Modulo => a % b, _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), }), + (Literal::F16(a), Literal::F16(b)) => Literal::F16(match op { + BinaryOperator::Add => a + b, + BinaryOperator::Subtract => a - b, + BinaryOperator::Multiply => a * b, + BinaryOperator::Divide => a / b, + BinaryOperator::Modulo => a % b, + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + }), (Literal::AbstractInt(a), Literal::AbstractInt(b)) => { Literal::AbstractInt(match op { BinaryOperator::Add => a.checked_add(b).ok_or_else(|| { @@ -2450,6 +2503,32 @@ impl TryFromAbstract for u64 { } } +impl TryFromAbstract for f16 { + fn try_from_abstract(value: f64) -> Result { + let f = f16::from_f64(value); + if f.is_infinite() { + return Err(ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "f16", + }); + } + Ok(f) + } +} + +impl TryFromAbstract for f16 { + fn try_from_abstract(value: i64) -> Result { + let f = f16::from_i64(value); + if f.is_none() { + return Err(ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "f16", + }); + } + Ok(f.unwrap()) + } +} + #[cfg(test)] mod tests { use std::vec; diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index abbe0c7e46..b19320fd7c 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -90,6 +90,10 @@ impl super::Scalar { kind: crate::ScalarKind::Uint, width: 4, }; + pub const F16: Self = Self { + kind: crate::ScalarKind::Float, + width: 2, + }; pub const F32: Self = Self { kind: crate::ScalarKind::Float, width: 4, @@ -157,6 +161,7 @@ impl super::Scalar { pub enum HashableLiteral { F64(u64), F32(u32), + F16(u16), U32(u32), I32(i32), U64(u64), @@ -171,6 +176,7 @@ impl From for HashableLiteral { match l { crate::Literal::F64(v) => Self::F64(v.to_bits()), crate::Literal::F32(v) => Self::F32(v.to_bits()), + crate::Literal::F16(v) => Self::F16(v.to_bits()), crate::Literal::U32(v) => Self::U32(v), crate::Literal::I32(v) => Self::I32(v), crate::Literal::U64(v) => Self::U64(v), @@ -209,6 +215,7 @@ impl crate::Literal { match *self { Self::F64(_) | Self::I64(_) | Self::U64(_) => 8, Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, + Self::F16(_) => 2, Self::Bool(_) => crate::BOOL_WIDTH, Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH, } @@ -217,6 +224,7 @@ impl crate::Literal { match *self { Self::F64(_) => crate::Scalar::F64, Self::F32(_) => crate::Scalar::F32, + Self::F16(_) => crate::Scalar::F16, Self::U32(_) => crate::Scalar::U32, Self::I32(_) => crate::Scalar::I32, Self::U64(_) => crate::Scalar::U64, diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index c314ec2ac8..72507bc835 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -143,6 +143,8 @@ bitflags::bitflags! { const SHADER_INT64_ATOMIC_MIN_MAX = 0x80000; /// Support for all atomic operations on 64-bit integers. const SHADER_INT64_ATOMIC_ALL_OPS = 0x100000; + /// Support for 16-bit floating-point types. + const SHADER_FLOAT16 = 0x200000; } } diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index c0c25dab79..61005bb19b 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -243,8 +243,8 @@ impl super::Validator { pub(super) const fn check_width(&self, scalar: crate::Scalar) -> Result<(), WidthError> { let good = match scalar.kind { crate::ScalarKind::Bool => scalar.width == crate::BOOL_WIDTH, - crate::ScalarKind::Float => { - if scalar.width == 8 { + crate::ScalarKind::Float => match scalar.width { + 8 => { if !self.capabilities.contains(Capabilities::FLOAT64) { return Err(WidthError::MissingCapability { name: "f64", @@ -252,10 +252,18 @@ impl super::Validator { }); } true - } else { - scalar.width == 4 } - } + 2 => { + if !self.capabilities.contains(Capabilities::SHADER_FLOAT16) { + return Err(WidthError::MissingCapability { + name: "f16", + flag: "FLOAT16", + }); + } + true + } + _ => scalar.width == 4, + }, crate::ScalarKind::Sint => { if scalar.width == 8 { if !self.capabilities.contains(Capabilities::SHADER_INT64) { diff --git a/naga/tests/in/float16.param.ron b/naga/tests/in/float16.param.ron new file mode 100644 index 0000000000..cf105af810 --- /dev/null +++ b/naga/tests/in/float16.param.ron @@ -0,0 +1,22 @@ +( + god_mode: true, + spv: ( + version: (1, 0), + ), + hlsl: ( + shader_model: V6_2, + binding_map: {}, + fake_missing_bindings: true, + special_constants_binding: Some((space: 1, register: 0)), + push_constants_target: Some((space: 0, register: 0)), + zero_initialize_workgroup_memory: true, + ), + msl: ( + lang_version: (1, 0), + per_entry_point_map: {}, + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: true, + zero_initialize_workgroup_memory: true, + ), +) diff --git a/naga/tests/in/float16.wgsl b/naga/tests/in/float16.wgsl new file mode 100644 index 0000000000..8db193e540 --- /dev/null +++ b/naga/tests/in/float16.wgsl @@ -0,0 +1,87 @@ +enable f16; +enable f16; //redundant directives are OK + +var private_variable: f16 = 1h; +const constant_variable: f16 = f16(15.2); + +struct UniformCompatible { + // Other types + val_u32: u32, + val_i32: i32, + val_f32: f32, + + // f16 + val_f16: f16, + val_f16_2: vec2, + val_f16_3: vec3, + val_f16_4: vec4, + final_value: f16, +} + +struct StorageCompatible { + val_f16_array_2: array, + val_f16_array_2: array, +} + +@group(0) @binding(0) +var input_uniform: UniformCompatible; + +@group(0) @binding(1) +var input_storage: UniformCompatible; + +@group(0) @binding(2) +var input_arrays: StorageCompatible; + +@group(0) @binding(3) +var output: UniformCompatible; + +@group(0) @binding(4) +var output_arrays: StorageCompatible; + +fn f16_function(x: f16) -> f16 { + var val: f16 = f16(constant_variable); + // A number too big for f16 + val += 1h - 33333h; + // Constructing an f16 from an AbstractInt + val += val + f16(5.); + // Constructing a f16 from other types and other types from f16. + val += f16(input_uniform.val_f32 + f32(val)); + // Constructing a vec3 from a i64 + val += vec3(input_uniform.val_f16).z; + + // Reading/writing to a uniform/storage buffer + output.val_f16 = input_uniform.val_f16 + input_storage.val_f16; + output.val_f16_2 = input_uniform.val_f16_2 + input_storage.val_f16_2; + output.val_f16_3 = input_uniform.val_f16_3 + input_storage.val_f16_3; + output.val_f16_4 = input_uniform.val_f16_4 + input_storage.val_f16_4; + + output_arrays.val_f16_array_2 = input_arrays.val_f16_array_2; + + // We make sure not to use 32 in these arguments, so it's clear in the results which are builtin + // constants based on the size of the type, and which are arguments. + + // Numeric functions + val += abs(val); + val += clamp(val, val, val); + //val += countLeadingZeros(val); + //val += countOneBits(val); + //val += countTrailingZeros(val); + val += dot(vec2(val), vec2(val)); + //val += extractBits(val, 15u, 28u); + //val += firstLeadingBit(val); + //val += firstTrailingBit(val); + //val += insertBits(val, 12li, 15u, 28u); + val += max(val, val); + val += min(val, val); + //val += reverseBits(val); + val += sign(val); // only for i64 + + // Make sure all the variables are used. + return f16(1.0); +} + +@compute @workgroup_size(1) +fn main() { + output.final_value = f16_function(2h); +} + diff --git a/naga/tests/out/hlsl/float16.hlsl b/naga/tests/out/hlsl/float16.hlsl new file mode 100644 index 0000000000..007306daaf --- /dev/null +++ b/naga/tests/out/hlsl/float16.hlsl @@ -0,0 +1,107 @@ +struct NagaConstants { + int first_vertex; + int first_instance; + uint other; +}; +ConstantBuffer _NagaConstants: register(b0, space1); + +struct UniformCompatible { + uint val_u32_; + int val_i32_; + float val_f32_; + half val_f16_; + half2 val_f16_2_; + int _pad5_0; + half3 val_f16_3_; + half4 val_f16_4_; + half final_value; + int _end_pad_0; +}; + +struct StorageCompatible { + half val_f16_array_2_[2]; + half val_f16_array_2_1[2]; +}; + +static const half constant_variable = 15.203125h; + +static half private_variable = 1.0h; +cbuffer input_uniform : register(b0) { UniformCompatible input_uniform; } +ByteAddressBuffer input_storage : register(t1); +ByteAddressBuffer input_arrays : register(t2); +RWByteAddressBuffer output : register(u3); +RWByteAddressBuffer output_arrays : register(u4); + +typedef half ret_Constructarray2_half_[2]; +ret_Constructarray2_half_ Constructarray2_half_(half arg0, half arg1) { + half ret[2] = { arg0, arg1 }; + return ret; +} + +half f16_function(half x) +{ + half val = 15.203125h; + + half _expr6 = val; + val = (_expr6 + (1.0h - 33344.0h)); + half _expr8 = val; + half _expr11 = val; + val = (_expr11 + (_expr8 + 5.0h)); + float _expr15 = input_uniform.val_f32_; + half _expr16 = val; + half _expr20 = val; + val = (_expr20 + half((_expr15 + float(_expr16)))); + half _expr24 = input_uniform.val_f16_; + half _expr27 = val; + val = (_expr27 + (_expr24).xxx.z); + half _expr33 = input_uniform.val_f16_; + half _expr36 = input_storage.Load(12); + output.Store(12, (_expr33 + _expr36)); + half2 _expr42 = input_uniform.val_f16_2_; + half2 _expr45 = input_storage.Load(16); + output.Store(16, (_expr42 + _expr45)); + half3 _expr51 = input_uniform.val_f16_3_; + half3 _expr54 = input_storage.Load(24); + output.Store(24, (_expr51 + _expr54)); + half4 _expr60 = input_uniform.val_f16_4_; + half4 _expr63 = input_storage.Load(32); + output.Store(32, (_expr60 + _expr63)); + half _expr69[2] = Constructarray2_half_(input_arrays.Load(0+0), input_arrays.Load(0+2)); + { + half _value2[2] = _expr69; + output_arrays.Store(0+0, _value2[0]); + output_arrays.Store(0+2, _value2[1]); + } + half _expr70 = val; + half _expr72 = val; + val = (_expr72 + abs(_expr70)); + half _expr74 = val; + half _expr75 = val; + half _expr76 = val; + half _expr78 = val; + val = (_expr78 + clamp(_expr74, _expr75, _expr76)); + half _expr80 = val; + half _expr82 = val; + half _expr85 = val; + val = (_expr85 + dot((_expr80).xx, (_expr82).xx)); + half _expr87 = val; + half _expr88 = val; + half _expr90 = val; + val = (_expr90 + max(_expr87, _expr88)); + half _expr92 = val; + half _expr93 = val; + half _expr95 = val; + val = (_expr95 + min(_expr92, _expr93)); + half _expr97 = val; + half _expr99 = val; + val = (_expr99 + sign(_expr97)); + return 1.0h; +} + +[numthreads(1, 1, 1)] +void main() +{ + const half _e3 = f16_function(2.0h); + output.Store(40, _e3); + return; +} diff --git a/naga/tests/out/hlsl/float16.ron b/naga/tests/out/hlsl/float16.ron new file mode 100644 index 0000000000..b396a4626e --- /dev/null +++ b/naga/tests/out/hlsl/float16.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_6_2", + ), + ], +) diff --git a/naga/tests/out/msl/float16.msl b/naga/tests/out/msl/float16.msl new file mode 100644 index 0000000000..aa119fb40b --- /dev/null +++ b/naga/tests/out/msl/float16.msl @@ -0,0 +1,99 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct UniformCompatible { + uint val_u32_; + int val_i32_; + float val_f32_; + half val_f16_; + char _pad4[2]; + metal::half2 val_f16_2_; + char _pad5[4]; + metal::half3 val_f16_3_; + metal::half4 val_f16_4_; + half final_value; +}; +struct type_7 { + half inner[2]; +}; +struct StorageCompatible { + type_7 val_f16_array_2_; + type_7 val_f16_array_2_1; +}; +constant half constant_variable = 15.203125; + +half f16_function( + half x, + constant UniformCompatible& input_uniform, + device UniformCompatible const& input_storage, + device StorageCompatible const& input_arrays, + device UniformCompatible& output, + device StorageCompatible& output_arrays +) { + half val = 15.203125; + half _e6 = val; + val = _e6 + (1.0 - 33344.0); + half _e8 = val; + half _e11 = val; + val = _e11 + (_e8 + 5.0); + float _e15 = input_uniform.val_f32_; + half _e16 = val; + half _e20 = val; + val = _e20 + static_cast(_e15 + static_cast(_e16)); + half _e24 = input_uniform.val_f16_; + half _e27 = val; + val = _e27 + metal::half3(_e24).z; + half _e33 = input_uniform.val_f16_; + half _e36 = input_storage.val_f16_; + output.val_f16_ = _e33 + _e36; + metal::half2 _e42 = input_uniform.val_f16_2_; + metal::half2 _e45 = input_storage.val_f16_2_; + output.val_f16_2_ = _e42 + _e45; + metal::half3 _e51 = input_uniform.val_f16_3_; + metal::half3 _e54 = input_storage.val_f16_3_; + output.val_f16_3_ = _e51 + _e54; + metal::half4 _e60 = input_uniform.val_f16_4_; + metal::half4 _e63 = input_storage.val_f16_4_; + output.val_f16_4_ = _e60 + _e63; + type_7 _e69 = input_arrays.val_f16_array_2_; + output_arrays.val_f16_array_2_ = _e69; + half _e70 = val; + half _e72 = val; + val = _e72 + metal::abs(_e70); + half _e74 = val; + half _e75 = val; + half _e76 = val; + half _e78 = val; + val = _e78 + metal::clamp(_e74, _e75, _e76); + half _e80 = val; + half _e82 = val; + half _e85 = val; + val = _e85 + metal::dot(metal::half2(_e80), metal::half2(_e82)); + half _e87 = val; + half _e88 = val; + half _e90 = val; + val = _e90 + metal::max(_e87, _e88); + half _e92 = val; + half _e93 = val; + half _e95 = val; + val = _e95 + metal::min(_e92, _e93); + half _e97 = val; + half _e99 = val; + val = _e99 + metal::sign(_e97); + return 1.0; +} + +kernel void main_( + constant UniformCompatible& input_uniform [[user(fake0)]] +, device UniformCompatible const& input_storage [[user(fake0)]] +, device StorageCompatible const& input_arrays [[user(fake0)]] +, device UniformCompatible& output [[user(fake0)]] +, device StorageCompatible& output_arrays [[user(fake0)]] +) { + half _e3 = f16_function(2.0, input_uniform, input_storage, input_arrays, output, output_arrays); + output.final_value = _e3; + return; +} diff --git a/naga/tests/out/spv/float16.spvasm b/naga/tests/out/spv/float16.spvasm new file mode 100644 index 0000000000..fa4cbaf039 --- /dev/null +++ b/naga/tests/out/spv/float16.spvasm @@ -0,0 +1,220 @@ +; SPIR-V +; Version: 1.0 +; Generator: rspirv +; Bound: 157 +OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %145 "main" +OpExecutionMode %145 LocalSize 1 1 1 +OpMemberDecorate %10 0 Offset 0 +OpMemberDecorate %10 1 Offset 4 +OpMemberDecorate %10 2 Offset 8 +OpMemberDecorate %10 3 Offset 12 +OpMemberDecorate %10 4 Offset 16 +OpMemberDecorate %10 5 Offset 24 +OpMemberDecorate %10 6 Offset 32 +OpMemberDecorate %10 7 Offset 40 +OpDecorate %11 ArrayStride 2 +OpMemberDecorate %13 0 Offset 0 +OpMemberDecorate %13 1 Offset 4 +OpDecorate %18 DescriptorSet 0 +OpDecorate %18 Binding 0 +OpDecorate %19 Block +OpMemberDecorate %19 0 Offset 0 +OpDecorate %21 NonWritable +OpDecorate %21 DescriptorSet 0 +OpDecorate %21 Binding 1 +OpDecorate %22 Block +OpMemberDecorate %22 0 Offset 0 +OpDecorate %24 NonWritable +OpDecorate %24 DescriptorSet 0 +OpDecorate %24 Binding 2 +OpDecorate %25 Block +OpMemberDecorate %25 0 Offset 0 +OpDecorate %27 DescriptorSet 0 +OpDecorate %27 Binding 3 +OpDecorate %28 Block +OpMemberDecorate %28 0 Offset 0 +OpDecorate %30 DescriptorSet 0 +OpDecorate %30 Binding 4 +OpDecorate %31 Block +OpMemberDecorate %31 0 Offset 0 +%2 = OpTypeVoid +%3 = OpTypeFloat 16 +%4 = OpTypeInt 32 0 +%5 = OpTypeInt 32 1 +%6 = OpTypeFloat 32 +%7 = OpTypeVector %3 2 +%8 = OpTypeVector %3 3 +%9 = OpTypeVector %3 4 +%10 = OpTypeStruct %4 %5 %6 %3 %7 %8 %9 %3 +%12 = OpConstant %4 2 +%11 = OpTypeArray %3 %12 +%13 = OpTypeStruct %11 %11 +%14 = OpConstant %3 2.1524e-41 +%15 = OpConstant %3 2.7121e-41 +%17 = OpTypePointer Private %3 +%16 = OpVariable %17 Private %14 +%19 = OpTypeStruct %10 +%20 = OpTypePointer Uniform %19 +%18 = OpVariable %20 Uniform +%22 = OpTypeStruct %10 +%23 = OpTypePointer StorageBuffer %22 +%21 = OpVariable %23 StorageBuffer +%25 = OpTypeStruct %13 +%26 = OpTypePointer StorageBuffer %25 +%24 = OpVariable %26 StorageBuffer +%28 = OpTypeStruct %10 +%29 = OpTypePointer StorageBuffer %28 +%27 = OpVariable %29 StorageBuffer +%31 = OpTypeStruct %13 +%32 = OpTypePointer StorageBuffer %31 +%30 = OpVariable %32 StorageBuffer +%36 = OpTypeFunction %3 %3 +%37 = OpTypePointer Uniform %10 +%38 = OpConstant %4 0 +%40 = OpTypePointer StorageBuffer %10 +%42 = OpTypePointer StorageBuffer %13 +%46 = OpConstant %3 4.3073e-41 +%47 = OpConstant %3 2.4753e-41 +%49 = OpTypePointer Function %3 +%58 = OpTypePointer Uniform %6 +%67 = OpTypePointer Uniform %3 +%68 = OpConstant %4 3 +%75 = OpTypePointer StorageBuffer %3 +%82 = OpTypePointer StorageBuffer %7 +%83 = OpTypePointer Uniform %7 +%84 = OpConstant %4 4 +%91 = OpTypePointer StorageBuffer %8 +%92 = OpTypePointer Uniform %8 +%93 = OpConstant %4 5 +%100 = OpTypePointer StorageBuffer %9 +%101 = OpTypePointer Uniform %9 +%102 = OpConstant %4 6 +%109 = OpTypePointer StorageBuffer %11 +%146 = OpTypeFunction %2 +%152 = OpConstant %3 2.2959e-41 +%155 = OpConstant %4 7 +%35 = OpFunction %3 None %36 +%34 = OpFunctionParameter %3 +%33 = OpLabel +%48 = OpVariable %49 Function %15 +%39 = OpAccessChain %37 %18 %38 +%41 = OpAccessChain %40 %21 %38 +%43 = OpAccessChain %42 %24 %38 +%44 = OpAccessChain %40 %27 %38 +%45 = OpAccessChain %42 %30 %38 +OpBranch %50 +%50 = OpLabel +%51 = OpFSub %3 %14 %46 +%52 = OpLoad %3 %48 +%53 = OpFAdd %3 %52 %51 +OpStore %48 %53 +%54 = OpLoad %3 %48 +%55 = OpFAdd %3 %54 %47 +%56 = OpLoad %3 %48 +%57 = OpFAdd %3 %56 %55 +OpStore %48 %57 +%59 = OpAccessChain %58 %39 %12 +%60 = OpLoad %6 %59 +%61 = OpLoad %3 %48 +%62 = OpFConvert %6 %61 +%63 = OpFAdd %6 %60 %62 +%64 = OpFConvert %3 %63 +%65 = OpLoad %3 %48 +%66 = OpFAdd %3 %65 %64 +OpStore %48 %66 +%69 = OpAccessChain %67 %39 %68 +%70 = OpLoad %3 %69 +%71 = OpCompositeConstruct %8 %70 %70 %70 +%72 = OpCompositeExtract %3 %71 2 +%73 = OpLoad %3 %48 +%74 = OpFAdd %3 %73 %72 +OpStore %48 %74 +%76 = OpAccessChain %67 %39 %68 +%77 = OpLoad %3 %76 +%78 = OpAccessChain %75 %41 %68 +%79 = OpLoad %3 %78 +%80 = OpFAdd %3 %77 %79 +%81 = OpAccessChain %75 %44 %68 +OpStore %81 %80 +%85 = OpAccessChain %83 %39 %84 +%86 = OpLoad %7 %85 +%87 = OpAccessChain %82 %41 %84 +%88 = OpLoad %7 %87 +%89 = OpFAdd %7 %86 %88 +%90 = OpAccessChain %82 %44 %84 +OpStore %90 %89 +%94 = OpAccessChain %92 %39 %93 +%95 = OpLoad %8 %94 +%96 = OpAccessChain %91 %41 %93 +%97 = OpLoad %8 %96 +%98 = OpFAdd %8 %95 %97 +%99 = OpAccessChain %91 %44 %93 +OpStore %99 %98 +%103 = OpAccessChain %101 %39 %102 +%104 = OpLoad %9 %103 +%105 = OpAccessChain %100 %41 %102 +%106 = OpLoad %9 %105 +%107 = OpFAdd %9 %104 %106 +%108 = OpAccessChain %100 %44 %102 +OpStore %108 %107 +%110 = OpAccessChain %109 %43 %38 +%111 = OpLoad %11 %110 +%112 = OpAccessChain %109 %45 %38 +OpStore %112 %111 +%113 = OpLoad %3 %48 +%114 = OpExtInst %3 %1 FAbs %113 +%115 = OpLoad %3 %48 +%116 = OpFAdd %3 %115 %114 +OpStore %48 %116 +%117 = OpLoad %3 %48 +%118 = OpLoad %3 %48 +%119 = OpLoad %3 %48 +%120 = OpExtInst %3 %1 FClamp %117 %118 %119 +%121 = OpLoad %3 %48 +%122 = OpFAdd %3 %121 %120 +OpStore %48 %122 +%123 = OpLoad %3 %48 +%124 = OpCompositeConstruct %7 %123 %123 +%125 = OpLoad %3 %48 +%126 = OpCompositeConstruct %7 %125 %125 +%127 = OpDot %3 %124 %126 +%128 = OpLoad %3 %48 +%129 = OpFAdd %3 %128 %127 +OpStore %48 %129 +%130 = OpLoad %3 %48 +%131 = OpLoad %3 %48 +%132 = OpExtInst %3 %1 FMax %130 %131 +%133 = OpLoad %3 %48 +%134 = OpFAdd %3 %133 %132 +OpStore %48 %134 +%135 = OpLoad %3 %48 +%136 = OpLoad %3 %48 +%137 = OpExtInst %3 %1 FMin %135 %136 +%138 = OpLoad %3 %48 +%139 = OpFAdd %3 %138 %137 +OpStore %48 %139 +%140 = OpLoad %3 %48 +%141 = OpExtInst %3 %1 FSign %140 +%142 = OpLoad %3 %48 +%143 = OpFAdd %3 %142 %141 +OpStore %48 %143 +OpReturnValue %14 +OpFunctionEnd +%145 = OpFunction %2 None %146 +%144 = OpLabel +%147 = OpAccessChain %37 %18 %38 +%148 = OpAccessChain %40 %21 %38 +%149 = OpAccessChain %42 %24 %38 +%150 = OpAccessChain %40 %27 %38 +%151 = OpAccessChain %42 %30 %38 +OpBranch %153 +%153 = OpLabel +%154 = OpFunctionCall %3 %35 %152 +%156 = OpAccessChain %75 %150 %155 +OpStore %156 %154 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/float16.wgsl b/naga/tests/out/wgsl/float16.wgsl new file mode 100644 index 0000000000..ac385c738a --- /dev/null +++ b/naga/tests/out/wgsl/float16.wgsl @@ -0,0 +1,91 @@ +struct UniformCompatible { + val_u32_: u32, + val_i32_: i32, + val_f32_: f32, + val_f16_: f16, + val_f16_2_: vec2, + val_f16_3_: vec3, + val_f16_4_: vec4, + final_value: f16, +} + +struct StorageCompatible { + val_f16_array_2_: array, + val_f16_array_2_1: array, +} + +const constant_variable: f16 = 15.203125h; + +var private_variable: f16 = 1h; +@group(0) @binding(0) +var input_uniform: UniformCompatible; +@group(0) @binding(1) +var input_storage: UniformCompatible; +@group(0) @binding(2) +var input_arrays: StorageCompatible; +@group(0) @binding(3) +var output: UniformCompatible; +@group(0) @binding(4) +var output_arrays: StorageCompatible; + +fn f16_function(x: f16) -> f16 { + var val: f16 = 15.203125h; + + let _e6 = val; + val = (_e6 + (1h - 33344h)); + let _e8 = val; + let _e11 = val; + val = (_e11 + (_e8 + 5h)); + let _e15 = input_uniform.val_f32_; + let _e16 = val; + let _e20 = val; + val = (_e20 + f16((_e15 + f32(_e16)))); + let _e24 = input_uniform.val_f16_; + let _e27 = val; + val = (_e27 + vec3(_e24).z); + let _e33 = input_uniform.val_f16_; + let _e36 = input_storage.val_f16_; + output.val_f16_ = (_e33 + _e36); + let _e42 = input_uniform.val_f16_2_; + let _e45 = input_storage.val_f16_2_; + output.val_f16_2_ = (_e42 + _e45); + let _e51 = input_uniform.val_f16_3_; + let _e54 = input_storage.val_f16_3_; + output.val_f16_3_ = (_e51 + _e54); + let _e60 = input_uniform.val_f16_4_; + let _e63 = input_storage.val_f16_4_; + output.val_f16_4_ = (_e60 + _e63); + let _e69 = input_arrays.val_f16_array_2_; + output_arrays.val_f16_array_2_ = _e69; + let _e70 = val; + let _e72 = val; + val = (_e72 + abs(_e70)); + let _e74 = val; + let _e75 = val; + let _e76 = val; + let _e78 = val; + val = (_e78 + clamp(_e74, _e75, _e76)); + let _e80 = val; + let _e82 = val; + let _e85 = val; + val = (_e85 + dot(vec2(_e80), vec2(_e82))); + let _e87 = val; + let _e88 = val; + let _e90 = val; + val = (_e90 + max(_e87, _e88)); + let _e92 = val; + let _e93 = val; + let _e95 = val; + val = (_e95 + min(_e92, _e93)); + let _e97 = val; + let _e99 = val; + val = (_e99 + sign(_e97)); + return 1h; +} + +@compute @workgroup_size(1, 1, 1) +fn main() { + let _e3 = f16_function(2h); + output.final_value = _e3; + return; +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index adf67f8333..25433901c4 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -902,6 +902,10 @@ fn convert_wgsl() { "int64", Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL, ), + ( + "float16", + Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL, + ), ( "subgroup-operations", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index 959f3cada7..2197a6ae6a 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -481,6 +481,10 @@ pub fn create_validator( features.contains(wgt::Features::PUSH_CONSTANTS), ); caps.set(Caps::FLOAT64, features.contains(wgt::Features::SHADER_F64)); + caps.set( + Caps::SHADER_FLOAT16, + features.contains(wgt::Features::SHADER_F16), + ); caps.set( Caps::PRIMITIVE_INDEX, features.contains(wgt::Features::SHADER_PRIMITIVE_INDEX), diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index 45d69f5584..c6c662308a 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -1,5 +1,5 @@ use std::{ - mem::{size_of, size_of_val}, + mem::{self, size_of, size_of_val}, ptr, sync::Arc, thread, @@ -378,6 +378,24 @@ impl super::Adapter { && features1.Int64ShaderOps.as_bool(), ); + let float16_supported = { + let mut features4: Direct3D12::D3D12_FEATURE_DATA_D3D12_OPTIONS4 = + unsafe { mem::zeroed() }; + let hr = unsafe { + device.CheckFeatureSupport( + Direct3D12::D3D12_FEATURE_D3D12_OPTIONS4, // https://learn.microsoft.com/en-us/windows/win32/api/d3d12/ne-d3d12-d3d12_feature#syntax + ptr::from_mut(&mut features4).cast(), + size_of::() as _, + ) + }; + hr.is_ok() && features4.Native16BitShaderOpsSupported.as_bool() + }; + + features.set( + wgt::Features::SHADER_F16, + shader_model >= naga::back::hlsl::ShaderModel::V6_2 && float16_supported, + ); + features.set( wgt::Features::SUBGROUP, shader_model >= naga::back::hlsl::ShaderModel::V6_0 diff --git a/wgpu-hal/src/dx12/types.rs b/wgpu-hal/src/dx12/types.rs index 5270c6ca8a..2df26265ea 100644 --- a/wgpu-hal/src/dx12/types.rs +++ b/wgpu-hal/src/dx12/types.rs @@ -37,3 +37,19 @@ pub struct ISwapChainPanelNative_Vtbl { swap_chain: *mut core::ffi::c_void, ) -> windows_core::HRESULT, } + +// winapi::ENUM! { +// enum D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER { +// D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER_0 = 0, +// // D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER_1, +// // D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER_2, +// } +// } +// +// winapi::STRUCT! { +// struct D3D12_FEATURE_DATA_D3D12_OPTIONS4 { +// MSAA64KBAlignedTextureSupported: winapi::shared::minwindef::BOOL, +// SharedResourceCompatibilityTier: D3D12_SHARED_RESOURCE_COMPATIBILITY_TIER, +// Native16BitShaderOpsSupported: winapi::shared::minwindef::BOOL, +// } +// } diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index ab6ae02c6f..e8f8058f95 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -382,6 +382,7 @@ impl PhysicalDeviceFeatures { vk::PhysicalDeviceShaderFloat16Int8Features::default().shader_float16(true), vk::PhysicalDevice16BitStorageFeatures::default() .storage_buffer16_bit_access(true) + .storage_input_output16(true) .uniform_and_storage_buffer16_bit_access(true), )) } else { @@ -664,7 +665,8 @@ impl PhysicalDeviceFeatures { F::SHADER_F16, f16_i8.shader_float16 != 0 && bit16.storage_buffer16_bit_access != 0 - && bit16.uniform_and_storage_buffer16_bit_access != 0, + && bit16.uniform_and_storage_buffer16_bit_access != 0 + && bit16.storage_input_output16 != 0, ); }