Skip to content

Commit

Permalink
[naga spv-out] Spill arrays and matrices for runtime indexing.
Browse files Browse the repository at this point in the history
Improve handling of `Access` expressions whose base is an array or
matrix (not a pointer to such), and whose index is not known at
compile time. SPIR-V does not have instructions that can do this
directly, so spill such values to temporary variables, and perform the
accesses using `OpAccessChain` instructions applied to the
temporaries.

Permit dynamic indexing of matrices in validation.

Handle matrices and arrays consistently; remove special cases for
arrays.

When performing chains of accesses like `a[i].x[j]`, do not reify
intermediate values; generate a single `OpAccessIndex` for the entire
thing.

For details, see the comments on the new tracking structures in
`naga::back::spv::Function`.

Add snapshot test `index-by-value.wgsl`.

Fixes gfx-rs#6358.
Alternative to gfx-rs#6362.
  • Loading branch information
jimblandy committed Oct 10, 2024
1 parent b005c17 commit c279c58
Show file tree
Hide file tree
Showing 13 changed files with 896 additions and 128 deletions.
6 changes: 6 additions & 0 deletions naga/src/arena/handle_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ impl<T> HandleSet<T> {
}
}

impl<T> Default for HandleSet<T> {
fn default() -> Self {
Self::new()
}
}

