diff --git a/Source/YoloV8/Parsers/SegmentationOutputParser.cs b/Source/YoloV8/Parsers/SegmentationOutputParser.cs index 4c04cd9..0e24f9c 100644 --- a/Source/YoloV8/Parsers/SegmentationOutputParser.cs +++ b/Source/YoloV8/Parsers/SegmentationOutputParser.cs @@ -10,6 +10,8 @@ internal readonly struct SegmentationOutputParser private readonly YoloV8Metadata _metadata; private readonly YoloV8Parameters _parameters; + private record struct IndexedBoundingBox(int Index, YoloV8Class Name, Rectangle Rectangle, float Confidence); + public SegmentationOutputParser(YoloV8Metadata metadata, YoloV8Parameters parameters) { _metadata = metadata; @@ -31,7 +33,9 @@ public IReadOnlyList Parse(IReadOnlyList var output0 = outputs[0]; var output1 = outputs[1]; - var boxes = new List(output0.Dimensions[2]); + var maskChannelCount = output0.Dimensions[1] - 4 - metadata.Classes.Count; + + var boxes = new List(output0.Dimensions[2]); Parallel.For(0, output0.Dimensions[2], i => { @@ -60,19 +64,7 @@ public IReadOnlyList Parse(IReadOnlyList var rectangle = Rectangle.FromLTRB(xMin, yMin, xMax, yMax); var name = metadata.Classes[j]; - var maskChannelCount = output0.Dimensions[1] - 4 - metadata.Classes.Count; - - var maskWeights = new float[maskChannelCount]; - - for (int k = 0; k < maskChannelCount; k++) - { - var offset = 4 + metadata.Classes.Count + k; - maskWeights[k] = output0[0, offset, i]; - } - - var mask = ProcessMask(output1, maskWeights, rectangle, originSize, metadata.ImageSize, xPadding, yPadding); - - var box = new SegmentationBoundingBox(name, rectangle, confidence, mask); + var box = new IndexedBoundingBox(i, name, rectangle, confidence); boxes.Add(box); } }); @@ -81,28 +73,43 @@ public IReadOnlyList Parse(IReadOnlyList x => x.Confidence, _parameters.IoU); - return selected; + var result = new SegmentationBoundingBox[selected.Count]; + + Parallel.For(0, selected.Count, index => + { + var box = selected[index]; + + var maskWeights = GetMaskWeights(output0, box.Index, maskChannelCount, metadata.Classes.Count + 4); + + var mask = ProcessMask(output1, maskWeights, box.Rectangle, originSize, metadata.ImageSize, xPadding, yPadding); + + var value = new SegmentationBoundingBox(box.Name, box.Rectangle, box.Confidence, mask); + + result[index] = value; + }); + + return result; } - private static IMask ProcessMask(Tensor prototypes, float[] weights, Rectangle rectangle, Size originSize, Size modelSize, int xPadding, int yPadding) + private static IMask ProcessMask(Tensor maskPrototypes, ReadOnlySpan maskWeights, Rectangle rectangle, Size originSize, Size modelSize, int xPadding, int yPadding) { - var maskChannels = prototypes.Dimensions[1]; - var maskHeight = prototypes.Dimensions[2]; - var maskWidth = prototypes.Dimensions[3]; + var maskChannels = maskPrototypes.Dimensions[1]; + var maskHeight = maskPrototypes.Dimensions[2]; + var maskWidth = maskPrototypes.Dimensions[3]; - if (maskChannels != weights.Length) + if (maskChannels != maskWeights.Length) throw new InvalidOperationException(); using var bitmap = new Image(maskWidth, maskHeight); - for (int x = 0; x < maskWidth; x++) + for (int y = 0; y < maskHeight; y++) { - for (int y = 0; y < maskHeight; y++) + for (int x = 0; x < maskWidth; x++) { var value = 0F; for (int i = 0; i < maskChannels; i++) - value += prototypes[0, i, x, y] * weights[i]; + value += maskPrototypes[0, i, y, x] * maskWeights[i]; value = Sigmoid(value); @@ -113,20 +120,23 @@ private static IMask ProcessMask(Tensor prototypes, float[] weights, Rect } } - bitmap.Mutate(x => - { - x.RotateFlip(RotateMode.Rotate90, FlipMode.Horizontal); + var xPad = xPadding * maskWidth / modelSize.Width; + var yPad = yPadding * maskHeight / modelSize.Height; - var xPad = xPadding * maskWidth / modelSize.Width; - var yPad = yPadding * maskHeight / modelSize.Height; + var paddingCropRectangle = new Rectangle(xPad, + yPad, + maskWidth - xPad * 2, + maskHeight - yPad * 2); - var crop = new Rectangle(xPad, - yPad, - maskWidth - xPad * 2, - maskHeight - yPad * 2); - x.Crop(crop); + bitmap.Mutate(x => + { + // crop for preprocess resize padding + x.Crop(paddingCropRectangle); + // resize to original image size x.Resize(originSize); + + // crop for getting the object segmentation only x.Crop(rectangle); }); @@ -141,6 +151,18 @@ private static IMask ProcessMask(Tensor prototypes, float[] weights, Rect return new Mask(final); } + private static ReadOnlySpan GetMaskWeights(Tensor output, int boxIndex, int maskChannelCount, int maskWeightsOffset) + { + var maskWeights = new float[maskChannelCount]; + + for (int i = 0; i < maskChannelCount; i++) + { + maskWeights[i] = output[0, maskWeightsOffset + i, boxIndex]; + } + + return maskWeights; + } + #region Helpers private static float Sigmoid(float value)