Skip to content

Commit

Permalink
Merge pull request #47 from seriousran/patch-3
Browse files Browse the repository at this point in the history
Skip for the case ray_directions_ablation is None
  • Loading branch information
gafniguy authored Oct 27, 2022
2 parents 8720e8e + 18e83b2 commit 6129c57
Showing 1 changed file with 47 additions and 23 deletions.
70 changes: 47 additions & 23 deletions nerface_code/nerf-pytorch/nerf/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def predict_and_render_radiance(
pts = ro[..., None, :] + rd[..., None, :] * z_vals[..., :, None]
# Uncomment to dump a ply file visualizing camera rays and sampling points
#dump_rays(ro.detach().cpu().numpy(), pts.detach().cpu().numpy())
ray_batch[...,3:6] = ray_dirs_fake[0][...,3:6] # TODO remove this this is for ablation of ray dir
if ray_dirs_fake:
ray_batch[...,3:6] = ray_dirs_fake[0][...,3:6] # TODO remove this this is for ablation of ray dir

radiance_field = run_network(
model_coarse,
Expand Down Expand Up @@ -178,6 +179,7 @@ def run_one_iter_of_nerf(
latent_code = None,
ray_directions_ablation = None
):
is_rad = torch.is_tensor(ray_directions_ablation)
viewdirs = None
if options.nerf.use_viewdirs:
# Provide ray directions as input
Expand All @@ -203,44 +205,66 @@ def run_one_iter_of_nerf(
#"caling normal rays (not NDC)"
ro = ray_origins.view((-1, 3))
rd = ray_directions.view((-1, 3))
rd_ablations = ray_directions_ablation.view((-1, 3))
if is_rad:
rd_ablations = ray_directions_ablation.view((-1, 3))
near = options.dataset.near * torch.ones_like(rd[..., :1])
far = options.dataset.far * torch.ones_like(rd[..., :1])
rays = torch.cat((ro, rd, near, far), dim=-1)
rays_ablation = torch.cat((ro, rd_ablations, near, far), dim=-1)
if is_rad:
rays_ablation = torch.cat((ro, rd_ablations, near, far), dim=-1)
# if options.nerf.use_viewdirs: # TODO uncomment
# rays = torch.cat((rays, viewdirs), dim=-1)
#
viewdirs = None # TODO remove this paragraph
if options.nerf.use_viewdirs:
# Provide ray directions as input
viewdirs = ray_directions_ablation
viewdirs = viewdirs / viewdirs.norm(p=2, dim=-1).unsqueeze(-1)
viewdirs = viewdirs.view((-1, 3))
if is_rad:
viewdirs = ray_directions_ablation
viewdirs = viewdirs / viewdirs.norm(p=2, dim=-1).unsqueeze(-1)
viewdirs = viewdirs.view((-1, 3))


batches_ablation = get_minibatches(rays_ablation, chunksize=getattr(options.nerf, mode).chunksize)
if is_rad:
batches_ablation = get_minibatches(rays_ablation, chunksize=getattr(options.nerf, mode).chunksize)
batches = get_minibatches(rays, chunksize=getattr(options.nerf, mode).chunksize)
assert(batches[0].shape == batches[0].shape)
background_prior = get_minibatches(background_prior, chunksize=getattr(options.nerf, mode).chunksize) if\
background_prior is not None else background_prior
#print("predicting")
pred = [
predict_and_render_radiance(
batch,
model_coarse,
model_fine,
options,
mode,
encode_position_fn=encode_position_fn,
encode_direction_fn=encode_direction_fn,
expressions = expressions,
background_prior = background_prior[i] if background_prior is not None else background_prior,
latent_code = latent_code,
ray_dirs_fake = batches_ablation
)
for i,batch in enumerate(batches)
]
if is_rad:
pred = [
predict_and_render_radiance(
batch,
model_coarse,
model_fine,
options,
mode,
encode_position_fn=encode_position_fn,
encode_direction_fn=encode_direction_fn,
expressions = expressions,
background_prior = background_prior[i] if background_prior is not None else background_prior,
latent_code = latent_code,
ray_dirs_fake = batches_ablation
)
for i,batch in enumerate(batches)
]
else:
pred = [
predict_and_render_radiance(
batch,
model_coarse,
model_fine,
options,
mode,
encode_position_fn=encode_position_fn,
encode_direction_fn=encode_direction_fn,
expressions = expressions,
background_prior = background_prior[i] if background_prior is not None else background_prior,
latent_code = latent_code,
ray_dirs_fake = None
)
for i,batch in enumerate(batches)
]
#print("predicted")

synthesized_images = list(zip(*pred))
Expand Down

0 comments on commit 6129c57

Please sign in to comment.