Skip to content

Commit

Permalink
Merge pull request #9 from ssmmoo1/main
Browse files Browse the repository at this point in the history
Add thresholds per object for owl_predictor
  • Loading branch information
jaybdub authored Jan 24, 2024
2 parents cca8017 + d8fa78b commit 60c5d9f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
18 changes: 11 additions & 7 deletions examples/owl_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,25 @@

parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg")
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--threshold", type=float, default=0.1)
parser.add_argument("--prompt", type=str, default="an owl, a glove")
parser.add_argument("--threshold", type=str, default="0.1,0.1")
parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg")
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
parser.add_argument("--image_encoder_engine", type=str, default="../data/owlvit_image_encoder_patch32.engine")
parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num_profiling_runs", type=int, default=30)
args = parser.parse_args()

prompt = args.prompt.strip("][()")

text = prompt.split(',')

print(text)

thresholds = args.threshold.strip("][()")
thresholds = thresholds.split(',')
thresholds = [float(x) for x in thresholds]
print(thresholds)


predictor = OwlPredictor(
args.model,
image_encoder_engine=args.image_encoder_engine
Expand All @@ -58,7 +62,7 @@
image=image,
text=text,
text_encodings=text_encodings,
threshold=args.threshold,
threshold=thresholds,
pad_square=False
)

Expand All @@ -70,7 +74,7 @@
image=image,
text=text,
text_encodings=text_encodings,
threshold=args.threshold,
threshold=thresholds,
pad_square=False
)
torch.cuda.current_stream().synchronize()
Expand Down
2 changes: 1 addition & 1 deletion examples/tree_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
parser.add_argument("--threshold", type=float, default=0.1)
parser.add_argument("--output", type=str, default="../data/tree_predict_out.jpg")
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
parser.add_argument("--image_encoder_engine", type=str, default="../data/owlvit_image_encoder_patch32.engine")
parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
args = parser.parse_args()

predictor = TreePredictor(
Expand Down
20 changes: 16 additions & 4 deletions nanoowl/owl_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,12 @@ def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool
def decode(self,
image_output: OwlEncodeImageOutput,
text_output: OwlEncodeTextOutput,
threshold: float = 0.1
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
) -> OwlDecodeOutput:

if isinstance(threshold, (int, float)):
threshold = [threshold]

num_input_images = image_output.image_class_embeds.shape[0]

image_class_embeds = image_output.image_class_embeds
Expand All @@ -290,8 +293,16 @@ def decode(self,
scores_max = scores_sigmoid.max(dim=-1)
labels = scores_max.indices
scores = scores_max.values

mask = (scores > threshold)
masks = []
for i, thresh in enumerate(threshold):
label_mask = labels == i
score_mask = scores > thresh
obj_mask = torch.logical_and(label_mask,score_mask)
masks.append(obj_mask)

mask = masks[0]
for mask_t in masks[1:]:
mask = torch.logical_or(mask, mask_t)

input_indices = torch.arange(0, num_input_images, dtype=labels.dtype, device=labels.device)
input_indices = input_indices[:, None].repeat(1, self.num_patches)
Expand Down Expand Up @@ -447,8 +458,9 @@ def predict(self,
image: PIL.Image,
text: List[str],
text_encodings: Optional[OwlEncodeTextOutput],
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
pad_square: bool = True,
threshold: float = 0.1

) -> OwlDecodeOutput:

image_tensor = self.image_preprocessor.preprocess_pil_image(image)
Expand Down

0 comments on commit 60c5d9f

Please sign in to comment.