Skip to content

Commit

Permalink
implementing autodiff in MicrofacetReflection: BUG in $if
Browse files Browse the repository at this point in the history
  • Loading branch information
W-Solaris committed Dec 10, 2023
1 parent 17dc35b commit 5816f61
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 28 deletions.
82 changes: 54 additions & 28 deletions src/util/scattering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ Float MicrofacetDistribution::G1(Expr<float3> w) const noexcept {
}

Float MicrofacetDistribution::G(Expr<float3> wo, Expr<float3> wi) const noexcept {
return 1.0f / (1.0f + Lambda(wo) + Lambda(wi));
return forward_compute_G(wo, wi, alpha());
}

Float MicrofacetDistribution::forward_compute_G(Expr<float3> wo, Expr<float3> wi, Float2 alpha) const noexcept {
return 1.0f / (1.0f + forward_compute_Lambda(wo, alpha) + forward_compute_Lambda(wi, alpha));
}

Float MicrofacetDistribution::pdf(Expr<float3> wo, Expr<float3> wh) const noexcept {
Expand Down Expand Up @@ -163,6 +167,10 @@ Float2 TrowbridgeReitzDistribution::grad_alpha_roughness(Expr<float2> roughness)
}

Float TrowbridgeReitzDistribution::D(Expr<float3> wh) const noexcept {
return forward_compute_D(wh, alpha());
}

