Skip to content

Commit

Permalink
fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
weiyithu committed Sep 18, 2021
1 parent 6f23ffc commit 5890f15
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def train(args):
else:
T = poses_tensor.clone()

depth_confidences = cal_depth_confidences(torch.from_numpy(depth_priors).to(device), T, K, args.topk)
depth_confidences = cal_depth_confidences(torch.from_numpy(depth_priors).to(device),
T, K, i_train, args.topk)

print('DEFINING BOUNDS')
if args.no_ndc:
Expand Down
15 changes: 9 additions & 6 deletions utils/nerf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,31 @@ def align_scales(depth_priors, colmap_depths, colmap_masks, poses, sc, i_train,
depth_priors = depth_priors * sc * ratio_priors #align scales
return depth_priors

def cal_depth_confidences(depths, T, K, topk=4):
view_num, H, W = depths.shape
def cal_depth_confidences(depths, T, K, i_train, topk=4):
_, H, W = depths.shape
view_num = len(i_train)
invK = torch.inverse(K)
batch_K = torch.unsqueeze(K, 0).repeat(view_num, 1, 1)
batch_invK = torch.unsqueeze(invK, 0).repeat(view_num, 1, 1)
invT = torch.inverse(T)
batch_invK = torch.unsqueeze(invK, 0).repeat(depths.shape[0], 1, 1)
T_train = T[i_train]
invT = torch.inverse(T_train)
pix_coords = calculate_coords(W, H)
cam_points = BackprojectDepth(depths, batch_invK, pix_coords)
depth_confidences = []

for i in range(view_num):
for i in range(depths.shape[0]):
cam_points_i = cam_points[i:i+1].repeat(view_num, 1, 1)
T_i = torch.matmul(invT, T[i:i+1].repeat(view_num, 1, 1))
pix_coords_ref = Project3D(cam_points_i, batch_K, T_i, H, W)
depths_ = Project3D_depth(cam_points_i, batch_K, T_i, H, W)
depths_proj = F.grid_sample(depths.unsqueeze(1), pix_coords_ref,
depths_proj = F.grid_sample(depths[i_train].unsqueeze(1), pix_coords_ref,
padding_mode="zeros").squeeze()
error = torch.abs(depths_proj - depths_) / (depths_ + 1e-7)
depth_confidence, _ = error.topk(k=topk, dim=0, largest=False)
depth_confidence = depth_confidence.mean(0).cpu().numpy()
depth_confidences.append(depth_confidence)
return np.stack(depth_confidences, 0)
return np.stack(depth_confidences, 0)

def calculate_coords(W, H):
meshgrid = np.meshgrid(range(W), range(H), indexing='xy')
Expand Down

0 comments on commit 5890f15

Please sign in to comment.