Skip to content

Commit

Permalink
looks good
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 10, 2023
1 parent 5b898e8 commit 89031ef
Show file tree
Hide file tree
Showing 14 changed files with 540 additions and 336 deletions.
45 changes: 27 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -432,19 +432,20 @@ let a = 1.0f32.var();
pass_by_ref.call(a);
cpu_dbg!(*a); // prints 2.0
```
***Note***: You cannot record a callable when recording another kernel or callables. This is because a callable can capture outer variables such as buffers. However, capturing local variables define in another callable is undefined behavior. To avoid this, we disallow recording a callable when recording another callable or kernel.
***Note***: You create callables within another callable and even capture variables from the outer callable. The only limitation is that you cannot return a callable from another callable.
```rust
let add = Callable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>::new(&device, track!(|a, b| {
// runtime error!
let another_add = Callable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>::new(&device, track!(|a, b| {
a + b
}));
a + b
}));
let add = track!(Callable::<fn(Expr<f32>, Expr<f32>) -> Expr<f32>>::new(
&device,
|a, b| {
// callables can be defined within callables
let partial_add = Callable::<fn(Expr<f32>) -> Expr<f32>>::new(&device, |y| a + y);
partial_add.call(b)
}
));
```

***However, we acknowledge that recording a callable inside another callable/kernel is a useful feature***. Thus we provide two ways to workaround this limitation:
1. Use static callables. A static callable does not capture any resources and thus can be safely recorded inside any callable/kernel. To create a static callable, use `create_static_callable(fn)`. For example,
#### Static callables
A static callable does not capture any resources and thus can be safely recorded inside any callable/kernel. To create a static callable, use `create_static_callable(fn)`. For example,
```rust
lazy_static! {
static ref ADD:Callable<fn(Expr<f32>, Expr<f32>)->Expr<f32>> = Callable::<fn(Expr<f32>, Expr<f32>)->Expr<f32>>::new_static(|a, b| {
Expand All @@ -454,16 +455,24 @@ lazy_static! {
ADD.call(x, y);
```

2. Use `DynCallable`. These are callables that defer recording until being called. As a result, it requires you to pass a `'static` closure, avoiding the capture issue. To create a `DynCallable`, use `Device::create_dyn_callable(Box::new(fn))`. The syntax is the same as `create_callable`. Furthermore, `DynCallable` supports `DynExpr` and `DynVar`, which provides some capablitiy of implementing template/overloading inside EDSL.
#### Dynamic callables (templates)
Use `DynCallable`. These are callables that defer recording until being called. As a result, it requires you to pass a `'static` closure, avoiding the capture issue. To create a `DynCallable`, use `Device::create_dyn_callable(Box::new(fn))`. The syntax is the same as `create_callable`. Furthermore, `DynCallable` supports `DynExpr` and `DynVar`, which provides some capablitiy of implementing template/overloading inside EDSL.

```rust
let add = Callable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>::new(&device, track!(|a, b| {
// no error!
let another_add = DynCallable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>::new(&device, track!(Box::new(|a, b| {
a + b
})));
a + b
}));
let add = DynCallable::<fn(DynExpr, DynExpr) -> DynExpr>::new(
&device,
Box::new(|a: DynExpr, b: DynExpr| -> DynExpr {
if let Some(a) = a.downcast::<f32>() {
let b = b.downcast::<f32>().unwrap();
return DynExpr::new(track!(a + b));
} else if let Some(a) = a.downcast::<i32>() {
let b = b.downcast::<i32>().unwrap();
return DynExpr::new(track!(a + b));
} else {
unreachable!()
}
}),
);
```

### Kernel
Expand Down
49 changes: 49 additions & 0 deletions luisa_compute/examples/nested_callable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use luisa::prelude::*;
use luisa_compute as luisa;
use std::env::current_exe;

fn main() {
luisa::init_logger();
let args: Vec<String> = std::env::args().collect();
assert!(
args.len() <= 2,
"Usage: {} <backend>. <backend>: cpu, cuda, dx, metal, remote",
args[0]
);

let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device(if args.len() == 2 {
args[1].as_str()
} else {
"cpu"
});

let add = track!(Callable::<fn(Expr<f32>, Expr<f32>) -> Expr<f32>>::new(
&device,
|a, b| {
// callables can be defined within callables
let partial_add = Callable::<fn(Expr<f32>) -> Expr<f32>>::new(&device, |y| a + y);
partial_add.call(b)
}
));
let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = Kernel::<fn(Buffer<f32>)>::new(
&device,
&track!(|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);

buf_z.write(tid, add.call(x, y));
}),
);
kernel.dispatch([1024, 1, 1], &z);
let z_data = z.view(..).copy_to_vec();
println!("{:?}", &z_data[0..16]);
}
51 changes: 29 additions & 22 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ impl<T: Aggregate> Aggregate for Vec<T> {
}

fn from_nodes<I: Iterator<Item = SafeNodeRef>>(iter: &mut I) -> Self {
let len_node = iter.next().unwrap();
let len = len_node.get().unwrap_user_data::<usize>();
let len_node = iter.next().unwrap().get();
let len = len_node.unwrap_user_data::<usize>();
let mut ret = Vec::with_capacity(*len);
for _ in 0..*len {
ret.push(T::from_nodes(iter));
Expand Down Expand Up @@ -206,8 +206,8 @@ impl<T: Aggregate> Aggregate for Option<T> {
}

fn from_nodes<I: Iterator<Item = SafeNodeRef>>(iter: &mut I) -> Self {
let node = iter.next().unwrap();
let tag = node.get().unwrap_user_data::<usize>();
let node = iter.next().unwrap().get();
let tag = node.unwrap_user_data::<usize>();
match *tag {
0 => None,
1 => Some(T::from_nodes(iter)),
Expand Down Expand Up @@ -292,7 +292,6 @@ impl_aggregate_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15)
pub(crate) struct FnRecorder {
pub(crate) parent: Option<FnRecorderPtr>,
pub(crate) scopes: Vec<IrBuilder>,
pub(crate) kernel_id: Option<usize>,
pub(crate) captured_resources: IndexMap<Binding, (usize, NodeRef, Binding, Arc<dyn Any>)>,
pub(crate) cpu_custom_ops: IndexMap<u64, (usize, CArc<CpuCustomOp>)>,
pub(crate) callables: IndexMap<u64, CallableModuleRef>,
Expand Down Expand Up @@ -330,7 +329,7 @@ impl FnRecorder {
if let Some((_, node, _, _)) = self.captured_resources.get(&binding) {
*node
} else {
let node = new_node(self.pools.as_ref(), create_node());
let node = new_node(&self.pools, create_node());
let i = self.captured_resources.len();
self.captured_resources
.insert(binding, (i, node, binding, handle.clone()));
Expand All @@ -350,7 +349,6 @@ impl FnRecorder {
pools: CArc::new(ModulePools::new()),
arena: Bump::new(),
building_kernel: false,
kernel_id: None,
callable_ret_type: None,
parent: None,
}
Expand All @@ -362,6 +360,7 @@ impl FnRecorder {
if self.captured_vars.contains_key(&node) {
return self.captured_vars[&node];
}
let ptr = self as *mut _;
let parent = self
.parent
.as_mut()
Expand All @@ -375,9 +374,9 @@ impl FnRecorder {
}
}
let arg = SafeNodeRef {
recorder: self as *mut _,
recorder: ptr,
node: new_node(
self.pools.as_ref(),
&self.pools,
Node::new(
CArc::new(Instruction::Argument {
by_value: !node.node.is_lvalue(),
Expand Down Expand Up @@ -647,13 +646,12 @@ pub fn pack_to<V: Value, B>(expr: Expr<V>, buffer: &B, index: impl AsExpr<Value
where
B: IndexWrite<Element = u32> + ToNode,
{
let index = index.as_expr();
let index = index.as_expr().node().get();
let expor = expr.node().get();
let buffer = buffer.node().get();
let expr = expr.node().get();
__current_scope(|b| {
b.call(
Func::Pack,
&[expr.node(), buffer.node(), index.node()],
Type::void(),
);
b.call(Func::Pack, &[expr, buffer, index], Type::void());
});
}

Expand All @@ -666,13 +664,9 @@ where
{
let index = index.into().node().get();
let buffer = buffer.node().get();
Expr::<T>::from_node(__current_scope(|b| {
b.call(
Func::Unpack,
&[buffer, index],
<T as TypeOf>::type_(),
)
}).into())
Expr::<T>::from_node(
__current_scope(|b| b.call(Func::Unpack, &[buffer, index], <T as TypeOf>::type_())).into(),
)
}

pub(crate) fn need_runtime_check() -> bool {
Expand Down Expand Up @@ -721,3 +715,16 @@ pub(crate) fn check_index_lt_usize(index: impl IntoIndex, size: usize) {
lc_assert!(index.lt(size as u64));
}
}

/// Outline a code snippet.
/// Snippets that have the same code will be deduplicated.
/// It helps reduce compilation time.
pub fn outline<F: Fn()>(f: F) {
let device = RECORDER.with(|r| {
let r = r.borrow();
let r = r.as_ref().unwrap();
let r = r.borrow();
r.device.clone().map(|x| x.upgrade().unwrap())
});
Callable::<fn()>::new_maybe_device(device, f).call();
}
39 changes: 22 additions & 17 deletions luisa_compute/src/lang/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,20 @@ pub fn requires_grad<V: Value>(var: Expr<V>) {
pub fn backward<V: Value>(out: Expr<V>) {
backward_with_grad(
out,
FromNode::from_node(__current_scope(|b| {
let one = new_node(
b.pools(),
Node::new(
CArc::new(Instruction::Const(Const::One(V::type_()))),
V::type_(),
),
);
b.append(one);
one.into()
})),
FromNode::from_node(
__current_scope(|b| {
let one = new_node(
b.pools(),
Node::new(
CArc::new(Instruction::Const(Const::One(V::type_()))),
V::type_(),
),
);
b.append(one);
one
})
.into(),
),
);
}

Expand Down Expand Up @@ -95,8 +98,7 @@ pub fn gradient<V: Value>(var: Expr<V>) -> Expr<V> {
});
let var = var.node().get();
Expr::<V>::from_node(
__current_scope(|b| b.call(Func::Gradient, &[var], var.type_().clone()))
.into(),
__current_scope(|b| b.call(Func::Gradient, &[var], var.type_().clone())).into(),
)
}
/// Gradient of a value in *Reverse mode* AD
Expand Down Expand Up @@ -187,10 +189,13 @@ pub fn output_gradients<V: Value>(v: Expr<V>) -> Vec<Expr<V>> {
let mut grads = vec![];
let v = v.node().get();
for i in 0..n {
grads.push(Expr::<V>::from_node(__current_scope(|b| {
let idx = b.const_(Const::Int32(i as i32));
b.call(Func::OutputGrad, &[v, idx], v.type_().clone())
}).into()));
grads.push(Expr::<V>::from_node(
__current_scope(|b| {
let idx = b.const_(Const::Int32(i as i32));
b.call(Func::OutputGrad, &[v, idx], v.type_().clone())
})
.into(),
));
}
grads
}
Expand Down
40 changes: 19 additions & 21 deletions luisa_compute/src/lang/control_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,27 +129,25 @@ pub fn if_then_else<R: Aggregate>(
b.if_(cond, then_block, else_block);
});
assert_eq!(then_nodes.len(), else_nodes.len());
let phis = __current_scope(|b| {
then_nodes
.iter()
.zip(else_nodes.iter())
.map(|(then, else_)| {
let incomings = vec![
PhiIncoming {
value: *then,
block: then_block,
},
PhiIncoming {
value: *else_,
block: else_block,
},
];
assert_eq!(then.type_(), else_.type_());
let phi = b.phi(&incomings, then.type_().clone());
phi.into()
})
.collect::<Vec<_>>()
});
let phis = then_nodes
.iter()
.zip(else_nodes.iter())
.map(|(then, else_)| {
let incomings = vec![
PhiIncoming {
value: *then,
block: then_block,
},
PhiIncoming {
value: *else_,
block: else_block,
},
];
assert_eq!(then.type_(), else_.type_());
let phi = __current_scope(|b| b.phi(&incomings, then.type_().clone()));
phi.into()
})
.collect::<Vec<_>>();
R::from_vec_nodes(phis)
}

Expand Down
20 changes: 12 additions & 8 deletions luisa_compute/src/lang/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ impl<T: Value, const N: usize, X: IntoIndex> Index<X> for ArrayAtomicRef<T, N> {
if need_runtime_check() {
check_index_lt_usize(i, N);
}
let inst = self.0.node.get().get().instruction.as_ref();
let node = self.0.node.get();
let inst = node.get().instruction.as_ref();
let mut args = match inst {
Instruction::Call(f, args) => match f {
Func::AtomicRef => args.to_vec(),
Expand Down Expand Up @@ -256,10 +257,13 @@ impl<T: Value> IndexRead for VLArrayVar<T> {
}
let self_node = self.node.get();
let i = i.node().get();
Expr::<T>::from_node(__current_scope(|b| {
let gep = b.call(Func::GetElementPtr, &[self_node, i], T::type_());
b.call(Func::Load, &[gep], T::type_())
}).into())
Expr::<T>::from_node(
__current_scope(|b| {
let gep = b.call(Func::GetElementPtr, &[self_node, i], T::type_());
b.call(Func::Load, &[gep], T::type_())
})
.into(),
)
}
}
impl<T: Value> IndexWrite for VLArrayVar<T> {
Expand Down Expand Up @@ -326,9 +330,9 @@ impl<T: Value> IndexRead for VLArrayExpr<T> {
}
let node = self.node.get();
let i = i.node().get();
Expr::<T>::from_node(__current_scope(|b| {
b.call(Func::ExtractElement, &[node, i], T::type_())
}).into())
Expr::<T>::from_node(
__current_scope(|b| b.call(Func::ExtractElement, &[node, i], T::type_())).into(),
)
}
}
impl<T: Value> VLArrayExpr<T> {
Expand Down
Loading

0 comments on commit 89031ef

Please sign in to comment.