diff --git a/CHANGELOG.md b/CHANGELOG.md index f60dff7eb1..ca7131118f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148] - Implement `quantizeToF16()` for WGSL frontend, and WGSL, SPIR-V, HLSL, MSL, and GLSL backends. By @jamienicol in [#6519](https://github.com/gfx-rs/wgpu/pull/6519). - Add support for GLSL `usampler*` and `isampler*`. By @DavidPeicho in [#6513](https://github.com/gfx-rs/wgpu/pull/6513). - Expose Ray Query flags as constants in WGSL. Implement candidate intersections. By @kvark in [#5429](https://github.com/gfx-rs/wgpu/pull/5429) +- Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635). #### General diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 0005cbcb0e..1a2dc17023 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -25,6 +25,8 @@ pub enum PipelineConstantError { ConstantEvaluatorError(#[from] ConstantEvaluatorError), #[error(transparent)] ValidationError(#[from] WithSpan), + #[error("workgroup_size was overridden to a negative value")] + NegativeWorkgroupSize, } /// Replace all overrides in `module` with constants. @@ -190,6 +192,7 @@ pub fn process_overrides<'a>( let mut entry_points = mem::take(&mut module.entry_points); for ep in entry_points.iter_mut() { process_function(&mut module, &override_map, &mut ep.function)?; + process_workgroup_size_override(&mut module, &override_map, ep)?; } module.entry_points = entry_points; @@ -202,6 +205,52 @@ pub fn process_overrides<'a>( Ok((Cow::Owned(module), Cow::Owned(module_info))) } +fn process_workgroup_size_override( + module: &mut Module, + override_map: &HandleVec>, + ep: &mut crate::EntryPoint, +) -> Result<(), PipelineConstantError> { + match ep.workgroup_size_overrides { + None => {} + Some(overrides) => { + overrides.iter().enumerate().try_for_each( + |(i, overridden)| -> Result<(), PipelineConstantError> { + match *overridden { + None => Ok(()), + Some(h) => { + let c = module.constants[override_map[h]].init; + let n = &module.global_expressions[c]; + match *n { + crate::Expression::Literal(literal) => { + ep.workgroup_size[i] = match literal { + crate::Literal::U32(m) => m, + crate::Literal::I32(m) => { + if m < 0 { + Err(PipelineConstantError::NegativeWorkgroupSize)?; + unreachable!(); + } else { + m as u32 + } + } + _ => { + unreachable!(); + } + }; + } + _ => { + unreachable!(); + } + } + Ok(()) + } + } + }, + )?; + } + } + Ok(()) +} + /// Add a [`Constant`] to `module` for the override `old_h`. /// /// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. diff --git a/naga/src/front/glsl/functions.rs b/naga/src/front/glsl/functions.rs index 658632e872..394be22eaa 100644 --- a/naga/src/front/glsl/functions.rs +++ b/naga/src/front/glsl/functions.rs @@ -1366,6 +1366,7 @@ impl Frontend { early_depth_test: Some(crate::EarlyDepthTest { conservative: None }) .filter(|_| self.meta.early_fragment_tests), workgroup_size: self.meta.workgroup_size, + workgroup_size_overrides: None, function: Function { arguments, expressions, diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index 7122e44837..271b96926b 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -569,6 +569,7 @@ impl> super::Frontend { stage: ep.stage, early_depth_test: ep.early_depth_test, workgroup_size: ep.workgroup_size, + workgroup_size_overrides: None, function, }); diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 5173e73d79..7de6fc945b 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1311,24 +1311,54 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .collect(); if let Some(ref entry) = f.entry_point { - let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size { + let workgroup_size_info = if let Some(workgroup_size) = entry.workgroup_size { // TODO: replace with try_map once stabilized let mut workgroup_size_out = [1; 3]; + let mut workgroup_size_overrides_out = [None; 3]; for (i, size) in workgroup_size.into_iter().enumerate() { if let Some(size_expr) = size { - workgroup_size_out[i] = self.const_u32(size_expr, &mut ctx.as_const())?.0; + match self.const_u32(size_expr, &mut ctx.as_const()) { + Ok(value) => { + workgroup_size_out[i] = value.0; + } + err => { + if let Err(Error::ConstantEvaluatorError(ref ty, _)) = err { + match **ty { + crate::proc::ConstantEvaluatorError::OverrideExpr => { + workgroup_size_overrides_out[i] = + Some(self.workgroup_size_override( + size_expr, + &mut ctx.as_override(), + i, + )?); + } + _ => { + err?; + } + } + } else { + err?; + } + } + } } } - workgroup_size_out + if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { + (workgroup_size_out, None) + } else { + (workgroup_size_out, Some(workgroup_size_overrides_out)) + } } else { - [0; 3] + ([0; 3], None) }; + let (workgroup_size, workgroup_size_overrides) = workgroup_size_info; ctx.module.entry_points.push(crate::EntryPoint { name: f.name.name.to_string(), stage: entry.stage, early_depth_test: entry.early_depth_test, workgroup_size, + workgroup_size_overrides, function, }); Ok(LoweredGlobalDecl::EntryPoint) @@ -1338,6 +1368,31 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } + fn workgroup_size_override( + &mut self, + size_expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, + i: usize, + ) -> Result, Error<'source>> { + let span = ctx.ast_expressions.get_span(size_expr); + let expr = self.expression(size_expr, ctx)?; + let ty = ctx.register_type(expr)?; + match ctx.module.types[ty].inner.scalar_kind().ok_or(0) { + Ok(crate::ScalarKind::Sint) | Ok(crate::ScalarKind::Uint) => Ok({ + ctx.module.overrides.append( + crate::Override { + name: Some(format!("__workgroup_size_{}", i)), + id: None, + ty, + init: Some(expr), + }, + span, + ) + }), + _ => Err(Error::ExpectedConstExprConcreteIntegerScalar(span)), + } + } + fn block( &mut self, b: &ast::Block<'source>, diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 4afbfff9d7..1c1929efa2 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -2186,6 +2186,8 @@ pub struct EntryPoint { pub early_depth_test: Option, /// Workgroup size for compute stages pub workgroup_size: [u32; 3], + /// Override expressions for workgroup size + pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, } diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 974080e998..e314078c2b 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -1854,6 +1854,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_vert"), arguments: [ @@ -2156,6 +2157,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_frag"), arguments: [], @@ -2348,6 +2350,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_through_ptr"), arguments: [], @@ -2430,6 +2433,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_to_ptr_components"), arguments: [], diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 974080e998..e314078c2b 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -1854,6 +1854,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_vert"), arguments: [ @@ -2156,6 +2157,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("foo_frag"), arguments: [], @@ -2348,6 +2350,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_through_ptr"), arguments: [], @@ -2430,6 +2433,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("assign_to_ptr_components"), arguments: [], diff --git a/naga/tests/out/ir/atomic_i_increment.compact.ron b/naga/tests/out/ir/atomic_i_increment.compact.ron index 7d024f4e81..5bb6820258 100644 --- a/naga/tests/out/ir/atomic_i_increment.compact.ron +++ b/naga/tests/out/ir/atomic_i_increment.compact.ron @@ -263,6 +263,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("stage::test_atomic_i_increment_wrap"), arguments: [], diff --git a/naga/tests/out/ir/atomic_i_increment.ron b/naga/tests/out/ir/atomic_i_increment.ron index aab4c07206..ae14821330 100644 --- a/naga/tests/out/ir/atomic_i_increment.ron +++ b/naga/tests/out/ir/atomic_i_increment.ron @@ -288,6 +288,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("stage::test_atomic_i_increment_wrap"), arguments: [], diff --git a/naga/tests/out/ir/collatz.compact.ron b/naga/tests/out/ir/collatz.compact.ron index 48ce8e76bc..6a7aebe544 100644 --- a/naga/tests/out/ir/collatz.compact.ron +++ b/naga/tests/out/ir/collatz.compact.ron @@ -257,6 +257,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [ diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index 48ce8e76bc..6a7aebe544 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -257,6 +257,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [ diff --git a/naga/tests/out/ir/fetch_depth.compact.ron b/naga/tests/out/ir/fetch_depth.compact.ron index 0d998e205c..f10ccb94f7 100644 --- a/naga/tests/out/ir/fetch_depth.compact.ron +++ b/naga/tests/out/ir/fetch_depth.compact.ron @@ -176,6 +176,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("cull::fetch_depth_wrap"), arguments: [], diff --git a/naga/tests/out/ir/fetch_depth.ron b/naga/tests/out/ir/fetch_depth.ron index c66b7eb065..d25e046d57 100644 --- a/naga/tests/out/ir/fetch_depth.ron +++ b/naga/tests/out/ir/fetch_depth.ron @@ -246,6 +246,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (32, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("cull::fetch_depth_wrap"), arguments: [], diff --git a/naga/tests/out/ir/index-by-value.compact.ron b/naga/tests/out/ir/index-by-value.compact.ron index f0ea76f496..93a9821426 100644 --- a/naga/tests/out/ir/index-by-value.compact.ron +++ b/naga/tests/out/ir/index-by-value.compact.ron @@ -300,6 +300,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("index_let_array_1d"), arguments: [ diff --git a/naga/tests/out/ir/index-by-value.ron b/naga/tests/out/ir/index-by-value.ron index f0ea76f496..93a9821426 100644 --- a/naga/tests/out/ir/index-by-value.ron +++ b/naga/tests/out/ir/index-by-value.ron @@ -300,6 +300,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("index_let_array_1d"), arguments: [ diff --git a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron index e762de0385..56be2f8ab6 100644 --- a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron +++ b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron @@ -85,6 +85,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("f"), arguments: [], diff --git a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron index e762de0385..56be2f8ab6 100644 --- a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron +++ b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron @@ -85,6 +85,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("f"), arguments: [], diff --git a/naga/tests/out/ir/overrides-ray-query.compact.ron b/naga/tests/out/ir/overrides-ray-query.compact.ron index f7d05aa92f..10cad83538 100644 --- a/naga/tests/out/ir/overrides-ray-query.compact.ron +++ b/naga/tests/out/ir/overrides-ray-query.compact.ron @@ -111,6 +111,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/overrides-ray-query.ron b/naga/tests/out/ir/overrides-ray-query.ron index f7d05aa92f..10cad83538 100644 --- a/naga/tests/out/ir/overrides-ray-query.ron +++ b/naga/tests/out/ir/overrides-ray-query.ron @@ -111,6 +111,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index d2df01c0db..d99beb19c6 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -108,6 +108,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index d2df01c0db..d99beb19c6 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -108,6 +108,7 @@ stage: Compute, early_depth_test: None, workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, function: ( name: Some("main"), arguments: [], diff --git a/naga/tests/out/ir/shadow.compact.ron b/naga/tests/out/ir/shadow.compact.ron index 39a25fd10b..24b4674515 100644 --- a/naga/tests/out/ir/shadow.compact.ron +++ b/naga/tests/out/ir/shadow.compact.ron @@ -958,6 +958,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("fs_main_wrap"), arguments: [ diff --git a/naga/tests/out/ir/shadow.ron b/naga/tests/out/ir/shadow.ron index 196536d56b..386b9d36b0 100644 --- a/naga/tests/out/ir/shadow.ron +++ b/naga/tests/out/ir/shadow.ron @@ -1236,6 +1236,7 @@ stage: Fragment, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("fs_main_wrap"), arguments: [ diff --git a/naga/tests/out/ir/spec-constants.compact.ron b/naga/tests/out/ir/spec-constants.compact.ron index 9ea75cd468..cde3117225 100644 --- a/naga/tests/out/ir/spec-constants.compact.ron +++ b/naga/tests/out/ir/spec-constants.compact.ron @@ -495,6 +495,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("main_wrap"), arguments: [ diff --git a/naga/tests/out/ir/spec-constants.ron b/naga/tests/out/ir/spec-constants.ron index 5d48e94efc..fa4139a1da 100644 --- a/naga/tests/out/ir/spec-constants.ron +++ b/naga/tests/out/ir/spec-constants.ron @@ -601,6 +601,7 @@ stage: Vertex, early_depth_test: None, workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, function: ( name: Some("main_wrap"), arguments: [ diff --git a/tests/tests/shader/mod.rs b/tests/tests/shader/mod.rs index 7d6ed7aaaa..f05fbac25c 100644 --- a/tests/tests/shader/mod.rs +++ b/tests/tests/shader/mod.rs @@ -19,6 +19,7 @@ pub mod compilation_messages; pub mod data_builtins; pub mod numeric_builtins; pub mod struct_layout; +pub mod workgroup_size_overrides; pub mod zero_init_workgroup_mem; #[derive(Clone, Copy, PartialEq)] diff --git a/tests/tests/shader/workgroup_size_overrides.rs b/tests/tests/shader/workgroup_size_overrides.rs new file mode 100644 index 0000000000..458011079a --- /dev/null +++ b/tests/tests/shader/workgroup_size_overrides.rs @@ -0,0 +1,116 @@ +use std::mem::size_of_val; +use wgpu::util::DeviceExt; +use wgpu::{BufferDescriptor, BufferUsages, Maintain, MapMode}; +use wgpu_test::{fail_if, gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; + +const SHADER: &str = r#" + override n = 3; + + @group(0) @binding(0) + var output: array; + + @compute @workgroup_size(n - 2) + fn main(@builtin(local_invocation_index) lii: u32) { + output[lii] = lii + 2; + } +"#; + +#[gpu_test] +static WORKGROUP_SIZE_OVERRIDES: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters(TestParameters::default().limits(wgpu::Limits::default())) + .run_async(move |ctx| async move { + workgroup_size_overrides(&ctx, None, &[2, 0, 0], false).await; + workgroup_size_overrides(&ctx, Some(4), &[2, 3, 0], false).await; + workgroup_size_overrides(&ctx, Some(1), &[0, 0, 0], true).await; + }); + +async fn workgroup_size_overrides( + ctx: &TestingContext, + n: Option, + out: &[u32], + should_fail: bool, +) { + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(SHADER)), + }); + let pipeline_options = wgpu::PipelineCompilationOptions { + constants: &[("n".to_owned(), n.unwrap_or(0).into())].into(), + ..Default::default() + }; + let compute_pipeline = fail_if( + &ctx.device, + should_fail, + || { + ctx.device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: Some("main"), + compilation_options: if n.is_some() { + pipeline_options + } else { + wgpu::PipelineCompilationOptions::default() + }, + cache: None, + }) + }, + None, + ); + if should_fail { + return; + } + let init: &[u32] = &[0, 0, 0]; + let init_size: u64 = size_of_val(init).try_into().unwrap(); + let buffer = DeviceExt::create_buffer_init( + &ctx.device, + &wgpu::util::BufferInitDescriptor { + label: None, + contents: bytemuck::cast_slice(init), + usage: wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::COPY_SRC, + }, + ); + let mapping_buffer = ctx.device.create_buffer(&BufferDescriptor { + label: Some("mapping buffer"), + size: init_size, + usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&compute_pipeline); + let bind_group_layout = compute_pipeline.get_bind_group_layout(0); + let bind_group_entries = [wgpu::BindGroupEntry { + binding: 0, + resource: buffer.as_entire_binding(), + }]; + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &bind_group_entries, + }); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups(1, 1, 1); + } + encoder.copy_buffer_to_buffer(&buffer, 0, &mapping_buffer, 0, init_size); + ctx.queue.submit(Some(encoder.finish())); + + mapping_buffer.slice(..).map_async(MapMode::Read, |_| ()); + ctx.async_poll(Maintain::wait()).await.panic_on_timeout(); + + let mapped = mapping_buffer.slice(..).get_mapped_range(); + + let typed: &[u32] = bytemuck::cast_slice(&mapped); + assert_eq!(typed, out); +}