diff --git a/demo/src/main/java/overrun/marshal/test/CMarshalTest.java b/demo/src/main/java/overrun/marshal/test/CMarshalTest.java index eccbefa..72abc02 100644 --- a/demo/src/main/java/overrun/marshal/test/CMarshalTest.java +++ b/demo/src/main/java/overrun/marshal/test/CMarshalTest.java @@ -51,6 +51,8 @@ public interface CMarshalTest { MemorySegment testWithArgAndReturnValue(MemorySegment segment); + MemorySegment testWithCustomBody(); + @Native(scope = MarshalScope.PRIVATE) int testWithPrivate(int i); @@ -59,10 +61,10 @@ public interface CMarshalTest { @Overload void testWithArray(int[] arr); - void testWithRefArray(MemorySegment arr0, MemorySegment arr1, MemorySegment arr2, MemorySegment arr3, MemorySegment arr4); + void testWithRefArray(MemorySegment arr0, MemorySegment arr1, MemorySegment arr2, MemorySegment arr3, MemorySegment arr4, MemorySegment arr5); @Overload - void testWithRefArray(int[] arr0, @Ref int[] arr1, @Ref(nullable = true) int[] arr2, boolean[] arr3, @Ref boolean[] arr4); + void testWithRefArray(int[] arr0, @Ref int[] arr1, @Ref(nullable = true) int[] arr2, boolean[] arr3, @Ref boolean[] arr4, int[] arr5); void testWithString(MemorySegment str); diff --git a/demo/src/main/java/overrun/marshal/test/MarshalTestOverload.java b/demo/src/main/java/overrun/marshal/test/MarshalTestOverload.java deleted file mode 100644 index 6859b54..0000000 --- a/demo/src/main/java/overrun/marshal/test/MarshalTestOverload.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2023 Overrun Organization - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - */ - -package overrun.marshal.test; - -import java.lang.foreign.MemorySegment; - -/** - * @author squid233 - * @since 0.1.0 - */ -public final class MarshalTestOverload { - public static MemorySegment testWithDefaultArg() { - return MarshalTest.testWithArgAndReturnValue(MemorySegment.NULL); - } -} diff --git a/gradle.properties b/gradle.properties index 2f09899..7ca950a 100644 --- a/gradle.properties +++ b/gradle.properties @@ -28,7 +28,7 @@ orgUrl=https://over-run.github.io/ # JDK Options jdkVersion=22 -jdkEnablePreview=true +jdkEnablePreview=false # javadoc link of JDK early access build # https://download.java.net/java/early_access/$jdkEarlyAccessDoc/docs/api/ # Uncomment it if you need to use EA build of JDK. diff --git a/src/main/java/overrun/marshal/BoolHelper.java b/src/main/java/overrun/marshal/BoolHelper.java index eb5bba0..5c1d8bb 100644 --- a/src/main/java/overrun/marshal/BoolHelper.java +++ b/src/main/java/overrun/marshal/BoolHelper.java @@ -27,6 +27,10 @@ * @since 0.1.0 */ public final class BoolHelper { + private BoolHelper() { + //no instance + } + public static MemorySegment of(SegmentAllocator allocator, boolean[] arr) { final MemorySegment segment = allocator.allocate(ValueLayout.JAVA_BOOLEAN, arr.length); for (int i = 0; i < arr.length; i++) { diff --git a/src/main/java/overrun/marshal/Checks.java b/src/main/java/overrun/marshal/Checks.java new file mode 100644 index 0000000..6288dbe --- /dev/null +++ b/src/main/java/overrun/marshal/Checks.java @@ -0,0 +1,57 @@ +/* + * MIT License + * + * Copyright (c) 2023 Overrun Organization + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + */ + +package overrun.marshal; + +import java.util.function.Supplier; + +/** + * @author squid233 + * @since 0.1.0 + */ +public final class Checks { + public static final Entry CHECK_ARRAY_SIZE = new Entry<>(() -> true); + + private Checks() { + //no instance + } + + public static void checkArraySize(int expected, int got) { + if (CHECK_ARRAY_SIZE.get() && expected != got) { + throw new IllegalArgumentException("Expected array of size " + expected + ", got " + got); + } + } + + /** + * @author squid233 + * @since 0.1.0 + */ + public static final class Entry { + private final Supplier supplier; + private T value; + + public Entry(Supplier supplier) { + this.supplier = supplier; + } + + public T get() { + if (value == null) { + value = supplier.get(); + } + return value; + } + } +} diff --git a/src/main/java/overrun/marshal/NativeApiProcessor.java b/src/main/java/overrun/marshal/NativeApiProcessor.java index c643fd7..719d372 100644 --- a/src/main/java/overrun/marshal/NativeApiProcessor.java +++ b/src/main/java/overrun/marshal/NativeApiProcessor.java @@ -45,8 +45,6 @@ @SupportedAnnotationTypes("overrun.marshal.NativeApi") @SupportedSourceVersion(SourceVersion.RELEASE_22) public class NativeApiProcessor extends AbstractProcessor { - private static final String PACKAGE_NAME = "overrun.marshal"; - private void processClasses(RoundEnvironment roundEnv) { final Set marshalApis = roundEnv.getElementsAnnotatedWith(NativeApi.class); final Set types = ElementFilter.typesIn(marshalApis); @@ -89,6 +87,11 @@ private void writeFile( // import sb.append(""" + import overrun.marshal.BoolHelper; + import overrun.marshal.Checks; + import overrun.marshal.FixedSize; + import overrun.marshal.Ref; + import java.lang.foreign.*; import java.lang.invoke.MethodHandle; @@ -104,225 +107,18 @@ private void writeFile( .append(" {\n"); // fields - entry.getKey().forEach(variable -> { - final Object constantValue = variable.getConstantValue(); - if (constantValue == null) return; - final String typeStr = variable.asType().toString(); - sb.append(" public static final ") - .append(isString(typeStr) ? "String" : typeStr) - .append(' ') - .append(variable.getSimpleName()) - .append(" = "); - if (constantValue instanceof String) { - sb.append('"').append(constantValue).append('"'); - } else { - sb.append(constantValue); - } - sb.append(";\n"); - }); - sb.append('\n'); + appendFields(entry.getKey(), sb); final List methods = entry.getValue(); if (!methods.isEmpty()) { // load library - final String libname = nativeApi.libname(); - final AnnotationMirror annotationMirror = type.getAnnotationMirrors().stream() - .filter(m -> NativeApi.class.getName().equals(m.getAnnotationType().toString())) - .findFirst() - .orElseThrow(); - final String selector = annotationMirror.getElementValues().entrySet().stream() - .filter(entry1 -> "selector()".equals(entry1.getKey().toString())) - .findFirst() - .map(entry2 -> entry2.getValue().getValue().toString()) - .orElse(null); - sb.append(" private static final SymbolLookup _LOOKUP = SymbolLookup.libraryLookup("); - if (selector != null) { - sb.append("new ") - .append(selector) - .append("().select(\"") - .append(libname) - .append("\"), Arena.global()"); - } else { - sb.append('"').append(libname).append("\", Arena.global()"); - } - sb.append(");\n private static final Linker _LINKER = Linker.nativeLinker();\n\n"); + appendLoader(type, nativeApi, sb); // method handles - // try to collect function names - sb.append(methods.stream() - .collect(Collectors.toMap(NativeApiProcessor::methodEntrypoint, Function.identity(), (e1, e2) -> { - // conflict - final Overload o1 = e1.getAnnotation(Overload.class); - final Overload o2 = e2.getAnnotation(Overload.class); - if (o1 != null) { - if (o2 != null) { - return e1; - } - return e2; - } - if (o2 != null) { - return e1; - } - // neither e1 nor e2 is marked as overload - if (e1.getSimpleName().contentEquals(e2.getSimpleName())) { - throw new IllegalStateException("Overload not supported"); - } - return e1; - }, LinkedHashMap::new)) - .entrySet() - .stream() - .map(entry1 -> { - final ExecutableElement method = entry1.getValue(); - final Native annotation = method.getAnnotation(Native.class); - final StringBuilder sb1 = new StringBuilder(256); - sb1.append(" ") - .append(annotation != null ? annotation.scope() : "public") - .append(" static final MethodHandle ") - .append(entry1.getKey()) - .append(" = _LOOKUP.find(\"") - .append(entry1.getKey()) - .append("\").map(_s -> _LINKER.downcallHandle(_s, FunctionDescriptor.of"); - // append function descriptor - if (method.getReturnType().getKind() == TypeKind.VOID) { - sb1.append("Void("); - } else { - sb1.append('(').append(toValueLayout(method.getReturnType())); - if (!method.getParameters().isEmpty()) sb1.append(", "); - } - sb1.append(method.getParameters().stream() - .map(e -> toValueLayout(e.asType())) - .collect(Collectors.joining(", "))); - sb1.append("))).orElse"); - if (annotation != null && annotation.optional()) { - sb1.append("(null)"); - } else { - sb1.append("Throw()"); - } - return sb1.toString(); - }) - .collect(Collectors.joining(";\n"))).append(";\n\n"); + appendMethodHandles(type, sb, methods); // method declarations - methods.forEach(method -> { - final TypeMirror returnType = method.getReturnType(); - final TypeKind returnTypeKind = returnType.getKind(); - final boolean array = returnTypeKind == TypeKind.ARRAY; - final boolean booleanArray = array && isBooleanArray(returnType); - final Native annotation = method.getAnnotation(Native.class); - final Overload overloadAnnotation = method.getAnnotation(Overload.class); - final var parameters = method.getParameters(); - // check @Ref - if (parameters.stream() - .filter(e -> e.getAnnotation(Ref.class) != null) - .anyMatch(e -> e.asType().getKind() != TypeKind.ARRAY)) { - throw new IllegalStateException("@Ref must be used on array types"); - } - final boolean shouldInsertAllocator = parameters.stream().anyMatch(e -> { - final TypeMirror type1 = e.asType(); - final TypeKind kind = type1.getKind(); - return kind == TypeKind.ARRAY || (kind == TypeKind.DECLARED && isString(type1)); // TODO: 2023/12/2 Struct - }); - final boolean notVoid = returnTypeKind != TypeKind.VOID; - final boolean shouldStoreResult = notVoid && - parameters.stream().anyMatch(e -> e.getAnnotation(Ref.class) != null); - final String entrypoint = methodEntrypoint(method); - final boolean isOverload = overloadAnnotation != null; - final String overload; - if (isOverload) { - final String value = overloadAnnotation.value(); - overload = value.isBlank() ? method.getSimpleName().toString() : value; - } else { - overload = null; - } - final String javaReturnType = toTargetType(returnType, overload); - // javadoc - if (annotation != null) { - final String doc = annotation.doc(); - if (!doc.isBlank()) { - sb.append(" /**\n"); - doc.lines().map(s -> " * " + s).forEach(s -> { - sb.append(s); - sb.append('\n'); - }); - sb.append(" */\n"); - } - } - sb.append(" ") - // access modifier - .append(annotation != null ? annotation.scope() : "public") - .append(" static ") - // return type - .append(javaReturnType) - .append(' ') - .append(method.getSimpleName()) - .append('('); - // parameters - if (shouldInsertAllocator) { - sb.append("SegmentAllocator _segmentAllocator, "); - } - sb.append(parameters.stream() - .map(e -> { - final Ref ref = e.getAnnotation(Ref.class); - final String refString = ref != null ? - "@" + PACKAGE_NAME + ".Ref" + (ref.nullable() ? "(nullable = true) " : " ") : - ""; - return refString + toTargetType(e.asType(), overload) + " " + e.getSimpleName(); - }) - .collect(Collectors.joining(", "))).append(") {\n"); - // method body - if (isOverload) { - // send a warning if using any boolean array - if (booleanArray || parameters.stream().anyMatch(e -> { - final TypeMirror type1 = e.asType(); - return type1.getKind() == TypeKind.ARRAY && isBooleanArray(type1); - })) { - processingEnv.getMessager().printWarning(type + "::" + method + ": Marshalling boolean array"); - } - prependOverloadArgs(sb, parameters); - if (notVoid) { - if (shouldStoreResult) { - sb.append(" var $_marshalResult = "); - } else { - sb.append(" return "); - } - } else { - sb.append(" "); - } - if (array && booleanArray) { - sb.append(PACKAGE_NAME) - .append(".BoolHelper.toArray("); - } - sb.append(overload).append('('); - appendOverloadArgs(sb, parameters); - sb.append(')'); - if (array) { - if (booleanArray) { - sb.append(')'); - } else { - sb.append(".toArray(") - .append(toValueLayout(returnType)) - .append(')'); - } - } - sb.append(";\n"); - appendRefArgs(sb, parameters); - if (shouldStoreResult) { - sb.append(" return $_marshalResult;\n"); - } - } else { - sb.append(" try {\n "); - if (notVoid) { - sb.append("return (").append(javaReturnType).append(") "); - } - sb.append(entrypoint).append(".invokeExact("); - // parameters - sb.append(parameters.stream() - .map(VariableElement::getSimpleName) - .collect(Collectors.joining(", "))); - sb.append(");\n } catch (Throwable e) { throw new AssertionError(\"should not reach here\", e); }\n"); - } - sb.append(" }\n\n"); - }); + appendMethodDecl(type, methods, sb); } // end @@ -331,6 +127,300 @@ private void writeFile( } } + private void appendMethodDecl(TypeElement type, List methods, StringBuilder sb) { + methods.forEach(method -> { + final TypeMirror returnType = method.getReturnType(); + final TypeKind returnTypeKind = returnType.getKind(); + final var parameters = method.getParameters(); + // check @Ref and @FixedSize + checkParamAnnotation(type, method, parameters); + + final boolean array = isArray(returnType); + final boolean booleanArray = array && isBooleanArray(returnType); + final Native annotation = method.getAnnotation(Native.class); + final Overload overloadAnnotation = method.getAnnotation(Overload.class); + final Custom customAnnotation = method.getAnnotation(Custom.class); + final boolean shouldInsertAllocator = parameters.stream().anyMatch(e -> { + final TypeMirror type1 = e.asType(); + final TypeKind kind = type1.getKind(); + return kind == TypeKind.ARRAY || (kind == TypeKind.DECLARED && isString(type1)); + // TODO: 2023/12/2 Struct + }); + final boolean notVoid = returnTypeKind != TypeKind.VOID; + final boolean shouldStoreResult = notVoid && + parameters.stream().anyMatch(e -> e.getAnnotation(Ref.class) != null); + final boolean isOverload = overloadAnnotation != null; + final String overload; + if (isOverload) { + final String value = overloadAnnotation.value(); + overload = value.isBlank() ? method.getSimpleName().toString() : value; + } else { + overload = null; + } + final String javaReturnType = toTargetType(returnType, overload); + + // javadoc + appendJavadoc(sb, annotation); + + sb.append(" ") + // access modifier + .append(annotation != null ? annotation.scope() : "public") + .append(" static ") + // return type + .append(javaReturnType) + .append(' ') + .append(method.getSimpleName()); + + // parameters + sb.append('('); + appendParameters(sb, shouldInsertAllocator, parameters, overload); + sb.append(") {\n"); + + // method body + if (customAnnotation == null) { + appendMethodBody(type, + sb, + method, + isOverload, + booleanArray, + parameters, + notVoid, + shouldStoreResult, + array, + overload, + returnType, + javaReturnType); + } else { + sb.append(customAnnotation.value().indent(8)); + } + sb.append(" }\n\n"); + }); + } + + private void appendMethodBody(TypeElement type, + StringBuilder sb, + ExecutableElement method, + boolean isOverload, + boolean booleanArray, + List parameters, + boolean notVoid, + boolean shouldStoreResult, + boolean array, + String overload, + TypeMirror returnType, + String javaReturnType) { + if (isOverload) { + // send a warning if using any boolean array + if (booleanArray || parameters.stream().anyMatch(e -> { + final TypeMirror type1 = e.asType(); + return isArray(type1) && isBooleanArray(type1); + })) { + processingEnv.getMessager().printWarning(type + "::" + method + ": Marshalling boolean array"); + } + // check array size + parameters.stream() + .filter(e -> isArray(e.asType()) && e.getAnnotation(FixedSize.class) != null) + .forEach(e -> { + final FixedSize fixedSize = e.getAnnotation(FixedSize.class); + sb.append(" ") + .append("Checks.checkArraySize(") + .append(fixedSize.value()) + .append(", ") + .append(e.getSimpleName()) + .append(".length);\n"); + }); + prependOverloadArgs(sb, parameters); + if (notVoid) { + if (shouldStoreResult) { + sb.append(" var $_marshalResult = "); + } else { + sb.append(" return "); + } + } else { + sb.append(" "); + } + if (array && booleanArray) { + sb.append("BoolHelper.toArray("); + } + sb.append(overload).append('('); + appendOverloadArgs(sb, parameters); + sb.append(')'); + if (array) { + if (booleanArray) { + sb.append(')'); + } else { + sb.append(".toArray(") + .append(toValueLayout(returnType)) + .append(')'); + } + } + sb.append(";\n"); + appendRefArgs(sb, parameters); + if (shouldStoreResult) { + sb.append(" return $_marshalResult;\n"); + } + } else { + sb.append(" try {\n "); + if (notVoid) { + sb.append("return (").append(javaReturnType).append(") "); + } + sb.append(methodEntrypoint(method)).append(".invokeExact("); + // parameters + sb.append(parameters.stream() + .map(VariableElement::getSimpleName) + .collect(Collectors.joining(", "))); + sb.append(");\n } catch (Throwable e) { throw new AssertionError(\"should not reach here\", e); }\n"); + } + } + + private static void appendParameters(StringBuilder sb, boolean shouldInsertAllocator, List parameters, String overload) { + if (shouldInsertAllocator) { + sb.append("SegmentAllocator _segmentAllocator, "); + } + sb.append(parameters.stream() + .map(e -> { + final Ref ref = e.getAnnotation(Ref.class); + final String refString = ref != null ? + "@Ref" + (ref.nullable() ? "(nullable = true) " : " ") : + ""; + final FixedSize fixedSize = e.getAnnotation(FixedSize.class); + final String fixedSizeString = fixedSize != null ? + "@FixedSize(" + fixedSize.value() + ") " : + ""; + return refString + fixedSizeString + toTargetType(e.asType(), overload) + " " + e.getSimpleName(); + }) + .collect(Collectors.joining(", "))); + } + + private static void appendJavadoc(StringBuilder sb, Native annotation) { + if (annotation != null) { + final String doc = annotation.doc(); + if (!doc.isBlank()) { + sb.append(" /**\n"); + doc.lines().map(s -> " * " + s).forEach(s -> { + sb.append(s); + sb.append('\n'); + }); + sb.append(" */\n"); + } + } + } + + private void checkParamAnnotation(TypeElement type, ExecutableElement method, List parameters) { + parameters.stream().filter(e -> !isArray(e.asType())).forEach(e -> { + final Ref ref = e.getAnnotation(Ref.class); + final FixedSize fixedSize = e.getAnnotation(FixedSize.class); + final String prefix = type + "::" + method + ": Using "; + final String suffix = " on non-array parameter " + e.asType() + " " + e.getSimpleName(); + if (ref != null) { + printError(prefix + "@Ref" + suffix); + } + if (fixedSize != null) { + printError(prefix + "@FixedSize" + suffix); + } + }); + } + + private void appendMethodHandles(TypeElement type, StringBuilder sb, List methods) { + // try to collect function names + sb.append(methods.stream() + .collect(Collectors.toMap(NativeApiProcessor::methodEntrypoint, Function.identity(), (e1, e2) -> { + // conflict + final Overload o1 = e1.getAnnotation(Overload.class); + final Overload o2 = e2.getAnnotation(Overload.class); + if (o1 != null) { + if (o2 != null) { + return e1; + } + return e2; + } + if (o2 != null) { + return e1; + } + // neither e1 nor e2 is marked as overload + if (e1.getSimpleName().contentEquals(e2.getSimpleName())) { + printError("Overload not supported between " + type + "::" + e1 + " and " + e2); + } + return e1; + }, LinkedHashMap::new)) + .entrySet() + .stream() + .map(entry1 -> { + final ExecutableElement method = entry1.getValue(); + final Native annotation = method.getAnnotation(Native.class); + final StringBuilder sb1 = new StringBuilder(256); + sb1.append(" ") + .append(annotation != null ? annotation.scope() : "public") + .append(" static final MethodHandle ") + .append(entry1.getKey()) + .append(" = _LOOKUP.find(\"") + .append(entry1.getKey()) + .append("\").map(_s -> _LINKER.downcallHandle(_s, FunctionDescriptor.of"); + // append function descriptor + if (method.getReturnType().getKind() == TypeKind.VOID) { + sb1.append("Void("); + } else { + sb1.append('(').append(toValueLayout(method.getReturnType())); + if (!method.getParameters().isEmpty()) sb1.append(", "); + } + sb1.append(method.getParameters().stream() + .map(e -> toValueLayout(e.asType())) + .collect(Collectors.joining(", "))); + sb1.append("))).orElse"); + if (annotation != null && annotation.optional()) { + sb1.append("(null)"); + } else { + sb1.append("Throw()"); + } + return sb1.toString(); + }) + .collect(Collectors.joining(";\n"))).append(";\n\n"); + } + + private static void appendLoader(TypeElement type, NativeApi nativeApi, StringBuilder sb) { + final String libname = nativeApi.libname(); + final AnnotationMirror annotationMirror = type.getAnnotationMirrors().stream() + .filter(m -> NativeApi.class.getName().equals(m.getAnnotationType().toString())) + .findFirst() + .orElseThrow(); + final String selector = annotationMirror.getElementValues().entrySet().stream() + .filter(entry1 -> "selector()".equals(entry1.getKey().toString())) + .findFirst() + .map(entry2 -> entry2.getValue().getValue().toString()) + .orElse(null); + sb.append(" private static final SymbolLookup _LOOKUP = SymbolLookup.libraryLookup("); + if (selector != null) { + sb.append("new ") + .append(selector) + .append("().select(\"") + .append(libname) + .append("\"), Arena.global()"); + } else { + sb.append('"').append(libname).append("\", Arena.global()"); + } + sb.append(");\n private static final Linker _LINKER = Linker.nativeLinker();\n\n"); + } + + private static void appendFields(List fields, StringBuilder sb) { + fields.forEach(e -> { + final Object constantValue = e.getConstantValue(); + if (constantValue == null) return; + final String typeStr = e.asType().toString(); + sb.append(" public static final ") + .append(isString(typeStr) ? "String" : typeStr) + .append(' ') + .append(e.getSimpleName()) + .append(" = "); + if (constantValue instanceof String) { + sb.append('"').append(constantValue).append('"'); + } else { + sb.append(constantValue); + } + sb.append(";\n"); + }); + sb.append('\n'); + } + @Override public boolean process(Set annotations, RoundEnvironment roundEnv) { try { @@ -345,8 +435,7 @@ private static void prependOverloadArgs(StringBuilder sb, List { final TypeMirror type = e.asType(); final Ref ref = e.getAnnotation(Ref.class); - final boolean array = type.getKind() == TypeKind.ARRAY; - if (array && isPrimitiveArray(type) && ref != null) { + if (isArray(type) && isPrimitiveArray(type) && ref != null) { final Name name = e.getSimpleName(); final TypeMirror arrayComponentType = getArrayComponentType(type); sb.append(" ") @@ -356,8 +445,7 @@ private static void prependOverloadArgs(StringBuilder sb, List