diff --git a/examples/owl_predict.py b/examples/owl_predict.py index c3d1b2c..e145247 100644 --- a/examples/owl_predict.py +++ b/examples/owl_predict.py @@ -30,7 +30,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg") - parser.add_argument("--prompt", type=str, default="an owl, a glove") + parser.add_argument("--prompt", type=str, default="[an owl, a glove]") parser.add_argument("--threshold", type=str, default="0.1,0.1") parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg") parser.add_argument("--model", type=str, default="google/owlvit-base-patch32") @@ -45,7 +45,10 @@ thresholds = args.threshold.strip("][()") thresholds = thresholds.split(',') - thresholds = [float(x) for x in thresholds] + if len(thresholds) == 1: + thresholds = float(thresholds[0]) + else: + thresholds = [float(x) for x in thresholds] print(thresholds) diff --git a/nanoowl/owl_predictor.py b/nanoowl/owl_predictor.py index a462a04..1afb897 100644 --- a/nanoowl/owl_predictor.py +++ b/nanoowl/owl_predictor.py @@ -278,7 +278,7 @@ def decode(self, ) -> OwlDecodeOutput: if isinstance(threshold, (int, float)): - threshold = [threshold] + threshold = [threshold] * len(text_output.text_embeds) #apply single threshold to all labels num_input_images = image_output.image_class_embeds.shape[0] @@ -468,7 +468,7 @@ def predict(self, if text_encodings is None: text_encodings = self.encode_text(text) - rois = torch.tensor([[0, 0, image.height, image.width]], dtype=image_tensor.dtype, device=image_tensor.device) + rois = torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device) image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)