diff --git a/nerface_code/nerf-pytorch/nerf/train_utils.py b/nerface_code/nerf-pytorch/nerf/train_utils.py index 1b9d56a6..abc88547 100644 --- a/nerface_code/nerf-pytorch/nerf/train_utils.py +++ b/nerface_code/nerf-pytorch/nerf/train_utils.py @@ -78,7 +78,8 @@ def predict_and_render_radiance( pts = ro[..., None, :] + rd[..., None, :] * z_vals[..., :, None] # Uncomment to dump a ply file visualizing camera rays and sampling points #dump_rays(ro.detach().cpu().numpy(), pts.detach().cpu().numpy()) - ray_batch[...,3:6] = ray_dirs_fake[0][...,3:6] # TODO remove this this is for ablation of ray dir + if ray_dirs_fake: + ray_batch[...,3:6] = ray_dirs_fake[0][...,3:6] # TODO remove this this is for ablation of ray dir radiance_field = run_network( model_coarse, @@ -178,6 +179,7 @@ def run_one_iter_of_nerf( latent_code = None, ray_directions_ablation = None ): + is_rad = torch.is_tensor(ray_directions_ablation) viewdirs = None if options.nerf.use_viewdirs: # Provide ray directions as input @@ -203,44 +205,66 @@ def run_one_iter_of_nerf( #"caling normal rays (not NDC)" ro = ray_origins.view((-1, 3)) rd = ray_directions.view((-1, 3)) - rd_ablations = ray_directions_ablation.view((-1, 3)) + if is_rad: + rd_ablations = ray_directions_ablation.view((-1, 3)) near = options.dataset.near * torch.ones_like(rd[..., :1]) far = options.dataset.far * torch.ones_like(rd[..., :1]) rays = torch.cat((ro, rd, near, far), dim=-1) - rays_ablation = torch.cat((ro, rd_ablations, near, far), dim=-1) + if is_rad: + rays_ablation = torch.cat((ro, rd_ablations, near, far), dim=-1) # if options.nerf.use_viewdirs: # TODO uncomment # rays = torch.cat((rays, viewdirs), dim=-1) # viewdirs = None # TODO remove this paragraph if options.nerf.use_viewdirs: # Provide ray directions as input - viewdirs = ray_directions_ablation - viewdirs = viewdirs / viewdirs.norm(p=2, dim=-1).unsqueeze(-1) - viewdirs = viewdirs.view((-1, 3)) + if is_rad: + viewdirs = ray_directions_ablation + viewdirs = viewdirs / viewdirs.norm(p=2, dim=-1).unsqueeze(-1) + viewdirs = viewdirs.view((-1, 3)) - batches_ablation = get_minibatches(rays_ablation, chunksize=getattr(options.nerf, mode).chunksize) + if is_rad: + batches_ablation = get_minibatches(rays_ablation, chunksize=getattr(options.nerf, mode).chunksize) batches = get_minibatches(rays, chunksize=getattr(options.nerf, mode).chunksize) assert(batches[0].shape == batches[0].shape) background_prior = get_minibatches(background_prior, chunksize=getattr(options.nerf, mode).chunksize) if\ background_prior is not None else background_prior #print("predicting") - pred = [ - predict_and_render_radiance( - batch, - model_coarse, - model_fine, - options, - mode, - encode_position_fn=encode_position_fn, - encode_direction_fn=encode_direction_fn, - expressions = expressions, - background_prior = background_prior[i] if background_prior is not None else background_prior, - latent_code = latent_code, - ray_dirs_fake = batches_ablation - ) - for i,batch in enumerate(batches) - ] + if is_rad: + pred = [ + predict_and_render_radiance( + batch, + model_coarse, + model_fine, + options, + mode, + encode_position_fn=encode_position_fn, + encode_direction_fn=encode_direction_fn, + expressions = expressions, + background_prior = background_prior[i] if background_prior is not None else background_prior, + latent_code = latent_code, + ray_dirs_fake = batches_ablation + ) + for i,batch in enumerate(batches) + ] + else: + pred = [ + predict_and_render_radiance( + batch, + model_coarse, + model_fine, + options, + mode, + encode_position_fn=encode_position_fn, + encode_direction_fn=encode_direction_fn, + expressions = expressions, + background_prior = background_prior[i] if background_prior is not None else background_prior, + latent_code = latent_code, + ray_dirs_fake = None + ) + for i,batch in enumerate(batches) + ] #print("predicted") synthesized_images = list(zip(*pred))