diff --git a/src/run_nerf.py b/src/run_nerf.py index 0019d23..38bd35b 100644 --- a/src/run_nerf.py +++ b/src/run_nerf.py @@ -460,7 +460,8 @@ def train(args): else: T = poses_tensor.clone() - depth_confidences = cal_depth_confidences(torch.from_numpy(depth_priors).to(device), T, K, args.topk) + depth_confidences = cal_depth_confidences(torch.from_numpy(depth_priors).to(device), + T, K, i_train, args.topk) print('DEFINING BOUNDS') if args.no_ndc: diff --git a/utils/nerf_utils.py b/utils/nerf_utils.py index 095482f..929a73b 100644 --- a/utils/nerf_utils.py +++ b/utils/nerf_utils.py @@ -20,28 +20,31 @@ def align_scales(depth_priors, colmap_depths, colmap_masks, poses, sc, i_train, depth_priors = depth_priors * sc * ratio_priors #align scales return depth_priors -def cal_depth_confidences(depths, T, K, topk=4): - view_num, H, W = depths.shape +def cal_depth_confidences(depths, T, K, i_train, topk=4): + _, H, W = depths.shape + view_num = len(i_train) invK = torch.inverse(K) batch_K = torch.unsqueeze(K, 0).repeat(view_num, 1, 1) - batch_invK = torch.unsqueeze(invK, 0).repeat(view_num, 1, 1) - invT = torch.inverse(T) + batch_invK = torch.unsqueeze(invK, 0).repeat(depths.shape[0], 1, 1) + T_train = T[i_train] + invT = torch.inverse(T_train) pix_coords = calculate_coords(W, H) cam_points = BackprojectDepth(depths, batch_invK, pix_coords) depth_confidences = [] - for i in range(view_num): + for i in range(depths.shape[0]): cam_points_i = cam_points[i:i+1].repeat(view_num, 1, 1) T_i = torch.matmul(invT, T[i:i+1].repeat(view_num, 1, 1)) pix_coords_ref = Project3D(cam_points_i, batch_K, T_i, H, W) depths_ = Project3D_depth(cam_points_i, batch_K, T_i, H, W) - depths_proj = F.grid_sample(depths.unsqueeze(1), pix_coords_ref, + depths_proj = F.grid_sample(depths[i_train].unsqueeze(1), pix_coords_ref, padding_mode="zeros").squeeze() error = torch.abs(depths_proj - depths_) / (depths_ + 1e-7) depth_confidence, _ = error.topk(k=topk, dim=0, largest=False) depth_confidence = depth_confidence.mean(0).cpu().numpy() depth_confidences.append(depth_confidence) return np.stack(depth_confidences, 0) + return np.stack(depth_confidences, 0) def calculate_coords(W, H): meshgrid = np.meshgrid(range(W), range(H), indexing='xy')