diff --git a/src/main/java/ai/nets/samj/models/EfficientSamJ.java b/src/main/java/ai/nets/samj/models/EfficientSamJ.java index 8669345..042a559 100644 --- a/src/main/java/ai/nets/samj/models/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientSamJ.java @@ -64,6 +64,7 @@ public class EfficientSamJ extends AbstractSamJ { + "from skimage import measure" + System.lineSeparator() + "measure.label(np.ones((10, 10)), connectivity=1)" + System.lineSeparator() + "import torch" + System.lineSeparator() + + "from scipy.ndimage import binary_fill_holes" + System.lineSeparator() + "import sys" + System.lineSeparator() + "sys.path.append(r'%s')" + System.lineSeparator() + "from multiprocessing import shared_memory" + System.lineSeparator() @@ -79,6 +80,7 @@ public class EfficientSamJ extends AbstractSamJ { + "globals()['measure'] = measure" + System.lineSeparator() + "globals()['np'] = np" + System.lineSeparator() + "globals()['torch'] = torch" + System.lineSeparator() + + "globals()['binary_fill_holes'] = binary_fill_holes" + System.lineSeparator() + "globals()['predictor'] = predictor" + System.lineSeparator(); /** diff --git a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java index 233b340..22e6939 100644 --- a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java @@ -79,6 +79,7 @@ public class EfficientViTSamJ extends AbstractSamJ { + "from skimage import measure" + System.lineSeparator() + "measure.label(np.ones((10, 10)), connectivity=1)" + System.lineSeparator() + "import torch" + System.lineSeparator() + + "from scipy.ndimage import binary_fill_holes" + System.lineSeparator() + "import sys" + System.lineSeparator() + "import os" + System.lineSeparator() + "os.chdir(r'%s')" + System.lineSeparator() @@ -107,6 +108,7 @@ public class EfficientViTSamJ extends AbstractSamJ { + "globals()['measure'] = measure" + System.lineSeparator() + "globals()['np'] = np" + System.lineSeparator() + "globals()['torch'] = torch" + System.lineSeparator() + + "globals()['binary_fill_holes'] = binary_fill_holes" + System.lineSeparator() + "globals()['predictor'] = predictor" + System.lineSeparator(); /** * String containing the Python imports code after it has been formatted with the correct diff --git a/src/main/java/ai/nets/samj/models/PythonMethods.java b/src/main/java/ai/nets/samj/models/PythonMethods.java index 7cb10d8..bb30c2b 100644 --- a/src/main/java/ai/nets/samj/models/PythonMethods.java +++ b/src/main/java/ai/nets/samj/models/PythonMethods.java @@ -89,8 +89,7 @@ public class PythonMethods { + " for obj in labels:" + System.lineSeparator() + " if obj.num_pixels >= at_least_of_this_size:" + System.lineSeparator() + " x_coords,y_coords = trace_contour(obj.image, obj.num_pixels, obj.bbox[1],obj.bbox[0])" + System.lineSeparator() - + " rle = encode_rle(obj.image * 1)" + System.lineSeparator() - + " print(np.array(rle)[1::2])" + System.lineSeparator() + + " rle = encode_rle(binary_fill_holes(obj.image))" + System.lineSeparator() + " bbox_w = obj.bbox[3] - obj.bbox[1]" + System.lineSeparator() + " for i in range(0, len(rle), 2):" + System.lineSeparator() + " rle[i] = sam_result.shape[1] * (obj.bbox[0] + rle[i] // bbox_w) + obj.bbox[1] + rle[i] % bbox_w" + System.lineSeparator() diff --git a/src/main/java/ai/nets/samj/models/Sam2.java b/src/main/java/ai/nets/samj/models/Sam2.java index 79938c6..adea456 100644 --- a/src/main/java/ai/nets/samj/models/Sam2.java +++ b/src/main/java/ai/nets/samj/models/Sam2.java @@ -75,6 +75,7 @@ public class Sam2 extends AbstractSamJ { + "from skimage import measure" + System.lineSeparator() + "measure.label(np.ones((10, 10)), connectivity=1)" + System.lineSeparator() + "import torch" + System.lineSeparator() + + "from scipy.ndimage import binary_fill_holes" + System.lineSeparator() + "import sys" + System.lineSeparator() + "import os" + System.lineSeparator() + "from multiprocessing import shared_memory" + System.lineSeparator() @@ -92,6 +93,7 @@ public class Sam2 extends AbstractSamJ { + "globals()['measure'] = measure" + System.lineSeparator() + "globals()['np'] = np" + System.lineSeparator() + "globals()['torch'] = torch" + System.lineSeparator() + + "globals()['binary_fill_holes'] = binary_fill_holes" + System.lineSeparator() + "globals()['predictor'] = predictor" + System.lineSeparator(); /** * String containing the Python imports code after it has been formated with the correct