How to take two gradients for two separate flax.nn optimizers in one loss function? #390
Unanswered
BoyuanJackChen
asked this question in
General
Replies: 1 comment
-
As far as I can tell I think that you're doing things correctly from the point of view of the flax and jax apis. (based on a first quick look). if you fix the "fine model" params does the coarse model still degenerate? My guess if I had to make one would be a subtle bug in the loss function or the model.... |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Description of the model to be implemented
I am implementing NeRF with flax. The paper mentions a hierarchical rendering that implements a coarse network and fine network, which are updated simultaneously from one loss value. Below is my attempt in doing it. As you can see, rgb_c is generated by the coarse model; while rgb_f is generated by the fine model. The total loss function is the sum of the mean-squared-loss of each rendering. The code works, and the fine model is updated well. Nonetheless, the coarse model keeps getting worse until it renders everything totally black. I wonder if there is some problem in my loss function. I tried to give the coarse network a lower learning rate but it didn't work. The grads just meant to make it dark.
Dataset the model could be trained on
Image data
Specific points to consider
Below is the render_rays_cnf function
Reference implementations in other frameworks
Beta Was this translation helpful? Give feedback.
All reactions