Skip to content

Commit

Permalink
Improve segmentation postprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
dme-compunet committed Aug 21, 2023
1 parent 42aa0f9 commit 525803c
Showing 1 changed file with 55 additions and 33 deletions.
88 changes: 55 additions & 33 deletions Source/YoloV8/Parsers/SegmentationOutputParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ internal readonly struct SegmentationOutputParser
private readonly YoloV8Metadata _metadata;
private readonly YoloV8Parameters _parameters;

private record struct IndexedBoundingBox(int Index, YoloV8Class Name, Rectangle Rectangle, float Confidence);

public SegmentationOutputParser(YoloV8Metadata metadata, YoloV8Parameters parameters)
{
_metadata = metadata;
Expand All @@ -31,7 +33,9 @@ public IReadOnlyList<ISegmentationBoundingBox> Parse(IReadOnlyList<Tensor<float>
var output0 = outputs[0];
var output1 = outputs[1];

var boxes = new List<SegmentationBoundingBox>(output0.Dimensions[2]);
var maskChannelCount = output0.Dimensions[1] - 4 - metadata.Classes.Count;

var boxes = new List<IndexedBoundingBox>(output0.Dimensions[2]);

Parallel.For(0, output0.Dimensions[2], i =>
{
Expand Down Expand Up @@ -60,19 +64,7 @@ public IReadOnlyList<ISegmentationBoundingBox> Parse(IReadOnlyList<Tensor<float>
var rectangle = Rectangle.FromLTRB(xMin, yMin, xMax, yMax);
var name = metadata.Classes[j];
var maskChannelCount = output0.Dimensions[1] - 4 - metadata.Classes.Count;
var maskWeights = new float[maskChannelCount];
for (int k = 0; k < maskChannelCount; k++)
{
var offset = 4 + metadata.Classes.Count + k;
maskWeights[k] = output0[0, offset, i];
}
var mask = ProcessMask(output1, maskWeights, rectangle, originSize, metadata.ImageSize, xPadding, yPadding);
var box = new SegmentationBoundingBox(name, rectangle, confidence, mask);
var box = new IndexedBoundingBox(i, name, rectangle, confidence);
boxes.Add(box);
}
});
Expand All @@ -81,28 +73,43 @@ public IReadOnlyList<ISegmentationBoundingBox> Parse(IReadOnlyList<Tensor<float>
x => x.Confidence,
_parameters.IoU);

return selected;
var result = new SegmentationBoundingBox[selected.Count];

Parallel.For(0, selected.Count, index =>
{
var box = selected[index];
var maskWeights = GetMaskWeights(output0, box.Index, maskChannelCount, metadata.Classes.Count + 4);
var mask = ProcessMask(output1, maskWeights, box.Rectangle, originSize, metadata.ImageSize, xPadding, yPadding);
var value = new SegmentationBoundingBox(box.Name, box.Rectangle, box.Confidence, mask);
result[index] = value;
});

return result;
}

private static IMask ProcessMask(Tensor<float> prototypes, float[] weights, Rectangle rectangle, Size originSize, Size modelSize, int xPadding, int yPadding)
private static IMask ProcessMask(Tensor<float> maskPrototypes, ReadOnlySpan<float> maskWeights, Rectangle rectangle, Size originSize, Size modelSize, int xPadding, int yPadding)
{
var maskChannels = prototypes.Dimensions[1];
var maskHeight = prototypes.Dimensions[2];
var maskWidth = prototypes.Dimensions[3];
var maskChannels = maskPrototypes.Dimensions[1];
var maskHeight = maskPrototypes.Dimensions[2];
var maskWidth = maskPrototypes.Dimensions[3];

if (maskChannels != weights.Length)
if (maskChannels != maskWeights.Length)
throw new InvalidOperationException();

using var bitmap = new Image<L8>(maskWidth, maskHeight);

for (int x = 0; x < maskWidth; x++)
for (int y = 0; y < maskHeight; y++)
{
for (int y = 0; y < maskHeight; y++)
for (int x = 0; x < maskWidth; x++)
{
var value = 0F;

for (int i = 0; i < maskChannels; i++)
value += prototypes[0, i, x, y] * weights[i];
value += maskPrototypes[0, i, y, x] * maskWeights[i];

value = Sigmoid(value);

Expand All @@ -113,20 +120,23 @@ private static IMask ProcessMask(Tensor<float> prototypes, float[] weights, Rect
}
}

bitmap.Mutate(x =>
{
x.RotateFlip(RotateMode.Rotate90, FlipMode.Horizontal);
var xPad = xPadding * maskWidth / modelSize.Width;
var yPad = yPadding * maskHeight / modelSize.Height;

var xPad = xPadding * maskWidth / modelSize.Width;
var yPad = yPadding * maskHeight / modelSize.Height;
var paddingCropRectangle = new Rectangle(xPad,
yPad,
maskWidth - xPad * 2,
maskHeight - yPad * 2);

var crop = new Rectangle(xPad,
yPad,
maskWidth - xPad * 2,
maskHeight - yPad * 2);
x.Crop(crop);
bitmap.Mutate(x =>
{
// crop for preprocess resize padding
x.Crop(paddingCropRectangle);
// resize to original image size
x.Resize(originSize);
// crop for getting the object segmentation only
x.Crop(rectangle);
});

Expand All @@ -141,6 +151,18 @@ private static IMask ProcessMask(Tensor<float> prototypes, float[] weights, Rect
return new Mask(final);
}

private static ReadOnlySpan<float> GetMaskWeights(Tensor<float> output, int boxIndex, int maskChannelCount, int maskWeightsOffset)
{
var maskWeights = new float[maskChannelCount];

for (int i = 0; i < maskChannelCount; i++)
{
maskWeights[i] = output[0, maskWeightsOffset + i, boxIndex];
}

return maskWeights;
}

#region Helpers

private static float Sigmoid(float value)
Expand Down

0 comments on commit 525803c

Please sign in to comment.