Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 14, 2023
1 parent 07ed711 commit d11eba3
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 25 deletions.
150 changes: 142 additions & 8 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::any::Any;
use std::cell::{Cell, RefCell};
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::rc::Rc;
use std::sync::atomic::AtomicUsize;
Expand Down Expand Up @@ -292,6 +292,8 @@ impl_aggregate_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15)
pub(crate) struct FnRecorder {
pub(crate) parent: Option<FnRecorderPtr>,
pub(crate) scopes: Vec<IrBuilder>,
/// Nodes that are defined in the current [`FnRecorder`]
pub(crate) defined: HashMap<NodeRef, bool>,
/// Nodes that are should not be acess
/// Once a basicblock is finished, all nodes in it are added to this set
pub(crate) inaccessible: Rc<RefCell<HashSet<NodeRef>>>,
Expand Down Expand Up @@ -330,6 +332,31 @@ impl FnRecorder {
}
None
}
pub(crate) fn defined_in_cur_recorder(&mut self, node: NodeRef) -> bool {
// fast path
if self.defined.contains_key(&node) {
return self.defined[&node];
}
// slow path
for b in &self.scopes {
let bb = b.bb();
for n in bb.iter() {
if !self.defined.contains_key(&n) {
self.defined.insert(n, true);
}
if n == node {
self.defined.insert(node, true);
return true;
}
}
}
self.defined.insert(node, false);
if let Some(p) = &self.parent {
// also update parent
p.borrow_mut().defined_in_cur_recorder(node);
}
false
}
pub(crate) fn capture_or_get<T: Any>(
&mut self,
binding: ir::Binding,
Expand All @@ -346,6 +373,22 @@ impl FnRecorder {
node
}
}
pub(crate) fn get_defined_recorder(&mut self, node: NodeRef) -> *mut FnRecorder {
if self.defined_in_cur_recorder(node) {
self as *mut _
} else {
self.parent
.as_mut()
.unwrap_or_else(|| {
panic!(
"Node {:?} not defined in any kernel",
node.get().instruction
)
})
.borrow_mut()
.get_defined_recorder(node)
}
}
pub(crate) fn new(kernel_id: usize, parent: Option<FnRecorderPtr>) -> Self {
FnRecorder {
inaccessible: parent
Expand All @@ -357,6 +400,7 @@ impl FnRecorder {
cpu_custom_ops: IndexMap::new(),
callables: IndexMap::new(),
captured_vars: IndexMap::new(),
defined: HashMap::new(),
shared: vec![],
device: None,
block_size: None,
Expand Down Expand Up @@ -385,11 +429,36 @@ impl FnRecorder {
});
let mut parent = parent.borrow_mut();
let node = parent.map_captured_vars(node0);
if self.captured_vars.contains_key(&node.node) {
return self.captured_vars[&node.node].1;
}
assert_eq!(node.recorder, &mut *parent as *mut _);
assert_ne!(node.recorder, ptr);
node
};

match node.node.get().instruction.as_ref() {
Instruction::Call(f, args) if *f == Func::GetElementPtr => {
let ancestor = args[0];
let r = self.get_defined_recorder(ancestor);
// now we capture the ancestor
let ancestor_node = SafeNodeRef {
recorder: r,
node: ancestor,
kernel_id: self.kernel_id,
};
let ancestor_node = self.map_captured_vars(ancestor_node);
// create a new gep node
// this is a bit ugly
let mut gep = ancestor_node;
for idx in args[1..].iter() {
let ty = gep.node.type_();
let idx = idx.get_i32().try_into().unwrap();
gep = __extract_impl(gep, idx, ty.extract(idx));
}
return gep;
}
_ => {}
}
let arg = match node.node.get().instruction.as_ref() {
Instruction::Local { .. }
| Instruction::Call { .. }
Expand Down Expand Up @@ -431,7 +500,15 @@ impl FnRecorder {
panic!("cannot capture node {:?}", node.node.get().instruction)
}
};
self.defined.insert(arg.node, true);
// eprintln!("FnRecorder: {:?}", ptr);
// eprintln!("Captured {:?} -> {:?} {:?}", node0, node, arg);
self.captured_vars.insert(node0.node, (node.node, arg));
#[cfg(debug_assertions)]
{
let captured = self.captured_vars.values().map(|x| x.0).collect::<Vec<_>>();
check_arg_alias(&captured);
}
arg
}
}
Expand Down Expand Up @@ -505,10 +582,63 @@ pub(crate) fn with_recorder<R>(f: impl FnOnce(&mut FnRecorder) -> R) -> R {
pub fn __current_scope<F: FnOnce(&mut IrBuilder) -> R, R>(f: F) -> R {
with_recorder(|r| {
let s = &mut r.scopes;
f(s.last_mut().unwrap())
let b = s.last_mut().unwrap();
let cur_insert_point = b.get_insert_point();
let ret = f(b);
let new_insert_point = b.get_insert_point();

// this is conservative
{
let mut p = cur_insert_point;
let defined = &mut r.defined;
loop {
defined.insert(p, true);
if p == new_insert_point {
break;
}
let next = p.get().next;
p = next;
}
}
ret
})
}

pub(crate) fn check_arg_alias(args: &[NodeRef]) {
let lvalues = args.iter().filter(|x| x.is_lvalue()).collect::<Vec<_>>();
let mut ancestor: HashMap<NodeRef, NodeRef> = HashMap::new();
macro_rules! check_and_insert {
($an:expr, $v:expr) => {
if ancestor.contains_key(&$v) {
eprintln!("Aliasing detected!");
for a in args.iter() {
eprintln!("{:?}", a);
}
panic!("Alias detected in callable arguments! Multiple Var<T> are referencing (maybe indirectly) the same var. Aliasing is not allowed in callable arguments.");
} else {
ancestor.insert($v, $an);
}
};
}
for v in &lvalues {
match v.get().instruction.as_ref() {
Instruction::Local { .. } => {
check_and_insert!(**v, **v);
}
Instruction::Argument { .. } => {
check_and_insert!(**v, **v);
}
Instruction::Shared => {
check_and_insert!(**v, **v);
}
Instruction::Call(f, args) => {
if *f == Func::GetElementPtr {
check_and_insert!(args[0], **v);
}
}
_ => {}
}
}
}
pub(crate) fn __invoke_callable(callable: &CallableModuleRef, args: &[NodeRef]) -> NodeRef {
with_recorder(|r| {
let id = CArc::as_ptr(&callable.0) as u64;
Expand All @@ -518,6 +648,7 @@ pub(crate) fn __invoke_callable(callable: &CallableModuleRef, args: &[NodeRef])
r.callables.insert(id, callable.clone());
}
});
check_arg_alias(args);
__current_scope(|b| {
b.call(
Func::Callable(callable.clone()),
Expand Down Expand Up @@ -573,14 +704,17 @@ pub fn __module_pools() -> &'static CArc<ModulePools> {
}

/// Don't call this function directly unless you know what you are doing
pub fn __extract<T: Value>(safe_node: SafeNodeRef, index: usize) -> SafeNodeRef {
__extract_impl(safe_node, index, <T as TypeOf>::type_())
}
/** This function is soley for constructing proxies
* Given a node, __extract selects the correct Func based on the node's
* type It then inserts the extract(node, i) call *at where the node is
* defined* *Note*, after insertion, the IrBuilder in the correct/parent
* scope might not be up to date Thus, for IrBuilder of each scope, it
* updates the insertion point to the end of the current basic block
*/
pub fn __extract<T: Value>(safe_node: SafeNodeRef, index: usize) -> SafeNodeRef {
fn __extract_impl(safe_node: SafeNodeRef, index: usize, ty: CArc<Type>) -> SafeNodeRef {
let node = unsafe { safe_node.get_raw() };
let inst = &node.get().instruction;
let r = unsafe {
Expand Down Expand Up @@ -640,22 +774,22 @@ pub fn __extract<T: Value>(safe_node: SafeNodeRef, index: usize) -> SafeNodeRef
Func::AtomicRef => {
let mut indices = args.to_vec();
indices.push(i);
let n = b.call_no_append(Func::AtomicRef, &indices, <T as TypeOf>::type_());
let n = b.call_no_append(Func::AtomicRef, &indices, ty);
update_builders!();
return wrap_up!(n);
}
Func::GetElementPtr => {
let mut indices = args.to_vec();
indices.push(i);
let n = b.call(Func::GetElementPtr, &indices, <T as TypeOf>::type_());
let n = b.call(Func::GetElementPtr, &indices, ty);
update_builders!();
return wrap_up!(n);
}
_ => Func::ExtractElement,
},
_ => Func::ExtractElement,
};
let node = b.call(op, &[node, i], <T as TypeOf>::type_());
let node = b.call(op, &[node, i], ty);

update_builders!();
wrap_up!(node)
Expand Down
36 changes: 33 additions & 3 deletions luisa_compute/src/runtime/kernel.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::lang::{pop_recorder, push_recorder, soa::SoaMetadata, KERNEL_ID};
use crate::lang::{check_arg_alias, pop_recorder, push_recorder, soa::SoaMetadata, KERNEL_ID};

use super::*;

Expand Down Expand Up @@ -221,14 +221,23 @@ impl KernelBuilder {
Node::new(CArc::new(Instruction::Argument { by_value }), ty),
);
self.args.push(node);
with_recorder(|r| {
r.defined.insert(node, true);
});
node
}
pub fn value<T: Value>(&mut self) -> Expr<T> {
let node = self.arg(T::type_(), true);
with_recorder(|r| {
r.defined.insert(node, true);
});
FromNode::from_node(node.into())
}
pub fn var<T: Value>(&mut self) -> Var<T> {
let node = self.arg(T::type_(), false);
with_recorder(|r| {
r.defined.insert(node, true);
});
FromNode::from_node(node.into())
}
pub fn uniform<T: Value>(&mut self) -> Expr<T> {
Expand All @@ -237,6 +246,9 @@ impl KernelBuilder {
Node::new(CArc::new(Instruction::Uniform), T::type_()),
);
self.args.push(node);
with_recorder(|r| {
r.defined.insert(node, true);
});
FromNode::from_node(node.into())
}
// pub fn byte_buffer(&mut self) -> ByteBufferVar {
Expand All @@ -253,6 +265,9 @@ impl KernelBuilder {
Node::new(CArc::new(Instruction::Buffer), T::type_()),
);
self.args.push(node);
with_recorder(|r| {
r.defined.insert(node, true);
});
BufferVar {
node: node.into(),
marker: PhantomData,
Expand All @@ -272,6 +287,9 @@ impl KernelBuilder {
Node::new(CArc::new(Instruction::Texture2D), T::type_()),
);
self.args.push(node);
with_recorder(|r| {
r.defined.insert(node, true);
});
Tex2dVar {
node: node.into(),
marker: PhantomData,
Expand All @@ -285,6 +303,9 @@ impl KernelBuilder {
Node::new(CArc::new(Instruction::Texture3D), T::type_()),
);
self.args.push(node);
with_recorder(|r| {
r.defined.insert(node, true);
});
Tex3dVar {
node: node.into(),
marker: PhantomData,
Expand All @@ -298,6 +319,9 @@ impl KernelBuilder {
Node::new(CArc::new(Instruction::Bindless), Type::void()),
);
self.args.push(node);
with_recorder(|r| {
r.defined.insert(node, true);
});
BindlessArrayVar {
node: node.into(),
handle: None,
Expand All @@ -309,6 +333,9 @@ impl KernelBuilder {
Node::new(CArc::new(Instruction::Accel), Type::void()),
);
self.args.push(node);
with_recorder(|r| {
r.defined.insert(node, true);
});
rtx::AccelVar {
node: node.into(),
handle: None,
Expand Down Expand Up @@ -396,7 +423,9 @@ impl KernelBuilder {
let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module);

let mut args = self.args.clone();

args.extend(r.captured_vars.values().map(|x| unsafe { x.1.get_raw() }));

for a in &args {
r.inaccessible.borrow_mut().insert(*a);
}
Expand All @@ -409,12 +438,13 @@ impl KernelBuilder {
pools: r.pools.clone(),
};
let module = CallableModuleRef(CArc::new(module));

let captured = r.captured_vars.values().map(|x| x.0).collect::<Vec<_>>();
check_arg_alias(&captured);
RawCallable {
device: self.device.clone(),
module,
resource_tracker: rt,
captured_args: r.captured_vars.values().map(|x| x.0).collect(),
captured_args: captured,
}
});
pop_recorder();
Expand Down
11 changes: 6 additions & 5 deletions luisa_compute/tests/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ fn nested_callable_capture_by_value() {
}
}

// this is broken on dx!!!
#[test]
fn nested_callable_capture_by_ref_alias() {
let device = get_device();
Expand All @@ -102,10 +101,12 @@ fn nested_callable_capture_by_ref_alias() {
let u = Var::<Float2>::zeroed();
let acc = |x: Expr<f32>| {
outline(|| {
*v.x += x;
*u = Float2::expr(v.x, v.x);
buf_z.write(tid, v.load().x);
})
outline(|| {
*v.x += x;
*u = Float2::expr(v.x, v.x);
buf_z.write(tid, v.load().x);
});
});
};
acc(x.read(tid));
acc(y.read(tid));
Expand Down
Loading

0 comments on commit d11eba3

Please sign in to comment.