From 9afaf534071729f2192d37698dd9d7fd904f70c4 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 8 Oct 2024 13:07:02 -0700 Subject: [PATCH 1/2] [naga spv-out] Rename `make_local` to `LocalType::from_inner`. Change the free function `back::spv::make_local` into an associated function `LocalType::from_inner`. --- naga/src/back/spv/block.rs | 8 +-- naga/src/back/spv/mod.rs | 102 +++++++++++++++++++----------------- naga/src/back/spv/writer.rs | 14 ++--- 3 files changed, 66 insertions(+), 58 deletions(-) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 496ae4fad8..4f203ada84 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3,8 +3,8 @@ Implementations for `BlockContext` methods. */ use super::{ - helpers, index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, - Dimension, Error, Instruction, LocalType, LookupType, ResultMember, Writer, WriterFlags, + helpers, index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error, + Instruction, LocalType, LookupType, ResultMember, Writer, WriterFlags, }; use crate::{arena::Handle, proc::TypeResolution, Statement}; use spirv::Word; @@ -1809,7 +1809,9 @@ impl<'w> BlockContext<'w> { Some(ty) => ty, None => LookupType::Handle(ty_handle), }, - TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()), + TypeResolution::Value(ref inner) => { + LookupType::Local(LocalType::from_inner(inner).unwrap()) + } }; let result_type_id = self.get_type_id(result_lookup_ty); diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index e6397017c5..93e9a466c4 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -246,9 +246,9 @@ impl LocalImageType { /// never synthesizes new struct types, so `LocalType` has nothing for that. /// /// Each `LocalType` variant should be handled identically to its analogous -/// `TypeInner` variant. You can use the [`make_local`] function to help with -/// this, by converting everything possible to a `LocalType` before inspecting -/// it. +/// `TypeInner` variant. You can use the [`LocalType::from_inner`] function to +/// help with this, by converting everything possible to a `LocalType` before +/// inspecting it. /// /// ## `LocalType` equality and SPIR-V `OpType` uniqueness /// @@ -357,52 +357,56 @@ struct LookupFunctionType { return_type_id: Word, } -fn make_local(inner: &crate::TypeInner) -> Option { - Some(match *inner { - crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => LocalType::Value { - vector_size: None, - scalar, - pointer_space: None, - }, - crate::TypeInner::Vector { size, scalar } => LocalType::Value { - vector_size: Some(size), - scalar, - pointer_space: None, - }, - crate::TypeInner::Matrix { - columns, - rows, - scalar, - } => LocalType::Matrix { - columns, - rows, - width: scalar.width, - }, - crate::TypeInner::Pointer { base, space } => LocalType::Pointer { - base, - class: helpers::map_storage_class(space), - }, - crate::TypeInner::ValuePointer { - size, - scalar, - space, - } => LocalType::Value { - vector_size: size, - scalar, - pointer_space: Some(helpers::map_storage_class(space)), - }, - crate::TypeInner::Image { - dim, - arrayed, - class, - } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)), - crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler, - crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure, - crate::TypeInner::RayQuery => LocalType::RayQuery, - crate::TypeInner::Array { .. } - | crate::TypeInner::Struct { .. } - | crate::TypeInner::BindingArray { .. } => return None, - }) +impl LocalType { + fn from_inner(inner: &crate::TypeInner) -> Option { + Some(match *inner { + crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => { + LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + } + } + crate::TypeInner::Vector { size, scalar } => LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + }, + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => LocalType::Matrix { + columns, + rows, + width: scalar.width, + }, + crate::TypeInner::Pointer { base, space } => LocalType::Pointer { + base, + class: helpers::map_storage_class(space), + }, + crate::TypeInner::ValuePointer { + size, + scalar, + space, + } => LocalType::Value { + vector_size: size, + scalar, + pointer_space: Some(helpers::map_storage_class(space)), + }, + crate::TypeInner::Image { + dim, + arrayed, + class, + } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)), + crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler, + crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure, + crate::TypeInner::RayQuery => LocalType::RayQuery, + crate::TypeInner::Array { .. } + | crate::TypeInner::Struct { .. } + | crate::TypeInner::BindingArray { .. } => return None, + }) + } } #[derive(Debug)] diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index cd02c22195..27f2cbfdb6 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1,10 +1,10 @@ use super::{ block::DebugInfoInner, helpers::{contains_builtin, global_needs_wrapper, map_storage_class}, - make_local, Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, - EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, - LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, Options, - PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE, + Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, EntryPointContext, Error, + Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalType, LocalVariable, + LogicalLayout, LookupFunctionType, LookupType, Options, PhysicalLayout, PipelineOptions, + ResultMember, Writer, WriterFlags, BITS_PER_BYTE, }; use crate::{ arena::{Handle, HandleVec, UniqueArena}, @@ -254,7 +254,9 @@ impl Writer { pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType { match *tr { TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle), - TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()), + TypeResolution::Value(ref inner) => { + LookupType::Local(LocalType::from_inner(inner).unwrap()) + } } } @@ -1025,7 +1027,7 @@ impl Writer { // because some types which map to the same LocalType have different // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569 self.request_type_capabilities(&ty.inner)?; - let id = if let Some(local) = make_local(&ty.inner) { + let id = if let Some(local) = LocalType::from_inner(&ty.inner) { // This type can be represented as a `LocalType`, so check if we've // already written an instruction for it. If not, do so now, with // `write_type_declaration_local`. From 3c926cf88096d395e36267fba13f013d2aab8f64 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 8 Oct 2024 14:50:24 -0700 Subject: [PATCH 2/2] [naga spv-out] Expand LocalType to permit pointers to matrices. In `back::spv`: - Factor out the numeric variants of `LocalType` into a new enum, `NumericType`. - Split the `Value` variant into `Numeric` and `LocalPointer` variants, and let `LocalPointer` point to any numeric type, including matrices. In subsequent commits, we'll need to spill matrices out into temporary local variables. This means we'll need to generate SPIR-V pointer-to-matrix types, so `LocalType` needs to be able to represent that. --- naga/src/back/spv/block.rs | 181 ++++++++++++---------------------- naga/src/back/spv/image.rs | 97 +++++++++--------- naga/src/back/spv/mod.rs | 68 +++++++------ naga/src/back/spv/ray.rs | 73 ++++++-------- naga/src/back/spv/subgroup.rs | 12 +-- naga/src/back/spv/writer.rs | 169 ++++++++++++------------------- 6 files changed, 251 insertions(+), 349 deletions(-) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 4f203ada84..d7c56d3f4c 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -4,7 +4,7 @@ Implementations for `BlockContext` methods. use super::{ helpers, index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error, - Instruction, LocalType, LookupType, ResultMember, Writer, WriterFlags, + Instruction, LocalType, LookupType, NumericType, ResultMember, Writer, WriterFlags, }; use crate::{arena::Handle, proc::TypeResolution, Statement}; use spirv::Word; @@ -105,10 +105,9 @@ impl Writer { position_id: Word, body: &mut Vec, ) -> Result<(), Error> { - let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::F32, - pointer_space: Some(spirv::StorageClass::Output), + let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::LocalPointer { + base: NumericType::Scalar(crate::Scalar::F32), + class: spirv::StorageClass::Output, })); let index_y_id = self.get_index_constant(1); let access_id = self.id_gen.next(); @@ -119,11 +118,9 @@ impl Writer { &[index_y_id], )); - let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::F32, - pointer_space: None, - })); + let float_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::F32), + ))); let load_id = self.id_gen.next(); body.push(Instruction::load(float_type_id, load_id, access_id, None)); @@ -145,11 +142,9 @@ impl Writer { frag_depth_id: Word, body: &mut Vec, ) -> Result<(), Error> { - let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::F32, - pointer_space: None, - })); + let float_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::F32), + ))); let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0)); let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0)); @@ -830,12 +825,8 @@ impl<'w> BlockContext<'w> { let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?; if let Some(size) = maybe_size { - let ty = LocalType::Value { - vector_size: Some(size), - scalar, - pointer_space: None, - } - .into(); + let ty = + LocalType::Numeric(NumericType::Vector { size, scalar }).into(); self.temp_list.clear(); self.temp_list.resize(size as _, arg1_id); @@ -950,12 +941,9 @@ impl<'w> BlockContext<'w> { &crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar(scalar), ) => { - let selector_type_id = - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(size), - scalar, - pointer_space: None, - })); + let selector_type_id = self.get_type_id(LookupType::Local( + LocalType::Numeric(NumericType::Vector { size, scalar }), + )); self.temp_list.clear(); self.temp_list.resize(size as usize, arg2_id); @@ -998,12 +986,8 @@ impl<'w> BlockContext<'w> { Mf::CountTrailingZeros => { let uint_id = match *arg_ty { crate::TypeInner::Vector { size, scalar } => { - let ty = LocalType::Value { - vector_size: Some(size), - scalar, - pointer_space: None, - } - .into(); + let ty = + LocalType::Numeric(NumericType::Vector { size, scalar }).into(); self.temp_list.clear(); self.temp_list.resize( @@ -1040,12 +1024,8 @@ impl<'w> BlockContext<'w> { Mf::CountLeadingZeros => { let (int_type_id, int_id, width) = match *arg_ty { crate::TypeInner::Vector { size, scalar } => { - let ty = LocalType::Value { - vector_size: Some(size), - scalar, - pointer_space: None, - } - .into(); + let ty = + LocalType::Numeric(NumericType::Vector { size, scalar }).into(); self.temp_list.clear(); self.temp_list.resize( @@ -1061,11 +1041,9 @@ impl<'w> BlockContext<'w> { ) } crate::TypeInner::Scalar(scalar) => ( - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar, - pointer_space: None, - })), + self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(scalar), + ))), self.writer .get_constant_scalar_with(scalar.width * 8 - 1, scalar)?, scalar.width, @@ -1130,14 +1108,9 @@ impl<'w> BlockContext<'w> { .writer .get_constant_scalar(crate::Literal::U32(bit_width as u32)); - let u32_type = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar { - kind: crate::ScalarKind::Uint, - width: 4, - }, - pointer_space: None, - })); + let u32_type = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::U32), + ))); // o = min(offset, w) let offset_id = self.gen_id(); @@ -1186,14 +1159,9 @@ impl<'w> BlockContext<'w> { .writer .get_constant_scalar(crate::Literal::U32(bit_width as u32)); - let u32_type = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar { - kind: crate::ScalarKind::Uint, - width: 4, - }, - pointer_space: None, - })); + let u32_type = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::U32), + ))); // o = min(offset, w) let offset_id = self.gen_id(); @@ -1259,23 +1227,16 @@ impl<'w> BlockContext<'w> { Mf::Pack4xU8 => (crate::ScalarKind::Uint, false), _ => unreachable!(), }; - let uint_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar { - kind: crate::ScalarKind::Uint, - width: 4, - }, - pointer_space: None, - })); + let uint_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::U32), + ))); - let int_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar { + let int_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar { kind: int_type, width: 4, - }, - pointer_space: None, - })); + }), + ))); let mut last_instruction = Instruction::new(spirv::Op::Nop); @@ -1352,24 +1313,17 @@ impl<'w> BlockContext<'w> { _ => unreachable!(), }; - let sint_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar { - kind: crate::ScalarKind::Sint, - width: 4, - }, - pointer_space: None, - })); + let sint_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::I32), + ))); let eight = self.writer.get_constant_scalar(crate::Literal::U32(8)); - let int_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar { + let int_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar { kind: int_type, width: 4, - }, - pointer_space: None, - })); + }), + ))); block .body .reserve(usize::from(VEC_LENGTH) * 2 + usize::from(is_signed)); @@ -1533,11 +1487,10 @@ impl<'w> BlockContext<'w> { self.writer.get_constant_scalar_with(0, src_scalar)?; let zero_id = match src_size { Some(size) => { - let ty = LocalType::Value { - vector_size: Some(size), + let ty = LocalType::Numeric(NumericType::Vector { + size, scalar: src_scalar, - pointer_space: None, - } + }) .into(); self.temp_list.clear(); @@ -1562,11 +1515,10 @@ impl<'w> BlockContext<'w> { self.writer.get_constant_scalar_with(1, dst_scalar)?; let (accept_id, reject_id) = match src_size { Some(size) => { - let ty = LocalType::Value { - vector_size: Some(size), + let ty = LocalType::Numeric(NumericType::Vector { + size, scalar: dst_scalar, - pointer_space: None, - } + }) .into(); self.temp_list.clear(); @@ -1704,12 +1656,12 @@ impl<'w> BlockContext<'w> { self.temp_list.clear(); self.temp_list.resize(size as usize, condition_id); - let bool_vector_type_id = - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(size), + let bool_vector_type_id = self.get_type_id(LookupType::Local( + LocalType::Numeric(NumericType::Vector { + size, scalar: condition_scalar, - pointer_space: None, - })); + }), + )); let id = self.gen_id(); block.body.push(Instruction::composite_construct( @@ -2031,11 +1983,11 @@ impl<'w> BlockContext<'w> { ) { self.temp_list.clear(); - let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(rows), - scalar: crate::Scalar::float(width), - pointer_space: None, - })); + let vector_type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { + size: rows, + scalar: crate::Scalar::float(width), + }))); for index in 0..columns as u32 { let column_id_left = self.gen_id(); @@ -2737,20 +2689,15 @@ impl<'w> BlockContext<'w> { crate::AtomicFunction::Exchange { compare: Some(cmp) } => { let scalar_type_id = match *value_inner { crate::TypeInner::Scalar(scalar) => { - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar, - pointer_space: None, - })) + self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(scalar), + ))) } _ => unimplemented!(), }; - let bool_type_id = - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::BOOL, - pointer_space: None, - })); + let bool_type_id = self.get_type_id(LookupType::Local( + LocalType::Numeric(NumericType::Scalar(crate::Scalar::BOOL)), + )); let cas_result_id = self.gen_id(); let equality_result_id = self.gen_id(); diff --git a/naga/src/back/spv/image.rs b/naga/src/back/spv/image.rs index 769971d136..a76d015f3f 100644 --- a/naga/src/back/spv/image.rs +++ b/naga/src/back/spv/image.rs @@ -4,7 +4,7 @@ Generating SPIR-V for image operations. use super::{ selection::{MergeTuple, Selection}, - Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType, + Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType, NumericType, }; use crate::arena::Handle; use spirv::Word; @@ -126,11 +126,10 @@ impl Load { // the right SPIR-V type for the access instruction here. let type_id = match image_class { crate::ImageClass::Depth { .. } => { - ctx.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Quad), + ctx.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { + size: crate::VectorSize::Quad, scalar: crate::Scalar::F32, - pointer_space: None, - })) + }))) } _ => result_type_id, }; @@ -292,15 +291,15 @@ impl<'w> BlockContext<'w> { // Find the component type of `coordinates`, and figure out the size the // combined coordinate vector will have. let (component_scalar, size) = match *inner_ty { - Ti::Scalar(scalar @ crate::Scalar { width: 4, .. }) => (scalar, Some(Vs::Bi)), + Ti::Scalar(scalar @ crate::Scalar { width: 4, .. }) => (scalar, Vs::Bi), Ti::Vector { scalar: scalar @ crate::Scalar { width: 4, .. }, size: Vs::Bi, - } => (scalar, Some(Vs::Tri)), + } => (scalar, Vs::Tri), Ti::Vector { scalar: scalar @ crate::Scalar { width: 4, .. }, size: Vs::Tri, - } => (scalar, Some(Vs::Quad)), + } => (scalar, Vs::Quad), Ti::Vector { size: Vs::Quad, .. } => { return Err(Error::Validation("extending vec4 coordinate")); } @@ -340,11 +339,9 @@ impl<'w> BlockContext<'w> { } }; let reconciled_array_index_id = if let Some(cast) = cast { - let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: component_scalar, - pointer_space: None, - })); + let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(component_scalar), + ))); let reconciled_id = self.gen_id(); block.body.push(Instruction::unary( cast, @@ -358,11 +355,11 @@ impl<'w> BlockContext<'w> { }; // Find the SPIR-V type for the combined coordinates/index vector. - let type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: size, - scalar: component_scalar, - pointer_space: None, - })); + let type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { + size, + scalar: component_scalar, + }))); // Schmear the coordinates and index together. let value_id = self.gen_id(); @@ -374,7 +371,7 @@ impl<'w> BlockContext<'w> { Ok(ImageCoordinates { value_id, type_id, - size, + size: Some(size), }) } @@ -529,11 +526,9 @@ impl<'w> BlockContext<'w> { &[spirv::Capability::ImageQuery], )?; - let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::I32, - pointer_space: None, - })); + let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::I32), + ))); // If `level` is `Some`, clamp it to fall within bounds. This must // happen first, because we'll use it to query the image size for @@ -616,11 +611,9 @@ impl<'w> BlockContext<'w> { )?; let bool_type_id = self.writer.get_bool_type_id(); - let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::I32, - pointer_space: None, - })); + let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::I32), + ))); let null_id = access.out_of_bounds_value(self); @@ -683,11 +676,15 @@ impl<'w> BlockContext<'w> { ); // Compare the coordinates against the bounds. - let coords_bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: coordinates.size, - scalar: crate::Scalar::BOOL, - pointer_space: None, - })); + let coords_numeric_type = match coordinates.size { + Some(size) => NumericType::Vector { + size, + scalar: crate::Scalar::BOOL, + }, + None => NumericType::Scalar(crate::Scalar::BOOL), + }; + let coords_bool_type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(coords_numeric_type))); let coords_conds_id = self.gen_id(); selection.block().body.push(Instruction::binary( spirv::Op::ULessThan, @@ -838,11 +835,10 @@ impl<'w> BlockContext<'w> { _ => false, }; let sample_result_type_id = if needs_sub_access { - self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Quad), + self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { + size: crate::VectorSize::Quad, scalar: crate::Scalar::F32, - pointer_space: None, - })) + }))) } else { result_type_id }; @@ -1038,11 +1034,16 @@ impl<'w> BlockContext<'w> { 4 => Some(crate::VectorSize::Quad), _ => None, }; - let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size, - scalar: crate::Scalar::U32, - pointer_space: None, - })); + let vector_numeric_type = match vector_size { + Some(size) => NumericType::Vector { + size, + scalar: crate::Scalar::U32, + }, + None => NumericType::Scalar(crate::Scalar::U32), + }; + + let extended_size_type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(vector_numeric_type))); let (query_op, level_id) = match class { Ic::Sampled { multi: true, .. } @@ -1108,11 +1109,11 @@ impl<'w> BlockContext<'w> { Id::D2 | Id::Cube => crate::VectorSize::Tri, Id::D3 => crate::VectorSize::Quad, }; - let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(vec_size), - scalar: crate::Scalar::U32, - pointer_space: None, - })); + let extended_size_type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { + size: vec_size, + scalar: crate::Scalar::U32, + }))); let id_extended = self.gen_id(); let mut inst = Instruction::image_query( spirv::Op::ImageQuerySizeLod, diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 93e9a466c4..aa4de68462 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -231,6 +231,21 @@ impl LocalImageType { } } +/// A numeric type, for use in [`LocalType`]. +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +enum NumericType { + Scalar(crate::Scalar), + Vector { + size: crate::VectorSize, + scalar: crate::Scalar, + }, + Matrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + scalar: crate::Scalar, + }, +} + /// A SPIR-V type constructed during code generation. /// /// This is the variant of [`LookupType`] used to represent types that might not @@ -276,19 +291,11 @@ impl LocalImageType { /// [`TypeInner`]: crate::TypeInner #[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] enum LocalType { - /// A scalar, vector, or pointer to one of those. - Value { - /// If `None`, this represents a scalar type. If `Some`, this represents - /// a vector type of the given size. - vector_size: Option, - scalar: crate::Scalar, - pointer_space: Option, - }, - /// A matrix of floating-point values. - Matrix { - columns: crate::VectorSize, - rows: crate::VectorSize, - width: crate::Bytes, + /// A numeric type. + Numeric(NumericType), + LocalPointer { + base: NumericType, + class: spirv::StorageClass, }, Pointer { base: Handle, @@ -361,38 +368,39 @@ impl LocalType { fn from_inner(inner: &crate::TypeInner) -> Option { Some(match *inner { crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => { - LocalType::Value { - vector_size: None, - scalar, - pointer_space: None, - } + LocalType::Numeric(NumericType::Scalar(scalar)) + } + crate::TypeInner::Vector { size, scalar } => { + LocalType::Numeric(NumericType::Vector { size, scalar }) } - crate::TypeInner::Vector { size, scalar } => LocalType::Value { - vector_size: Some(size), - scalar, - pointer_space: None, - }, crate::TypeInner::Matrix { columns, rows, scalar, - } => LocalType::Matrix { + } => LocalType::Numeric(NumericType::Matrix { columns, rows, - width: scalar.width, - }, + scalar, + }), crate::TypeInner::Pointer { base, space } => LocalType::Pointer { base, class: helpers::map_storage_class(space), }, crate::TypeInner::ValuePointer { - size, + size: Some(size), scalar, space, - } => LocalType::Value { - vector_size: size, + } => LocalType::LocalPointer { + base: NumericType::Vector { size, scalar }, + class: helpers::map_storage_class(space), + }, + crate::TypeInner::ValuePointer { + size: None, scalar, - pointer_space: Some(helpers::map_storage_class(space)), + space, + } => LocalType::LocalPointer { + base: NumericType::Scalar(scalar), + class: helpers::map_storage_class(space), }, crate::TypeInner::Image { dim, diff --git a/naga/src/back/spv/ray.rs b/naga/src/back/spv/ray.rs index bc2c4ce3c6..c2daf4b3f6 100644 --- a/naga/src/back/spv/ray.rs +++ b/naga/src/back/spv/ray.rs @@ -2,7 +2,7 @@ Generating SPIR-V for ray query operations. */ -use super::{Block, BlockContext, Instruction, LocalType, LookupType}; +use super::{Block, BlockContext, Instruction, LocalType, LookupType, NumericType}; use crate::arena::Handle; impl<'w> BlockContext<'w> { @@ -22,11 +22,9 @@ impl<'w> BlockContext<'w> { let desc_id = self.cached[descriptor]; let acc_struct_id = self.get_handle_id(acceleration_structure); - let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::U32, - pointer_space: None, - })); + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::U32), + ))); let ray_flags_id = self.gen_id(); block.body.push(Instruction::composite_extract( flag_type_id, @@ -42,11 +40,9 @@ impl<'w> BlockContext<'w> { &[1], )); - let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::F32, - pointer_space: None, - })); + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::F32), + ))); let tmin_id = self.gen_id(); block.body.push(Instruction::composite_extract( scalar_type_id, @@ -62,11 +58,11 @@ impl<'w> BlockContext<'w> { &[3], )); - let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Tri), - scalar: crate::Scalar::F32, - pointer_space: None, - })); + let vector_type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { + size: crate::VectorSize::Tri, + scalar: crate::Scalar::F32, + }))); let ray_origin_id = self.gen_id(); block.body.push(Instruction::composite_extract( vector_type_id, @@ -116,11 +112,9 @@ impl<'w> BlockContext<'w> { spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, )); - let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::U32, - pointer_space: None, - })); + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::U32), + ))); let kind_id = self.gen_id(); block.body.push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionTypeKHR, @@ -170,11 +164,9 @@ impl<'w> BlockContext<'w> { intersection_id, )); - let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::F32, - pointer_space: None, - })); + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::F32), + ))); let t_id = self.gen_id(); block.body.push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionTKHR, @@ -184,11 +176,11 @@ impl<'w> BlockContext<'w> { intersection_id, )); - let barycentrics_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Bi), - scalar: crate::Scalar::F32, - pointer_space: None, - })); + let barycentrics_type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { + size: crate::VectorSize::Bi, + scalar: crate::Scalar::F32, + }))); let barycentrics_id = self.gen_id(); block.body.push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionBarycentricsKHR, @@ -198,11 +190,9 @@ impl<'w> BlockContext<'w> { intersection_id, )); - let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::BOOL, - pointer_space: None, - })); + let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(crate::Scalar::BOOL), + ))); let front_face_id = self.gen_id(); block.body.push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionFrontFaceKHR, @@ -212,11 +202,12 @@ impl<'w> BlockContext<'w> { intersection_id, )); - let transform_type_id = self.get_type_id(LookupType::Local(LocalType::Matrix { - columns: crate::VectorSize::Quad, - rows: crate::VectorSize::Tri, - width: 4, - })); + let transform_type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + scalar: crate::Scalar::F32, + }))); let object_to_world_id = self.gen_id(); block.body.push(Instruction::ray_query_get_intersection( spirv::Op::RayQueryGetIntersectionObjectToWorldKHR, diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs index c952cb11a7..eb273f5f19 100644 --- a/naga/src/back/spv/subgroup.rs +++ b/naga/src/back/spv/subgroup.rs @@ -1,4 +1,4 @@ -use super::{Block, BlockContext, Error, Instruction}; +use super::{Block, BlockContext, Error, Instruction, NumericType}; use crate::{ arena::Handle, back::spv::{LocalType, LookupType}, @@ -16,11 +16,11 @@ impl<'w> BlockContext<'w> { "GroupNonUniformBallot", &[spirv::Capability::GroupNonUniformBallot], )?; - let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Quad), - scalar: crate::Scalar::U32, - pointer_space: None, - })); + let vec4_u32_type_id = + self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar::U32, + }))); let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); let predicate = if let Some(predicate) = *predicate { self.cached[predicate] diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 27f2cbfdb6..cfff00be40 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -3,8 +3,8 @@ use super::{ helpers::{contains_builtin, global_needs_wrapper, map_storage_class}, Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalType, LocalVariable, - LogicalLayout, LookupFunctionType, LookupType, Options, PhysicalLayout, PipelineOptions, - ResultMember, Writer, WriterFlags, BITS_PER_BYTE, + LogicalLayout, LookupFunctionType, LookupType, NumericType, Options, PhysicalLayout, + PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE, }; use crate::{ arena::{Handle, HandleVec, UniqueArena}, @@ -291,83 +291,52 @@ impl Writer { } pub(super) fn get_uint_type_id(&mut self) -> Word { - let local_type = LocalType::Value { - vector_size: None, - scalar: crate::Scalar::U32, - pointer_space: None, - }; + let local_type = LocalType::Numeric(NumericType::Scalar(crate::Scalar::U32)); self.get_type_id(local_type.into()) } pub(super) fn get_float_type_id(&mut self) -> Word { - let local_type = LocalType::Value { - vector_size: None, - scalar: crate::Scalar::F32, - pointer_space: None, - }; + let local_type = LocalType::Numeric(NumericType::Scalar(crate::Scalar::F32)); self.get_type_id(local_type.into()) } pub(super) fn get_uint3_type_id(&mut self) -> Word { - let local_type = LocalType::Value { - vector_size: Some(crate::VectorSize::Tri), + let local_type = LocalType::Numeric(NumericType::Vector { + size: crate::VectorSize::Tri, scalar: crate::Scalar::U32, - pointer_space: None, - }; + }); self.get_type_id(local_type.into()) } pub(super) fn get_float_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { - let lookup_type = LookupType::Local(LocalType::Value { - vector_size: None, - scalar: crate::Scalar::F32, - pointer_space: Some(class), - }); - if let Some(&id) = self.lookup_type.get(&lookup_type) { - id - } else { - let id = self.id_gen.next(); - let ty_id = self.get_float_type_id(); - let instruction = Instruction::type_pointer(id, class, ty_id); - instruction.to_words(&mut self.logical_layout.declarations); - self.lookup_type.insert(lookup_type, id); - id - } + let local_type = LocalType::LocalPointer { + base: NumericType::Scalar(crate::Scalar::F32), + class, + }; + self.get_type_id(local_type.into()) } pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { - let lookup_type = LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Tri), - scalar: crate::Scalar::U32, - pointer_space: Some(class), - }); - if let Some(&id) = self.lookup_type.get(&lookup_type) { - id - } else { - let id = self.id_gen.next(); - let ty_id = self.get_uint3_type_id(); - let instruction = Instruction::type_pointer(id, class, ty_id); - instruction.to_words(&mut self.logical_layout.declarations); - self.lookup_type.insert(lookup_type, id); - id - } + let local_type = LocalType::LocalPointer { + base: NumericType::Vector { + size: crate::VectorSize::Tri, + scalar: crate::Scalar::U32, + }, + class, + }; + self.get_type_id(local_type.into()) } pub(super) fn get_bool_type_id(&mut self) -> Word { - let local_type = LocalType::Value { - vector_size: None, - scalar: crate::Scalar::BOOL, - pointer_space: None, - }; + let local_type = LocalType::Numeric(NumericType::Scalar(crate::Scalar::BOOL)); self.get_type_id(local_type.into()) } pub(super) fn get_bool3_type_id(&mut self) -> Word { - let local_type = LocalType::Value { - vector_size: Some(crate::VectorSize::Tri), + let local_type = LocalType::Numeric(NumericType::Vector { + size: crate::VectorSize::Tri, scalar: crate::Scalar::BOOL, - pointer_space: None, - }; + }); self.get_type_id(local_type.into()) } @@ -935,62 +904,50 @@ impl Writer { Ok(()) } + fn write_numeric_type_declaration_local(&mut self, id: Word, numeric: NumericType) { + let instruction = + match numeric { + NumericType::Scalar(scalar) => self.make_scalar(id, scalar), + NumericType::Vector { size, scalar } => { + let scalar_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Scalar(scalar), + ))); + Instruction::type_vector(id, scalar_id, size) + } + NumericType::Matrix { + columns, + rows, + scalar, + } => { + let column_id = self.get_type_id(LookupType::Local(LocalType::Numeric( + NumericType::Vector { size: rows, scalar }, + ))); + Instruction::type_matrix(id, column_id, columns) + } + }; + + instruction.to_words(&mut self.logical_layout.declarations); + } + fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) { let instruction = match local_ty { - LocalType::Value { - vector_size: None, - scalar, - pointer_space: None, - } => self.make_scalar(id, scalar), - LocalType::Value { - vector_size: Some(size), - scalar, - pointer_space: None, - } => { - let scalar_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar, - pointer_space: None, - })); - Instruction::type_vector(id, scalar_id, size) + LocalType::Numeric(numeric) => { + self.write_numeric_type_declaration_local(id, numeric); + return; } - LocalType::Matrix { - columns, - rows, - width, - } => { - let vector_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(rows), - scalar: crate::Scalar::float(width), - pointer_space: None, - })); - Instruction::type_matrix(id, vector_id, columns) + LocalType::LocalPointer { base, class } => { + let base_id = self.get_type_id(LookupType::Local(LocalType::Numeric(base))); + Instruction::type_pointer(id, class, base_id) } LocalType::Pointer { base, class } => { let type_id = self.get_type_id(LookupType::Handle(base)); Instruction::type_pointer(id, class, type_id) } - LocalType::Value { - vector_size, - scalar, - pointer_space: Some(class), - } => { - let type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size, - scalar, - pointer_space: None, - })); - Instruction::type_pointer(id, class, type_id) - } LocalType::Image(image) => { - let local_type = LocalType::Value { - vector_size: None, - scalar: crate::Scalar { - kind: image.sampled_type, - width: 4, - }, - pointer_space: None, - }; + let local_type = LocalType::Numeric(NumericType::Scalar(crate::Scalar { + kind: image.sampled_type, + width: 4, + })); let type_id = self.get_type_id(LookupType::Local(local_type)); Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format) } @@ -1224,11 +1181,9 @@ impl Writer { self.debugs.push(Instruction::name(id, name)); } } - let type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: None, - scalar: value.scalar(), - pointer_space: None, - })); + let type_id = self.get_type_id(LookupType::Local(LocalType::Numeric(NumericType::Scalar( + value.scalar(), + )))); let instruction = match *value { crate::Literal::F64(value) => { let bits = value.to_bits();