diff --git a/Source/YoloV8.Tests/YoloMetadataTests.cs b/Source/YoloV8.Tests/YoloMetadataTests.cs index a332dae..900ca93 100644 --- a/Source/YoloV8.Tests/YoloMetadataTests.cs +++ b/Source/YoloV8.Tests/YoloMetadataTests.cs @@ -38,7 +38,7 @@ public void MetadataParsingTest() { nameof(names), $"{{{names[0]}, {names[1]}}}" } }; - var metadata = new YoloMetadata(dictionary); + var metadata = new YoloMetadata(dictionary, YoloArchitecture.YoloV8Or11); Assert.Equal(author, metadata.Author); Assert.Equal(description, metadata.Description); @@ -51,6 +51,6 @@ public void MetadataParsingTest() Assert.Equal($"{names[0]}", $"{metadata.Names[0]}"); Assert.Equal($"{names[1]}", $"{metadata.Names[1]}"); - Assert.Equal(YoloArchitecture.YoloV8, metadata.Architecture); + Assert.Equal(YoloArchitecture.YoloV8Or11, metadata.Architecture); } } \ No newline at end of file diff --git a/Source/YoloV8/Metadata/YoloArchitecture.cs b/Source/YoloV8/Metadata/YoloArchitecture.cs index ad510ff..3fe9ab9 100644 --- a/Source/YoloV8/Metadata/YoloArchitecture.cs +++ b/Source/YoloV8/Metadata/YoloArchitecture.cs @@ -2,7 +2,6 @@ public enum YoloArchitecture { - YoloV8, - YoloV10, - Yolo11, -} + YoloV8Or11, + YoloV10 +} \ No newline at end of file diff --git a/Source/YoloV8/Metadata/YoloMetadata.cs b/Source/YoloV8/Metadata/YoloMetadata.cs index b368747..ebfbd7e 100644 --- a/Source/YoloV8/Metadata/YoloMetadata.cs +++ b/Source/YoloV8/Metadata/YoloMetadata.cs @@ -19,10 +19,11 @@ public class YoloMetadata public YoloArchitecture Architecture { get; } internal YoloMetadata(InferenceSession session) - : - this(session.ModelMetadata.CustomMetadataMap) + : + this(session.ModelMetadata.CustomMetadataMap, ParseYoloArchitecture(session)) { } - internal YoloMetadata(Dictionary metadata) + + internal YoloMetadata(Dictionary metadata, YoloArchitecture architecture) { Author = metadata["author"]; Description = metadata["description"]; @@ -38,7 +39,7 @@ internal YoloMetadata(Dictionary metadata) _ => throw new InvalidOperationException("Unknow YoloV8 'task' value") }; - Architecture = GetYoloArchitecture(Description); + Architecture = architecture; BatchSize = int.Parse(metadata["batch"]); ImageSize = ParseSize(metadata["imgsz"]); Names = ParseNames(metadata["names"]); @@ -61,24 +62,26 @@ public static YoloMetadata Parse(InferenceSession session) } } - private static YoloArchitecture GetYoloArchitecture(string description) + private static YoloArchitecture ParseYoloArchitecture(InferenceSession session) { - if (description.Contains("yolov8", StringComparison.CurrentCultureIgnoreCase)) - { - return YoloArchitecture.YoloV8; - } + var metadata = session.ModelMetadata.CustomMetadataMap; - if (description.Contains("yolov10", StringComparison.CurrentCultureIgnoreCase)) + if (metadata.TryGetValue("task", out var task) == false) { - return YoloArchitecture.YoloV10; + throw new InvalidOperationException(); } - if (description.Contains("yolo11", StringComparison.CurrentCultureIgnoreCase)) + if (task == "detect") { - return YoloArchitecture.Yolo11; + var output0 = session.OutputMetadata["output0"]; + + if (output0.Dimensions[2] == 6) // YOLOv10 output0: [1, 300, 6] + { + return YoloArchitecture.YoloV10; + } } - throw new NotSupportedException("Unrecognized YOLO model architecture"); + return YoloArchitecture.YoloV8Or11; } #region Parsers diff --git a/Source/YoloV8/Services/Parsers/RawBoundingBoxParser.cs b/Source/YoloV8/Services/Parsers/RawBoundingBoxParser.cs index 15d414a..24e350f 100644 --- a/Source/YoloV8/Services/Parsers/RawBoundingBoxParser.cs +++ b/Source/YoloV8/Services/Parsers/RawBoundingBoxParser.cs @@ -18,7 +18,7 @@ private T[] ParseYoloV8(DenseTensor tensor) where T : IRawBoundingBox< var context = new RawParsingContext { - Architecture = YoloArchitecture.YoloV8, + Architecture = YoloArchitecture.YoloV8Or11, Tensor = tensor, Stride1 = stride1, NameCount = namesCount,