Skip to content

Commit

Permalink
More improvements to occlusion estimation for point tracking (#285)
Browse files Browse the repository at this point in the history
* Fix half-pixel error in query points for point tracking

* More improvements to occlusion estimation.  Depth map interpolation used raster coordinates rather than pixel.  We also detect occlusions by detecting when the point appears on the wrong segment in 2D, and when its surface normal faces away from the camera.  This helps deal with very thin objects, where the depth map alone is unreliable.
  • Loading branch information
cdoersch authored Mar 31, 2023
1 parent 6e4fe4c commit 95aff56
Showing 1 changed file with 175 additions and 18 deletions.
193 changes: 175 additions & 18 deletions challenges/point_tracking/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,44 @@ def reproject(coords, camera, camera_pos, num_frames, bbox=None):
# Project each point back to the image using the camera.
projections = project_point(camera, world_coords, num_frames)

return tf.transpose(projections, (1, 0, 2)), tf.transpose(depths)
return (
tf.transpose(projections, (1, 0, 2)),
tf.transpose(depths),
tf.transpose(world_coords, (1, 0, 2)),
)


def estimate_scene_depth_for_point(data, x, y, num_frames):
def estimate_occlusion_by_depth_and_segment(
data,
segments,
x,
y,
num_frames,
thresh,
seg_id,
):
"""Estimate depth at a (floating point) x,y position.
We prefer overestimating depth at the point, so we take the max over the 4
neightoring pixels.
Args:
data: depth map. First axis is num_frames.
segments: segmentation map. First axis is num_frames.
x: x coordinate. First axis is num_frames.
y: y coordinate. First axis is num_frames.
num_frames: number of frames.
thresh: Depth threshold at which we consider the point occluded.
seg_id: Original segment id. Assume occlusion if there's a mismatch.
Returns:
Depth for each point.
"""

# need to convert from raster to pixel coordinates
x = x - 0.5
y = y - 0.5

x0 = tf.cast(tf.floor(x), tf.int32)
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), tf.int32)
Expand All @@ -175,7 +195,21 @@ def estimate_scene_depth_for_point(data, x, y, num_frames):
i3 = tf.gather(data, rng * shp[1] * shp[2] + y0 * shp[2] + x1)
i4 = tf.gather(data, rng * shp[1] * shp[2] + y1 * shp[2] + x1)

return tf.maximum(tf.maximum(tf.maximum(i1, i2), i3), i4)
depth = tf.maximum(tf.maximum(tf.maximum(i1, i2), i3), i4)

segments = tf.reshape(segments, [-1])
i1 = tf.gather(segments, rng * shp[1] * shp[2] + y0 * shp[2] + x0)
i2 = tf.gather(segments, rng * shp[1] * shp[2] + y1 * shp[2] + x0)
i3 = tf.gather(segments, rng * shp[1] * shp[2] + y0 * shp[2] + x1)
i4 = tf.gather(segments, rng * shp[1] * shp[2] + y1 * shp[2] + x1)

depth_occluded = tf.less(tf.transpose(depth), thresh)
seg_occluded = True
for i in [i1, i2, i3, i4]:
i = tf.cast(i, tf.int32)
seg_occluded = tf.logical_and(seg_occluded, tf.not_equal(seg_id, i))

return tf.logical_or(depth_occluded, tf.transpose(seg_occluded))


