Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add arrays of buffer descriptors #1148

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -933,8 +933,23 @@ fn trans_intrinsic_type<'tcx>(
// We use a generic to indicate the underlying element type.
// The spirv type of it will be generated by querying the type of the first generic.
if let Some(elem_ty) = args.types().next() {
let element = cx.layout_of(elem_ty).spirv_type(span, cx);
Ok(SpirvType::RuntimeArray { element }.def(span, cx))
let layout = cx.layout_of(elem_ty);
let element = layout.spirv_type(span, cx);
let element_ty = cx.lookup_type(element);

if element_ty.is_uniform_constant() {
// array of image, sampler, SampledImage etc. descriptors
Ok(SpirvType::RuntimeArray { element }.def(span, cx))
} else {
// array of buffer descriptors
Ok(SpirvType::RuntimeArray {
element: SpirvType::InterfaceBlock {
inner_type: element,
}
.def(span, cx),
}
.def(span, cx))
}
} else {
Err(cx
.tcx
Expand Down
20 changes: 14 additions & 6 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ pub enum SpirvAttribute {
// `fn`/closure attributes:
BufferLoadIntrinsic,
BufferStoreIntrinsic,
RuntimeArrayIndexIntrinsic,
}

// HACK(eddyb) this is similar to `rustc_span::Spanned` but with `value` as the
Expand Down Expand Up @@ -133,6 +134,7 @@ pub struct AggregatedSpirvAttributes {
// `fn`/closure attributes:
pub buffer_load_intrinsic: Option<Spanned<()>>,
pub buffer_store_intrinsic: Option<Spanned<()>>,
pub runtime_array_index_intrinsic: Option<Spanned<()>>,
}

struct MultipleAttrs {
Expand Down Expand Up @@ -237,6 +239,12 @@ impl AggregatedSpirvAttributes {
span,
"#[spirv(buffer_store_intrinsic)]",
),
RuntimeArrayIndexIntrinsic => try_insert(
&mut self.runtime_array_index_intrinsic,
(),
span,
"#[spirv(runtime_array_index_intrinsic)]",
),
}
}
}
Expand Down Expand Up @@ -358,12 +366,12 @@ impl CheckSpirvAttrVisitor<'_> {

