From c4f379dbd7780a5c6f953781de3aa6d5c86f5c1c Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Fri, 15 Mar 2024 05:55:33 -0400 Subject: [PATCH] added signum for floats --- luisa_compute/src/lang/ops/impls.rs | 10 ++++++++++ luisa_compute/src/lang/ops/spread.rs | 1 + luisa_compute/src/lang/ops/traits.rs | 1 + luisa_compute/src/lib.rs | 6 +++++- luisa_compute_sys/LuisaCompute | 2 +- 5 files changed, 18 insertions(+), 2 deletions(-) diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index 36f828e..0d8fd89 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -236,6 +236,16 @@ where fn sin_cos(&self) -> (Self, Self) { (self.sin(), self.cos()) } + fn signum(&self) -> Self { + let self_node = self.node().get(); + ::from_node( + __current_scope(|b| { + let one = b.const_(Const::One(::type_())); + b.call(Func::Copysign, &[one, self_node], ::type_()) + }) + .into(), + ) + } } impl NormExpr for Expr> where diff --git a/luisa_compute/src/lang/ops/spread.rs b/luisa_compute/src/lang/ops/spread.rs index 26db436..9b55ba2 100644 --- a/luisa_compute/src/lang/ops/spread.rs +++ b/luisa_compute/src/lang/ops/spread.rs @@ -286,6 +286,7 @@ where Expr::::_mul_add(self, S::lift_self(mul), S::lift_other(add)) } } + impl FloatCopySignExpr for T where T: SpreadOps, diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs index 690a718..5bb5323 100644 --- a/luisa_compute/src/lang/ops/traits.rs +++ b/luisa_compute/src/lang/ops/traits.rs @@ -244,6 +244,7 @@ pub trait FloatExpr: Sized { fn cube(&self) -> Self; fn recip(&self) -> Self; fn sin_cos(&self) -> (Self, Self); + fn signum(&self) -> Self; } pub trait ReduceExpr: Sized { diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 9aad0ab..74e59e9 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -198,7 +198,11 @@ impl Context { pub fn create_device(&self, device: D) -> Device { self.create_device_with_config(device, serde_json::json!({})) } - pub fn create_device_with_config(&self, device: D, config: serde_json::Value) -> Device { + pub fn create_device_with_config( + &self, + device: D, + config: serde_json::Value, + ) -> Device { let backend = self.inner.create_device(&device.into_device_name(), config); let default_stream = backend.create_stream(api::StreamTag::Graphics); Device { diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index c26cdb6..33ce816 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit c26cdb6918d5eeaf9bf39f191c9b347d8f41ea52 +Subproject commit 33ce816b586808a202bc5f4e5f0728ddd803934a