Skip to content

Commit

Permalink
Simplify flags codegen for Vulkan
Browse files Browse the repository at this point in the history
  • Loading branch information
xoofx committed Jun 30, 2024
1 parent f5e9ed0 commit 7b60cdd
Show file tree
Hide file tree
Showing 10 changed files with 1,392 additions and 5,193 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="CppAst.CodeGen" Version="0.19.0" />
<PackageReference Include="CppAst.CodeGen" Version="0.20.0" />
<!--<ProjectReference Include="..\..\..\..\..\CppAst.CodeGen\src\CppAst.CodeGen\CppAst.CodeGen.csproj" />-->
</ItemGroup>

Expand Down
119 changes: 50 additions & 69 deletions src/codegen/XenoAtom.Interop.CodeGen/vulkan/VulkanGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using System.Threading.Tasks;
using System.Xml;
using System.Xml.Linq;
using ClangSharp.Interop;
using CppAst;
using CppAst.CodeGen.CSharp;

Expand All @@ -31,7 +32,6 @@ internal partial class VulkanGenerator(LibDescriptor descriptor) : GeneratorBase
private readonly Dictionary<string, VulkanCommand> _functionRegistry = new();
private readonly Dictionary<VulkanDocTypeKind, VulkanDocDefinitions> _docDefinitions = new();
private readonly Dictionary<string, CSharpStruct> _structFunctionPointers = new();
private readonly Dictionary<string, CSharpStruct> _structAsEnumFlags = new();
private readonly Dictionary<string, VulkanElementInfo> _vulkanElementInfos = new();
private readonly List<int> _tempOptionalParameterIndexList = new();
private readonly Dictionary<string, Dictionary<string, string>> _mapStructToFieldsWithDefaultValue = new();
Expand Down Expand Up @@ -145,6 +145,8 @@ public override async Task Initialize(ApkManager apkHelper)
Path.Combine(vulkanInclude, "vulkan/vk_layer.h"),
};

csOptions.Plugins.Add(new CSharpConverterVulkanTypedefFlags());

var csCompilation = CSharpConverter.Convert(files, csOptions);

{
Expand All @@ -160,36 +162,25 @@ public override async Task Initialize(ApkManager apkHelper)
}
}

foreach (var csEnum in csCompilation.AllEnums)
{
ApplyDocumentation(csEnum);
}

foreach (var csStruct in csCompilation.AllStructs)
foreach (var csStruct in csCompilation.AllStructs.Distinct())
{
ApplyDocumentation(csStruct);
AddVulkanVersionAndExtensionInfoToCSharpElement(csStruct);
ProcessStruct(csStruct);

// Associate Enum XXXFlagBits with Struct XXXFlags
if (csStruct.Name.Contains("Flags", StringComparison.Ordinal))
{
_structAsEnumFlags.Add(csStruct.Name, csStruct);
}

// Collect PFN function pointers
if (csStruct.Name.Contains("PFN_vk", StringComparison.Ordinal))
{
_structFunctionPointers.Add(csStruct.Name["PFN_".Length..], csStruct);
}
}

foreach (var csFunction in csCompilation.AllFunctions)
foreach (var csFunction in csCompilation.AllFunctions.Distinct())
{
ProcessVulkanFunction(csFunction);
}

foreach (var csEnum in csCompilation.AllEnums)
foreach (var csEnum in csCompilation.AllEnums.Distinct())
{
ProcessVulkanEnum(csEnum);
}
Expand Down Expand Up @@ -295,59 +286,8 @@ private void ProcessStruct(CSharpStruct csStruct)

