-
Notifications
You must be signed in to change notification settings - Fork 101
/
utils.py
125 lines (90 loc) · 4.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import json
import numpy as np
import pdb
import torch
from ray_utils import get_rays, get_ray_directions, get_ndc_rays
BOX_OFFSETS = torch.tensor([[[i,j,k] for i in [0, 1] for j in [0, 1] for k in [0, 1]]],
device='cuda')
def hash(coords, log2_hashmap_size):
'''
coords: this function can process upto 7 dim coordinates
log2T: logarithm of T w.r.t 2
'''
primes = [1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737]
xor_result = torch.zeros_like(coords)[..., 0]
for i in range(coords.shape[-1]):
xor_result ^= coords[..., i]*primes[i]
return torch.tensor((1<<log2_hashmap_size)-1).to(xor_result.device) & xor_result
def get_bbox3d_for_blenderobj(camera_transforms, H, W, near=2.0, far=6.0):
camera_angle_x = float(camera_transforms['camera_angle_x'])
focal = 0.5*W/np.tan(0.5 * camera_angle_x)
# ray directions in camera coordinates
directions = get_ray_directions(H, W, focal)
min_bound = [100, 100, 100]
max_bound = [-100, -100, -100]
points = []
for frame in camera_transforms["frames"]:
c2w = torch.FloatTensor(frame["transform_matrix"])
rays_o, rays_d = get_rays(directions, c2w)
def find_min_max(pt):
for i in range(3):
if(min_bound[i] > pt[i]):
min_bound[i] = pt[i]
if(max_bound[i] < pt[i]):
max_bound[i] = pt[i]
return
for i in [0, W-1, H*W-W, H*W-1]:
min_point = rays_o[i] + near*rays_d[i]
max_point = rays_o[i] + far*rays_d[i]
points += [min_point, max_point]
find_min_max(min_point)
find_min_max(max_point)
return (torch.tensor(min_bound)-torch.tensor([1.0,1.0,1.0]), torch.tensor(max_bound)+torch.tensor([1.0,1.0,1.0]))
def get_bbox3d_for_llff(poses, hwf, near=0.0, far=1.0):
H, W, focal = hwf
H, W = int(H), int(W)
# ray directions in camera coordinates
directions = get_ray_directions(H, W, focal)
min_bound = [100, 100, 100]
max_bound = [-100, -100, -100]
points = []
poses = torch.FloatTensor(poses)
for pose in poses:
rays_o, rays_d = get_rays(directions, pose)
rays_o, rays_d = get_ndc_rays(H, W, focal, 1.0, rays_o, rays_d)
def find_min_max(pt):
for i in range(3):
if(min_bound[i] > pt[i]):
min_bound[i] = pt[i]
if(max_bound[i] < pt[i]):
max_bound[i] = pt[i]
return
for i in [0, W-1, H*W-W, H*W-1]:
min_point = rays_o[i] + near*rays_d[i]
max_point = rays_o[i] + far*rays_d[i]
points += [min_point, max_point]
find_min_max(min_point)
find_min_max(max_point)
return (torch.tensor(min_bound)-torch.tensor([0.1,0.1,0.0001]), torch.tensor(max_bound)+torch.tensor([0.1,0.1,0.0001]))
def get_voxel_vertices(xyz, bounding_box, resolution, log2_hashmap_size):
'''
xyz: 3D coordinates of samples. B x 3
bounding_box: min and max x,y,z coordinates of object bbox
resolution: number of voxels per axis
'''
box_min, box_max = bounding_box
keep_mask = xyz==torch.max(torch.min(xyz, box_max), box_min)
if not torch.all(xyz <= box_max) or not torch.all(xyz >= box_min):
# print("ALERT: some points are outside bounding box. Clipping them!")
xyz = torch.clamp(xyz, min=box_min, max=box_max)
grid_size = (box_max-box_min)/resolution
bottom_left_idx = torch.floor((xyz-box_min)/grid_size).int()
voxel_min_vertex = bottom_left_idx*grid_size + box_min
voxel_max_vertex = voxel_min_vertex + torch.tensor([1.0,1.0,1.0])*grid_size
voxel_indices = bottom_left_idx.unsqueeze(1) + BOX_OFFSETS
hashed_voxel_indices = hash(voxel_indices, log2_hashmap_size)
return voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices, keep_mask
if __name__=="__main__":
with open("data/nerf_synthetic/chair/transforms_train.json", "r") as f:
camera_transforms = json.load(f)
bounding_box = get_bbox3d_for_blenderobj(camera_transforms, 800, 800)