From 104f618971bdffabca37725ac533a546b0089568 Mon Sep 17 00:00:00 2001 From: healthonrails Date: Wed, 13 Nov 2024 20:34:46 -0500 Subject: [PATCH] Write predicted results to JSON files in batches during inference to enable real-time tracking visualization in the UI --- annolid/tracker/cotracker/track.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/annolid/tracker/cotracker/track.py b/annolid/tracker/cotracker/track.py index 4f35f18..88eb70b 100644 --- a/annolid/tracker/cotracker/track.py +++ b/annolid/tracker/cotracker/track.py @@ -209,8 +209,15 @@ def process_video(self, pred_tracks, pred_visibility = self.process_step( window_frames, is_first_step, grid_size, grid_query_frame) if pred_tracks is not None: - logger.info( - f"Tracking frame {i}, {pred_tracks.shape}, {pred_visibility.shape}") + if i % 100 == 0: + logger.info( + f"Tracking frame {i}, {pred_tracks.shape}, {pred_visibility.shape}") + if is_first_step: + batch_size = self.model.step * 2 + else: + batch_size = self.model.step + self.extract_frame_points( + pred_tracks, pred_visibility, batch_size, i) is_first_step = False window_frames.append(frame) if len(window_frames) > self.model.step * 2: @@ -288,9 +295,11 @@ def save_current_frame_tracked_points_to_json(self, frame_number, points): def extract_frame_points(self, tracks: torch.Tensor, visibility: torch.Tensor = None, query_frame: int = 0, - start_frame: int = 0): + start_frame: int = 0 + ): tracks = tracks[0].long().detach().cpu().numpy() - for t in range(query_frame, tracks.shape[0]): + batch_start = start_frame-query_frame + for t in range(batch_start, tracks.shape[0]): _points = [] for i in range(tracks.shape[1]): coord = (tracks[t, i, 0], tracks[t, i, 1]) @@ -300,7 +309,7 @@ def extract_frame_points(self, tracks: torch.Tensor, visible = visibility[0, t, i].item() _points.append((coord, visible)) self.save_current_frame_tracked_points_to_json( - t + start_frame, _points) + t, _points) message = f"Saved all json file #{tracks.shape[0]}" logger.info(message) return message