_ => Err(Expected("function parameter")),
},
SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
match target {
Target::Fn => Ok(()),
_ => Err(Expected("function")),
}
}
SpirvAttribute::BufferLoadIntrinsic
| SpirvAttribute::BufferStoreIntrinsic
| SpirvAttribute::RuntimeArrayIndexIntrinsic => match target {
Target::Fn => Ok(()),
_ => Err(Expected("function")),
},
};
match valid_target {
Err(Expected(expected_target)) => {
Expand Down
7 changes: 7 additions & 0 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2531,6 +2531,11 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.borrow()
.get(&callee_val)
.copied();
let runtime_array_index_intrinsic = self
.runtime_array_index_intrinsic_fn_id
.borrow()
.get(&callee_val)
.copied();
if let Some(libm_intrinsic) = libm_intrinsic {
let result = self.call_libm_intrinsic(libm_intrinsic, result_type, args);
if result_type != result.ty {
Expand Down Expand Up @@ -3024,6 +3029,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
kind: SpirvValueKind::IllegalTypeUsed(void_ty),
ty: void_ty,
}
} else if let Some(mode) = runtime_array_index_intrinsic {
self.codegen_runtime_array_index_intrinsic(result_type, args, mode)
} else {
let args = args.iter().map(|arg| arg.def(self)).collect::<Vec<_>>();
self.emit()
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod byte_addressable_buffer;
mod ext_inst;
mod intrinsics;
pub mod libm_intrinsics;
mod runtime_array;
mod spirv_asm;

pub use ext_inst::ExtInst;
Expand Down
130 changes: 130 additions & 0 deletions crates/rustc_codegen_spirv/src/builder/runtime_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use crate::builder::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt, SpirvValueKind};
use crate::spirv_type::SpirvType;
use rspirv::spirv::Word;
use rustc_codegen_ssa::traits::{BaseTypeMethods, BuilderMethods};
use rustc_target::abi::call::PassMode;

impl<'a, 'tcx> Builder<'a, 'tcx> {
/// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller.
pub fn codegen_runtime_array_index_intrinsic(
&mut self,
result_type: Word,
args: &[SpirvValue],
pass_mode: &PassMode,
) -> SpirvValue {
match pass_mode {
PassMode::Ignore => {
return SpirvValue {
kind: SpirvValueKind::IllegalTypeUsed(result_type),
ty: result_type,
};
}
// PassMode::Pair is identical to PassMode::Direct - it's returned as a struct
PassMode::Direct(_) | PassMode::Pair(_, _) => (),
PassMode::Cast { .. } => {
self.fatal("PassMode::Cast not supported in codegen_runtime_array_index_intrinsic")
}
PassMode::Indirect { .. } => self
.fatal("PassMode::Indirect not supported in codegen_runtime_array_index_intrinsic"),
}

// Signatures:
// fn <T: ?Sized>(runtime_array: &RuntimeArray<T>, index: usize) -> &T
// fn <T: ?Sized>(runtime_array: &mut RuntimeArray<T>, index: usize) -> &mut T
if args.len() != 2 {
self.fatal(format!(
"runtime_array_index_intrinsic should have 3 args, it has {}",
args.len()
));
}
let runtime_array = args[0];
let index = args[1];

let runtime_array_type = self.lookup_type(runtime_array.ty);
let element_ty = match runtime_array_type {
SpirvType::Pointer { pointee } => {
match self.lookup_type(pointee) {
SpirvType::RuntimeArray { element } => {
element
}
_ => self.fatal(format!(
"runtime_array_index_intrinsic args[0] is {:?} and not a Pointer to a RuntimeArray!",
runtime_array_type
)),
}
}
_ => self.fatal(format!(
"runtime_array_index_intrinsic args[0] is {:?} and not a Pointer!",
runtime_array_type
)),
};

let ptr_element = self.type_ptr_to(element_ty);
let element = self
.emit()
.access_chain(
ptr_element,
None,
runtime_array.def(self),
[index.def(self)],
)
.unwrap()
.with_type(ptr_element);

match self.lookup_type(element_ty) {
SpirvType::InterfaceBlock { .. } => {
// array of buffer descriptors
let inner = self.struct_gep(element_ty, element, 0);
match pass_mode {
PassMode::Direct(_) => {
// element is sized
if inner.ty == result_type {
inner
} else {
self.fatal(format!(
"runtime_array_index_intrinsic expected result_type to equal RuntimeArray's InterfaceBlock's inner_type: {:?} == {:?}",
self.lookup_type(result_type).debug(result_type, self),
self.lookup_type(inner.ty).debug(inner.ty, self)
))
}
}
PassMode::Pair(_, _) => {
// element is a slice
match self.lookup_type(result_type) {
SpirvType::Adt { field_types, .. } if field_types.len() == 2
&& matches!(self.lookup_type(field_types[0]), SpirvType::Pointer {..})
&& field_types[1] == self.type_isize() => {
}
_ => self.fatal(format!(
"Expected element of RuntimeArray to be a plain slice, like `&RuntimeArray<[u32]>`, but got {:?}!",
self.lookup_type(result_type).debug(result_type, self)
))
};
let len = self
.emit()
.array_length(self.type_isize(), None, element.def(self), 0)
.unwrap();
self.emit()
.composite_construct(result_type, None, [inner.def(self), len])
.unwrap()
.with_type(result_type)
}
_ => unreachable!(),
}
}
_ => {
// array of UniformConstant (image, sampler, etc.) descriptors
if ptr_element == result_type {
element
} else {
self.fatal(format!(
"runtime_array_index_intrinsic expected result_type to equal RuntimeArray's element_ty: {:?} == {:?}",
self.lookup_type(result_type).debug(result_type, self),
self.lookup_type(ptr_element).debug(ptr_element, self)
))
}
}
}
}
}
6 changes: 6 additions & 0 deletions crates/rustc_codegen_spirv/src/codegen_cx/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ impl<'tcx> CodegenCx<'tcx> {
.borrow_mut()
.insert(fn_id, mode);
}
if attrs.runtime_array_index_intrinsic.is_some() {
let mode = &fn_abi.ret.mode;
self.runtime_array_index_intrinsic_fn_id
.borrow_mut()
.insert(fn_id, mode);
}

let instance_def_id = instance.def_id();

Expand Down
117 changes: 60 additions & 57 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,25 +266,21 @@ impl<'tcx> CodegenCx<'tcx> {
}
ty => ty,
};
let deduced_storage_class_from_ty = match element_ty {
SpirvType::Image { .. }
| SpirvType::Sampler
| SpirvType::SampledImage { .. }
| SpirvType::AccelerationStructureKhr { .. } => {
if is_ref {
Some(StorageClass::UniformConstant)
} else {
self.tcx.sess.span_err(
hir_param.ty_span,
format!(
"entry parameter type must be by-reference: `&{}`",
value_layout.ty,
),
);
None
}
let deduced_storage_class_from_ty = if element_ty.is_uniform_constant() {
if is_ref {
Some(StorageClass::UniformConstant)
} else {
self.tcx.sess.span_err(
hir_param.ty_span,
format!(
"entry parameter type must be by-reference: `&{}`",
value_layout.ty,
),
);
None
}
_ => None,
} else {
None
};
// Storage classes can be specified via attribute. Compute that here, and emit diagnostics.
let attr_storage_class = attrs.storage_class.map(|storage_class_attr| {
Expand Down Expand Up @@ -502,49 +498,56 @@ impl<'tcx> CodegenCx<'tcx> {
Ok(
StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer,
) => {
let var_spirv_type = SpirvType::InterfaceBlock {
inner_type: value_spirv_type,
}
.def(hir_param.span, self);
var_ptr_spirv_type = self.type_ptr_to(var_spirv_type);

let value_ptr = bx.struct_gep(
var_spirv_type,
var_id.unwrap().with_type(var_ptr_spirv_type),
0,
);

let value_len = if is_unsized_with_len {
match self.lookup_type(value_spirv_type) {
SpirvType::RuntimeArray { .. } => {}
_ => {
self.tcx.sess.span_err(
hir_param.ty_span,
"only plain slices are supported as unsized types",
);
}
match self.lookup_type(value_spirv_type) {
SpirvType::RuntimeArray { element }
if matches!(
self.lookup_type(element),
SpirvType::InterfaceBlock { .. }
) =>
{
// array of buffer descriptors
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
(Ok(var_id.unwrap().with_type(var_ptr_spirv_type)), None)
}
_ => {
// single buffer descriptor
let var_spirv_type = SpirvType::InterfaceBlock {
inner_type: value_spirv_type,
}
.def(hir_param.span, self);
var_ptr_spirv_type = self.type_ptr_to(var_spirv_type);

// FIXME(eddyb) shouldn't this be `usize`?
let len_spirv_type = self.type_isize();
let len = bx
.emit()
.array_length(len_spirv_type, None, var_id.unwrap(), 0)
.unwrap();
let value_ptr = bx.struct_gep(
var_spirv_type,
var_id.unwrap().with_type(var_ptr_spirv_type),
0,
);

Some(len.with_type(len_spirv_type))
} else {
if is_unsized {
// It's OK to use a RuntimeArray<u32> and not have a length parameter, but
// it's just nicer ergonomics to use a slice.
self.tcx
.sess
.span_warn(hir_param.ty_span, "use &[T] instead of &RuntimeArray<T>");
}
None
};
let value_len = if is_unsized_with_len {
match self.lookup_type(value_spirv_type) {
SpirvType::RuntimeArray { .. } => {}
_ => {
self.tcx.sess.span_err(
hir_param.ty_span,
"only plain slices are supported as unsized types",
);
}
}

// FIXME(eddyb) shouldn't this be `usize`?
let len_spirv_type = self.type_isize();
let len = bx
.emit()
.array_length(len_spirv_type, None, var_id.unwrap(), 0)
.unwrap();
Some(len.with_type(len_spirv_type))
} else {
None
};

(Ok(value_ptr), value_len)
(Ok(value_ptr), value_len)
}
}
}
Ok(StorageClass::UniformConstant) => {
var_ptr_spirv_type = self.type_ptr_to(value_spirv_type);
Expand Down
3 changes: 3 additions & 0 deletions crates/rustc_codegen_spirv/src/codegen_cx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ pub struct CodegenCx<'tcx> {
pub buffer_load_intrinsic_fn_id: RefCell<FxHashMap<Word, &'tcx PassMode>>,
/// Intrinsic for storing a <T> into a &[u32]. The PassMode is the mode of the <T>.
pub buffer_store_intrinsic_fn_id: RefCell<FxHashMap<Word, &'tcx PassMode>>,
/// Intrinsic for loading a descriptor from a `RuntimeArray`. The PassMode is the mode of the <T>.
pub runtime_array_index_intrinsic_fn_id: RefCell<FxHashMap<Word, &'tcx PassMode>>,

/// Some runtimes (e.g. intel-compute-runtime) disallow atomics on i8 and i16, even though it's allowed by the spec.
/// This enables/disables them.
Expand Down Expand Up @@ -132,6 +134,7 @@ impl<'tcx> CodegenCx<'tcx> {
fmt_rt_arg_new_fn_ids_to_ty_and_spec: Default::default(),
buffer_load_intrinsic_fn_id: Default::default(),
buffer_store_intrinsic_fn_id: Default::default(),
runtime_array_index_intrinsic_fn_id: Default::default(),
i8_i16_atomics_allowed: false,
codegen_args,
}
Expand Down
Loading
Loading