def get_camera_matrices(
Expand Down Expand Up @@ -219,32 +253,93 @@ def get_camera_matrices(
)
matrix_world.append(transformation)

return tf.cast(tf.stack(intrinsics),
tf.float32), tf.cast(tf.stack(matrix_world), tf.float32)
return (
tf.cast(tf.stack(intrinsics), tf.float32),
tf.cast(tf.stack(matrix_world), tf.float32),
)


def quat2rot(quats):
"""Convert a list of quaternions to rotation matrices."""
rotation_matrices = []
for frame_idx in range(quats.shape[0]):
quat = quats[frame_idx]
rotation_matrix = rotation_matrix_3d.from_quaternion(
tf.concat([quat[1:], quat[0:1]], axis=0))
rotation_matrices.append(rotation_matrix)
return tf.cast(tf.stack(rotation_matrices), tf.float32)


def rotate_surface_normals(
world_frame_normals,
point_3d,
cam_pos,
obj_rot_mats,
frame_for_query,
):
"""Points are occluded if the surface normal points away from the camera."""
query_obj_rot_mat = tf.gather(obj_rot_mats, frame_for_query)
obj_frame_normals = tf.einsum(
'boi,bi->bo',
tf.linalg.inv(query_obj_rot_mat),
world_frame_normals,
)
world_frame_normals_frames = tf.einsum(
'foi,bi->bfo',
obj_rot_mats,
obj_frame_normals,
)
cam_to_pt = point_3d - cam_pos[tf.newaxis, :, :]
dots = tf.reduce_sum(world_frame_normals_frames * cam_to_pt, axis=-1)
faces_away = dots > 0

# If the query point also faces away, it's probably a bug in the meshes, so
# ignore the result of the test.
faces_away_query = tf.reduce_sum(
tf.cast(faces_away, tf.int32)
* tf.one_hot(frame_for_query, tf.shape(faces_away)[1], dtype=tf.int32),
axis=1,
keepdims=True,
)
faces_away = tf.logical_and(faces_away, tf.logical_not(faces_away_query > 0))
return faces_away


def single_object_reproject(
bbox_3d=None,
pt=None,
pt_segments=None,
camera=None,
cam_positions=None,
num_frames=None,
depth_map=None,
segments=None,
window=None,
input_size=None,
quat=None,
normals=None,
frame_for_pt=None,
trust_normals=None,
):
"""Reproject points for a single object.
Args:
bbox_3d: The object bounding box from Kubric. If none, assume it's
background.
pt: The set of points in 3D, with shape [num_points, 3]
pt_segments: The segment each point came from, with shape [num_points]
camera: Camera intrinsic parameters
cam_positions: Camera positions, with shape [num_frames, 3]
num_frames: Number of frames
depth_map: Depth map video for the camera
segments: Segmentation map video for the camera
window: the window inside which we're sampling points
input_size: [height, width] of the input images.
quat: Object quaternion [num_frames, 4]
normals: Point normals on the query frame [num_points, 3]
frame_for_pt: Integer frame where the query point came from [num_points]
trust_normals: Boolean flag for whether the surface normals for each query
are trustworthy [num_points]
Returns:
Position for each point, of shape [num_points, num_frames, 2], in pixel
Expand All @@ -254,7 +349,7 @@ def single_object_reproject(
"""
# Finally, reproject
reproj, depth_proj = reproject(
reproj, depth_proj, world_pos = reproject(
pt,
camera,
cam_positions,
Expand All @@ -267,20 +362,38 @@ def single_object_reproject(
np.newaxis, :]
occluded = tf.logical_or(
occluded,
tf.less(
tf.transpose(
estimate_scene_depth_for_point(depth_map[:, :, :, 0],
tf.transpose(reproj[:, :, 0]),
tf.transpose(reproj[:, :, 1]),
num_frames)), depth_proj * .99))
estimate_occlusion_by_depth_and_segment(
depth_map[:, :, :, 0],
segments[:, :, :, 0],
tf.transpose(reproj[:, :, 0]),
tf.transpose(reproj[:, :, 1]),
num_frames,
depth_proj * .99,
pt_segments,
),
)
obj_occ = occluded
obj_reproj = reproj

obj_occ = tf.logical_or(obj_occ, tf.less(obj_reproj[:, :, 1], window[0]))
obj_occ = tf.logical_or(obj_occ, tf.less(obj_reproj[:, :, 0], window[1]))
obj_occ = tf.logical_or(obj_occ, tf.greater(obj_reproj[:, :, 1], window[2]))
obj_occ = tf.logical_or(obj_occ, tf.greater(obj_reproj[:, :, 0], window[3]))
return obj_reproj, obj_occ

if quat is not None:
faces_away = rotate_surface_normals(
normals,
world_pos,
cam_positions,
quat2rot(quat),
frame_for_pt,
)
faces_away = tf.logical_and(faces_away, trust_normals)
else:
# world is convex; can't face away from cam.
faces_away = tf.zeros([tf.shape(pt)[0], num_frames], dtype=tf.bool)

return obj_reproj, tf.logical_or(faces_away, obj_occ)


def get_num_to_sample(counts, max_seg_id, max_sampled_frac, tracks_to_sample):
Expand Down Expand Up @@ -340,7 +453,9 @@ def track_points(
depth,
depth_range,
segmentations,
surface_normals,
bboxes_3d,
obj_quat,
cam_focal_length,
cam_positions,
cam_quaternions,
Expand All @@ -361,7 +476,11 @@ def track_points(
metric depth.
segmentations: Integer object id for each pixel. Shape
[num_frames, height, width]
surface_normals: uint16 surface normal map. Shape
[num_frames, height, width, 3]
bboxes_3d: The set of all object bounding boxes from Kubric
obj_quat: Quaternion rotation for each object. Shape
[num_objects, num_frames, 4]
cam_focal_length: Camera focal length
cam_positions: Camera positions, with shape [num_frames, 3]
cam_quaternions: Camera orientations, with shape [num_frames, 4]
Expand Down Expand Up @@ -398,6 +517,8 @@ def track_points(
depth_f32 = tf.cast(depth, tf.float32)
depth_map = depth_min + depth_f32 * (depth_max-depth_min) / 65535

surface_normal_map = surface_normals / 65535 * 2. - 1.

input_size = object_coordinates.shape.as_list()[1:3]
num_frames = object_coordinates.shape.as_list()[0]

Expand Down Expand Up @@ -442,6 +563,24 @@ def extract_box(x):
num_frames=num_frames,
)

# If the normal map is very rough, it's often because they come from a normal
# map rather than the mesh. These aren't trustworthy, and the normal test
# may fail (i.e. the normal is pointing away from the camera even though the
# point is still visible). So don't use the normal test when inferring
# occlusion.
trust_sn = True
sn_pad = tf.pad(surface_normal_map, [(0, 0), (1, 1), (1, 1), (0, 0)])
shp = surface_normal_map.shape
sum_thresh = 0
for i in [0, 2]:
for j in [0, 2]:
diff = sn_pad[:, i : shp[1] + i, j : shp[2] + j, :] - surface_normal_map
diff = tf.reduce_sum(tf.square(diff), axis=-1)
sum_thresh += tf.cast(diff > 0.05 * 0.05, tf.int32)
trust_sn = tf.logical_and(trust_sn, (sum_thresh <= 2))[..., tf.newaxis]
surface_normals_box = extract_box(surface_normal_map)
trust_sn_box = extract_box(trust_sn)

def get_camera(fr=None):
if fr is None:
return {'intrinsics': intrinsics, 'matrix_world': matrix_world}
Expand All @@ -465,6 +604,8 @@ def get_camera(fr=None):
obj_id = i - 1
mask = tf.equal(tf.reshape(segmentations_box, [-1]), i)
pt = tf.boolean_mask(tf.reshape(object_coordinates_box, [-1, 3]), mask)
normals = tf.boolean_mask(tf.reshape(surface_normals_box, [-1, 3]), mask)
trust_sn_mask = tf.boolean_mask(tf.reshape(trust_sn_box, [-1, 1]), mask)
idx = tf.cond(
tf.shape(pt)[0] > 0,
lambda: tf.multinomial( # pylint: disable=g-long-lambda
Expand All @@ -473,7 +614,9 @@ def get_camera(fr=None):
lambda: tf.zeros([0], dtype=tf.int64))
# note: pt_coords is pixel coordinates, not raster coordinates.
pt_coords = tf.gather(tf.boolean_mask(pix_coords, mask), idx)

normals = tf.gather(normals, idx)
trust_sn_gather = tf.gather(trust_sn_mask, idx)

pixel_to_raster = tf.constant([0.0, 0.5, 0.5])[tf.newaxis,:]

if obj_id == -1:
Expand All @@ -493,9 +636,11 @@ def get_camera(fr=None):
unproject(pt_coords_chunk[:, 1:], get_camera(fr), depth_map[fr]))
pt = tf.concat(pt_3d, axis=0)
chosen_points.append(
tf.cast(tf.concat(pt_coords_reorder, axis=0), tf.float32) +
tf.cast(tf.concat(pt_coords_reorder, axis=0), tf.float32) +
pixel_to_raster)
bbox = None
quat = None
frame_for_pt = None
else:
# For any other object, we just use the point coordinates supplied by
# kubric.
Expand All @@ -506,6 +651,9 @@ def get_camera(fr=None):
# points, so just use a dummy to prevent tf from crashing.
bbox = tf.cond(obj_id >= tf.shape(bboxes_3d)[0], lambda: bboxes_3d[0, :],
lambda: bboxes_3d[obj_id, :])
quat = tf.cond(obj_id >= tf.shape(obj_quat)[0], lambda: obj_quat[0, :],
lambda: obj_quat[obj_id, :])
frame_for_pt = pt_coords[..., 0]

# Finally, compute the reprojections for this particular object.
obj_reproj, obj_occ = tf.cond(
Expand All @@ -514,12 +662,18 @@ def get_camera(fr=None):
single_object_reproject,
bbox_3d=bbox,
pt=pt,
pt_segments=i,
camera=get_camera(),
cam_positions=cam_positions,
num_frames=num_frames,
depth_map=depth_map,
segments=segmentations,
window=window,
input_size=input_size,
quat=quat,
normals=normals,
frame_for_pt=frame_for_pt,
trust_normals=trust_sn_gather,
),
lambda: # pylint: disable=g-long-lambda
(tf.zeros([0, num_frames, 2], dtype=tf.float32),
Expand Down Expand Up @@ -652,7 +806,9 @@ def add_tracks(data,
query_points, target_points, occluded = track_points(
data['object_coordinates'], data['depth'],
data['metadata']['depth_range'], data['segmentations'],
data['instances']['bboxes_3d'], data['camera']['focal_length'],
data['normal'],
data['instances']['bboxes_3d'], data['instances']['quaternions'],
data['camera']['focal_length'],
data['camera']['positions'], data['camera']['quaternions'],
data['camera']['sensor_width'], crop_window, tracks_to_sample,
sampling_stride, max_seg_id, max_sampled_frac)
Expand Down Expand Up @@ -786,9 +942,10 @@ def plot_tracks(rgb, points, occluded, trackgroup=None):

colalpha = np.concatenate([colors[:, :-1], 1 - occluded[:, i:i + 1]],
axis=1)
# Note: matplotlib uses pixel corrdinates, not raster.
plt.scatter(
points[valid, i, 0],
points[valid, i, 1],
points[valid, i, 0] - 0.5,
points[valid, i, 1] - 0.5,
s=3,
c=colalpha[valid],
)
Expand Down

0 comments on commit 95aff56

Please sign in to comment.