pub trait ArenaType<T> {
fn len(&self) -> usize;
}
Expand Down
139 changes: 115 additions & 24 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,34 +297,48 @@ impl<'w> BlockContext<'w> {
// that actually dereferences the pointer.
0
}
_ if self.function.spilled_accesses.contains(base) => {
// As far as Naga IR is concerned, this expression does not yield
// a pointer, but we spilled it to a temporary variable.

// The base expression is something we spilled to a temporary
// variable, so mark this access as spilled as well.
self.function.spilled_accesses.insert(expr_handle);
self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
}
crate::TypeInner::Vector { .. } => {
self.write_vector_access(expr_handle, base, index, block)?
}
crate::TypeInner::Array {
base: ty_element, ..
} => {
let index_id = self.cached[index];
let base_id = self.cached[base];
let base_ty = match self.fun_info[base].ty {
TypeResolution::Handle(handle) => handle,
TypeResolution::Value(_) => {
return Err(Error::Validation(
"Array types should always be in the arena",
))
crate::TypeInner::Array { .. } | crate::TypeInner::Matrix { .. } => {
// See if `index` is known at compile time.
match GuardedIndex::from_expression(index, self.ir_function, self.ir_module)
{
GuardedIndex::Known(value) => {
// If `index` is known, we can just use `OpCompositeExtract`.
//
// We never need bounds checks for these cases: everything
// size is statically known and checked in validation.
let id = self.gen_id();
let base_id = self.cached[base];
block.body.push(Instruction::composite_extract(
result_type_id,
id,
base_id,
&[value],
));
id
}
};
let (id, variable) = self.writer.promote_access_expression_to_variable(
result_type_id,
base_id,
base_ty,
index_id,
ty_element,
block,
)?;
self.function.internal_variables.push(variable);
id
GuardedIndex::Expression(_) => {
self.spill_to_internal_variable(base, block);
self.function.spilled_accesses.insert(expr_handle);
self.maybe_access_spilled_composite(
expr_handle,
block,
result_type_id,
)?
}
}
}
// wgpu#4337: Support `crate::TypeInner::Matrix`
crate::TypeInner::BindingArray {
base: binding_type, ..
} => {
Expand Down Expand Up @@ -396,6 +410,15 @@ impl<'w> BlockContext<'w> {
// that actually dereferences the pointer.
0
}
_ if self.function.spilled_accesses.contains(base) => {
// As far as Naga IR is concerned, this expression does not yield
// a pointer, but we spilled it to a temporary variable.

// The base expression is something we spilled to a temporary
// variable, so mark this access as spilled as well.
self.function.spilled_accesses.insert(expr_handle);
self.maybe_access_spilled_composite(expr_handle, block, result_type_id)?
}
crate::TypeInner::Vector { .. }
| crate::TypeInner::Matrix { .. }
| crate::TypeInner::Array { .. }
Expand Down Expand Up @@ -1760,7 +1783,15 @@ impl<'w> BlockContext<'w> {
crate::Expression::FunctionArgument(index) => {
break self.function.parameter_id(index);
}
ref other => unimplemented!("Unexpected pointer expression {:?}", other),
ref other => {
let Some(spilled) = self.function.spilled_composites.get(&expr_handle) else {
unimplemented!("Unexpected pointer expression {:?}", other);
};

// The root id of the `OpAccessChain` instruction is the temporary
// variable we spilled the composite to.
break spilled.id;
}
}
};

Expand Down Expand Up @@ -1961,6 +1992,66 @@ impl<'w> BlockContext<'w> {
}
}

fn spill_to_internal_variable(&mut self, base: Handle<crate::Expression>, block: &mut Block) {
// Generate an internal variable of the appropriate type for `base`.
let variable_id = self.writer.id_gen.next();
let pointer_type_id = self
.writer
.get_resolution_pointer_id(&self.fun_info[base].ty, spirv::StorageClass::Function);
let variable = super::LocalVariable {
id: variable_id,
instruction: Instruction::variable(
pointer_type_id,
variable_id,
spirv::StorageClass::Function,
None,
),
};

let base_id = self.cached[base];
block
.body
.push(Instruction::store(variable.id, base_id, None));
self.function.spilled_composites.insert(base, variable);
}

/// Generate an access to a spilled temporary, if necessary.
///
/// Given `access`, an [`Access`] or [`AccessIndex`] expression that refers
/// to a component of a composite value that has been spilled to a temporary
/// variable, determine whether other expressions are going to use
/// `access`'s value:
///
/// - If so, perform the access and cache that as the value of `access`.
///
/// - Otherwise, generate no code and cache no value for `access`.
///
/// Return `Ok(0)` if no value was fetched, or `Ok(id)` if we loaded it into
/// the instruction given by `id`.
///
/// [`Access`]: crate::Expression::Access
/// [`AccessIndex`]: crate::Expression::AccessIndex
fn maybe_access_spilled_composite(
&mut self,
access: Handle<crate::Expression>,
block: &mut Block,
result_type_id: Word,
) -> Result<Word, Error> {
let access_uses = self.function.access_uses.get(&access).map_or(0, |r| *r);
if access_uses == self.fun_info[access].ref_count {
// This expression is only used by other `Access` and
// `AccessIndex` expressions, so we don't need to cache a
// value for it yet.
Ok(0)
} else {
// There are other expressions that are going to expect this
// expression's value to be cached, not just other `Access` or
// `AccessIndex` expressions. We must actually perform the
// access on the spill variable now.
self.write_checked_load(access, block, result_type_id)
}
}

/// Build the instructions for matrix - matrix column operations
#[allow(clippy::too_many_arguments)]
fn write_matrix_matrix_column_op(
Expand Down
75 changes: 60 additions & 15 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,38 @@ struct Function {
signature: Option<Instruction>,
parameters: Vec<FunctionArgument>,
variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
internal_variables: Vec<LocalVariable>,

/// A map from expressions that yield composite values (arrays, matrices) to
/// the temporary variables we have spilled them to. Spilling allows us to
/// render an arbitrary chain of [`Access`] and [`AccessIndex`] expressions
/// as a single `OpAccessChain` and `OpLoad` instruction (plus bounds checks).
///
/// [`Access`]: crate::Expression::Access
/// [`AccessIndex`]: crate::Expression::AccessIndex
spilled_composites: crate::FastIndexMap<Handle<crate::Expression>, LocalVariable>,

/// A set of expressions that are either in [`spilled_composites`] or refer
/// to some component/element of such.
///
/// [`spilled_composites`]: Function::spilled_composites
spilled_accesses: crate::arena::HandleSet<crate::Expression>,

/// A map from expressions to the number of [`Access`] and [`AccessIndex`]
/// expressions that use them as a base value. If an expression has no
/// entry, it is never used as a [`Access`] or [`AccessIndex`] base.
///
/// We use this, together with [`ExpressionInfo::ref_count`], to recognize
/// the tips of chains of [`Access`] and [`AccessIndex`] expressions based
/// on spilled values --- expressions in [`spilled_composites`]. We defer
/// generating code for the chain until we reach its tip, so we can handle
/// it with a single instruction.
///
/// [`Access`]: crate::Expression::Access
/// [`AccessIndex`]: crate::Expression::AccessIndex
/// [`ExpressionInfo::ref_count`]: crate::valid::ExpressionInfo
/// [`spilled_composites`]: Function::spilled_composites
access_uses: crate::FastHashMap<Handle<crate::Expression>, usize>,

blocks: Vec<TerminatedBlock>,
entry_point_context: Option<EntryPointContext>,
}
Expand Down Expand Up @@ -246,6 +277,27 @@ enum NumericType {
},
}

impl NumericType {
const fn from_inner(inner: &crate::TypeInner) -> Option<Self> {
match *inner {
crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => {
Some(NumericType::Scalar(scalar))
}
crate::TypeInner::Vector { size, scalar } => Some(NumericType::Vector { size, scalar }),
crate::TypeInner::Matrix {
columns,
rows,
scalar,
} => Some(NumericType::Matrix {
columns,
rows,
scalar,
}),
_ => None,
}
}
}

/// A SPIR-V type constructed during code generation.
///
/// This is the variant of [`LookupType`] used to represent types that might not
Expand Down Expand Up @@ -367,21 +419,14 @@ struct LookupFunctionType {
impl LocalType {
fn from_inner(inner: &crate::TypeInner) -> Option<Self> {
Some(match *inner {
crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => {
LocalType::Numeric(NumericType::Scalar(scalar))
}
crate::TypeInner::Vector { size, scalar } => {
LocalType::Numeric(NumericType::Vector { size, scalar })
crate::TypeInner::Scalar(_)
| crate::TypeInner::Atomic(_)
| crate::TypeInner::Vector { .. }
| crate::TypeInner::Matrix { .. } => {
// We expect `NumericType::from_inner` to handle all
// these cases, so unwrap.
LocalType::Numeric(NumericType::from_inner(inner).unwrap())
}
crate::TypeInner::Matrix {
columns,
rows,
scalar,
} => LocalType::Numeric(NumericType::Matrix {
columns,
rows,
scalar,
}),
crate::TypeInner::Pointer { base, space } => LocalType::Pointer {
base,
class: helpers::map_storage_class(space),
Expand Down
Loading

0 comments on commit c279c58

Please sign in to comment.