From 3af9655d2c2cdb50cbd166ea971702cc9be37e11 Mon Sep 17 00:00:00 2001 From: Carl Doersch Date: Tue, 26 Mar 2024 13:56:39 +0000 Subject: [PATCH] Add support for relative depth in point tracking reader. --- challenges/point_tracking/dataset.py | 34 ++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/challenges/point_tracking/dataset.py b/challenges/point_tracking/dataset.py index 6760a221..8e669db5 100644 --- a/challenges/point_tracking/dataset.py +++ b/challenges/point_tracking/dataset.py @@ -407,7 +407,7 @@ def single_object_reproject( # world is convex; can't face away from cam. faces_away = tf.zeros([tf.shape(pt)[0], num_frames], dtype=tf.bool) - return obj_reproj, tf.logical_or(faces_away, obj_occ) + return obj_reproj, tf.logical_or(faces_away, obj_occ), depth_proj def get_num_to_sample(counts, max_seg_id, max_sampled_frac, tracks_to_sample): @@ -522,6 +522,8 @@ def track_points( chosen_points = [] all_reproj = [] all_occ = [] + chosen_points_depth = [] + all_reproj_depth = [] # Convert to metric depth @@ -669,8 +671,18 @@ def get_camera(fr=None): lambda: obj_quat[obj_id, :]) frame_for_pt = pt_coords[..., 0] + pt_depth = [] + for fr in range(num_frames): + pt_coords_chunk = tf.boolean_mask( + pt_coords, tf.equal(pt_coords[:, 0], fr) + ) + shp = tf.convert_to_tensor(tf.shape(depth_map[fr])) + idx = pt_coords_chunk[:, 1] * shp[1] + pt_coords_chunk[:, 2] + pt_depth.append(tf.gather(tf.reshape(depth_map[fr], [-1]), idx)) + chosen_points_depth.append(tf.concat(pt_depth, axis=0)) + # Finally, compute the reprojections for this particular object. - obj_reproj, obj_occ = tf.cond( + obj_reproj, obj_occ, reproj_depth = tf.cond( tf.shape(pt)[0] > 0, functools.partial( single_object_reproject, @@ -689,11 +701,13 @@ def get_camera(fr=None): frame_for_pt=frame_for_pt, trust_normals=trust_sn_gather, ), - lambda: # pylint: disable=g-long-lambda - (tf.zeros([0, num_frames, 2], dtype=tf.float32), - tf.zeros([0, num_frames], dtype=tf.bool))) + lambda: ( # pylint: disable=g-long-lambda + tf.zeros([0, num_frames, 2], dtype=tf.float32), + tf.zeros([0, num_frames], dtype=tf.bool), + tf.zeros([0, num_frames], dtype=tf.float32))) all_reproj.append(obj_reproj) all_occ.append(obj_occ) + all_reproj_depth.append(reproj_depth) # Points are currently in pixel coordinates of the original video. We now # convert them to coordinates within the window frame, and rescale to @@ -712,6 +726,8 @@ def get_camera(fr=None): all_reproj = (all_reproj - window_top_left) / window_size all_reproj = all_reproj * coord_multiplier[2:0:-1] all_occ = tf.concat(all_occ, axis=0) + chosen_points_depth = tf.concat(chosen_points_depth, axis=0) + all_reproj_depth = tf.concat(all_reproj_depth, axis=0) # chosen_points is [num_points, (z,y,x)] chosen_points = tf.concat(chosen_points, axis=0) @@ -723,8 +739,10 @@ def get_camera(fr=None): chosen_points = chosen_points * coord_multiplier # Note: all_reproj is in (x,y) format, but chosen_points is in (z,y,x) format + all_relative_depth = all_reproj_depth / chosen_points_depth[..., tf.newaxis] + return tf.cast(chosen_points, tf.float32), tf.cast(all_reproj, - tf.float32), all_occ + tf.float32), all_occ, all_relative_depth def _get_distorted_bounding_box( @@ -817,7 +835,7 @@ def add_tracks(data, dtype=tf.int32, shape=[4]) - query_points, target_points, occluded = track_points( + query_points, target_points, occluded, relative_depth = track_points( data['object_coordinates'], data['depth'], data['metadata']['depth_range'], data['segmentations'], data['normal'], @@ -831,6 +849,7 @@ def add_tracks(data, shp = video.shape.as_list() query_points.set_shape([tracks_to_sample, 3]) target_points.set_shape([tracks_to_sample, num_frames, 2]) + relative_depth.set_shape([tracks_to_sample, num_frames]) occluded.set_shape([tracks_to_sample, num_frames]) # Crop the video to the sampled window, in a way which matches the coordinate @@ -851,6 +870,7 @@ def add_tracks(data, res = { 'query_points': query_points, 'target_points': target_points, + 'relative_depth': relative_depth, 'occluded': occluded, 'video': video / (255. / 2.) - 1., }