Skip to content

Commit

Permalink
correct two small bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 10, 2024
1 parent f0b1d32 commit 21b3ad1
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/main/java/ai/nets/samj/models/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,6 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
code = String.format(code, size);
code += "])" + System.lineSeparator();
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
code += "num_features -= 1" + System.lineSeparator();
}
code += ""
+ "contours_x = []" + System.lineSeparator()
Expand All @@ -442,7 +441,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
+ " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator()
+ " for pp in range(n_points):" + System.lineSeparator()
+ " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
+ " extracted_point_labels += [n_feat]" + System.lineSeparator()
+ " extracted_point_labels += [1]" + System.lineSeparator()
+ " ip = torch.reshape(torch.tensor(np.array(extracted_point_prompts).reshape(len(extracted_point_prompts), 2)), [1, 1, -1, 2])" + System.lineSeparator()
+ " il = torch.reshape(torch.tensor(np.array(extracted_point_labels)), [1, 1, -1])" + System.lineSeparator()
+ " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator()
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/ai/nets/samj/models/EfficientViTSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,6 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
code = String.format(code, size);
code += "])" + System.lineSeparator();
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
code += "num_features -= 1" + System.lineSeparator();
}
code += ""
+ "contours_x = []" + System.lineSeparator()
Expand All @@ -523,7 +522,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
+ " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator()
+ " for pp in range(n_points):" + System.lineSeparator()
+ " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
+ " extracted_point_labels += [n_feat]" + System.lineSeparator()
+ " extracted_point_labels += [1]" + System.lineSeparator()
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=np.array(extracted_point_prompts)," + System.lineSeparator()
+ " point_labels=np.array(extracted_point_labels)," + System.lineSeparator()
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu
code = String.format(code, size);
code += "])" + System.lineSeparator();
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
code += "num_features -= 1" + System.lineSeparator();
}
code += ""
+ "contours_x = []" + System.lineSeparator()
Expand All @@ -497,7 +496,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu
+ " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator()
+ " for pp in range(n_points):" + System.lineSeparator()
+ " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
+ " extracted_point_labels += [n_feat]" + System.lineSeparator()
+ " extracted_point_labels += [1]" + System.lineSeparator()
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=np.array(extracted_point_prompts)," + System.lineSeparator()
+ " point_labels=np.array(extracted_point_labels)," + System.lineSeparator()
Expand Down

0 comments on commit 21b3ad1

Please sign in to comment.