Float TrowbridgeReitzDistribution::forward_compute_D(Expr<float3> wh, Float2 alpha) const noexcept {
using compute::isinf;
static Callable impl = [](Float3 wh, Float2 alpha) noexcept {
auto tan2Theta = tan2_theta(wh);
Expand All @@ -172,10 +180,14 @@ Float TrowbridgeReitzDistribution::D(Expr<float3> wh) const noexcept {
auto d = 1.0f / (pi * alpha.x * alpha.y * cos4Theta * sqr(1.f + e));
return ite(isinf(tan2Theta), 0.f, d);
};
return impl(wh, alpha());
return impl(wh, alpha);
}

Float TrowbridgeReitzDistribution::Lambda(Expr<float3> w) const noexcept {
return forward_compute_Lambda(w, alpha());
}

Float TrowbridgeReitzDistribution::forward_compute_Lambda(Expr<float3> w, Float2 alpha) const noexcept {
using compute::isinf;
static Callable impl = [](Float3 w, Float2 alpha) noexcept {
auto tanTheta = abs(tan_theta(w));
Expand All @@ -186,7 +198,7 @@ Float TrowbridgeReitzDistribution::Lambda(Expr<float3> w) const noexcept {
auto L = (-1.f + sqrt(1.f + alpha2Tan2Theta)) * .5f;
return ite(isinf(tanTheta), 0.f, L);
};
return impl(w, alpha());
return impl(w, alpha);
}

[[nodiscard]] inline Float2 TrowbridgeReitzSample11(Expr<float> cosTheta, Expr<float2> U) noexcept {
Expand Down Expand Up @@ -274,19 +286,14 @@ MicrofacetDistribution::Gradient TrowbridgeReitzDistribution::grad_Lambda(Expr<f
}

MicrofacetDistribution::Gradient TrowbridgeReitzDistribution::grad_D(Expr<float3> wh) const noexcept {
using compute::isinf;
auto tan2Theta = tan2_theta(wh);
auto cos4Theta = sqr(cos2_theta(wh));

auto e0 = tan2Theta * sqr(cos_phi(wh) / alpha().x);
auto e1 = tan2Theta * sqr(sin_phi(wh) / alpha().y);
auto e = e0 + e1;
auto D = 1.0f / (pi * alpha().x * alpha().y * cos4Theta * sqr(1.f + e));

auto d_D = ite(isinf(tan2Theta), 0.f, 1.f);
auto d_e = -d_D * 2.f / (1.f + e) * D;
auto d_alpha = -d_D * D + d_e * 2.f * make_float2(e0, e1) / alpha();

auto d_alpha = alpha();
$autodiff {
auto alpha_back = alpha();
requires_grad(alpha_back);
auto y = forward_compute_D(wh, alpha_back);
backward(y);
d_alpha = grad(alpha_back);
};
return {.dAlpha = d_alpha};
}

Expand Down Expand Up @@ -349,20 +356,26 @@ LambertianTransmission::Gradient LambertianTransmission::backward(

SampledSpectrum MicrofacetReflection::evaluate(
Expr<float3> wo, Expr<float3> wi, TransportMode mode) const noexcept {
return forward_compute(wo, wi, mode, _r, _distribution->alpha());
}

SampledSpectrum MicrofacetReflection::forward_compute(
Expr<float3> wo, Expr<float3> wi, TransportMode mode, SampledSpectrum r, Float2 alpha) const noexcept {
using compute::any;
using compute::normalize;
auto wh = wi + wo;
SampledSpectrum f{_r.dimension()};
// TODO: autodiff do not support $if ?
$if(same_hemisphere(wo, wi) & any(wh != 0.f)) {
wh = normalize(wh);
// wh = normalize(wh);
// For the Fresnel call, make sure that wh is in the same hemisphere
// as the surface normal, so that TIR is handled correctly.
auto F = _fresnel->evaluate(dot(wi, face_forward(wh, make_float3(0.f, 0.f, 1.f))));
auto D = _distribution->D(wh);
auto G = _distribution->G(wo, wi);
auto D = _distribution->forward_compute_D(wh, alpha);
auto G = _distribution->forward_compute_G(wo, wi, alpha);
auto cos_o = cos_theta(wo);
auto cos_i = cos_theta(wi);
f = _r * F * abs(0.25f * D * G / (cos_i * cos_o));
f = r * F * abs(0.25f * D * G / (cos_i * cos_o));
};
return f;
}
Expand All @@ -385,6 +398,19 @@ Float MicrofacetReflection::pdf(Expr<float3> wo, Expr<float3> wi, TransportMode

MicrofacetReflection::Gradient MicrofacetReflection::backward(
Expr<float3> wo, Expr<float3> wi, const SampledSpectrum &df, TransportMode mode) const noexcept {
auto d_alpha = _distribution->alpha();
auto d_r = _r;
$autodiff {
auto alpha_back = _distribution->alpha();
auto r = _r;
requires_grad(alpha_back);
r.requires_grad();
auto y = forward_compute(wo, wi, mode, r, alpha_back);
y.backward(df);
d_alpha = grad(alpha_back);
d_r = r.grad();
};

using compute::any;
using compute::normalize;
auto wh = wi + wo;
Expand All @@ -400,15 +426,15 @@ MicrofacetReflection::Gradient MicrofacetReflection::backward(
auto k1 = k0 * D * G;
auto h = abs(k1);

// backward
auto d_h = (df * _r * F).sum() * ite(valid, 1.f, 0.f);
// // backward
// auto d_h = (df * _r * F).sum() * ite(valid, 1.f, 0.f);
auto d_F = df * _r * ite(valid, h, 0.f);
auto k2 = d_h * sign(k1) * k0;
auto d_D = k2 * G;
auto d_G = k2 * D;
auto d_alpha = d_D * _distribution->grad_D(wh).dAlpha +
d_G * _distribution->grad_G(wo, wi).dAlpha;
auto d_r = df * F * ite(valid, h, 0.f);
// auto k2 = d_h * sign(k1) * k0;
// auto d_D = k2 * G;
// auto d_G = k2 * D;
// auto d_alpha = d_D * _distribution->grad_D(wh).dAlpha +
// d_G * _distribution->grad_G(wo, wi).dAlpha;
// auto d_r = df * F * ite(valid, h, 0.f);

return {.dR = d_r, .dAlpha = d_alpha, .dFresnel = _fresnel->backward(cosI_eval, d_F)};
}
Expand Down
9 changes: 9 additions & 0 deletions src/util/scattering.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class MicrofacetDistribution {
[[nodiscard]] Float pdf(Expr<float3> wo, Expr<float3> wh) const noexcept;
[[nodiscard]] auto alpha() const noexcept { return _alpha; }

[[nodiscard]] virtual Float forward_compute_Lambda(Expr<float3> w, Float2 alpha) const noexcept = 0;
[[nodiscard]] virtual Float forward_compute_D(Expr<float3> wh, Float2 alpha) const noexcept = 0;
[[nodiscard]] virtual Float forward_compute_G(Expr<float3> wo, Expr<float3> wi, Float2 alpha) const noexcept;

[[nodiscard]] virtual Gradient grad_G1(Expr<float3> w) const noexcept;
[[nodiscard]] virtual Gradient grad_G(Expr<float3> wo, Expr<float3> wi) const noexcept;
[[nodiscard]] virtual Gradient grad_D(Expr<float3> wh) const noexcept = 0;
Expand All @@ -67,6 +71,9 @@ struct TrowbridgeReitzDistribution : public MicrofacetDistribution {
[[nodiscard]] static Float alpha_to_roughness(Expr<float> alpha) noexcept;
[[nodiscard]] static Float2 alpha_to_roughness(Expr<float2> alpha) noexcept;

[[nodiscard]] Float forward_compute_Lambda(Expr<float3> w, Float2 alpha) const noexcept override;
[[nodiscard]] Float forward_compute_D(Expr<float3> wh, Float2 alpha) const noexcept override;

[[nodiscard]] Gradient grad_D(Expr<float3> wh) const noexcept override;
[[nodiscard]] Gradient grad_Lambda(Expr<float3> w) const noexcept override;
[[nodiscard]] static Float2 grad_alpha_roughness(Expr<float2> roughness) noexcept;
Expand Down Expand Up @@ -195,6 +202,7 @@ class MicrofacetReflection : public BxDF {
: _r{R}, _distribution{d}, _fresnel{f} {}
[[nodiscard]] SampledSpectrum evaluate(Expr<float3> wo, Expr<float3> wi, TransportMode mode) const noexcept override;
[[nodiscard]] SampledDirection sample_wi(Expr<float3> wo, Expr<float2> u, TransportMode mode) const noexcept override;
[[nodiscard]] SampledSpectrum forward_compute(Expr<float3> wo, Expr<float3> wi, TransportMode mode, SampledSpectrum r, Float2 alpha) const noexcept;
[[nodiscard]] Float pdf(Expr<float3> wo, Expr<float3> wi, TransportMode mode) const noexcept override;
[[nodiscard]] Gradient backward(Expr<float3> wo, Expr<float3> wi, const SampledSpectrum &df, TransportMode mode) const noexcept;
[[nodiscard]] SampledSpectrum albedo() const noexcept override { return _r; }
Expand Down Expand Up @@ -222,6 +230,7 @@ class MicrofacetTransmission : public BxDF {
Expr<float> etaA, Expr<float> etaB) noexcept
: _t{T}, _distribution{d}, _eta_a{etaA}, _eta_b{etaB} {}
[[nodiscard]] SampledSpectrum evaluate(Expr<float3> wo, Expr<float3> wi, TransportMode mode) const noexcept override;
[[nodiscard]] SampledSpectrum forward_compute(Expr<float3> wo, Expr<float3> wi, TransportMode mode, SampledSpectrum r, Float2 alpha) const noexcept;
[[nodiscard]] SampledDirection sample_wi(Expr<float3> wo, Expr<float2> u, TransportMode mode) const noexcept override;
[[nodiscard]] Float pdf(Expr<float3> wo, Expr<float3> wi, TransportMode mode) const noexcept override;
[[nodiscard]] Gradient backward(Expr<float3> wo, Expr<float3> wi, const SampledSpectrum &df, TransportMode mode) const noexcept;
Expand Down

0 comments on commit 5816f61

Please sign in to comment.