diff --git a/src/main/java/ai/nets/samj/annotation/Mask.java b/src/main/java/ai/nets/samj/annotation/Mask.java index c8b1e25..65ac588 100644 --- a/src/main/java/ai/nets/samj/annotation/Mask.java +++ b/src/main/java/ai/nets/samj/annotation/Mask.java @@ -20,7 +20,6 @@ package ai.nets.samj.annotation; import java.awt.Polygon; -import java.awt.Rectangle; import java.util.Arrays; import java.util.List; @@ -41,16 +40,13 @@ public class Mask { private final long[] rleEncoding; - private final Rectangle crop; - - private Mask(Polygon contour, long[] rleEncoding, Rectangle crop) { + private Mask(Polygon contour, long[] rleEncoding) { this.contour = contour; this.rleEncoding = rleEncoding; - this.crop = crop; } - public static Mask build(Polygon contour, long[] rleEncoding, Rectangle crop) { - return new Mask(contour, rleEncoding, crop); + public static Mask build(Polygon contour, long[] rleEncoding) { + return new Mask(contour, rleEncoding); } public Polygon getContour() { @@ -77,9 +73,7 @@ public static RandomAccessibleInterval getMask(long width, lon for (Mask mask : masks) { long[] rle = mask.getRLEMask(); for (int i = 0; i < rle.length; i += 2) { - int cropStartx = mask.crop.x; - int cropStarty = mask.crop.y; - int start = (int) (width * (cropStarty + i / 2) + cropStartx + mask.getRLEMask()[i]); + int start = (int) mask.getRLEMask()[i]; int len = (int) mask.getRLEMask()[i+ 1]; Arrays.fill(arr, start, start + len, (byte) 1); } diff --git a/src/main/java/ai/nets/samj/models/AbstractSamJ.java b/src/main/java/ai/nets/samj/models/AbstractSamJ.java index 86628cb..3713134 100644 --- a/src/main/java/ai/nets/samj/models/AbstractSamJ.java +++ b/src/main/java/ai/nets/samj/models/AbstractSamJ.java @@ -371,7 +371,7 @@ else if (task.outputs.get("contours_x") == null) throw new RuntimeException(); else if (task.outputs.get("contours_y") == null) throw new RuntimeException(); - else if (task.outputs.get("rles") == null) + else if (task.outputs.get("rle") == null) throw new RuntimeException(); results = task.outputs; } catch (IOException | InterruptedException | RuntimeException e) { @@ -393,7 +393,7 @@ else if (task.outputs.get("rles") == null) int[] xArr = contours_x.next().stream().mapToInt(Number::intValue).toArray(); int[] yArr = contours_y.next().stream().mapToInt(Number::intValue).toArray(); long[] rle = rles.next().stream().mapToLong(Number::longValue).toArray(); - masks.add(Mask.build(new Polygon(xArr, yArr, xArr.length), rle, cropRect)); + masks.add(Mask.build(new Polygon(xArr, yArr, xArr.length), rle)); } return masks; } @@ -1034,15 +1034,19 @@ protected long[] calculateEncodingNewCoords(int[] boundingBox, long[] imageSize) * to detect small objects compared to the size of the whole image, SAMJ might encode crops of * the total image, thus the coordinates of the polygons obtained need to be shifted in order * to match the original image. - * @param polys - * polys obtained by SAMJ on the encoded crop + * @param masks + * masks obtained by SAMJ on the encoded crop * @param encodeCoords * position of the crop in the total image */ - protected void recalculatePolys(List polys, long[] encodeCoords) { - polys.stream().forEach(pp -> { + protected void recalculatePolys(List masks, long[] encodeCoords) { + masks.stream().forEach(pp -> { pp.getContour().xpoints = Arrays.stream(pp.getContour().xpoints).map(x -> x + (int) encodeCoords[0]).toArray(); pp.getContour().ypoints = Arrays.stream(pp.getContour().ypoints).map(y -> y + (int) encodeCoords[1]).toArray(); + for (int i = 0; i < pp.getRLEMask().length; i += 2) { + pp.getRLEMask()[i] = encodeCoords[0] + pp.getRLEMask()[i] % this.targetDims[0] + + (((int) (pp.getRLEMask()[i] / this.targetDims[0])) + encodeCoords[1]) * this.targetDims[0]; + } }); } diff --git a/src/main/java/ai/nets/samj/models/PythonMethods.java b/src/main/java/ai/nets/samj/models/PythonMethods.java index d478f85..13fc3e8 100644 --- a/src/main/java/ai/nets/samj/models/PythonMethods.java +++ b/src/main/java/ai/nets/samj/models/PythonMethods.java @@ -92,7 +92,7 @@ public class PythonMethods { + " rle = encode_rle(obj.image)" + System.lineSeparator() + " bbox_w = obj.bbox[3] - obj.bbox[1]" + System.lineSeparator() + " for i in range(0, len(rle), 2):" + System.lineSeparator() - + " rle[i] = sam_result.shape[1] * (obj.bbox[0] - 1 + rle[i] // bbox_w) + obj.bbox[1] + rle[i] % bbox_w" + System.lineSeparator() + + " rle[i] = sam_result.shape[1] * (obj.bbox[0] + rle[i] // bbox_w) + obj.bbox[1] + rle[i] % bbox_w" + System.lineSeparator() + " rles.append(rle)" + System.lineSeparator() + " x_contours.append(x_coords)" + System.lineSeparator() + " y_contours.append(y_coords)" + System.lineSeparator()