From e9ddd15164746c4268fa878e4e1fa4ebfb8c46de Mon Sep 17 00:00:00 2001 From: Mike Date: Fri, 8 Dec 2023 15:48:29 +0800 Subject: [PATCH] add AD support for SampledSpectrum --- src/util/spec.h | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/util/spec.h b/src/util/spec.h index 65855db1..d49be310 100644 --- a/src/util/spec.h +++ b/src/util/spec.h @@ -55,6 +55,10 @@ class SampledSpectrum { private: Local _samples; +private: + explicit SampledSpectrum(Local &&samples) noexcept + : _samples{std::move(samples)} {} + public: SampledSpectrum(uint n, Expr value) noexcept : _samples{n} { compute::outline([&] { @@ -84,6 +88,12 @@ class SampledSpectrum { } [[nodiscard]] Local &values() noexcept { return _samples; } [[nodiscard]] const Local &values() const noexcept { return _samples; } + + void requires_grad() const noexcept { _samples.requires_grad(); } + void backward() const noexcept { _samples.backward(); } + void backward(const SampledSpectrum &grad) const noexcept { _samples.backward(grad._samples); } + [[nodiscard]] auto grad() const noexcept { return SampledSpectrum{_samples.grad()}; } + [[nodiscard]] Float &operator[](Expr i) noexcept { return dimension() == 1u ? _samples[0u] : _samples[i]; } @@ -179,7 +189,7 @@ class SampledSpectrum { }); \ return *this; \ } \ - auto &operator op##=(const SampledSpectrum &rhs) noexcept { \ + auto &operator op##=(const SampledSpectrum & rhs) noexcept { \ LUISA_ASSERT(rhs.dimension() == 1u || dimension() == rhs.dimension(), \ "Invalid sampled spectrum dimension for operator" #op "=: {} vs {}.", \ dimension(), rhs.dimension()); \