From 02e5698fcbb020cc6f53fd7d08e039c1292f0d4f Mon Sep 17 00:00:00 2001 From: Vincent Biret Date: Wed, 29 Nov 2023 09:55:34 -0500 Subject: [PATCH] - fixes null ref exception in types trimming --- src/Kiota.Builder/CodeDOM/CodeBlock.cs | 2 +- src/Kiota.Builder/CodeDOM/ICodeElement.cs | 1 + src/Kiota.Builder/KiotaBuilder.cs | 17 +++++++++-------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/Kiota.Builder/CodeDOM/CodeBlock.cs b/src/Kiota.Builder/CodeDOM/CodeBlock.cs index 9b9ad01bfe..6a7e80c9cb 100644 --- a/src/Kiota.Builder/CodeDOM/CodeBlock.cs +++ b/src/Kiota.Builder/CodeDOM/CodeBlock.cs @@ -39,7 +39,7 @@ public virtual void RenameChildElement(string oldName, string newName) } else throw new InvalidOperationException($"The element to rename was not found {oldName}"); } - public void RemoveChildElement(params T[] elements) where T : CodeElement + public void RemoveChildElement(params T[] elements) where T : ICodeElement { if (elements == null) return; RemoveChildElementByName(elements.Select(static x => x.Name).ToArray()); diff --git a/src/Kiota.Builder/CodeDOM/ICodeElement.cs b/src/Kiota.Builder/CodeDOM/ICodeElement.cs index 671cd8dbe4..4cbfd90c4b 100644 --- a/src/Kiota.Builder/CodeDOM/ICodeElement.cs +++ b/src/Kiota.Builder/CodeDOM/ICodeElement.cs @@ -9,4 +9,5 @@ string Name { get; set; } + T GetImmediateParentOfType(CodeElement? item = null); } diff --git a/src/Kiota.Builder/KiotaBuilder.cs b/src/Kiota.Builder/KiotaBuilder.cs index af0c2f02db..376242ffe9 100644 --- a/src/Kiota.Builder/KiotaBuilder.cs +++ b/src/Kiota.Builder/KiotaBuilder.cs @@ -1991,12 +1991,12 @@ private IEnumerable> GetDiscriminatorMappings(Ope .Where(static x => x.Value != null) .Select(static x => KeyValuePair.Create(x.Key, x.Value!)); } - private static IEnumerable GetAllModels(CodeNamespace currentNamespace) + private static IEnumerable GetAllModels(CodeNamespace currentNamespace) { var classes = currentNamespace.Classes.ToArray(); return classes.Union(classes.SelectMany(GetAllInnerClasses)) .Where(static x => x.IsOfKind(CodeClassKind.Model)) - .OfType() + .OfType() .Union(currentNamespace.Enums) .Union(currentNamespace.Namespaces.SelectMany(static x => GetAllModels(x))); } @@ -2016,7 +2016,7 @@ private void TrimInheritedModels() var classesInUse = derivedClassesInUse.Union(classesDirectlyInUse).Union(baseOfModelsInUse).ToHashSet(); var reusableClassesDerivationIndex = GetDerivationIndex(reusableModels.OfType()); var reusableClassesInheritanceIndex = GetInheritanceIndex(allModelClassesIndex); - var relatedModels = classesInUse.SelectMany(x => GetRelatedDefinitions(x, reusableClassesDerivationIndex, reusableClassesInheritanceIndex)).Union(modelsDirectlyInUse.Where(x => x is CodeEnum)).ToHashSet();// re-including models directly in use for enums + var relatedModels = classesInUse.SelectMany(x => GetRelatedDefinitions(x, reusableClassesDerivationIndex, reusableClassesInheritanceIndex)).Union(modelsDirectlyInUse.OfType()).ToHashSet();// re-including models directly in use for enums Parallel.ForEach(reusableModels, parallelOptions, x => { if (relatedModels.Contains(x) || classesInUse.Contains(x)) return; @@ -2074,23 +2074,24 @@ private static IEnumerable GetDerivedDefinitions(ConcurrentDictionary var currentDerived = modelsInUse.SelectMany(x => models.TryGetValue(x, out var res) ? res : Enumerable.Empty()).ToArray(); return currentDerived.Union(currentDerived.SelectMany(x => GetDerivedDefinitions(models, [x]))); } - private static IEnumerable GetRelatedDefinitions(CodeElement currentElement, ConcurrentDictionary> derivedIndex, ConcurrentDictionary> inheritanceIndex, ConcurrentDictionary? visited = null) + private static IEnumerable GetRelatedDefinitions(ITypeDefinition currentElement, ConcurrentDictionary> derivedIndex, ConcurrentDictionary> inheritanceIndex, ConcurrentDictionary? visited = null) { visited ??= new(); - if (currentElement is not CodeClass currentClass || !visited.TryAdd(currentClass, true)) return Enumerable.Empty(); + if (currentElement is not CodeClass currentClass || !visited.TryAdd(currentClass, true)) return Enumerable.Empty(); var propertiesDefinitions = currentClass.Properties .SelectMany(static x => x.Type.AllTypes) - .Select(static x => x.TypeDefinition!) + .Select(static x => x.TypeDefinition) + .OfType() .Where(static x => x is CodeClass || x is CodeEnum) .SelectMany(x => x is CodeClass classDefinition ? (inheritanceIndex.TryGetValue(classDefinition, out var res) ? res : Enumerable.Empty()) .Union(GetDerivedDefinitions(derivedIndex, [classDefinition])) .Union(new[] { classDefinition }) - .OfType() : + .OfType() : new[] { x }) .Distinct() .ToArray(); - var propertiesParentTypes = propertiesDefinitions.OfType().SelectMany(static x => x.GetInheritanceTree(false, false)).ToArray(); + var propertiesParentTypes = propertiesDefinitions.OfType().SelectMany(static x => x.GetInheritanceTree(false, false)).OfType().ToArray(); return propertiesDefinitions .Union(propertiesParentTypes) .Union(propertiesParentTypes.SelectMany(x => GetRelatedDefinitions(x, derivedIndex, inheritanceIndex, visited)))