diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index a6fada1..8e22085 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -282,7 +282,7 @@ impl Recorder { self.callable_ret_type = None; } - pub(crate) fn check_on_same_device(&self, other: &Device) -> Option<(String, String)> { + pub(crate) fn check_on_same_device(&mut self, other: &Device) -> Option<(String, String)> { if let Some(device) = &self.device { let device = device.upgrade().unwrap(); if !Arc::ptr_eq(&device.inner, &other.inner) { @@ -292,7 +292,7 @@ impl Recorder { )); } } else { - // @FIXME: What should we do? + self.device = Some(WeakDevice::new(other)); } None } diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index a1409b1..f83aecf 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1254,7 +1254,7 @@ pub struct RawCallable { impl RawCallable { pub(crate) fn check_on_same_device(&self) { RECORDER.with(|r| { - let r =r.borrow(); + let mut r =r.borrow_mut(); if let Some(device) = &self.device { if let Some((a,b)) = r.check_on_same_device(device) { panic!("Callable created on a different device than the one it is called on: {:?} vs {:?}", a,b);