Skip to content

Commit

Permalink
add AD support for SampledSpectrum
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Dec 8, 2023
1 parent acdb3f8 commit e9ddd15
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/util/spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class SampledSpectrum {
private:
Local<float> _samples;

private:
explicit SampledSpectrum(Local<float> &&samples) noexcept
: _samples{std::move(samples)} {}

public:
SampledSpectrum(uint n, Expr<float> value) noexcept : _samples{n} {
compute::outline([&] {
Expand Down Expand Up @@ -84,6 +88,12 @@ class SampledSpectrum {
}
[[nodiscard]] Local<float> &values() noexcept { return _samples; }
[[nodiscard]] const Local<float> &values() const noexcept { return _samples; }

void requires_grad() const noexcept { _samples.requires_grad(); }
void backward() const noexcept { _samples.backward(); }
void backward(const SampledSpectrum &grad) const noexcept { _samples.backward(grad._samples); }
[[nodiscard]] auto grad() const noexcept { return SampledSpectrum{_samples.grad()}; }

[[nodiscard]] Float &operator[](Expr<uint> i) noexcept {
return dimension() == 1u ? _samples[0u] : _samples[i];
}
Expand Down Expand Up @@ -179,7 +189,7 @@ class SampledSpectrum {
}); \
return *this; \
} \
auto &operator op##=(const SampledSpectrum &rhs) noexcept { \
auto &operator op##=(const SampledSpectrum & rhs) noexcept { \
LUISA_ASSERT(rhs.dimension() == 1u || dimension() == rhs.dimension(), \
"Invalid sampled spectrum dimension for operator" #op "=: {} vs {}.", \
dimension(), rhs.dimension()); \
Expand Down

0 comments on commit e9ddd15

Please sign in to comment.