private void ProcessVulkanEnum(CSharpEnum csEnum)
{
ApplyDocumentation(csEnum);
ApplyApiVersion(csEnum);

// We only need to modify flags in this method
if (!csEnum.Name.Contains("FlagBits", StringComparison.Ordinal))
{
return;
}

csEnum.Attributes.Add(new CSharpFreeAttribute("Flags"));

var structName = csEnum.Name.Replace("FlagBits", "Flags", StringComparison.Ordinal);
if (_structAsEnumFlags.TryGetValue(structName, out var csStruct))
{
// Add implicit operators between XXXFlagBits and Struct XXXFlags
csStruct.Members.Add(new CSharpMethod(string.Empty)
{
Kind = CSharpMethodKind.Operator,
ReturnType = csEnum,
Modifiers = CSharpModifiers.Static | CSharpModifiers.Implicit,
Parameters =
{
new CSharpParameter("from") {ParameterType = csStruct},
},
BodyInline = ((writer, _) =>
{
writer.Write("(");
csEnum.DumpReferenceTo(writer);
writer.Write(")(uint)from.Value");
}),
Visibility = CSharpVisibility.Public
});
csStruct.Members.Add(new CSharpMethod(string.Empty)
{
Kind = CSharpMethodKind.Operator,
ReturnType = csStruct,
Modifiers = CSharpModifiers.Static | CSharpModifiers.Implicit,
Parameters =
{
new CSharpParameter("from") {ParameterType = csEnum},
},
BodyInline = (writer, element) =>
{
writer.Write("new ");
csStruct.DumpReferenceTo(writer);
writer.Write("((uint)from)");
},
Visibility = CSharpVisibility.Public
});
}
else
{
Console.Error.WriteLine($"Cannot find struct {structName} for enum {csEnum.Name}");
}
}

