Skip to content

Commit

Permalink
add copy() and copy_async() for cloning contents of Buffer/Tex2d
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 19, 2023
1 parent a2bddd1 commit 299eefb
Show file tree
Hide file tree
Showing 8 changed files with 400 additions and 347 deletions.
26 changes: 15 additions & 11 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ use std::any::Any;
use std::cell::{Cell, RefCell};
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::rc::Rc;
use std::rc::{Rc, Weak};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::sync::{Arc, Weak as WeakArc};
use std::{env, unreachable};

use crate::internal_prelude::*;

use bumpalo::Bump;
use indexmap::IndexMap;

use crate::runtime::WeakDevice;
use crate::runtime::{RawCallable, WeakDevice};

pub mod ir {
pub use luisa_compute_ir::context::register_type;
Expand Down Expand Up @@ -298,7 +298,7 @@ pub(crate) struct FnRecorder {
/// Once a basicblock is finished, all nodes in it are added to this set
pub(crate) inaccessible: Rc<RefCell<HashSet<NodeRef>>>,
pub(crate) kernel_id: usize,
pub(crate) captured_resources: IndexMap<Binding, (usize, NodeRef, Binding, Arc<dyn Any>)>,
pub(crate) captured_resources: IndexMap<Binding, (usize, NodeRef, Binding, WeakArc<dyn Any>)>,
pub(crate) cpu_custom_ops: IndexMap<u64, (usize, CArc<CpuCustomOp>)>,
pub(crate) callables: IndexMap<u64, CallableModuleRef>,
pub(crate) captured_vars: IndexMap<NodeRef, (NodeRef, SafeNodeRef)>,
Expand All @@ -311,6 +311,7 @@ pub(crate) struct FnRecorder {
pub(crate) callable_ret_type: Option<CArc<Type>>,
pub(crate) const_builder: IrBuilder,
pub(crate) index_const_pool: IndexMap<i32, NodeRef>,
pub(crate) rt: ResourceTracker,
}
pub(crate) type FnRecorderPtr = Rc<RefCell<FnRecorder>>;
impl FnRecorder {
Expand Down Expand Up @@ -372,7 +373,7 @@ impl FnRecorder {
pub(crate) fn capture_or_get<T: Any>(
&mut self,
binding: ir::Binding,
handle: &Arc<T>,
handle: &WeakArc<T>,
create_node: impl FnOnce() -> Node,
) -> NodeRef {
if let Some((_, node, _, _)) = self.captured_resources.get(&binding) {
Expand Down Expand Up @@ -425,6 +426,7 @@ impl FnRecorder {
parent,
index_const_pool: IndexMap::new(),
const_builder: IrBuilder::new(pools.clone()),
rt: ResourceTracker::new(),
}
}
pub(crate) fn map_captured_vars(&mut self, node0: SafeNodeRef) -> SafeNodeRef {
Expand Down Expand Up @@ -660,21 +662,23 @@ pub(crate) fn check_arg_alias(args: &[NodeRef]) {
}
}
}
pub(crate) fn __invoke_callable(callable: &CallableModuleRef, args: &[NodeRef]) -> NodeRef {
pub(crate) fn __invoke_callable(callable: &RawCallable, args: &[NodeRef]) -> NodeRef {
let inner = &callable.module;
with_recorder(|r| {
let id = CArc::as_ptr(&callable.0) as u64;
let id = CArc::as_ptr(&inner.0) as u64;
if let Some(c) = r.callables.get(&id) {
assert_eq!(CArc::as_ptr(&c.0), CArc::as_ptr(&callable.0));
assert_eq!(CArc::as_ptr(&c.0), CArc::as_ptr(&inner.0));
} else {
r.callables.insert(id, callable.clone());
r.callables.insert(id, inner.clone());
r.rt.merge(callable.resource_tracker.clone());
}
});
check_arg_alias(args);
__current_scope(|b| {
b.call(
Func::Callable(callable.clone()),
Func::Callable(inner.clone()),
args,
callable.0.ret_type.clone(),
inner.0.ret_type.clone(),
)
})
}
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/src/lang/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub trait PolymorphicImpl<T: ?Sized + 'static>: Value {
#[macro_export]
macro_rules! impl_new_poly_array {
($buffer:expr, $tag:expr, $key:expr) => {{
let buffer = unsafe { $buffer.shallow_clone() };
let buffer = $buffer.view(..);
luisa_compute::PolyArray::new(
$tag,
$key,
Expand All @@ -53,7 +53,7 @@ macro_rules! impl_polymorphic {
tag: i32,
key: K,
) -> luisa_compute::lang::poly::PolyArray<K, dyn $trait_> {
let buffer = unsafe { buffer.shallow_clone() };
let buffer = buffer.view(..);
luisa_compute::lang::poly::PolyArray::new(
tag,
key,
Expand Down
35 changes: 31 additions & 4 deletions luisa_compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,47 @@ impl Context {

#[derive(Clone)]
pub struct ResourceTracker {
resources: Vec<Arc<dyn Any>>,
strong_refs: Vec<Arc<dyn Any>>,
weak_refs: Vec<Weak<dyn Any>>,
}

impl ResourceTracker {
pub fn add<T: Any>(&mut self, ptr: Arc<T>) -> &mut Self {
self.resources.push(ptr);
self.strong_refs.push(ptr);
self
}
pub fn add_any(&mut self, ptr: Arc<dyn Any>) -> &mut Self {
self.resources.push(ptr);
self.strong_refs.push(ptr);
self
}
pub fn add_weak<T: Any>(&mut self, ptr: Weak<T>) -> &mut Self {
self.weak_refs.push(ptr);
self
}
pub fn add_weak_any(&mut self, ptr: Weak<dyn Any>) -> &mut Self {
self.weak_refs.push(ptr);
self
}
pub fn merge(&mut self, other: Self) {
self.strong_refs.extend(other.strong_refs);
self.weak_refs.extend(other.weak_refs);
}
pub fn upgrade(&self) -> Self {
let mut strong_refs = vec![];
for r in self.weak_refs.iter() {
strong_refs.push(r.upgrade().unwrap_or_else(|| panic!("Bad weak ref. Kernel captured resources might be dropped.")));
}
strong_refs.extend(self.strong_refs.iter().cloned());
Self {
strong_refs,
weak_refs: vec![],
}
}
pub fn new() -> Self {
Self { resources: vec![] }
Self {
strong_refs: vec![],
weak_refs: vec![],
}
}
}

Expand Down
Loading

0 comments on commit 299eefb

Please sign in to comment.