From 3acb74864892b86460a15d235db50b8f5eccbd97 Mon Sep 17 00:00:00 2001 From: Sammy Ochoa Date: Thu, 28 Dec 2023 11:03:16 -0600 Subject: [PATCH 1/3] add thresholds per object for owl_predictor --- examples/owl_predict.py | 18 +++++++++++------- nanoowl/owl_predictor.py | 19 ++++++++++++++----- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/examples/owl_predict.py b/examples/owl_predict.py index 06480ac..736fd58 100644 --- a/examples/owl_predict.py +++ b/examples/owl_predict.py @@ -30,21 +30,25 @@ parser = argparse.ArgumentParser() parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg") - parser.add_argument("--prompt", type=str, default="") - parser.add_argument("--threshold", type=float, default=0.1) + parser.add_argument("--prompt", type=str, default="an owl, a glove") + parser.add_argument("--thresholds", type=str, default="0.1,0.1") parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg") parser.add_argument("--model", type=str, default="google/owlvit-base-patch32") - parser.add_argument("--image_encoder_engine", type=str, default="../data/owlvit_image_encoder_patch32.engine") + parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine") parser.add_argument("--profile", action="store_true") parser.add_argument("--num_profiling_runs", type=int, default=30) args = parser.parse_args() prompt = args.prompt.strip("][()") - text = prompt.split(',') - print(text) + thresholds = args.thresholds.strip("][()") + thresholds = thresholds.split(',') + thresholds = [float(x) for x in thresholds] + print(thresholds) + + predictor = OwlPredictor( args.model, image_encoder_engine=args.image_encoder_engine @@ -58,7 +62,7 @@ image=image, text=text, text_encodings=text_encodings, - threshold=args.threshold, + thresholds=thresholds, pad_square=False ) @@ -70,7 +74,7 @@ image=image, text=text, text_encodings=text_encodings, - threshold=args.threshold, + thresholds=thresholds, pad_square=False ) torch.cuda.current_stream().synchronize() diff --git a/nanoowl/owl_predictor.py b/nanoowl/owl_predictor.py index a8d8c7a..88418de 100644 --- a/nanoowl/owl_predictor.py +++ b/nanoowl/owl_predictor.py @@ -274,7 +274,7 @@ def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool def decode(self, image_output: OwlEncodeImageOutput, text_output: OwlEncodeTextOutput, - threshold: float = 0.1 + thresholds: List[float], ) -> OwlDecodeOutput: num_input_images = image_output.image_class_embeds.shape[0] @@ -290,8 +290,16 @@ def decode(self, scores_max = scores_sigmoid.max(dim=-1) labels = scores_max.indices scores = scores_max.values - - mask = (scores > threshold) + masks = [] + for i, threshold in enumerate(thresholds): + label_mask = labels == i + score_mask = scores > threshold + obj_mask = torch.logical_and(label_mask,score_mask) + masks.append(obj_mask) + + mask = masks[0] + for mask_t in masks[1:]: + mask = torch.logical_or(mask, mask_t) input_indices = torch.arange(0, num_input_images, dtype=labels.dtype, device=labels.device) input_indices = input_indices[:, None].repeat(1, self.num_patches) @@ -447,8 +455,9 @@ def predict(self, image: PIL.Image, text: List[str], text_encodings: Optional[OwlEncodeTextOutput], + thresholds: List[float], pad_square: bool = True, - threshold: float = 0.1 + ) -> OwlDecodeOutput: image_tensor = self.image_preprocessor.preprocess_pil_image(image) @@ -460,5 +469,5 @@ def predict(self, image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square) - return self.decode(image_encodings, text_encodings, threshold) + return self.decode(image_encodings, text_encodings, thresholds) From b9381678503b2dbaa77b384ee95e0db79209f3fc Mon Sep 17 00:00:00 2001 From: Sammy Ochoa Date: Fri, 12 Jan 2024 16:40:56 -0600 Subject: [PATCH 2/3] make thresholds backwards compatible --- examples/owl_predict.py | 6 +++--- examples/tree_predict.py | 2 +- nanoowl/owl_predictor.py | 13 ++++++++----- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/owl_predict.py b/examples/owl_predict.py index 736fd58..647edbf 100644 --- a/examples/owl_predict.py +++ b/examples/owl_predict.py @@ -31,7 +31,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg") parser.add_argument("--prompt", type=str, default="an owl, a glove") - parser.add_argument("--thresholds", type=str, default="0.1,0.1") + parser.add_argument("--threshold", type=str, default="0.1,0.1") parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg") parser.add_argument("--model", type=str, default="google/owlvit-base-patch32") parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine") @@ -43,7 +43,7 @@ text = prompt.split(',') print(text) - thresholds = args.thresholds.strip("][()") + thresholds = args.threshold.strip("][()") thresholds = thresholds.split(',') thresholds = [float(x) for x in thresholds] print(thresholds) @@ -62,7 +62,7 @@ image=image, text=text, text_encodings=text_encodings, - thresholds=thresholds, + threshold=thresholds, pad_square=False ) diff --git a/examples/tree_predict.py b/examples/tree_predict.py index 90162a9..abf2491 100644 --- a/examples/tree_predict.py +++ b/examples/tree_predict.py @@ -33,7 +33,7 @@ parser.add_argument("--threshold", type=float, default=0.1) parser.add_argument("--output", type=str, default="../data/tree_predict_out.jpg") parser.add_argument("--model", type=str, default="google/owlvit-base-patch32") - parser.add_argument("--image_encoder_engine", type=str, default="../data/owlvit_image_encoder_patch32.engine") + parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine") args = parser.parse_args() predictor = TreePredictor( diff --git a/nanoowl/owl_predictor.py b/nanoowl/owl_predictor.py index 88418de..a462a04 100644 --- a/nanoowl/owl_predictor.py +++ b/nanoowl/owl_predictor.py @@ -274,9 +274,12 @@ def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool def decode(self, image_output: OwlEncodeImageOutput, text_output: OwlEncodeTextOutput, - thresholds: List[float], + threshold: Union[int, float, List[Union[int, float]]] = 0.1, ) -> OwlDecodeOutput: + if isinstance(threshold, (int, float)): + threshold = [threshold] + num_input_images = image_output.image_class_embeds.shape[0] image_class_embeds = image_output.image_class_embeds @@ -291,9 +294,9 @@ def decode(self, labels = scores_max.indices scores = scores_max.values masks = [] - for i, threshold in enumerate(thresholds): + for i, thresh in enumerate(threshold): label_mask = labels == i - score_mask = scores > threshold + score_mask = scores > thresh obj_mask = torch.logical_and(label_mask,score_mask) masks.append(obj_mask) @@ -455,7 +458,7 @@ def predict(self, image: PIL.Image, text: List[str], text_encodings: Optional[OwlEncodeTextOutput], - thresholds: List[float], + threshold: Union[int, float, List[Union[int, float]]] = 0.1, pad_square: bool = True, ) -> OwlDecodeOutput: @@ -469,5 +472,5 @@ def predict(self, image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square) - return self.decode(image_encodings, text_encodings, thresholds) + return self.decode(image_encodings, text_encodings, threshold) From d8fa78bb7d9dcfd35f79ea9e31ff56c8c7656669 Mon Sep 17 00:00:00 2001 From: Sammy Ochoa Date: Tue, 16 Jan 2024 10:37:40 -0600 Subject: [PATCH 3/3] update threshold --- examples/owl_predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/owl_predict.py b/examples/owl_predict.py index 647edbf..c3d1b2c 100644 --- a/examples/owl_predict.py +++ b/examples/owl_predict.py @@ -74,7 +74,7 @@ image=image, text=text, text_encodings=text_encodings, - thresholds=thresholds, + threshold=thresholds, pad_square=False ) torch.cuda.current_stream().synchronize()