[GeneratedRegex($"{CommonVkExt}")]
Expand Down Expand Up @@ -1005,9 +945,15 @@ private void ApplyDocumentation(CSharpElement element)
}
}
}
else if (element is CSharpEnum csEnum && element.CppElement is CppEnum cppEnum)
else if (element is CSharpEnum csEnum && element.CppElement is ICppMember cppEnum)
{
if (_docDefinitions.TryGetValue(VulkanDocTypeKind.Enum, out var definitions) && definitions.TryGetValue(cppEnum.Name, out var definition))
var cppEnumName = cppEnum.Name;
if (csEnum.IsFlags)
{
cppEnumName = csEnum.Name.Replace("Flags", "FlagBits");
}

if (_docDefinitions.TryGetValue(VulkanDocTypeKind.Enum, out var definitions) && definitions.TryGetValue(cppEnumName, out var definition))
{
description = definition.Description;

Expand Down Expand Up @@ -1966,6 +1912,41 @@ private enum VulkanCommandOptional
Both,
}

private class CSharpConverterVulkanTypedefFlags : ICSharpConverterPlugin
{
private readonly Dictionary<string, CSharpEnum> _flags = new();

public void Register(CSharpConverter converter, CSharpConverterPipeline pipeline)
{
pipeline.Converted.Add(ProcessConverted);
pipeline.TypedefConverters.Add(ProcessTypeDef);
}

private void ProcessConverted(CSharpConverter converter, CSharpElement element, CSharpElement context)
{
if (element is CSharpEnum cSharpEnum && cSharpEnum.Name.Contains("FlagBits"))
{
cSharpEnum.IsFlags = true;
cSharpEnum.Name = cSharpEnum.Name.Replace("FlagBits", "Flags", StringComparison.Ordinal);
_flags.Add(cSharpEnum.Name, cSharpEnum);
}
}

private CSharpElement? ProcessTypeDef(CSharpConverter converter, CppTypedef cpptypedef, CSharpElement context)
{
if (cpptypedef.Name.Contains("Flags"))
{
if (_flags.TryGetValue(cpptypedef.Name, out var csEnum))
{
return csEnum;
}
}
return null;
}
}



private record VulkanElementInfo
{
public string? ApiVersion { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ public partial struct VkIcdSurfaceDisplay

public uint planeStackIndex;

public vulkan.VkSurfaceTransformFlagBitsKHR transform;
public vulkan.VkSurfaceTransformFlagsKHR transform;

public float globalAlpha;

public vulkan.VkDisplayPlaneAlphaFlagBitsKHR alphaMode;
public vulkan.VkDisplayPlaneAlphaFlagsKHR alphaMode;

public vulkan.VkExtent2D imageExtent;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ public enum VkLayerFunction_ : uint
public const vulkan.VkLayerFunction_ VK_LOADER_FEATURES = VkLayerFunction_.VK_LOADER_FEATURES;

[Flags]
public enum VkLoaderFeatureFlagBits : uint
public enum VkLoaderFeatureFlags : uint
{
VK_LOADER_FEATURE_PHYSICAL_DEVICE_SORTING = unchecked((uint)1),
}

public const vulkan.VkLoaderFeatureFlagBits VK_LOADER_FEATURE_PHYSICAL_DEVICE_SORTING = VkLoaderFeatureFlagBits.VK_LOADER_FEATURE_PHYSICAL_DEVICE_SORTING;
public const vulkan.VkLoaderFeatureFlags VK_LOADER_FEATURE_PHYSICAL_DEVICE_SORTING = VkLoaderFeatureFlags.VK_LOADER_FEATURE_PHYSICAL_DEVICE_SORTING;

public enum VkChainType : uint
{
Expand Down Expand Up @@ -253,33 +253,6 @@ public partial struct VkLayerInstanceCreateInfo_layerDevice
public static bool operator !=(PFN_vkSetInstanceLoaderData left, PFN_vkSetInstanceLoaderData right) => !left.Equals(right);
}

public readonly partial struct VkLoaderFeatureFlags : IEquatable<vulkan.VkLoaderFeatureFlags>
{
public VkLoaderFeatureFlags(vulkan.VkFlags value) => this.Value = value;

public vulkan.VkFlags Value { get; }

public override bool Equals(object obj) => obj is VkLoaderFeatureFlags other && Equals(other);

public bool Equals(VkLoaderFeatureFlags other) => Value.Equals(other.Value);

public override int GetHashCode() => Value.GetHashCode();

public override string ToString() => Value.ToString();

public static implicit operator vulkan.VkFlags (vulkan.VkLoaderFeatureFlags from) => from.Value;

public static implicit operator vulkan.VkLoaderFeatureFlags (vulkan.VkFlags from) => new vulkan.VkLoaderFeatureFlags(from);

public static bool operator ==(VkLoaderFeatureFlags left, VkLoaderFeatureFlags right) => left.Equals(right);

public static bool operator !=(VkLoaderFeatureFlags left, VkLoaderFeatureFlags right) => !left.Equals(right);

public static implicit operator vulkan.VkLoaderFeatureFlagBits (vulkan.VkLoaderFeatureFlags from) => (vulkan.VkLoaderFeatureFlagBits)(uint)from.Value;

public static implicit operator vulkan.VkLoaderFeatureFlags (vulkan.VkLoaderFeatureFlagBits from) => new vulkan.VkLoaderFeatureFlags((uint)from);
}

/// <summary>
/// Sub type of structure for instance and device loader ext of CreateInfo.
/// When sType == VK_STRUCTURE_TYPE_LOADER_INSTANCE_CREATE_INFO
Expand Down Expand Up @@ -522,9 +495,9 @@ public vulkan.VkResult Invoke(vulkan.VkNegotiateLayerInterface* pVersionStruct)

public readonly partial struct VkLoaderFlagBits : IEquatable<vulkan.VkLoaderFlagBits>
{
public VkLoaderFlagBits(vulkan.VkLoaderFeatureFlagBits value) => this.Value = value;
public VkLoaderFlagBits(vulkan.VkLoaderFeatureFlags value) => this.Value = value;

public vulkan.VkLoaderFeatureFlagBits Value { get; }
public vulkan.VkLoaderFeatureFlags Value { get; }

public override bool Equals(object obj) => obj is VkLoaderFlagBits other && Equals(other);

Expand All @@ -534,9 +507,9 @@ public vulkan.VkResult Invoke(vulkan.VkNegotiateLayerInterface* pVersionStruct)

public override string ToString() => Value.ToString();

public static implicit operator vulkan.VkLoaderFeatureFlagBits (vulkan.VkLoaderFlagBits from) => from.Value;
public static implicit operator vulkan.VkLoaderFeatureFlags (vulkan.VkLoaderFlagBits from) => from.Value;

public static implicit operator vulkan.VkLoaderFlagBits (vulkan.VkLoaderFeatureFlagBits from) => new vulkan.VkLoaderFlagBits(from);
public static implicit operator vulkan.VkLoaderFlagBits (vulkan.VkLoaderFeatureFlags from) => new vulkan.VkLoaderFlagBits(from);

public static bool operator ==(VkLoaderFlagBits left, VkLoaderFlagBits right) => left.Equals(right);

Expand Down
Loading

0 comments on commit 7b60cdd

Please sign in to comment.