Skip to content

Commit

Permalink
Merge pull request #20 from ssmmoo1/main
Browse files Browse the repository at this point in the history
Fix bug with non square images
  • Loading branch information
jaybdub authored Mar 12, 2024
2 parents 9ae3d83 + 995f6d2 commit cfef75a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 5 additions & 2 deletions examples/owl_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg")
parser.add_argument("--prompt", type=str, default="an owl, a glove")
parser.add_argument("--prompt", type=str, default="[an owl, a glove]")
parser.add_argument("--threshold", type=str, default="0.1,0.1")
parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg")
parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
Expand All @@ -45,7 +45,10 @@

thresholds = args.threshold.strip("][()")
thresholds = thresholds.split(',')
thresholds = [float(x) for x in thresholds]
if len(thresholds) == 1:
thresholds = float(thresholds[0])
else:
thresholds = [float(x) for x in thresholds]
print(thresholds)


Expand Down
4 changes: 2 additions & 2 deletions nanoowl/owl_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def decode(self,
) -> OwlDecodeOutput:

if isinstance(threshold, (int, float)):
threshold = [threshold]
threshold = [threshold] * len(text_output.text_embeds) #apply single threshold to all labels

num_input_images = image_output.image_class_embeds.shape[0]

Expand Down Expand Up @@ -468,7 +468,7 @@ def predict(self,
if text_encodings is None:
text_encodings = self.encode_text(text)

rois = torch.tensor([[0, 0, image.height, image.width]], dtype=image_tensor.dtype, device=image_tensor.device)
rois = torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device)

image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)

Expand Down

0 comments on commit cfef75a

Please sign in to comment.