Skip to content

Commit

Permalink
[shards] Implement atomic compare and exchange for msl
Browse files Browse the repository at this point in the history
  • Loading branch information
guusw committed May 27, 2024
1 parent c6a603f commit 484503a
Showing 1 changed file with 49 additions and 2 deletions.
51 changes: 49 additions & 2 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2986,8 +2986,44 @@ impl<W: Write> Writer<W> {
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_baking_expression(result, &context.expression, &res_name)?;
self.named_expressions.insert(result, res_name);
let fun_str = fun.to_msl()?;
self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
match *fun {
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
let policy = context.expression.choose_bounds_check_policy(pointer);

let result_type_name = match context.expression.info[result].ty {
TypeResolution::Handle(ty_handle) => {
let ty_name = TypeContext {
handle: ty_handle,
gctx: context.expression.module.to_ctx(),
names: &self.names,
access: crate::StorageAccess::empty(),
binding: None,
first_time: false,
};
ty_name
}
_ => {
return Err(Error::FeatureNotImplemented(
"atomic CompareExchange".to_string(),
));
}
};

write!(
self.out,
"__make_atomic_cas_result<{result_type_name}>({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, &context.expression)?;
write!(self.out, ", ")?;
self.put_expression(cmp, &context.expression, true)?;
write!(self.out, ", ")?;
self.put_expression(value, &context.expression, true)?;
write!(self.out, ")")?;
}
_ => { let fun_str = fun.to_msl()?;
self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
}
}
// done
writeln!(self.out, ";")?;
}
Expand Down Expand Up @@ -3356,6 +3392,17 @@ impl<W: Write> Writer<W> {
}
writeln!(self.out)?;

writeln!(
self.out,
r#"
template<typename R, typename T, typename A> R __make_atomic_cas_result(device A* ptr, T cmp_, T v) {{
T cmp = cmp_;
bool exchanged = {NAMESPACE}::atomic_compare_exchange_weak_explicit(ptr, &cmp, v, {NAMESPACE}::memory_order_relaxed, {NAMESPACE}::memory_order_relaxed);
return R{{cmp, exchanged}};
}}
"#
)?;

{
let mut indices = vec![];
for (handle, var) in module.global_variables.iter() {
Expand Down

0 comments on commit 484503a

Please sign in to comment.