From a64f3201a8f72d96ba28f9bec16d8ccd53614b6d Mon Sep 17 00:00:00 2001 From: Love Leifland Date: Sun, 8 Dec 2024 13:36:39 +0100 Subject: [PATCH] Support lazy loading --- build.gradle | 2 +- procedure-collector/build.gradle | 14 -- .../gds/proc/ProcedureCollectorProcessor.java | 126 ------------------ .../gds/proc/ProcedureCollectorStep.java | 83 ------------ .../java/apoc/processor/ApocProcessor.java | 67 ++-------- .../apoc/processor/ExtensionClassWriter.java | 51 +++---- .../processor/ProcedureServiceWriter.java | 30 ++++- .../java/apoc/processor/SignatureVisitor.java | 66 +++++---- 8 files changed, 92 insertions(+), 347 deletions(-) delete mode 100644 procedure-collector/build.gradle delete mode 100644 procedure-collector/src/main/java/org/neo4j/gds/proc/ProcedureCollectorProcessor.java delete mode 100644 procedure-collector/src/main/java/org/neo4j/gds/proc/ProcedureCollectorStep.java diff --git a/build.gradle b/build.gradle index b56dcbd5b..60750608e 100644 --- a/build.gradle +++ b/build.gradle @@ -170,7 +170,7 @@ apply from: "licenses-source-header.gradle" ext { publicDir = "${project.rootDir}" - neo4jVersionEffective = project.hasProperty("neo4jVersionOverride") ? project.getProperty("neo4jVersionOverride") : "2024.12.0-SNAPSHOT" + neo4jVersionEffective = project.hasProperty("neo4jVersionOverride") ? project.getProperty("neo4jVersionOverride") : "2024.12.0" testContainersVersion = '1.20.2' apacheArrowVersion = '15.0.0' } diff --git a/procedure-collector/build.gradle b/procedure-collector/build.gradle deleted file mode 100644 index afb4cbae1..000000000 --- a/procedure-collector/build.gradle +++ /dev/null @@ -1,14 +0,0 @@ -apply plugin: 'java-library' - -description = 'APOC Procedure Collector' - -group = 'apoc' - -dependencies { - annotationProcessor 'com.google.auto.service:auto-service:1.1.0' - - compileOnly 'com.google.auto.service:auto-service:1.1.0' - - implementation 'org.neo4j:neo4j-procedure-api:2024.12.0-SNAPSHOT' - implementation 'com.google.auto.service:auto-service:1.1.0' -} diff --git a/procedure-collector/src/main/java/org/neo4j/gds/proc/ProcedureCollectorProcessor.java b/procedure-collector/src/main/java/org/neo4j/gds/proc/ProcedureCollectorProcessor.java deleted file mode 100644 index ead381210..000000000 --- a/procedure-collector/src/main/java/org/neo4j/gds/proc/ProcedureCollectorProcessor.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.proc; - -import com.google.auto.common.BasicAnnotationProcessor; -import com.google.auto.service.AutoService; - -import javax.annotation.processing.Processor; -import javax.annotation.processing.RoundEnvironment; -import javax.lang.model.SourceVersion; -import javax.lang.model.element.TypeElement; -import javax.tools.Diagnostic; -import java.io.BufferedOutputStream; -import java.io.IOException; -import java.io.PrintWriter; -import java.nio.charset.StandardCharsets; -import java.util.HashSet; -import java.util.List; -import java.util.Locale; -import java.util.Set; - -import static javax.tools.StandardLocation.CLASS_OUTPUT; - -/** - * An annotation processor that creates files to enable service loading for procedures, user functions and aggregations. - *

- * Only things listed in the session allow list will be written (and thus loaded). - */ -public class ProcedureCollectorProcessor extends BasicAnnotationProcessor { - - private final Set proceduresWithAnnotations = new HashSet<>(); - private final Set functionsWithAnnotations = new HashSet<>(); - private final Set aggregationsWithAnnotations = new HashSet<>(); - - @Override - public SourceVersion getSupportedSourceVersion() { - return SourceVersion.RELEASE_21; - } - - @Override - protected Iterable steps() { - return List.of(new ProcedureCollectorStep( - proceduresWithAnnotations, - functionsWithAnnotations, - aggregationsWithAnnotations - )); - } - - @Override - protected void postRound(RoundEnvironment roundEnv) { - if (roundEnv.processingOver()) { - tryWriteElements(); - } - } - - private void tryWriteElements() { - try { - writeElements(); - proceduresWithAnnotations.clear(); - } catch (IOException e) { - logError(e, - String.format( - Locale.ENGLISH, - "Failed to write procedures for service loading. First: %s", - proceduresWithAnnotations - ) - ); - } - } - - private void writeElements() throws IOException { - if (!proceduresWithAnnotations.isEmpty()) { - writeElementsOfType( ProcedureCollectorStep.PROCEDURE, proceduresWithAnnotations); - } - if (!functionsWithAnnotations.isEmpty()) { - writeElementsOfType( ProcedureCollectorStep.USER_FUNCTION, functionsWithAnnotations); - } - if (!aggregationsWithAnnotations.isEmpty()) { - writeElementsOfType( ProcedureCollectorStep.USER_AGGREGATION, aggregationsWithAnnotations); - } - } - - private void writeElementsOfType(String typeName, Iterable elements) throws IOException { - // we fake being a service so that we get properly merged in the shadow jar - var path = "META-INF/services/" + typeName; - var file = processingEnv.getFiler().createResource(CLASS_OUTPUT, "", path); - - try (var writer = new PrintWriter( - new BufferedOutputStream(file.openOutputStream()), - true, - StandardCharsets.UTF_8 - )) { - for (var element : elements) { - writer.println(element.getQualifiedName()); - } - } - } - - private void logError(Exception e, String message) { - processingEnv.getMessager().printMessage( - Diagnostic.Kind.ERROR, - String.format( - Locale.ENGLISH, - message, - e.getMessage() - ) - ); - } -} diff --git a/procedure-collector/src/main/java/org/neo4j/gds/proc/ProcedureCollectorStep.java b/procedure-collector/src/main/java/org/neo4j/gds/proc/ProcedureCollectorStep.java deleted file mode 100644 index 6ad036fee..000000000 --- a/procedure-collector/src/main/java/org/neo4j/gds/proc/ProcedureCollectorStep.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.proc; - -import com.google.auto.common.BasicAnnotationProcessor; -import com.google.auto.common.MoreElements; -import com.google.common.collect.ImmutableSetMultimap; -import org.neo4j.procedure.Procedure; -import org.neo4j.procedure.UserAggregationFunction; -import org.neo4j.procedure.UserFunction; - -import javax.lang.model.element.Element; -import javax.lang.model.element.TypeElement; -import java.util.Set; - -public class ProcedureCollectorStep implements BasicAnnotationProcessor.Step { - static final String PROCEDURE = Procedure.class.getCanonicalName(); - static final String USER_FUNCTION = UserFunction.class.getCanonicalName(); - static final String USER_AGGREGATION = UserAggregationFunction.class.getCanonicalName(); - - private final Set procedures; - private final Set functions; - private final Set aggregations; - - ProcedureCollectorStep( - Set outProcedures, - Set outFunctions, - Set outAggregations - ) { - this.procedures = outProcedures; - this.functions = outFunctions; - this.aggregations = outAggregations; - } - - @Override - public Set annotations() { - return Set.of(PROCEDURE, USER_FUNCTION, USER_AGGREGATION); - } - - @Override - public Set process(ImmutableSetMultimap elementsByAnnotation) { - for (var procedure : elementsByAnnotation.get(PROCEDURE)) { - if(isInPackage(procedure)) { - procedures.add(MoreElements.asType(procedure.getEnclosingElement())); - } - } - - for (var function : elementsByAnnotation.get(USER_FUNCTION)) { - if(isInPackage(function)) { - functions.add(MoreElements.asType(function.getEnclosingElement())); - } - } - - for (var aggregation : elementsByAnnotation.get(USER_AGGREGATION)) { - if(isInPackage(aggregation)) { - aggregations.add(MoreElements.asType(aggregation.getEnclosingElement())); - } - } - - return Set.of(); - } - - private boolean isInPackage(Element element) { - return MoreElements.getPackage(element).getQualifiedName().toString().startsWith("apoc."); - } -} diff --git a/processor/src/main/java/apoc/processor/ApocProcessor.java b/processor/src/main/java/apoc/processor/ApocProcessor.java index 871582a7d..a22bf7e5c 100644 --- a/processor/src/main/java/apoc/processor/ApocProcessor.java +++ b/processor/src/main/java/apoc/processor/ApocProcessor.java @@ -19,9 +19,7 @@ package apoc.processor; import java.util.ArrayList; -import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Set; import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.ProcessingEnvironment; @@ -29,7 +27,6 @@ import javax.annotation.processing.SupportedSourceVersion; import javax.lang.model.SourceVersion; import javax.lang.model.element.TypeElement; -import org.neo4j.kernel.api.QueryLanguage; import org.neo4j.procedure.Procedure; import org.neo4j.procedure.UserAggregationFunction; import org.neo4j.procedure.UserFunction; @@ -37,13 +34,12 @@ @SupportedSourceVersion(SourceVersion.RELEASE_21) public class ApocProcessor extends AbstractProcessor { - private List>> procedureSignatures; - private List>> userFunctionSignatures; - private Set procedureClassNames; + private List signatures; private SignatureVisitor signatureVisitor; private ExtensionClassWriter extensionClassWriter; + private ProcedureServiceWriter procedureServiceWriter; @Override public Set getSupportedAnnotationTypes() { @@ -52,65 +48,24 @@ public Set getSupportedAnnotationTypes() { @Override public synchronized void init(ProcessingEnvironment processingEnv) { - procedureSignatures = new ArrayList<>(); - userFunctionSignatures = new ArrayList<>(); - procedureClassNames = new HashSet<>(); + signatures = new ArrayList<>(); extensionClassWriter = new ExtensionClassWriter(processingEnv.getFiler()); + procedureServiceWriter = new ProcedureServiceWriter(processingEnv.getFiler()); signatureVisitor = new SignatureVisitor(processingEnv.getElementUtils(), processingEnv.getMessager()); } @Override public boolean process(Set annotations, RoundEnvironment roundEnv) { - annotations.forEach(annotation -> extractSignature(annotation, roundEnv)); - annotations.forEach(annotation -> procedureClassNames.add(annotation.getQualifiedName().toString())); - - List procedureSignaturesCypher5 = new ArrayList<>(); - List userFunctionSignaturesCypher5 = new ArrayList<>(); - List procedureSignaturesCypher25 = new ArrayList<>(); - List userFunctionSignaturesCypher25 = new ArrayList<>(); - - separateKeysByQueryLanguage(procedureSignatures, procedureSignaturesCypher5, procedureSignaturesCypher25); - separateKeysByQueryLanguage( - userFunctionSignatures, userFunctionSignaturesCypher5, userFunctionSignaturesCypher25); + for (final var annotation : annotations) { + for (final var method : roundEnv.getElementsAnnotatedWith(annotation)) { + signatures.add(signatureVisitor.visit(method)); + } + } if (roundEnv.processingOver()) { - extensionClassWriter.write( - procedureSignaturesCypher5, - userFunctionSignaturesCypher5, - procedureSignaturesCypher25, - userFunctionSignaturesCypher25); - + extensionClassWriter.write(signatures); + procedureServiceWriter.write(signatures); } return false; } - - private void extractSignature(TypeElement annotation, RoundEnvironment roundEnv) { - List>> signatures = accumulator(annotation); - roundEnv.getElementsAnnotatedWith(annotation) - .forEach(annotatedElement -> signatures.add(signatureVisitor.visit(annotatedElement))); - } - - private List>> accumulator(TypeElement annotation) { - if (annotation.getQualifiedName().contentEquals(Procedure.class.getName())) { - return procedureSignatures; - } - return userFunctionSignatures; - } - - public static void separateKeysByQueryLanguage( - List>> list, List c5Keys, List c6Keys) { - for (Map> map : list) { - for (Map.Entry> entry : map.entrySet()) { - String key = entry.getKey(); - List values = entry.getValue(); - - if (values.contains(QueryLanguage.CYPHER_5)) { - c5Keys.add(key); - } - if (values.contains(QueryLanguage.CYPHER_25)) { - c6Keys.add(key); - } - } - } - } } diff --git a/processor/src/main/java/apoc/processor/ExtensionClassWriter.java b/processor/src/main/java/apoc/processor/ExtensionClassWriter.java index 191f417c9..ee6ac600c 100644 --- a/processor/src/main/java/apoc/processor/ExtensionClassWriter.java +++ b/processor/src/main/java/apoc/processor/ExtensionClassWriter.java @@ -25,6 +25,7 @@ import com.squareup.javapoet.ParameterizedTypeName; import com.squareup.javapoet.TypeSpec; import java.io.IOException; +import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; import javax.annotation.processing.Filer; @@ -32,6 +33,7 @@ import javax.lang.model.element.Modifier; import javax.tools.FileObject; import javax.tools.StandardLocation; +import org.neo4j.kernel.api.QueryLanguage; public class ExtensionClassWriter { @@ -41,22 +43,10 @@ public ExtensionClassWriter(Filer filer) { this.filer = filer; } - public void write( - List procedureSignaturesCypher5, - List userFunctionSignaturesCypher5, - List procedureSignaturesCypher25, - List userFunctionSignaturesCypher25) { - + public void write(List signatures) { try { - String suffix = isExtendedProject() ? "Extended" : ""; - final TypeSpec typeSpec = defineClass( - procedureSignaturesCypher5, - userFunctionSignaturesCypher5, - procedureSignaturesCypher25, - userFunctionSignaturesCypher25, - suffix); - - JavaFile.builder("apoc", typeSpec).build().writeTo(filer); + final var suffix = isExtendedProject() ? "Extended" : ""; + JavaFile.builder("apoc", defineClass(signatures, suffix)).build().writeTo(filer); } catch (IOException e) { throw new RuntimeException(e); } @@ -71,33 +61,36 @@ private boolean isExtendedProject() throws IOException { return projectPath.contains("extended/build/generated"); } - private TypeSpec defineClass( - List procedureSignaturesCypher5, - List userFunctionSignaturesCypher5, - List procedureSignaturesCypher25, - List userFunctionSignaturesCypher25, - String suffix) { + private TypeSpec defineClass(List signatures, String suffix) { return TypeSpec.classBuilder("ApocSignatures" + suffix) .addModifiers(Modifier.PUBLIC) - .addField(signatureListField("PROCEDURES_CYPHER_5", procedureSignaturesCypher5)) - .addField(signatureListField("FUNCTIONS_CYPHER_5", userFunctionSignaturesCypher5)) - .addField(signatureListField("PROCEDURES_CYPHER_25", procedureSignaturesCypher25)) - .addField(signatureListField("FUNCTIONS_CYPHER_25", userFunctionSignaturesCypher25)) + .addField(signatureListField("PROCEDURES_CYPHER_5", names(signatures, true, QueryLanguage.CYPHER_5))) + .addField(signatureListField("FUNCTIONS_CYPHER_5", names(signatures, false, QueryLanguage.CYPHER_5))) + .addField(signatureListField("PROCEDURES_CYPHER_25", names(signatures, true, QueryLanguage.CYPHER_25))) + .addField(signatureListField("FUNCTIONS_CYPHER_25", names(signatures, false, QueryLanguage.CYPHER_25))) .build(); } - private FieldSpec signatureListField(String fieldName, List signatures) { + private String[] names(List signatures, boolean procedure, QueryLanguage lang) { + return signatures.stream() + .filter(s -> s.isProcedure() == procedure) + .filter(s -> s.scope().contains(lang)) + .map(Signature::methodName) + .toArray(String[]::new); + } + + private FieldSpec signatureListField(String fieldName, String[] signatures) { ParameterizedTypeName fieldType = ParameterizedTypeName.get(ClassName.get(List.class), ClassName.get(String.class)); return FieldSpec.builder(fieldType, fieldName, Modifier.PUBLIC, Modifier.STATIC, Modifier.FINAL) .initializer(CodeBlock.builder() - .addStatement(String.format("List.of(%s)", placeholders(signatures)), signatures.toArray()) + .addStatement(String.format("List.of(%s)", placeholders(signatures)), (Object[]) signatures) .build()) .build(); } - private String placeholders(List signatures) { + private String placeholders(String[] signatures) { // FIXME: find a way to manage the indentation automatically - return signatures.stream().map((ignored) -> "$S").collect(Collectors.joining(",\n\t\t")); + return Arrays.stream(signatures).map((ignored) -> "$S").collect(Collectors.joining(",\n\t\t")); } } diff --git a/processor/src/main/java/apoc/processor/ProcedureServiceWriter.java b/processor/src/main/java/apoc/processor/ProcedureServiceWriter.java index dd5b4a6f8..71f5208d0 100644 --- a/processor/src/main/java/apoc/processor/ProcedureServiceWriter.java +++ b/processor/src/main/java/apoc/processor/ProcedureServiceWriter.java @@ -1,16 +1,42 @@ package apoc.processor; +import static javax.tools.StandardLocation.CLASS_OUTPUT; + +import java.io.BufferedOutputStream; +import java.io.IOException; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; import java.util.List; import javax.annotation.processing.Filer; +import org.neo4j.procedure.Procedure; public class ProcedureServiceWriter { private final Filer filer; - public ProcedureServiceWriter( Filer filer ) { + public ProcedureServiceWriter(Filer filer) { this.filer = filer; } - public void write( List procedureNames) { + public void write(List signatures) { + final var classNames = signatures.stream() + .map(Signature::className) + .distinct() + .sorted() + .toList(); + try { + writeProcedureClasses(classNames); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private void writeProcedureClasses(Iterable classNames) throws IOException { + final var path = "META-INF/services/" + Procedure.class.getCanonicalName(); + var file = filer.createResource(CLASS_OUTPUT, "", path); + try (var writer = + new PrintWriter(new BufferedOutputStream(file.openOutputStream()), true, StandardCharsets.UTF_8)) { + for (final var name : classNames) writer.println(name); + } } } diff --git a/processor/src/main/java/apoc/processor/SignatureVisitor.java b/processor/src/main/java/apoc/processor/SignatureVisitor.java index f8373f231..392cbd8e8 100644 --- a/processor/src/main/java/apoc/processor/SignatureVisitor.java +++ b/processor/src/main/java/apoc/processor/SignatureVisitor.java @@ -19,12 +19,13 @@ package apoc.processor; import java.util.Arrays; -import java.util.List; -import java.util.Map; +import java.util.EnumSet; import java.util.Optional; +import java.util.Set; import javax.annotation.processing.Messager; import javax.lang.model.element.Element; import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.TypeElement; import javax.lang.model.util.Elements; import javax.lang.model.util.SimpleElementVisitor9; import javax.tools.Diagnostic; @@ -34,7 +35,7 @@ import org.neo4j.procedure.UserAggregationFunction; import org.neo4j.procedure.UserFunction; -public class SignatureVisitor extends SimpleElementVisitor9>, Void> { +public class SignatureVisitor extends SimpleElementVisitor9 { private final Elements elementUtils; @@ -47,59 +48,52 @@ public SignatureVisitor(Elements elementUtils, Messager messager) { } @Override - public Map> visitExecutable(ExecutableElement method, Void unused) { - return Map.of( - getAnnotationName(method) - .orElse(String.format("%s.%s", elementUtils.getPackageOf(method), method.getSimpleName())), - getCypherScopes(method)); + public Signature visitExecutable(ExecutableElement method, Void unused) { + final var isProcedure = method.getAnnotation(Procedure.class) != null; + final var className = + ((TypeElement) method.getEnclosingElement()).getQualifiedName().toString(); + final var methodName = getProcedureName(method) + .or(() -> getUserFunctionName(method)) + .or(() -> getUserAggregationFunctionName(method)) + .orElse("%s.%s".formatted(elementUtils.getPackageOf(method), method.getSimpleName())); + return new Signature(methodName, isProcedure, cypherScopes(method), className); } @Override - public Map> visitUnknown(Element e, Void unused) { + public Signature visitUnknown(Element e, Void unused) { messager.printMessage(Diagnostic.Kind.ERROR, "unexpected ....."); return super.visitUnknown(e, unused); } - private Optional getAnnotationName(ExecutableElement method) { - return getProcedureName(method) - .or(() -> getUserFunctionName(method)) - .or(() -> getUserAggregationFunctionName(method)); - } - - private List getCypherScopes(ExecutableElement method) { - return Optional.ofNullable(method.getAnnotation(QueryLanguageScope.class)) - .map(annotation -> { - QueryLanguage[] scope = annotation.scope(); - return scope.length > 0 - ? Arrays.asList(scope) - : List.of(QueryLanguage.CYPHER_5, QueryLanguage.CYPHER_25); - }) - .orElse(List.of(QueryLanguage.CYPHER_5, QueryLanguage.CYPHER_25)); + private Set cypherScopes(ExecutableElement method) { + final var annotation = method.getAnnotation(QueryLanguageScope.class); + if (annotation != null && annotation.scope().length > 0) { + return EnumSet.copyOf(Arrays.asList(annotation.scope())); + } else { + return QueryLanguage.ALL; + } } private Optional getProcedureName(ExecutableElement method) { return Optional.ofNullable(method.getAnnotation(Procedure.class)) - .map((annotation) -> pickFirstNonBlank(annotation.name(), annotation.value())) - .flatMap(this::blankToEmpty); + .flatMap((annotation) -> pickFirstNonBlank(annotation.name(), annotation.value())); } private Optional getUserFunctionName(ExecutableElement method) { return Optional.ofNullable(method.getAnnotation(UserFunction.class)) - .map((annotation) -> pickFirstNonBlank(annotation.name(), annotation.value())) - .flatMap(this::blankToEmpty); + .flatMap((annotation) -> pickFirstNonBlank(annotation.name(), annotation.value())); } private Optional getUserAggregationFunctionName(ExecutableElement method) { return Optional.ofNullable(method.getAnnotation(UserAggregationFunction.class)) - .map((annotation) -> pickFirstNonBlank(annotation.name(), annotation.value())) - .flatMap(this::blankToEmpty); + .flatMap((annotation) -> pickFirstNonBlank(annotation.name(), annotation.value())); } - private Optional blankToEmpty(String s) { - return s.isBlank() ? Optional.empty() : Optional.of(s); - } - - private String pickFirstNonBlank(String name, String value) { - return name.isBlank() ? value : name; + private Optional pickFirstNonBlank(String name, String value) { + if (!name.isBlank()) return Optional.of(name); + else if (!value.isBlank()) return Optional.of(value); + else return Optional.empty(); } } + +record Signature(String methodName, boolean isProcedure, Set scope, String className) {}