diff --git a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/Handler.java b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/Handler.java index a755d7b1..67bbf5df 100644 --- a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/Handler.java +++ b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/model/Handler.java @@ -9,18 +9,25 @@ package dev.restate.sdk.gen.model; import java.util.Objects; +import org.jspecify.annotations.Nullable; public class Handler { private final CharSequence name; private final HandlerType handlerType; + private final @Nullable String inputAccept; private final PayloadType inputType; private final PayloadType outputType; public Handler( - CharSequence name, HandlerType handlerType, PayloadType inputType, PayloadType outputType) { + CharSequence name, + HandlerType handlerType, + @Nullable String inputAccept, + PayloadType inputType, + PayloadType outputType) { this.name = name; this.handlerType = handlerType; + this.inputAccept = inputAccept; this.inputType = inputType; this.outputType = outputType; } @@ -33,6 +40,10 @@ public HandlerType getHandlerType() { return handlerType; } + public String getInputAccept() { + return inputAccept; + } + public PayloadType getInputType() { return inputType; } @@ -48,6 +59,7 @@ public static Builder builder() { public static class Builder { private CharSequence name; private HandlerType handlerType; + private String inputAccept; private PayloadType inputType; private PayloadType outputType; @@ -61,6 +73,11 @@ public Builder withHandlerType(HandlerType handlerType) { return this; } + public Builder withInputAccept(String inputAccept) { + this.inputAccept = inputAccept; + return this; + } + public Builder withInputType(PayloadType inputType) { this.inputType = inputType; return this; @@ -96,7 +113,11 @@ public Handler validateAndBuild() { } return new Handler( - Objects.requireNonNull(name), Objects.requireNonNull(handlerType), inputType, outputType); + Objects.requireNonNull(name), + Objects.requireNonNull(handlerType), + inputAccept, + inputType, + outputType); } } } diff --git a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/template/HandlebarsTemplateEngine.java b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/template/HandlebarsTemplateEngine.java index 687f68b5..eddf5f10 100644 --- a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/template/HandlebarsTemplateEngine.java +++ b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/template/HandlebarsTemplateEngine.java @@ -107,7 +107,12 @@ private ServiceTemplateModel( this.handlers = inner.getMethods().stream() - .map(h -> new HandlerTemplateModel(h, handlerNamesToPrefix)) + .map( + h -> + new HandlerTemplateModel( + h, + this.generatedClassSimpleNamePrefix + "Definitions.Serde", + handlerNamesToPrefix)) .collect(Collectors.toList()); } } @@ -126,14 +131,18 @@ static class HandlerTemplateModel { public final String inputSerdeDecl; public final String boxedInputFqcn; public final String inputSerdeFieldName; + public final String inputAcceptContentType; + public final String inputSerdeRef; public final boolean outputEmpty; public final String outputFqcn; public final String outputSerdeDecl; public final String boxedOutputFqcn; public final String outputSerdeFieldName; + public final String outputSerdeRef; - private HandlerTemplateModel(Handler inner, Set handlerNamesToPrefix) { + private HandlerTemplateModel( + Handler inner, String definitionsClass, Set handlerNamesToPrefix) { this.name = inner.getName().toString(); this.methodName = (handlerNamesToPrefix.contains(this.name) ? "_" : "") + this.name; this.handlerType = inner.getHandlerType().toString(); @@ -146,13 +155,16 @@ private HandlerTemplateModel(Handler inner, Set handlerNamesToPrefix) { this.inputFqcn = inner.getInputType().getName(); this.inputSerdeDecl = inner.getInputType().getSerdeDecl(); this.boxedInputFqcn = inner.getInputType().getBoxed(); - this.inputSerdeFieldName = "SERDE_" + this.name.toUpperCase() + "_INPUT"; + this.inputSerdeFieldName = this.name.toUpperCase() + "_INPUT"; + this.inputAcceptContentType = inner.getInputAccept(); + this.inputSerdeRef = definitionsClass + "." + this.inputSerdeFieldName; this.outputEmpty = inner.getOutputType().isEmpty(); this.outputFqcn = inner.getOutputType().getName(); this.outputSerdeDecl = inner.getOutputType().getSerdeDecl(); this.boxedOutputFqcn = inner.getOutputType().getBoxed(); - this.outputSerdeFieldName = "SERDE_" + this.name.toUpperCase() + "_OUTPUT"; + this.outputSerdeFieldName = this.name.toUpperCase() + "_OUTPUT"; + this.outputSerdeRef = definitionsClass + "." + this.outputSerdeFieldName; } } } diff --git a/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/utils/AnnotationUtils.java b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/utils/AnnotationUtils.java new file mode 100644 index 00000000..b48f0ca9 --- /dev/null +++ b/sdk-api-gen-common/src/main/java/dev/restate/sdk/gen/utils/AnnotationUtils.java @@ -0,0 +1,23 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.gen.utils; + +import java.lang.annotation.Annotation; +import java.util.Objects; + +public class AnnotationUtils { + public static Object getAnnotationDefaultValue( + Class annotation, String name) { + try { + return Objects.requireNonNull(annotation.getMethod(name).getDefaultValue()); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } +} diff --git a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ElementConverter.java b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ElementConverter.java index 9b37f3b7..7a86734f 100644 --- a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ElementConverter.java +++ b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ElementConverter.java @@ -11,32 +11,31 @@ import dev.restate.sdk.Context; import dev.restate.sdk.ObjectContext; import dev.restate.sdk.SharedObjectContext; -import dev.restate.sdk.annotation.Exclusive; -import dev.restate.sdk.annotation.Shared; -import dev.restate.sdk.annotation.Workflow; +import dev.restate.sdk.annotation.*; import dev.restate.sdk.common.ServiceType; import dev.restate.sdk.gen.model.*; +import dev.restate.sdk.gen.model.Handler; +import dev.restate.sdk.gen.model.Service; +import dev.restate.sdk.gen.utils.AnnotationUtils; import dev.restate.sdk.workflow.WorkflowContext; import dev.restate.sdk.workflow.WorkflowSharedContext; import java.util.List; -import java.util.Objects; import java.util.regex.Pattern; import java.util.stream.Collectors; import javax.annotation.processing.Messager; -import javax.lang.model.element.ElementKind; -import javax.lang.model.element.ExecutableElement; -import javax.lang.model.element.Modifier; -import javax.lang.model.element.TypeElement; +import javax.lang.model.element.*; import javax.lang.model.type.TypeKind; import javax.lang.model.type.TypeMirror; import javax.lang.model.util.Elements; import javax.lang.model.util.Types; import javax.tools.Diagnostic; +import org.jspecify.annotations.Nullable; public class ElementConverter { private static final PayloadType EMPTY_PAYLOAD = new PayloadType(true, "", "Void", "dev.restate.sdk.common.CoreSerdes.VOID"); + private static final String RAW_SERDE = "dev.restate.sdk.common.CoreSerdes.RAW"; private final Messager messager; private final Elements elements; @@ -197,18 +196,12 @@ private Handler fromExecutableElement(ServiceType serviceType, ExecutableElement return new Handler.Builder() .withName(element.getSimpleName()) .withHandlerType(handlerType) - .withInputType( - element.getParameters().size() > 1 - ? payloadFromType(element.getParameters().get(1).asType()) - : EMPTY_PAYLOAD) - .withOutputType( - !element.getReturnType().getKind().equals(TypeKind.VOID) - ? payloadFromType(element.getReturnType()) - : EMPTY_PAYLOAD) + .withInputAccept(inputAcceptFromParameterList(element.getParameters())) + .withInputType(inputPayloadFromParameterList(element.getParameters())) + .withOutputType(outputPayloadFromExecutableElement(element)) .validateAndBuild(); } catch (Exception e) { - messager.printMessage( - Diagnostic.Kind.ERROR, "Error when building handler: " + e.getMessage(), element); + messager.printMessage(Diagnostic.Kind.ERROR, "Error when building handler: " + e, element); return null; } } @@ -280,13 +273,89 @@ private void validateFirstParameterType(Class clazz, ExecutableElement elemen } } - private PayloadType payloadFromType(TypeMirror typeMirror) { - Objects.requireNonNull(typeMirror); - return new PayloadType( - false, typeMirror.toString(), boxedType(typeMirror), serdeDecl(typeMirror)); + private String inputAcceptFromParameterList(List element) { + if (element.size() <= 1) { + return null; + } + + Accept accept = element.get(1).getAnnotation(Accept.class); + if (accept == null) { + return null; + } + return accept.value(); + } + + private PayloadType inputPayloadFromParameterList(List element) { + if (element.size() <= 1) { + return EMPTY_PAYLOAD; + } + + Element parameterElement = element.get(1); + return payloadFromTypeMirrorAndAnnotations( + parameterElement.asType(), + parameterElement.getAnnotation(Json.class), + parameterElement.getAnnotation(Raw.class), + parameterElement); + } + + private PayloadType outputPayloadFromExecutableElement(ExecutableElement element) { + return payloadFromTypeMirrorAndAnnotations( + element.getReturnType(), + element.getAnnotation(Json.class), + element.getAnnotation(Raw.class), + element); + } + + private PayloadType payloadFromTypeMirrorAndAnnotations( + TypeMirror ty, @Nullable Json jsonAnnotation, @Nullable Raw rawAnnotation, Element element) { + if (ty.getKind().equals(TypeKind.VOID)) { + if (rawAnnotation != null || jsonAnnotation != null) { + messager.printMessage( + Diagnostic.Kind.ERROR, "Unexpected annotation for void type.", element); + } + return EMPTY_PAYLOAD; + } + // Some validation + if (rawAnnotation != null && jsonAnnotation != null) { + messager.printMessage( + Diagnostic.Kind.ERROR, + "A parameter cannot be annotated both with @Raw and @Json.", + element); + } + if (rawAnnotation != null + && !types.isSameType(ty, types.getArrayType(types.getPrimitiveType(TypeKind.BYTE)))) { + messager.printMessage( + Diagnostic.Kind.ERROR, + "A parameter annotated with @Raw MUST be of type byte[], was " + ty, + element); + } + + String serdeDecl = rawAnnotation != null ? RAW_SERDE : jsonSerdeDecl(ty); + if (rawAnnotation != null + && !rawAnnotation + .contentType() + .equals(AnnotationUtils.getAnnotationDefaultValue(Raw.class, "contentType"))) { + serdeDecl = contentTypeDecoratedSerdeDecl(serdeDecl, rawAnnotation.contentType()); + } + if (jsonAnnotation != null + && !jsonAnnotation + .contentType() + .equals(AnnotationUtils.getAnnotationDefaultValue(Json.class, "contentType"))) { + serdeDecl = contentTypeDecoratedSerdeDecl(serdeDecl, jsonAnnotation.contentType()); + } + + return new PayloadType(false, ty.toString(), boxedType(ty), serdeDecl); + } + + private static String contentTypeDecoratedSerdeDecl(String serdeDecl, String contentType) { + return "dev.restate.sdk.common.Serde.withContentType(\"" + + contentType + + "\", " + + serdeDecl + + ")"; } - private static String serdeDecl(TypeMirror ty) { + private static String jsonSerdeDecl(TypeMirror ty) { switch (ty.getKind()) { case BOOLEAN: return "dev.restate.sdk.common.CoreSerdes.JSON_BOOLEAN"; diff --git a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java index f8893c76..e6c1a71d 100644 --- a/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java +++ b/sdk-api-gen/src/main/java/dev/restate/sdk/gen/ServiceProcessor.java @@ -35,6 +35,7 @@ @SupportedSourceVersion(SourceVersion.RELEASE_11) public class ServiceProcessor extends AbstractProcessor { + private HandlebarsTemplateEngine definitionsCodegen; private HandlebarsTemplateEngine bindableServiceFactoryCodegen; private HandlebarsTemplateEngine bindableServiceCodegen; private HandlebarsTemplateEngine clientCodegen; @@ -47,6 +48,18 @@ public synchronized void init(ProcessingEnvironment processingEnv) { FilerTemplateLoader filerTemplateLoader = new FilerTemplateLoader(processingEnv.getFiler()); + this.definitionsCodegen = + new HandlebarsTemplateEngine( + "Definitions", + filerTemplateLoader, + Map.of( + ServiceType.WORKFLOW, + "templates/Definitions.hbs", + ServiceType.SERVICE, + "templates/Definitions.hbs", + ServiceType.VIRTUAL_OBJECT, + "templates/Definitions.hbs"), + RESERVED_METHOD_NAMES); this.bindableServiceFactoryCodegen = new HandlebarsTemplateEngine( "BindableServiceFactory", @@ -108,6 +121,7 @@ public boolean process(Set annotations, RoundEnvironment try { ThrowingFunction fileCreator = name -> filer.createSourceFile(name, e.getKey()).openWriter(); + this.definitionsCodegen.generate(fileCreator, e.getValue()); this.bindableServiceFactoryCodegen.generate(fileCreator, e.getValue()); this.bindableServiceCodegen.generate(fileCreator, e.getValue()); this.clientCodegen.generate(fileCreator, e.getValue()); diff --git a/sdk-api-gen/src/main/resources/templates/BindableService.hbs b/sdk-api-gen/src/main/resources/templates/BindableService.hbs index 132adc65..cde5c7df 100644 --- a/sdk-api-gen/src/main/resources/templates/BindableService.hbs +++ b/sdk-api-gen/src/main/resources/templates/BindableService.hbs @@ -2,38 +2,41 @@ public class {{generatedClassSimpleName}} implements dev.restate.sdk.common.BindableService { - public static final String SERVICE_NAME = "{{serviceName}}"; - - private final dev.restate.sdk.Service service; + private final dev.restate.sdk.common.syscalls.ServiceDefinition service; + private final dev.restate.sdk.Service.Options options; public {{generatedClassSimpleName}}({{originalClassFqcn}} bindableService) { this(bindableService, dev.restate.sdk.Service.Options.DEFAULT); } public {{generatedClassSimpleName}}({{originalClassFqcn}} bindableService, dev.restate.sdk.Service.Options options) { - this.service = dev.restate.sdk.Service.{{#if isObject}}virtualObject{{else}}service{{/if}}(SERVICE_NAME) + this.service = dev.restate.sdk.common.syscalls.ServiceDefinition.of( + {{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, + {{#if isObject}}dev.restate.sdk.common.ServiceType.VIRTUAL_OBJECT{{else}}dev.restate.sdk.common.ServiceType.SERVICE{{/if}}, + java.util.List.of( {{#handlers}} - .{{#if isShared}}withShared{{else if isExclusive}}withExclusive{{else}}with{{/if}}( - dev.restate.sdk.Service.HandlerSignature.of("{{name}}", {{{inputSerdeDecl}}}, {{{outputSerdeDecl}}}), - (ctx, req) -> { - {{#if outputEmpty}} - {{#if inputEmpty}}bindableService.{{name}}(ctx){{else}}bindableService.{{name}}(ctx, req){{/if}}; - return null; - {{else}} - return {{#if inputEmpty}}bindableService.{{name}}(ctx){{else}}bindableService.{{name}}(ctx, req){{/if}}; - {{/if}} - }) + dev.restate.sdk.common.syscalls.HandlerDefinition.of( + dev.restate.sdk.common.syscalls.HandlerSpecification.of( + "{{name}}", + {{#if isExclusive}}dev.restate.sdk.common.HandlerType.EXCLUSIVE{{else}}dev.restate.sdk.common.HandlerType.SHARED{{/if}}, + {{inputSerdeRef}}, + {{outputSerdeRef}} + ){{#if inputAcceptContentType}}.withAcceptContentType("{{inputAcceptContentType}}"){{/if}}, + dev.restate.sdk.Service.Handler.of(bindableService::{{name}}) + ){{#unless @last}},{{/unless}} {{/handlers}} - .build(options); + ) + ); + this.options = options; } @Override public dev.restate.sdk.Service.Options options() { - return this.service.options(); + return this.options; } @Override public java.util.List> definitions() { - return this.service.definitions(); + return java.util.List.of(this.service); } } \ No newline at end of file diff --git a/sdk-api-gen/src/main/resources/templates/Client.hbs b/sdk-api-gen/src/main/resources/templates/Client.hbs index f58b3c65..ac3c80fc 100644 --- a/sdk-api-gen/src/main/resources/templates/Client.hbs +++ b/sdk-api-gen/src/main/resources/templates/Client.hbs @@ -10,13 +10,6 @@ import java.time.Duration; public class {{generatedClassSimpleName}} { - public static final String SERVICE_NAME = "{{serviceName}}"; - - {{#handlers}} - private static final Serde<{{{boxedInputFqcn}}}> {{inputSerdeFieldName}} = {{{inputSerdeDecl}}}; - private static final Serde<{{{boxedOutputFqcn}}}> {{outputSerdeFieldName}} = {{{outputSerdeDecl}}}; - {{/handlers}} - public static ContextClient fromContext(Context ctx{{#isObject}}, String key{{/isObject}}) { return new ContextClient(ctx{{#isObject}}, key{{/isObject}}); } @@ -42,9 +35,9 @@ public class {{generatedClassSimpleName}} { {{#handlers}} public Awaitable<{{{boxedOutputFqcn}}}> {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { return this.ctx.call( - {{#if isObject}}Target.virtualObject(SERVICE_NAME, this.key, "{{name}}"){{else}}Target.service(SERVICE_NAME, "{{name}}"){{/if}}, - {{inputSerdeFieldName}}, - {{outputSerdeFieldName}}, + {{#if isObject}}Target.virtualObject({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, this.key, "{{name}}"){{else}}Target.service({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, "{{name}}"){{/if}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, {{#if inputEmpty}}null{{else}}req{{/if}}); }{{/handlers}} @@ -67,8 +60,8 @@ public class {{generatedClassSimpleName}} { {{#handlers}} public void {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req{{/inputEmpty}}) { ContextClient.this.ctx.send( - {{#if isObject}}Target.virtualObject(SERVICE_NAME, ContextClient.this.key, "{{name}}"){{else}}Target.service(SERVICE_NAME, "{{name}}"){{/if}}, - {{inputSerdeFieldName}}, + {{#if isObject}}Target.virtualObject({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, ContextClient.this.key, "{{name}}"){{else}}Target.service({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, "{{name}}"){{/if}}, + {{inputSerdeRef}}, {{#if inputEmpty}}null{{else}}req{{/if}}, delay); }{{/handlers}} @@ -94,9 +87,9 @@ public class {{generatedClassSimpleName}} { public {{#if outputEmpty}}void{{else}}{{{outputFqcn}}}{{/if}} {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) { {{^outputEmpty}}return {{/outputEmpty}}this.ingressClient.call( - {{#if isObject}}Target.virtualObject(SERVICE_NAME, this.key, "{{name}}"){{else}}Target.service(SERVICE_NAME, "{{name}}"){{/if}}, - {{inputSerdeFieldName}}, - {{outputSerdeFieldName}}, + {{#if isObject}}Target.virtualObject({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, this.key, "{{name}}"){{else}}Target.service({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, "{{name}}"){{/if}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, {{#if inputEmpty}}null{{else}}req{{/if}}, requestOptions); } @@ -109,9 +102,9 @@ public class {{generatedClassSimpleName}} { public {{#if outputEmpty}}java.util.concurrent.CompletableFuture{{else}}java.util.concurrent.CompletableFuture<{{{boxedOutputFqcn}}}>{{/if}} {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) { return this.ingressClient.callAsync( - {{#if isObject}}Target.virtualObject(SERVICE_NAME, this.key, "{{name}}"){{else}}Target.service(SERVICE_NAME, "{{name}}"){{/if}}, - {{inputSerdeFieldName}}, - {{outputSerdeFieldName}}, + {{#if isObject}}Target.virtualObject({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, this.key, "{{name}}"){{else}}Target.service({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, "{{name}}"){{/if}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, {{#if inputEmpty}}null{{else}}req{{/if}}, requestOptions); }{{/handlers}} @@ -141,8 +134,8 @@ public class {{generatedClassSimpleName}} { public String {{methodName}}({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) { return IngressClient.this.ingressClient.send( - {{#if isObject}}Target.virtualObject(SERVICE_NAME, IngressClient.this.key, "{{name}}"){{else}}Target.service(SERVICE_NAME, "{{name}}"){{/if}}, - {{inputSerdeFieldName}}, + {{#if isObject}}Target.virtualObject({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, IngressClient.this.key, "{{name}}"){{else}}Target.service({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, "{{name}}"){{/if}}, + {{inputSerdeRef}}, {{#if inputEmpty}}null{{else}}req{{/if}}, this.delay, requestOptions); @@ -156,8 +149,8 @@ public class {{generatedClassSimpleName}} { public java.util.concurrent.CompletableFuture {{methodName}}Async({{^inputEmpty}}{{{inputFqcn}}} req, {{/inputEmpty}}dev.restate.sdk.client.RequestOptions requestOptions) { return IngressClient.this.ingressClient.sendAsync( - {{#if isObject}}Target.virtualObject(SERVICE_NAME, IngressClient.this.key, "{{name}}"){{else}}Target.service(SERVICE_NAME, "{{name}}"){{/if}}, - {{inputSerdeFieldName}}, + {{#if isObject}}Target.virtualObject({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, IngressClient.this.key, "{{name}}"){{else}}Target.service({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, "{{name}}"){{/if}}, + {{inputSerdeRef}}, {{#if inputEmpty}}null{{else}}req{{/if}}, this.delay, requestOptions); diff --git a/sdk-api-gen/src/main/resources/templates/Definitions.hbs b/sdk-api-gen/src/main/resources/templates/Definitions.hbs new file mode 100644 index 00000000..bb3932b4 --- /dev/null +++ b/sdk-api-gen/src/main/resources/templates/Definitions.hbs @@ -0,0 +1,18 @@ +{{#if originalClassPkg}}package {{originalClassPkg}};{{/if}} + +public final class {{generatedClassSimpleName}} { + + public static final String SERVICE_NAME = "{{serviceName}}"; + + private {{generatedClassSimpleName}}() {} + + public final static class Serde { + {{#handlers}} + public static final dev.restate.sdk.common.Serde<{{{boxedInputFqcn}}}> {{inputSerdeFieldName}} = {{{inputSerdeDecl}}}; + public static final dev.restate.sdk.common.Serde<{{{boxedOutputFqcn}}}> {{outputSerdeFieldName}} = {{{outputSerdeDecl}}}; + {{/handlers}} + + private Serde() {} + } + +} \ No newline at end of file diff --git a/sdk-api-gen/src/main/resources/templates/workflow/BindableService.hbs b/sdk-api-gen/src/main/resources/templates/workflow/BindableService.hbs index 0b113786..64f277d2 100644 --- a/sdk-api-gen/src/main/resources/templates/workflow/BindableService.hbs +++ b/sdk-api-gen/src/main/resources/templates/workflow/BindableService.hbs @@ -14,7 +14,8 @@ public class {{generatedClassSimpleName}} implements dev.restate.sdk.common.Bind this.service = dev.restate.sdk.workflow.WorkflowBuilder.named( SERVICE_NAME, {{#handlers}}{{#if isWorkflow}} - dev.restate.sdk.Service.HandlerSignature.of("{{name}}", {{{inputSerdeDecl}}}, {{{outputSerdeDecl}}}), + {{{inputSerdeDecl}}}, + {{{outputSerdeDecl}}}, (ctx, req) -> { {{#if outputEmpty}} {{#if inputEmpty}}bindableService.{{name}}(ctx){{else}}bindableService.{{name}}(ctx, req){{/if}}; @@ -26,7 +27,9 @@ public class {{generatedClassSimpleName}} implements dev.restate.sdk.common.Bind {{/if}}{{/handlers}}) {{#handlers}}{{#if isShared}} .with{{capitalizeFirst (lower handlerType)}}( - dev.restate.sdk.Service.HandlerSignature.of("{{name}}", {{{inputSerdeDecl}}}, {{{outputSerdeDecl}}}), + "{{name}}", + {{{inputSerdeDecl}}}, + {{{outputSerdeDecl}}}, (ctx, req) -> { {{#if outputEmpty}} {{#if inputEmpty}}bindableService.{{name}}(ctx){{else}}bindableService.{{name}}(ctx, req){{/if}}; diff --git a/sdk-api-gen/src/test/java/dev/restate/sdk/CodegenTest.java b/sdk-api-gen/src/test/java/dev/restate/sdk/CodegenTest.java index eeae2221..17086cb7 100644 --- a/sdk-api-gen/src/test/java/dev/restate/sdk/CodegenTest.java +++ b/sdk-api-gen/src/test/java/dev/restate/sdk/CodegenTest.java @@ -19,6 +19,7 @@ import dev.restate.sdk.core.ProtoUtils; import dev.restate.sdk.core.TestDefinitions; import dev.restate.sdk.core.TestDefinitions.TestSuite; +import java.nio.charset.StandardCharsets; import java.util.stream.Stream; public class CodegenTest implements TestSuite { @@ -106,6 +107,45 @@ public String send(ObjectContext context, String request) { } } + @Service(name = "RawInputOutput") + static class RawInputOutput { + + @Handler + @Raw + public byte[] rawOutput(Context context) { + var client = RawInputOutputClient.fromContext(context); + return client.rawOutput().await(); + } + + @Handler + @Raw(contentType = "application/vnd.my.custom") + public byte[] rawOutputWithCustomCT(Context context) { + var client = RawInputOutputClient.fromContext(context); + return client.rawOutputWithCustomCT().await(); + } + + @Handler + public void rawInput(Context context, @Raw byte[] input) { + var client = RawInputOutputClient.fromContext(context); + client.rawInput(input).await(); + } + + @Handler + public void rawInputWithCustomCt( + Context context, @Raw(contentType = "application/vnd.my.custom") byte[] input) { + var client = RawInputOutputClient.fromContext(context); + client.rawInputWithCustomCt(input).await(); + } + + @Handler + public void rawInputWithCustomAccept( + Context context, + @Accept("application/*") @Raw(contentType = "application/vnd.my.custom") byte[] input) { + var client = RawInputOutputClient.fromContext(context); + client.rawInputWithCustomCt(input).await(); + } + } + @Override public Stream definitions() { return Stream.of( @@ -174,6 +214,53 @@ public Stream definitions() { Target.service("PrimitiveTypes", "primitiveInput"), CoreSerdes.JSON_INT, 10), outputMessage(), END_MESSAGE) - .named("primitive input")); + .named("primitive input"), + testInvocation(RawInputOutput::new, "rawInput") + .withInput( + startMessage(1), + inputMessage("{{".getBytes(StandardCharsets.UTF_8)), + completionMessage(1, CoreSerdes.VOID, null)) + .onlyUnbuffered() + .expectingOutput( + invokeMessage( + Target.service("RawInputOutput", "rawInput"), + "{{".getBytes(StandardCharsets.UTF_8)), + outputMessage(), + END_MESSAGE), + testInvocation(RawInputOutput::new, "rawInputWithCustomCt") + .withInput( + startMessage(1), + inputMessage("{{".getBytes(StandardCharsets.UTF_8)), + completionMessage(1, CoreSerdes.VOID, null)) + .onlyUnbuffered() + .expectingOutput( + invokeMessage( + Target.service("RawInputOutput", "rawInputWithCustomCt"), + "{{".getBytes(StandardCharsets.UTF_8)), + outputMessage(), + END_MESSAGE), + testInvocation(RawInputOutput::new, "rawOutput") + .withInput( + startMessage(1), + inputMessage(), + completionMessage(1, CoreSerdes.RAW, "{{".getBytes(StandardCharsets.UTF_8))) + .onlyUnbuffered() + .expectingOutput( + invokeMessage(Target.service("RawInputOutput", "rawOutput"), CoreSerdes.VOID, null), + outputMessage("{{".getBytes(StandardCharsets.UTF_8)), + END_MESSAGE), + testInvocation(RawInputOutput::new, "rawOutputWithCustomCT") + .withInput( + startMessage(1), + inputMessage(), + completionMessage(1, CoreSerdes.RAW, "{{".getBytes(StandardCharsets.UTF_8))) + .onlyUnbuffered() + .expectingOutput( + invokeMessage( + Target.service("RawInputOutput", "rawOutputWithCustomCT"), + CoreSerdes.VOID, + null), + outputMessage("{{".getBytes(StandardCharsets.UTF_8)), + END_MESSAGE)); } } diff --git a/sdk-api-gen/src/test/java/dev/restate/sdk/JavaCodegenTests.java b/sdk-api-gen/src/test/java/dev/restate/sdk/JavaCodegenTests.java index 039c56d2..291a6d93 100644 --- a/sdk-api-gen/src/test/java/dev/restate/sdk/JavaCodegenTests.java +++ b/sdk-api-gen/src/test/java/dev/restate/sdk/JavaCodegenTests.java @@ -8,12 +8,18 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk; +import static dev.restate.sdk.core.AssertUtils.assertThatDiscovery; +import static org.assertj.core.api.InstanceOfAssertFactories.type; + import dev.restate.sdk.core.MockMultiThreaded; import dev.restate.sdk.core.MockSingleThread; import dev.restate.sdk.core.TestDefinitions.TestExecutor; import dev.restate.sdk.core.TestDefinitions.TestSuite; import dev.restate.sdk.core.TestRunner; +import dev.restate.sdk.core.manifest.Input; +import dev.restate.sdk.core.manifest.Output; import java.util.stream.Stream; +import org.junit.jupiter.api.Test; public class JavaCodegenTests extends TestRunner { @@ -26,4 +32,34 @@ protected Stream executors() { public Stream definitions() { return Stream.of(new CodegenTest()); } + + @Test + void checkCustomInputContentType() { + assertThatDiscovery(new CodegenTest.RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInputWithCustomCt") + .extracting(dev.restate.sdk.core.manifest.Handler::getInput, type(Input.class)) + .extracting(Input::getContentType) + .isEqualTo("application/vnd.my.custom"); + } + + @Test + void checkCustomInputAcceptContentType() { + assertThatDiscovery(new CodegenTest.RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInputWithCustomAccept") + .extracting(dev.restate.sdk.core.manifest.Handler::getInput, type(Input.class)) + .extracting(Input::getContentType) + .isEqualTo("application/*"); + } + + @Test + void checkCustomOutputContentType() { + assertThatDiscovery(new CodegenTest.RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawOutputWithCustomCT") + .extracting(dev.restate.sdk.core.manifest.Handler::getOutput, type(Output.class)) + .extracting(Output::getContentType) + .isEqualTo("application/vnd.my.custom"); + } } diff --git a/sdk-api-gen/src/test/java/dev/restate/sdk/NameInferenceTest.java b/sdk-api-gen/src/test/java/dev/restate/sdk/NameInferenceTest.java index a8b49955..7fbdf76e 100644 --- a/sdk-api-gen/src/test/java/dev/restate/sdk/NameInferenceTest.java +++ b/sdk-api-gen/src/test/java/dev/restate/sdk/NameInferenceTest.java @@ -16,9 +16,10 @@ public class NameInferenceTest { @Test void expectedName() { - assertThat(CodegenTestServiceGreeterClient.SERVICE_NAME).isEqualTo("CodegenTestServiceGreeter"); - assertThat(GreeterWithoutExplicitNameClient.SERVICE_NAME) + assertThat(CodegenTestServiceGreeterDefinitions.SERVICE_NAME) + .isEqualTo("CodegenTestServiceGreeter"); + assertThat(GreeterWithoutExplicitNameDefinitions.SERVICE_NAME) .isEqualTo("GreeterWithoutExplicitName"); - assertThat(MyExplicitNameClient.SERVICE_NAME).isEqualTo("MyExplicitName"); + assertThat(MyExplicitNameDefinitions.SERVICE_NAME).isEqualTo("MyExplicitName"); } } diff --git a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt index 5c335f23..7c6031e0 100644 --- a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt +++ b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/KElementConverter.kt @@ -16,23 +16,31 @@ import com.google.devtools.ksp.processing.KSBuiltIns import com.google.devtools.ksp.processing.KSPLogger import com.google.devtools.ksp.symbol.* import com.google.devtools.ksp.visitor.KSDefaultVisitor +import dev.restate.sdk.annotation.Accept +import dev.restate.sdk.annotation.Json +import dev.restate.sdk.annotation.Raw import dev.restate.sdk.common.ServiceType import dev.restate.sdk.gen.model.Handler import dev.restate.sdk.gen.model.HandlerType import dev.restate.sdk.gen.model.PayloadType import dev.restate.sdk.gen.model.Service +import dev.restate.sdk.gen.utils.AnnotationUtils.getAnnotationDefaultValue import dev.restate.sdk.kotlin.Context import dev.restate.sdk.kotlin.ObjectContext import dev.restate.sdk.kotlin.SharedObjectContext import java.util.regex.Pattern import kotlin.reflect.KClass -class KElementConverter(private val logger: KSPLogger, private val builtIns: KSBuiltIns) : - KSDefaultVisitor() { +class KElementConverter( + private val logger: KSPLogger, + private val builtIns: KSBuiltIns, + private val byteArrayType: KSType +) : KSDefaultVisitor() { companion object { private val SUPPORTED_CLASS_KIND: Set = setOf(ClassKind.CLASS, ClassKind.INTERFACE) private val EMPTY_PAYLOAD: PayloadType = PayloadType(true, "", "Unit", "dev.restate.sdk.kotlin.KtSerdes.UNIT") + private const val RAW_SERDE: String = "dev.restate.sdk.common.CoreSerdes.RAW" } override fun defaultHandler(node: KSNode, data: Service.Builder) {} @@ -157,18 +165,88 @@ class KElementConverter(private val logger: KSPLogger, private val builtIns: KSB handlerBuilder .withName(function.simpleName.asString()) .withHandlerType(handlerType) - .withInputType( - if (function.parameters.size == 2) payloadFromType(function.parameters[1].type) - else EMPTY_PAYLOAD) - .withOutputType( - if (function.returnType != null) payloadFromType(function.returnType!!) - else EMPTY_PAYLOAD) + .withInputAccept(inputAcceptFromParameterList(function.parameters)) + .withInputType(inputPayloadFromParameterList(function.parameters)) + .withOutputType(outputPayloadFromExecutableElement(function)) .validateAndBuild()) } catch (e: Exception) { logger.error("Error when building handler: $e", function) } } + @OptIn(KspExperimental::class) + private fun inputAcceptFromParameterList(paramList: List): String? { + if (paramList.size <= 1) { + return null + } + + return paramList[1].getAnnotationsByType(Accept::class).firstOrNull()?.value + } + + @OptIn(KspExperimental::class) + private fun inputPayloadFromParameterList(paramList: List): PayloadType { + if (paramList.size <= 1) { + return EMPTY_PAYLOAD + } + + val parameterElement: KSValueParameter = paramList[1] + return payloadFromTypeMirrorAndAnnotations( + parameterElement.type.resolve(), + parameterElement.getAnnotationsByType(Json::class).firstOrNull(), + parameterElement.getAnnotationsByType(Raw::class).firstOrNull(), + parameterElement) + } + + @OptIn(KspExperimental::class) + private fun outputPayloadFromExecutableElement(fn: KSFunctionDeclaration): PayloadType { + return payloadFromTypeMirrorAndAnnotations( + fn.returnType?.resolve() ?: builtIns.unitType, + fn.getAnnotationsByType(Json::class).firstOrNull(), + fn.getAnnotationsByType(Raw::class).firstOrNull(), + fn) + } + + private fun payloadFromTypeMirrorAndAnnotations( + ty: KSType, + jsonAnnotation: Json?, + rawAnnotation: Raw?, + relatedNode: KSNode + ): PayloadType { + if (ty == builtIns.unitType) { + if (rawAnnotation != null || jsonAnnotation != null) { + logger.error("Unexpected annotation for void type.", relatedNode) + } + return EMPTY_PAYLOAD + } + // Some validation + if (rawAnnotation != null && jsonAnnotation != null) { + logger.error("A parameter cannot be annotated both with @Raw and @Json.", relatedNode) + } + if (rawAnnotation != null && ty != byteArrayType) { + logger.error("A parameter annotated with @Raw MUST be of type byte[], was $ty", relatedNode) + } + + var serdeDecl: String = if (rawAnnotation != null) RAW_SERDE else jsonSerdeDecl(ty) + if (rawAnnotation != null && + rawAnnotation.contentType != getAnnotationDefaultValue(Raw::class.java, "contentType")) { + serdeDecl = contentTypeDecoratedSerdeDecl(serdeDecl, rawAnnotation.contentType) + } + if (jsonAnnotation != null && + jsonAnnotation.contentType != getAnnotationDefaultValue(Json::class.java, "contentType")) { + serdeDecl = contentTypeDecoratedSerdeDecl(serdeDecl, jsonAnnotation.contentType) + } + + return PayloadType(false, ty.toString(), boxedType(ty), serdeDecl) + } + + private fun contentTypeDecoratedSerdeDecl(serdeDecl: String, contentType: String): String { + return ("dev.restate.sdk.common.Serde.withContentType(\"" + + contentType + + "\", " + + serdeDecl + + ")") + } + private fun defaultHandlerType(serviceType: ServiceType, node: KSNode): HandlerType { when (serviceType) { ServiceType.SERVICE -> return HandlerType.STATELESS @@ -222,12 +300,7 @@ class KElementConverter(private val logger: KSPLogger, private val builtIns: KSB } } - private fun payloadFromType(typeRef: KSTypeReference): PayloadType { - val ty = typeRef.resolve() - return PayloadType(false, typeRef.toString(), boxedType(ty), serdeDecl(ty)) - } - - private fun serdeDecl(ty: KSType): String { + private fun jsonSerdeDecl(ty: KSType): String { return when (ty) { builtIns.unitType -> "dev.restate.sdk.kotlin.KtSerdes.UNIT" else -> "dev.restate.sdk.kotlin.KtSerdes.json<${boxedType(ty)}>()" diff --git a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt index 12ce57f0..975a3e8e 100644 --- a/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt +++ b/sdk-api-kotlin-gen/src/main/kotlin/dev/restate/sdk/kotlin/gen/ServiceProcessor.kt @@ -9,7 +9,9 @@ package dev.restate.sdk.kotlin.gen import com.github.jknack.handlebars.io.ClassPathTemplateLoader +import com.google.devtools.ksp.KspExperimental import com.google.devtools.ksp.containingFile +import com.google.devtools.ksp.getKotlinClassByName import com.google.devtools.ksp.processing.* import com.google.devtools.ksp.symbol.KSAnnotated import com.google.devtools.ksp.symbol.Origin @@ -53,9 +55,22 @@ class ServiceProcessor(private val logger: KSPLogger, private val codeGenerator: ServiceType.SERVICE to "templates/Client", ServiceType.VIRTUAL_OBJECT to "templates/Client"), RESERVED_METHOD_NAMES) + private val definitionsCodegen: HandlebarsTemplateEngine = + HandlebarsTemplateEngine( + "Definitions", + ClassPathTemplateLoader(), + mapOf( + ServiceType.SERVICE to "templates/Definitions", + ServiceType.VIRTUAL_OBJECT to "templates/Definitions"), + RESERVED_METHOD_NAMES) + @OptIn(KspExperimental::class) override fun process(resolver: Resolver): List { - val converter = KElementConverter(logger, resolver.builtIns) + val converter = + KElementConverter( + logger, + resolver.builtIns, + resolver.getKotlinClassByName(ByteArray::class.qualifiedName!!)!!.asType(listOf())) val resolved = resolver @@ -101,6 +116,7 @@ class ServiceProcessor(private val logger: KSPLogger, private val codeGenerator: this.bindableServiceFactoryCodegen.generate(fileCreator, service.second) this.bindableServiceCodegen.generate(fileCreator, service.second) this.clientCodegen.generate(fileCreator, service.second) + this.definitionsCodegen.generate(fileCreator, service.second) } catch (ex: Throwable) { throw RuntimeException(ex) } diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/BindableService.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/BindableService.hbs index bdbb106d..5e01053b 100644 --- a/sdk-api-kotlin-gen/src/main/resources/templates/BindableService.hbs +++ b/sdk-api-kotlin-gen/src/main/resources/templates/BindableService.hbs @@ -2,26 +2,33 @@ class {{generatedClassSimpleName}}( bindableService: {{originalClassFqcn}}, - options: dev.restate.sdk.kotlin.Service.Options = dev.restate.sdk.kotlin.Service.Options.DEFAULT + private val options: dev.restate.sdk.kotlin.Service.Options = dev.restate.sdk.kotlin.Service.Options.DEFAULT ): dev.restate.sdk.common.BindableService { - companion object { - const val SERVICE_NAME: String = "{{serviceName}}"; - } - - val service: dev.restate.sdk.kotlin.Service = dev.restate.sdk.kotlin.Service.{{#if isObject}}virtualObject{{else}}service{{/if}}(SERVICE_NAME, options) { - {{#handlers}} - {{#if isShared}}sharedHandler{{else if isExclusive}}exclusiveHandler{{else}}handler{{/if}}(dev.restate.sdk.kotlin.Service.HandlerSignature("{{name}}", {{{inputSerdeDecl}}}, {{{outputSerdeDecl}}})) { ctx, req -> - {{#if inputEmpty}}bindableService.{{name}}(ctx){{else}}bindableService.{{name}}(ctx, req){{/if}} - } - {{/handlers}} - } + val service: dev.restate.sdk.common.syscalls.ServiceDefinition = + dev.restate.sdk.common.syscalls.ServiceDefinition.of( + {{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, + {{#if isObject}}dev.restate.sdk.common.ServiceType.VIRTUAL_OBJECT{{else}}dev.restate.sdk.common.ServiceType.SERVICE{{/if}}, + listOf( + {{#handlers}} + dev.restate.sdk.common.syscalls.HandlerDefinition.of( + dev.restate.sdk.common.syscalls.HandlerSpecification.of( + "{{name}}", + {{#if isExclusive}}dev.restate.sdk.common.HandlerType.EXCLUSIVE{{else}}dev.restate.sdk.common.HandlerType.SHARED{{/if}}, + {{inputSerdeRef}}, + {{outputSerdeRef}} + ){{#if inputAcceptContentType}}.withAcceptContentType("{{inputAcceptContentType}}"){{/if}}, + dev.restate.sdk.kotlin.Service.Handler.of(bindableService::{{name}}) + ){{#unless @last}},{{/unless}} + {{/handlers}} + ) + ) override fun options(): dev.restate.sdk.kotlin.Service.Options { - return service.options() + return this.options } override fun definitions(): List> { - return service.definitions() + return listOf(service) } } \ No newline at end of file diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/BindableServiceFactory.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/BindableServiceFactory.hbs index c6d64bd9..bf715177 100644 --- a/sdk-api-kotlin-gen/src/main/resources/templates/BindableServiceFactory.hbs +++ b/sdk-api-kotlin-gen/src/main/resources/templates/BindableServiceFactory.hbs @@ -4,10 +4,6 @@ import dev.restate.sdk.kotlin.Service.Options class {{generatedClassSimpleName}}: dev.restate.sdk.common.BindableServiceFactory<{{originalClassFqcn}}, Options> { - companion object { - const val SERVICE_NAME: String = "{{serviceName}}"; - } - override fun create(bindableService: {{originalClassFqcn}}): dev.restate.sdk.common.BindableService { return {{generatedClassSimpleNamePrefix}}BindableService(bindableService); } diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs index 21762586..5d5bfd4a 100644 --- a/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs +++ b/sdk-api-kotlin-gen/src/main/resources/templates/Client.hbs @@ -11,13 +11,6 @@ import dev.restate.sdk.kotlin.sendSuspend object {{generatedClassSimpleName}} { - const val SERVICE_NAME: String = "{{serviceName}}" - - {{#handlers}} - private val {{inputSerdeFieldName}}: Serde<{{{boxedInputFqcn}}}> = {{{inputSerdeDecl}}} - private val {{outputSerdeFieldName}}: Serde<{{{boxedOutputFqcn}}}> = {{{outputSerdeDecl}}} - {{/handlers}} - fun fromContext(ctx: Context{{#isObject}}, key: String{{/isObject}}): ContextClient { return ContextClient(ctx{{#isObject}}, key{{/isObject}}) } @@ -34,9 +27,9 @@ object {{generatedClassSimpleName}} { {{#handlers}} suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}{{/inputEmpty}}): Awaitable<{{{boxedOutputFqcn}}}> { return this.ctx.callAsync( - {{#if isObject}}Target.virtualObject(SERVICE_NAME, this.key, "{{name}}"){{else}}Target.service(SERVICE_NAME, "{{name}}"){{/if}}, - {{inputSerdeFieldName}}, - {{outputSerdeFieldName}}, + {{#if isObject}}Target.virtualObject({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, this.key, "{{name}}"){{else}}Target.service({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, "{{name}}"){{/if}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, {{#if inputEmpty}}Unit{{else}}req{{/if}}) }{{/handlers}} @@ -48,8 +41,8 @@ object {{generatedClassSimpleName}} { {{#handlers}} suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}{{/inputEmpty}}) { this@ContextClient.ctx.send( - {{#if isObject}}Target.virtualObject(SERVICE_NAME, this@ContextClient.key, "{{name}}"){{else}}Target.service(SERVICE_NAME, "{{name}}"){{/if}}, - {{inputSerdeFieldName}}, + {{#if isObject}}Target.virtualObject({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, this@ContextClient.key, "{{name}}"){{else}}Target.service({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, "{{name}}"){{/if}}, + {{inputSerdeRef}}, {{#if inputEmpty}}Unit{{else}}req{{/if}}, delay); }{{/handlers}} @@ -61,9 +54,9 @@ object {{generatedClassSimpleName}} { {{#handlers}} suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): {{{boxedOutputFqcn}}} { return this.ingressClient.callSuspend( - {{#if isObject}}Target.virtualObject(SERVICE_NAME, this.key, "{{name}}"){{else}}Target.service(SERVICE_NAME, "{{name}}"){{/if}}, - {{inputSerdeFieldName}}, - {{outputSerdeFieldName}}, + {{#if isObject}}Target.virtualObject({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, this.key, "{{name}}"){{else}}Target.service({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, "{{name}}"){{/if}}, + {{inputSerdeRef}}, + {{outputSerdeRef}}, {{#if inputEmpty}}Unit{{else}}req{{/if}}, requestOptions); }{{/handlers}} @@ -76,8 +69,8 @@ object {{generatedClassSimpleName}} { {{#handlers}} suspend fun {{methodName}}({{^inputEmpty}}req: {{{inputFqcn}}}, {{/inputEmpty}}requestOptions: dev.restate.sdk.client.RequestOptions = dev.restate.sdk.client.RequestOptions.DEFAULT): String { return this@IngressClient.ingressClient.sendSuspend( - {{#if isObject}}Target.virtualObject(SERVICE_NAME, this@IngressClient.key, "{{name}}"){{else}}Target.service(SERVICE_NAME, "{{name}}"){{/if}}, - {{inputSerdeFieldName}}, + {{#if isObject}}Target.virtualObject({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, this@IngressClient.key, "{{name}}"){{else}}Target.service({{generatedClassSimpleNamePrefix}}Definitions.SERVICE_NAME, "{{name}}"){{/if}}, + {{inputSerdeRef}}, {{#if inputEmpty}}Unit{{else}}req{{/if}}, delay, requestOptions); diff --git a/sdk-api-kotlin-gen/src/main/resources/templates/Definitions.hbs b/sdk-api-kotlin-gen/src/main/resources/templates/Definitions.hbs new file mode 100644 index 00000000..1da7a7fe --- /dev/null +++ b/sdk-api-kotlin-gen/src/main/resources/templates/Definitions.hbs @@ -0,0 +1,14 @@ +{{#if originalClassPkg}}package {{originalClassPkg}}{{/if}} + +object {{generatedClassSimpleName}} { + + const val SERVICE_NAME: String = "{{serviceName}}" + + object Serde { + {{#handlers}} + val {{inputSerdeFieldName}}: dev.restate.sdk.common.Serde<{{{boxedInputFqcn}}}> = {{{inputSerdeDecl}}} + val {{outputSerdeFieldName}}: dev.restate.sdk.common.Serde<{{{boxedOutputFqcn}}}> = {{{outputSerdeDecl}}} + {{/handlers}} + } + +} \ No newline at end of file diff --git a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt b/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt index d2c2f5b4..cb8051bf 100644 --- a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt +++ b/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt @@ -98,6 +98,47 @@ class CodegenTest : TestDefinitions.TestSuite { } } + @Service(name = "RawInputOutput") + class RawInputOutput { + @Handler + @Raw + suspend fun rawOutput(context: Context): ByteArray { + val client: RawInputOutputClient.ContextClient = RawInputOutputClient.fromContext(context) + return client.rawOutput().await() + } + + @Handler + @Raw(contentType = "application/vnd.my.custom") + suspend fun rawOutputWithCustomCT(context: Context): ByteArray { + val client: RawInputOutputClient.ContextClient = RawInputOutputClient.fromContext(context) + return client.rawOutputWithCustomCT().await() + } + + @Handler + suspend fun rawInput(context: Context, @Raw input: ByteArray) { + val client: RawInputOutputClient.ContextClient = RawInputOutputClient.fromContext(context) + client.rawInput(input).await() + } + + @Handler + suspend fun rawInputWithCustomCt( + context: Context, + @Raw(contentType = "application/vnd.my.custom") input: ByteArray + ) { + val client: RawInputOutputClient.ContextClient = RawInputOutputClient.fromContext(context) + client.rawInputWithCustomCt(input).await() + } + + @Handler + suspend fun rawInputWithCustomAccept( + context: Context, + @Accept("application/*") @Raw(contentType = "application/vnd.my.custom") input: ByteArray + ) { + val client: RawInputOutputClient.ContextClient = RawInputOutputClient.fromContext(context) + client.rawInputWithCustomCt(input).await() + } + } + override fun definitions(): Stream { return Stream.of( testInvocation({ ServiceGreeter() }, "greet") @@ -165,6 +206,48 @@ class CodegenTest : TestDefinitions.TestSuite { Target.service("PrimitiveTypes", "primitiveInput"), CoreSerdes.JSON_INT, 10), outputMessage(), END_MESSAGE) - .named("primitive input")) + .named("primitive input"), + testInvocation({ RawInputOutput() }, "rawInput") + .withInput( + startMessage(1), + inputMessage("{{".toByteArray()), + completionMessage(1, KtSerdes.UNIT, null)) + .onlyUnbuffered() + .expectingOutput( + invokeMessage(Target.service("RawInputOutput", "rawInput"), "{{".toByteArray()), + outputMessage(), + END_MESSAGE), + testInvocation({ RawInputOutput() }, "rawInputWithCustomCt") + .withInput( + startMessage(1), + inputMessage("{{".toByteArray()), + completionMessage(1, KtSerdes.UNIT, null)) + .onlyUnbuffered() + .expectingOutput( + invokeMessage( + Target.service("RawInputOutput", "rawInputWithCustomCt"), "{{".toByteArray()), + outputMessage(), + END_MESSAGE), + testInvocation({ RawInputOutput() }, "rawOutput") + .withInput( + startMessage(1), + inputMessage(), + completionMessage(1, CoreSerdes.RAW, "{{".toByteArray())) + .onlyUnbuffered() + .expectingOutput( + invokeMessage(Target.service("RawInputOutput", "rawOutput"), KtSerdes.UNIT, null), + outputMessage("{{".toByteArray()), + END_MESSAGE), + testInvocation({ RawInputOutput() }, "rawOutputWithCustomCT") + .withInput( + startMessage(1), + inputMessage(), + completionMessage(1, CoreSerdes.RAW, "{{".toByteArray())) + .onlyUnbuffered() + .expectingOutput( + invokeMessage( + Target.service("RawInputOutput", "rawOutputWithCustomCT"), KtSerdes.UNIT, null), + outputMessage("{{".toByteArray()), + END_MESSAGE)) } } diff --git a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/KtCodegenTests.kt b/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/KtCodegenTests.kt index f76303ae..3dbde7a4 100644 --- a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/KtCodegenTests.kt +++ b/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/KtCodegenTests.kt @@ -8,12 +8,17 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.kotlin +import dev.restate.sdk.core.AssertUtils.assertThatDiscovery import dev.restate.sdk.core.MockMultiThreaded import dev.restate.sdk.core.MockSingleThread import dev.restate.sdk.core.TestDefinitions import dev.restate.sdk.core.TestDefinitions.TestExecutor import dev.restate.sdk.core.TestRunner +import dev.restate.sdk.core.manifest.Input +import dev.restate.sdk.core.manifest.Output import java.util.stream.Stream +import org.assertj.core.api.InstanceOfAssertFactories.type +import org.junit.jupiter.api.Test class KtCodegenTests : TestRunner() { override fun executors(): Stream { @@ -23,4 +28,34 @@ class KtCodegenTests : TestRunner() { public override fun definitions(): Stream { return Stream.of(CodegenTest()) } + + @Test + fun checkCustomInputContentType() { + assertThatDiscovery(CodegenTest.RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInputWithCustomCt") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo("application/vnd.my.custom") + } + + @Test + fun checkCustomInputAcceptContentType() { + assertThatDiscovery(CodegenTest.RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawInputWithCustomAccept") + .extracting({ it.input }, type(Input::class.java)) + .extracting { it.contentType } + .isEqualTo("application/*") + } + + @Test + fun checkCustomOutputContentType() { + assertThatDiscovery(CodegenTest.RawInputOutput()) + .extractingService("RawInputOutput") + .extractingHandler("rawOutputWithCustomCT") + .extracting({ it.output }, type(Output::class.java)) + .extracting { it.contentType } + .isEqualTo("application/vnd.my.custom") + } } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Service.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Service.kt index 4a144af7..101eb3a3 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Service.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/Service.kt @@ -23,14 +23,14 @@ class Service private constructor( fqsn: String, isKeyed: Boolean, - handlers: Map>, + handlers: Map>, private val options: Options ) : BindableService { private val serviceDefinition = - ServiceDefinition( + ServiceDefinition.of( fqsn, if (isKeyed) ServiceType.VIRTUAL_OBJECT else ServiceType.SERVICE, - handlers.values.map { obj: Handler<*, *, *> -> obj.toHandlerDefinition() }) + handlers.values.toList()) override fun options(): Options { return this.options @@ -61,76 +61,90 @@ private constructor( } class VirtualObjectBuilder internal constructor(private val name: String) { - private val handlers: MutableMap> = mutableMapOf() + private val handlers: MutableMap> = mutableMapOf() fun sharedHandler( - sig: HandlerSignature, - runner: suspend (ObjectContext, REQ) -> RES + name: String, + requestSerde: Serde, + responseSerde: Serde, + runner: suspend (SharedObjectContext, REQ) -> RES ): VirtualObjectBuilder { - handlers[sig.name] = Handler(sig, HandlerType.SHARED, runner) + handlers[name] = + HandlerDefinition.of( + HandlerSpecification.of(name, HandlerType.SHARED, requestSerde, responseSerde), + Handler(runner)) return this } inline fun sharedHandler( name: String, - noinline runner: suspend (ObjectContext, REQ) -> RES - ) = this.sharedHandler(HandlerSignature(name, KtSerdes.json(), KtSerdes.json()), runner) + noinline runner: suspend (SharedObjectContext, REQ) -> RES + ) = this.sharedHandler(name, KtSerdes.json(), KtSerdes.json(), runner) fun exclusiveHandler( - sig: HandlerSignature, + name: String, + requestSerde: Serde, + responseSerde: Serde, runner: suspend (ObjectContext, REQ) -> RES ): VirtualObjectBuilder { - handlers[sig.name] = Handler(sig, HandlerType.EXCLUSIVE, runner) + handlers[name] = + HandlerDefinition.of( + HandlerSpecification.of(name, HandlerType.EXCLUSIVE, requestSerde, responseSerde), + Handler(runner)) return this } inline fun exclusiveHandler( name: String, noinline runner: suspend (ObjectContext, REQ) -> RES - ) = this.exclusiveHandler(HandlerSignature(name, KtSerdes.json(), KtSerdes.json()), runner) + ) = this.exclusiveHandler(name, KtSerdes.json(), KtSerdes.json(), runner) fun build(options: Options) = Service(this.name, true, this.handlers, options) } class ServiceBuilder internal constructor(private val name: String) { - private val handlers: MutableMap> = mutableMapOf() + private val handlers: MutableMap> = mutableMapOf() fun handler( - sig: HandlerSignature, - runner: suspend (Context, REQ) -> RES + name: String, + requestSerde: Serde, + responseSerde: Serde, + runner: suspend (SharedObjectContext, REQ) -> RES ): ServiceBuilder { - handlers[sig.name] = Handler(sig, HandlerType.SHARED, runner) + handlers[name] = + HandlerDefinition.of( + HandlerSpecification.of(name, HandlerType.SHARED, requestSerde, responseSerde), + Handler(runner)) return this } inline fun handler( name: String, noinline runner: suspend (Context, REQ) -> RES - ) = this.handler(HandlerSignature(name, KtSerdes.json(), KtSerdes.json()), runner) + ) = this.handler(name, KtSerdes.json(), KtSerdes.json(), runner) fun build(options: Options) = Service(this.name, false, this.handlers, options) } - class Handler( - private val handlerSignature: HandlerSignature, - private val handlerType: HandlerType, + class Handler + internal constructor( private val runner: suspend (CTX, REQ) -> RES, - ) : InvocationHandler { + ) : InvocationHandler { companion object { private val LOG = LogManager.getLogger() - } - fun toHandlerDefinition() = - HandlerDefinition( - handlerSignature.name, - handlerType, - handlerSignature.requestSerde.contentType() != null, - handlerSignature.requestSerde.contentType(), - handlerSignature.responseSerde.contentType(), - this) + fun of(runner: suspend (CTX, REQ) -> RES): Handler { + return Handler(runner) + } + + fun of(runner: suspend (CTX) -> RES): Handler { + return Handler { ctx: CTX, _: Unit -> runner(ctx) } + } + } override fun handle( + handlerSpecification: HandlerSpecification, syscalls: Syscalls, options: Options, callback: SyscallCallback @@ -149,7 +163,7 @@ private constructor( // Parse input val req: REQ try { - req = handlerSignature.requestSerde.deserialize(syscalls.request().bodyBuffer()) + req = handlerSpecification.requestSerde.deserialize(syscalls.request().bodyBuffer()) } catch (e: Error) { throw e } catch (e: Throwable) { @@ -163,7 +177,7 @@ private constructor( // Serialize output try { - serializedResult = handlerSignature.responseSerde.serializeToByteString(res) + serializedResult = handlerSpecification.responseSerde.serializeToByteString(res) } catch (e: Error) { throw e } catch (e: Throwable) { @@ -182,12 +196,6 @@ private constructor( } } - class HandlerSignature( - val name: String, - val requestSerde: Serde, - val responseSerde: Serde - ) - class Options(val coroutineContext: CoroutineContext) { companion object { val DEFAULT: Options = Options(Dispatchers.Default) diff --git a/sdk-api/src/main/java/dev/restate/sdk/Service.java b/sdk-api/src/main/java/dev/restate/sdk/Service.java index 916244b4..396fb7c0 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Service.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Service.java @@ -16,8 +16,7 @@ import java.util.*; import java.util.concurrent.Executor; import java.util.concurrent.Executors; -import java.util.function.BiFunction; -import java.util.stream.Collectors; +import java.util.function.*; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -26,14 +25,15 @@ public final class Service implements BindableService { private final Service.Options options; private Service( - String fqsn, boolean isKeyed, HashMap> handlers, Options options) { + String fqsn, + boolean isKeyed, + HashMap> handlers, + Options options) { this.serviceDefinition = - new ServiceDefinition<>( + ServiceDefinition.of( fqsn, isKeyed ? ServiceType.VIRTUAL_OBJECT : ServiceType.SERVICE, - handlers.values().stream() - .map(Handler::toHandlerDefinition) - .collect(Collectors.toList())); + new ArrayList<>(handlers.values())); this.options = options; } @@ -58,7 +58,7 @@ public static VirtualObjectBuilder virtualObject(String name) { public static class AbstractServiceBuilder { protected final String name; - protected final HashMap> handlers; + protected final HashMap> handlers; public AbstractServiceBuilder(String name) { this.name = name; @@ -73,14 +73,28 @@ public static class VirtualObjectBuilder extends AbstractServiceBuilder { } public VirtualObjectBuilder withShared( - HandlerSignature sig, BiFunction runner) { - this.handlers.put(sig.getName(), new Handler<>(sig, HandlerType.SHARED, runner)); + String name, + Serde requestSerde, + Serde responseSerde, + BiFunction runner) { + this.handlers.put( + name, + HandlerDefinition.of( + HandlerSpecification.of(name, HandlerType.SHARED, requestSerde, responseSerde), + new Handler<>(runner))); return this; } public VirtualObjectBuilder withExclusive( - HandlerSignature sig, BiFunction runner) { - this.handlers.put(sig.getName(), new Handler<>(sig, HandlerType.EXCLUSIVE, runner)); + String name, + Serde requestSerde, + Serde responseSerde, + BiFunction runner) { + this.handlers.put( + name, + HandlerDefinition.of( + HandlerSpecification.of(name, HandlerType.EXCLUSIVE, requestSerde, responseSerde), + new Handler<>(runner))); return this; } @@ -96,8 +110,15 @@ public static class ServiceBuilder extends AbstractServiceBuilder { } public ServiceBuilder with( - HandlerSignature sig, BiFunction runner) { - this.handlers.put(sig.getName(), new Handler<>(sig, HandlerType.SHARED, runner)); + String name, + Serde requestSerde, + Serde responseSerde, + BiFunction runner) { + this.handlers.put( + name, + HandlerDefinition.of( + HandlerSpecification.of(name, HandlerType.SHARED, requestSerde, responseSerde), + new Handler<>(runner))); return this; } @@ -106,44 +127,26 @@ public Service build(Service.Options options) { } } - public static class Handler implements InvocationHandler { - private final HandlerSignature handlerSignature; - private final HandlerType handlerType; + public static class Handler implements InvocationHandler { private final BiFunction runner; private static final Logger LOG = LogManager.getLogger(Handler.class); - public Handler( - HandlerSignature handlerSignature, - HandlerType handlerType, - BiFunction runner) { - this.handlerSignature = handlerSignature; - this.handlerType = handlerType; + Handler(BiFunction runner) { //noinspection unchecked this.runner = (BiFunction) runner; } - public HandlerSignature getHandlerSignature() { - return handlerSignature; - } - public BiFunction getRunner() { return runner; } - public HandlerDefinition toHandlerDefinition() { - return new HandlerDefinition<>( - this.handlerSignature.name, - this.handlerType, - this.handlerSignature.requestSerde.contentType() != null, - this.handlerSignature.requestSerde.contentType(), - this.handlerSignature.responseSerde.contentType(), - this); - } - @Override public void handle( - Syscalls syscalls, Service.Options options, SyscallCallback callback) { + HandlerSpecification handlerSpecification, + Syscalls syscalls, + Service.Options options, + SyscallCallback callback) { // Wrap the executor for setting/unsetting the thread local Executor wrapped = runnable -> @@ -164,7 +167,10 @@ public void handle( // Parse input REQ req; try { - req = this.handlerSignature.requestSerde.deserialize(syscalls.request().bodyBuffer()); + req = + handlerSpecification + .getRequestSerde() + .deserialize(syscalls.request().bodyBuffer()); } catch (Error e) { throw e; } catch (Throwable e) { @@ -190,7 +196,7 @@ public void handle( // Serialize output ByteString serializedResult; try { - serializedResult = this.handlerSignature.responseSerde.serializeToByteString(res); + serializedResult = handlerSpecification.getResponseSerde().serializeToByteString(res); } catch (Error e) { throw e; } catch (Throwable e) { @@ -206,35 +212,33 @@ public void handle( callback.onSuccess(serializedResult); }); } - } - - public static class HandlerSignature { - - private final String name; - private final Serde requestSerde; - private final Serde responseSerde; - - HandlerSignature(String name, Serde requestSerde, Serde responseSerde) { - this.name = name; - this.requestSerde = requestSerde; - this.responseSerde = responseSerde; - } - public static HandlerSignature of( - String method, Serde requestSerde, Serde responseSerde) { - return new HandlerSignature<>(method, requestSerde, responseSerde); + public static Handler of( + BiFunction runner) { + return new Handler<>(runner); } - public String getName() { - return name; + @SuppressWarnings("unchecked") + public static Handler of(Function runner) { + return new Handler<>((context, o) -> runner.apply((CTX) context)); } - public Serde getRequestSerde() { - return requestSerde; + @SuppressWarnings("unchecked") + public static Handler of(BiConsumer runner) { + return new Handler<>( + (context, o) -> { + runner.accept((CTX) context, o); + return null; + }); } - public Serde getResponseSerde() { - return responseSerde; + @SuppressWarnings("unchecked") + public static Handler of(Consumer runner) { + return new Handler<>( + (ctx, o) -> { + runner.accept((CTX) ctx); + return null; + }); } } diff --git a/sdk-api/src/test/java/dev/restate/sdk/JavaBlockingTests.java b/sdk-api/src/test/java/dev/restate/sdk/JavaBlockingTests.java index 82d0e023..97a83028 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/JavaBlockingTests.java +++ b/sdk-api/src/test/java/dev/restate/sdk/JavaBlockingTests.java @@ -49,7 +49,7 @@ public static TestInvocationBuilder testDefinitionForService( String name, Serde reqSerde, Serde resSerde, BiFunction runner) { return TestDefinitions.testInvocation( Service.service(name) - .with(Service.HandlerSignature.of("run", reqSerde, resSerde), runner) + .with("run", reqSerde, resSerde, runner) .build(Service.Options.DEFAULT), "run"); } @@ -58,7 +58,7 @@ public static TestInvocationBuilder testDefinitionForVirtualObject( String name, Serde reqSerde, Serde resSerde, BiFunction runner) { return TestDefinitions.testInvocation( Service.virtualObject(name) - .withExclusive(Service.HandlerSignature.of("run", reqSerde, resSerde), runner) + .withExclusive("run", reqSerde, resSerde, runner) .build(Service.Options.DEFAULT), "run"); } diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Accept.java b/sdk-common/src/main/java/dev/restate/sdk/annotation/Accept.java new file mode 100644 index 00000000..d905f8da --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/annotation/Accept.java @@ -0,0 +1,20 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.SOURCE) +public @interface Accept { + String value(); +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Json.java b/sdk-common/src/main/java/dev/restate/sdk/annotation/Json.java new file mode 100644 index 00000000..df02468f --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/annotation/Json.java @@ -0,0 +1,21 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target({ElementType.METHOD, ElementType.PARAMETER}) +@Retention(RetentionPolicy.SOURCE) +public @interface Json { + /** Content-type to use in request/responses. */ + String contentType() default "application/json"; +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/annotation/Raw.java b/sdk-common/src/main/java/dev/restate/sdk/annotation/Raw.java new file mode 100644 index 00000000..fa548492 --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/annotation/Raw.java @@ -0,0 +1,21 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target({ElementType.METHOD, ElementType.PARAMETER}) +@Retention(RetentionPolicy.SOURCE) +public @interface Raw { + /** Content-type to use in request/responses. */ + String contentType() default "application/octet-stream"; +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerDefinition.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerDefinition.java index 0381dc04..0cca6200 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerDefinition.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerDefinition.java @@ -8,67 +8,49 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.common.syscalls; -import dev.restate.sdk.common.HandlerType; import java.util.Objects; -import org.jspecify.annotations.Nullable; -public final class HandlerDefinition { - private final String name; - private final HandlerType handlerType; - private final boolean inputRequired; - private final @Nullable String acceptInputContentType; - private final @Nullable String returnedContentType; - private final InvocationHandler handler; +public final class HandlerDefinition { - public HandlerDefinition( - String name, - HandlerType handlerType, - boolean inputRequired, - @Nullable String acceptInputContentType, - @Nullable String returnedContentType, - InvocationHandler handler) { - this.name = name; - this.handlerType = handlerType; - this.inputRequired = inputRequired; - this.acceptInputContentType = acceptInputContentType; - this.returnedContentType = returnedContentType; - this.handler = handler; - } + private final HandlerSpecification spec; + private final InvocationHandler handler; - public String getName() { - return name; + HandlerDefinition(HandlerSpecification spec, InvocationHandler handler) { + this.spec = spec; + this.handler = handler; } - public HandlerType getHandlerType() { - return handlerType; + public HandlerSpecification getSpec() { + return spec; } - public boolean isInputRequired() { - return inputRequired; + public InvocationHandler getHandler() { + return handler; } - public String getAcceptInputContentType() { - return acceptInputContentType; - } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; - public String getReturnedContentType() { - return returnedContentType; + HandlerDefinition that = (HandlerDefinition) o; + return Objects.equals(spec, that.spec) && Objects.equals(handler, that.handler); } - public InvocationHandler getHandler() { - return handler; + @Override + public int hashCode() { + int result = Objects.hashCode(spec); + result = 31 * result + Objects.hashCode(handler); + return result; } @Override - public boolean equals(Object object) { - if (this == object) return true; - if (object == null || getClass() != object.getClass()) return false; - HandlerDefinition that = (HandlerDefinition) object; - return Objects.equals(name, that.name) && Objects.equals(handler, that.handler); + public String toString() { + return "HandlerDefinition{" + "spec=" + spec + ", handler=" + handler + '}'; } - @Override - public int hashCode() { - return Objects.hash(name, handler); + public static HandlerDefinition of( + HandlerSpecification spec, InvocationHandler handler) { + return new HandlerDefinition<>(spec, handler); } } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerSpecification.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerSpecification.java new file mode 100644 index 00000000..73ab9b9e --- /dev/null +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/HandlerSpecification.java @@ -0,0 +1,107 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk.common.syscalls; + +import dev.restate.sdk.common.HandlerType; +import dev.restate.sdk.common.Serde; +import java.util.Objects; +import org.jspecify.annotations.Nullable; + +public final class HandlerSpecification { + + private final String name; + private final HandlerType handlerType; + private final @Nullable String acceptContentType; + private final Serde requestSerde; + private final Serde responseSerde; + + HandlerSpecification( + String name, + HandlerType handlerType, + @Nullable String acceptContentType, + Serde requestSerde, + Serde responseSerde) { + this.name = name; + this.handlerType = handlerType; + this.acceptContentType = acceptContentType; + this.requestSerde = requestSerde; + this.responseSerde = responseSerde; + } + + public static HandlerSpecification of( + String method, HandlerType handlerType, Serde requestSerde, Serde responseSerde) { + return new HandlerSpecification<>(method, handlerType, null, requestSerde, responseSerde); + } + + public String getName() { + return name; + } + + public HandlerType getHandlerType() { + return handlerType; + } + + public @Nullable String getAcceptContentType() { + return acceptContentType; + } + + public Serde getRequestSerde() { + return requestSerde; + } + + public Serde getResponseSerde() { + return responseSerde; + } + + public HandlerSpecification withAcceptContentType(String acceptContentType) { + return new HandlerSpecification<>( + name, handlerType, acceptContentType, requestSerde, responseSerde); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + HandlerSpecification that = (HandlerSpecification) o; + return Objects.equals(name, that.name) + && handlerType == that.handlerType + && Objects.equals(acceptContentType, that.acceptContentType) + && Objects.equals(requestSerde, that.requestSerde) + && Objects.equals(responseSerde, that.responseSerde); + } + + @Override + public int hashCode() { + int result = Objects.hashCode(name); + result = 31 * result + Objects.hashCode(handlerType); + result = 31 * result + Objects.hashCode(acceptContentType); + result = 31 * result + Objects.hashCode(requestSerde); + result = 31 * result + Objects.hashCode(responseSerde); + return result; + } + + @Override + public String toString() { + return "HandlerSpecification{" + + "name='" + + name + + '\'' + + ", handlerType=" + + handlerType + + ", acceptContentType='" + + acceptContentType + + '\'' + + ", requestContentType=" + + requestSerde.contentType() + + ", responseContentType=" + + responseSerde.contentType() + + '}'; + } +} diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/InvocationHandler.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/InvocationHandler.java index 2a141fdc..c4aa7235 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/InvocationHandler.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/InvocationHandler.java @@ -10,7 +10,7 @@ import com.google.protobuf.ByteString; -public interface InvocationHandler { +public interface InvocationHandler { /** * Thread local to store {@link Syscalls}. * @@ -21,5 +21,9 @@ public interface InvocationHandler { */ ThreadLocal SYSCALLS_THREAD_LOCAL = new ThreadLocal<>(); - void handle(Syscalls syscalls, O options, SyscallCallback callback); + void handle( + HandlerSpecification handlerSpecification, + Syscalls syscalls, + O options, + SyscallCallback callback); } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ServiceDefinition.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ServiceDefinition.java index e22cde40..a4f644cb 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ServiceDefinition.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/ServiceDefinition.java @@ -17,17 +17,14 @@ public final class ServiceDefinition { private final String serviceName; private final ServiceType serviceType; - private final Map> handlers; - - public ServiceDefinition( - String fullyQualifiedComponentName, - ServiceType serviceType, - Collection> handlers) { - this.serviceName = fullyQualifiedComponentName; - this.serviceType = serviceType; + private final Map> handlers; + + ServiceDefinition(String name, ServiceType ty, Collection> handlers) { + this.serviceName = name; + this.serviceType = ty; this.handlers = handlers.stream() - .collect(Collectors.toMap(HandlerDefinition::getName, Function.identity())); + .collect(Collectors.toMap(h -> h.getSpec().getName(), Function.identity())); } public String getServiceName() { @@ -38,11 +35,11 @@ public ServiceType getServiceType() { return serviceType; } - public Collection> getHandlers() { + public Collection> getHandlers() { return handlers.values(); } - public HandlerDefinition getHandler(String name) { + public HandlerDefinition getHandler(String name) { return handlers.get(name); } @@ -60,4 +57,9 @@ public boolean equals(Object object) { public int hashCode() { return Objects.hash(serviceName, serviceType, handlers); } + + public static ServiceDefinition of( + String name, ServiceType ty, Collection> handlers) { + return new ServiceDefinition<>(name, ty, handlers); + } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/DeploymentManifest.java b/sdk-core/src/main/java/dev/restate/sdk/core/DeploymentManifest.java index b6a8cfa0..11d05dbe 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/DeploymentManifest.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/DeploymentManifest.java @@ -11,6 +11,7 @@ import dev.restate.sdk.common.HandlerType; import dev.restate.sdk.common.ServiceType; import dev.restate.sdk.common.syscalls.HandlerDefinition; +import dev.restate.sdk.common.syscalls.HandlerSpecification; import dev.restate.sdk.common.syscalls.ServiceDefinition; import dev.restate.sdk.core.manifest.*; import java.util.stream.Collectors; @@ -58,21 +59,25 @@ private static Service.Ty convertServiceType(ServiceType serviceType) { throw new IllegalStateException(); } - private static Handler convertHandler(HandlerDefinition handler) { + private static Handler convertHandler(HandlerDefinition handler) { + HandlerSpecification spec = handler.getSpec(); + String acceptContentType = + spec.getAcceptContentType() != null + ? spec.getAcceptContentType() + : spec.getRequestSerde().contentType(); + return new Handler() - .withName(handler.getName()) - .withTy(convertHandlerType(handler.getHandlerType())) + .withName(spec.getName()) + .withTy(convertHandlerType(spec.getHandlerType())) .withInput( - handler.getAcceptInputContentType() == null + acceptContentType == null ? EMPTY_INPUT - : new Input() - .withRequired(handler.isInputRequired()) - .withContentType(handler.getAcceptInputContentType())) + : new Input().withRequired(true).withContentType(acceptContentType)) .withOutput( - handler.getReturnedContentType() == null + spec.getResponseSerde().contentType() == null ? EMPTY_OUTPUT : new Output() - .withContentType(handler.getReturnedContentType()) + .withContentType(spec.getResponseSerde().contentType()) .withSetContentTypeIfEmpty(false)); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java index c9f01018..40a3a535 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ResolvedEndpointHandlerImpl.java @@ -10,9 +10,7 @@ import com.google.protobuf.ByteString; import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.common.syscalls.InvocationHandler; -import dev.restate.sdk.common.syscalls.SyscallCallback; -import dev.restate.sdk.common.syscalls.Syscalls; +import dev.restate.sdk.common.syscalls.*; import java.util.concurrent.Executor; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -23,17 +21,22 @@ final class ResolvedEndpointHandlerImpl implements ResolvedEndpointHandler { private static final Logger LOG = LogManager.getLogger(ResolvedEndpointHandlerImpl.class); private final InvocationStateMachine stateMachine; - private final InvocationHandler wrappedHandler; + private final HandlerSpecification spec; + private final InvocationHandler wrappedHandler; private final Object componentOptions; private final @Nullable Executor syscallsExecutor; + @SuppressWarnings("unchecked") public ResolvedEndpointHandlerImpl( InvocationStateMachine stateMachine, - InvocationHandler handler, + HandlerDefinition handler, Object serviceOptions, @Nullable Executor syscallExecutor) { this.stateMachine = stateMachine; - this.wrappedHandler = new InvocationHandlerWrapper<>(handler); + this.spec = (HandlerSpecification) handler.getSpec(); + this.wrappedHandler = + new InvocationHandlerWrapper<>( + (InvocationHandler) handler.getHandler()); this.componentOptions = serviceOptions; this.syscallsExecutor = syscallExecutor; } @@ -63,6 +66,7 @@ public void start() { // pollInput then invoke the wrappedHandler wrappedHandler.handle( + spec, syscalls, componentOptions, SyscallCallback.of( @@ -102,18 +106,23 @@ private void end(SyscallsInternal syscalls, @Nullable Throwable exception) { } } - private static class InvocationHandlerWrapper implements InvocationHandler { + private static class InvocationHandlerWrapper + implements InvocationHandler { - private final InvocationHandler handler; + private final InvocationHandler handler; - private InvocationHandlerWrapper(InvocationHandler handler) { + private InvocationHandlerWrapper(InvocationHandler handler) { this.handler = handler; } @Override - public void handle(Syscalls syscalls, O options, SyscallCallback callback) { + public void handle( + HandlerSpecification spec, + Syscalls syscalls, + O options, + SyscallCallback callback) { try { - this.handler.handle(syscalls, options, callback); + this.handler.handle(spec, syscalls, options, callback); } catch (Throwable e) { callback.onCancel(e); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java b/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java index 6441d76b..f481933c 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java @@ -65,7 +65,7 @@ public ResolvedEndpointHandler resolve( throw ProtocolException.methodNotFound(componentName, handlerName); } String fullyQualifiedServiceMethod = componentName + "/" + handlerName; - HandlerDefinition handler = svc.service.getHandler(handlerName); + HandlerDefinition handler = svc.service.getHandler(handlerName); if (handler == null) { throw ProtocolException.methodNotFound(componentName, handlerName); } @@ -96,8 +96,7 @@ public ResolvedEndpointHandler resolve( new InvocationStateMachine( componentName, fullyQualifiedServiceMethod, span, loggingContextSetter); - return new ResolvedEndpointHandlerImpl( - stateMachine, handler.getHandler(), svc.options, syscallExecutor); + return new ResolvedEndpointHandlerImpl(stateMachine, handler, svc.options, syscallExecutor); } public DeploymentManifestSchema handleDiscoveryRequest() { diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java b/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java index a17c60ab..327130b0 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/AssertUtils.java @@ -9,14 +9,24 @@ package dev.restate.sdk.core; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.InstanceOfAssertFactories.STRING; import static org.assertj.core.api.InstanceOfAssertFactories.type; import com.google.protobuf.MessageLite; import dev.restate.generated.service.protocol.Protocol; +import dev.restate.sdk.common.BindableService; import dev.restate.sdk.common.TerminalException; +import dev.restate.sdk.core.manifest.DeploymentManifestSchema; +import dev.restate.sdk.core.manifest.Handler; +import dev.restate.sdk.core.manifest.Service; +import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.assertj.core.api.AbstractObjectAssert; +import org.assertj.core.api.ObjectAssert; public class AssertUtils { @@ -57,4 +67,69 @@ public static Consumer protocolExceptionErrorMessage(int co .extracting(Protocol.ErrorMessage::getMessage, STRING) .startsWith(ProtocolException.class.getCanonicalName())); } + + public static DeploymentManifestSchemaAssert assertThatDiscovery(Object... services) { + return new DeploymentManifestSchemaAssert( + new DeploymentManifest( + DeploymentManifestSchema.ProtocolMode.BIDI_STREAM, + Arrays.stream(services) + .flatMap( + svc -> { + if (svc instanceof BindableService) { + return ((BindableService) svc).definitions().stream(); + } + + return RestateEndpoint.discoverBindableServiceFactory(svc) + .create(svc) + .definitions() + .stream(); + })) + .manifest(), + DeploymentManifestSchemaAssert.class); + } + + public static class DeploymentManifestSchemaAssert + extends AbstractObjectAssert { + public DeploymentManifestSchemaAssert( + DeploymentManifestSchema deploymentManifestSchema, Class selfType) { + super(deploymentManifestSchema, selfType); + } + + public ServiceAssert extractingService(String service) { + Optional svc = + this.actual.getServices().stream().filter(s -> s.getName().equals(service)).findFirst(); + + if (svc.isEmpty()) { + fail( + "Expecting deployment manifest to contain service {}. Available services: {}", + service, + this.actual.getServices().stream().map(Service::getName).collect(Collectors.toList())); + } + + return new ServiceAssert(svc.get(), ServiceAssert.class); + } + } + + public static class ServiceAssert extends AbstractObjectAssert { + public ServiceAssert(Service svc, Class selfType) { + super(svc, selfType); + } + + public ObjectAssert extractingHandler(String handlerName) { + Optional handler = + this.actual.getHandlers().stream() + .filter(s -> s.getName().equals(handlerName)) + .findFirst(); + + if (handler.isEmpty()) { + fail( + "Expecting service {} manifest to contain handler {}. Available handler: {}", + this.actual.getName(), + handlerName, + this.actual.getHandlers().stream().map(Handler::getName).collect(Collectors.toList())); + } + + return assertThat(handler.get()); + } + } } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java b/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java index 4a9332a7..351981b9 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/ComponentDiscoveryHandlerTest.java @@ -10,9 +10,11 @@ import static org.assertj.core.api.Assertions.assertThat; +import dev.restate.sdk.common.CoreSerdes; import dev.restate.sdk.common.HandlerType; import dev.restate.sdk.common.ServiceType; import dev.restate.sdk.common.syscalls.HandlerDefinition; +import dev.restate.sdk.common.syscalls.HandlerSpecification; import dev.restate.sdk.common.syscalls.ServiceDefinition; import dev.restate.sdk.core.manifest.DeploymentManifestSchema; import dev.restate.sdk.core.manifest.DeploymentManifestSchema.ProtocolMode; @@ -29,16 +31,13 @@ void handleWithMultipleServices() { new DeploymentManifest( ProtocolMode.REQUEST_RESPONSE, Stream.of( - new ServiceDefinition<>( + ServiceDefinition.of( "MyGreeter", ServiceType.SERVICE, List.of( - new HandlerDefinition<>( - "greet", - HandlerType.EXCLUSIVE, - false, - "application/json", - "application/json", + HandlerDefinition.of( + HandlerSpecification.of( + "greet", HandlerType.EXCLUSIVE, CoreSerdes.VOID, CoreSerdes.VOID), null))))); DeploymentManifestSchema manifest = deploymentManifest.manifest(); diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java b/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java index b96d5289..74003448 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java @@ -106,6 +106,10 @@ public static Protocol.InputEntryMessage inputMessage() { return Protocol.InputEntryMessage.newBuilder().setValue(ByteString.EMPTY).build(); } + public static Protocol.InputEntryMessage inputMessage(byte[] value) { + return Protocol.InputEntryMessage.newBuilder().setValue(ByteString.copyFrom(value)).build(); + } + public static Protocol.InputEntryMessage inputMessage(Serde serde, T value) { return Protocol.InputEntryMessage.newBuilder() .setValue(serde.serializeToByteString(value)) @@ -134,6 +138,10 @@ public static Protocol.OutputEntryMessage outputMessage(int value) { return outputMessage(CoreSerdes.JSON_INT, value); } + public static Protocol.OutputEntryMessage outputMessage(byte[] b) { + return outputMessage(CoreSerdes.RAW, b); + } + public static Protocol.OutputEntryMessage outputMessage() { return Protocol.OutputEntryMessage.newBuilder().setValue(ByteString.EMPTY).build(); } @@ -202,6 +210,10 @@ public static Protocol.CallEntryMessage.Builder invokeMessage(Target target) { return builder; } + public static Protocol.CallEntryMessage.Builder invokeMessage(Target target, byte[] parameter) { + return invokeMessage(target, CoreSerdes.RAW, parameter); + } + public static Protocol.CallEntryMessage.Builder invokeMessage( Target target, Serde reqSerde, T parameter) { return invokeMessage(target).setParameter(reqSerde.serializeToByteString(parameter)); diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java b/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java index 9d0d4b22..802f637c 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/TestDefinitions.java @@ -236,7 +236,9 @@ private ExpectingOutputMessages( method, input, onlyUnbuffered, - service != null ? service.definitions().get(0).getServiceName() : "Unknown"); + service != null + ? service.definitions().get(0).getServiceName() + "#" + method + : "Unknown"); this.messagesAssert = messagesAssert; } diff --git a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/VertxExecutorsTest.kt b/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/VertxExecutorsTest.kt index 0040a7b6..ff8136a1 100644 --- a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/VertxExecutorsTest.kt +++ b/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/VertxExecutorsTest.kt @@ -10,6 +10,7 @@ package dev.restate.sdk.http.vertx import com.google.protobuf.ByteString import dev.restate.generated.service.protocol.Protocol +import dev.restate.sdk.Service import dev.restate.sdk.common.CoreSerdes import dev.restate.sdk.core.ProtoUtils.* import dev.restate.sdk.core.TestDefinitions @@ -82,10 +83,11 @@ class VertxExecutorsTest : TestDefinitions.TestSuite { testInvocation( dev.restate.sdk.Service.service("CheckBlockingComponentTrampolineExecutor") .with( - dev.restate.sdk.Service.HandlerSignature.of( - "do", CoreSerdes.VOID, CoreSerdes.VOID), + "do", + CoreSerdes.VOID, + CoreSerdes.VOID, this::checkBlockingComponentTrampolineExecutor) - .build(dev.restate.sdk.Service.Options.DEFAULT), + .build(Service.Options.DEFAULT), "do") .withInput(startMessage(1), inputMessage(), ackMessage(1)) .onlyUnbuffered() diff --git a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java b/sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java index 6333ef2c..5a5a8f80 100644 --- a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java +++ b/sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java @@ -23,7 +23,7 @@ import dev.restate.sdk.core.ProtoUtils; import dev.restate.sdk.core.manifest.DeploymentManifestSchema; import dev.restate.sdk.core.manifest.Service; -import dev.restate.sdk.lambda.testservices.JavaCounterClient; +import dev.restate.sdk.lambda.testservices.JavaCounterDefinitions; import dev.restate.sdk.lambda.testservices.MyServicesHandler; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -35,7 +35,7 @@ class LambdaHandlerTest { - @ValueSource(strings = {JavaCounterClient.SERVICE_NAME, "KtCounter"}) + @ValueSource(strings = {JavaCounterDefinitions.SERVICE_NAME, "KtCounter"}) @ParameterizedTest public void testInvoke(String serviceName) throws IOException { MyServicesHandler handler = new MyServicesHandler(); @@ -97,7 +97,7 @@ public void testDiscovery() throws IOException { assertThat(discoveryResponse.getServices()) .map(Service::getName) - .containsOnly(JavaCounterClient.SERVICE_NAME, "KtCounter"); + .containsOnly(JavaCounterDefinitions.SERVICE_NAME, "KtCounter"); } private static byte[] serializeEntries(MessageLite... msgs) throws IOException { diff --git a/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/WorkflowBuilder.java b/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/WorkflowBuilder.java index da2749bf..3a2a27de 100644 --- a/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/WorkflowBuilder.java +++ b/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/WorkflowBuilder.java @@ -11,6 +11,9 @@ import dev.restate.sdk.Service; import dev.restate.sdk.common.BindableService; import dev.restate.sdk.common.HandlerType; +import dev.restate.sdk.common.Serde; +import dev.restate.sdk.common.syscalls.HandlerDefinition; +import dev.restate.sdk.common.syscalls.HandlerSpecification; import dev.restate.sdk.workflow.impl.WorkflowImpl; import java.util.HashMap; import java.util.function.BiFunction; @@ -18,18 +21,25 @@ public final class WorkflowBuilder { private final String name; - private final Service.Handler workflowMethod; - private final HashMap> sharedMethods; + private final HandlerDefinition workflowMethod; + private final HashMap> sharedMethods; - private WorkflowBuilder(String name, Service.Handler workflowMethod) { + private WorkflowBuilder(String name, HandlerDefinition workflowMethod) { this.name = name; this.workflowMethod = workflowMethod; this.sharedMethods = new HashMap<>(); } public WorkflowBuilder withShared( - Service.HandlerSignature sig, BiFunction runner) { - this.sharedMethods.put(sig.getName(), new Service.Handler<>(sig, HandlerType.SHARED, runner)); + String name, + Serde requestSerde, + Serde responseSerde, + BiFunction runner) { + this.sharedMethods.put( + name, + HandlerDefinition.of( + HandlerSpecification.of(name, HandlerType.SHARED, requestSerde, responseSerde), + Service.Handler.of(runner))); return this; } @@ -39,8 +49,13 @@ public BindableService build(Service.Options options) { public static WorkflowBuilder named( String name, - Service.HandlerSignature sig, + Serde requestSerde, + Serde responseSerde, BiFunction runner) { - return new WorkflowBuilder(name, new Service.Handler<>(sig, HandlerType.SHARED, runner)); + return new WorkflowBuilder( + name, + HandlerDefinition.of( + HandlerSpecification.of("run", HandlerType.EXCLUSIVE, requestSerde, responseSerde), + Service.Handler.of(runner))); } } diff --git a/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowImpl.java b/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowImpl.java index e5484dd4..23d658d4 100644 --- a/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowImpl.java +++ b/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowImpl.java @@ -13,8 +13,9 @@ import dev.restate.sdk.Context; import dev.restate.sdk.ObjectContext; import dev.restate.sdk.Service; -import dev.restate.sdk.Service.HandlerSignature; import dev.restate.sdk.common.*; +import dev.restate.sdk.common.syscalls.HandlerDefinition; +import dev.restate.sdk.common.syscalls.HandlerSpecification; import dev.restate.sdk.common.syscalls.ServiceDefinition; import dev.restate.sdk.serde.jackson.JacksonSerdes; import dev.restate.sdk.serde.protobuf.ProtobufSerdes; @@ -61,14 +62,14 @@ public class WorkflowImpl implements BindableService { private final String name; private final Service.Options options; - private final Service.Handler workflowMethod; - private final HashMap> sharedHandlers; + private final HandlerDefinition workflowMethod; + private final HashMap> sharedHandlers; public WorkflowImpl( String name, Service.Options options, - Service.Handler workflowMethod, - HashMap> sharedHandlers) { + HandlerDefinition workflowMethod, + HashMap> sharedHandlers) { this.name = name; this.options = options; this.workflowMethod = workflowMethod; @@ -103,7 +104,7 @@ private void internalStart(Context context, InvokeRequest invokeRequest) { // Convert input Object input = this.workflowMethod - .getHandlerSignature() + .getSpec() .getRequestSerde() .deserialize(invokeRequest.getPayload().toString().getBytes(StandardCharsets.UTF_8)); @@ -111,12 +112,13 @@ private void internalStart(Context context, InvokeRequest invokeRequest) { WorkflowContext ctx = new WorkflowContextImpl(context, name, invokeRequest.getKey(), true); @SuppressWarnings("unchecked") Object output = - ((BiFunction) this.workflowMethod.getRunner()).apply(ctx, input); + ((Service.Handler) this.workflowMethod.getHandler()) + .getRunner() + .apply(ctx, input); //noinspection unchecked valueOutput = - ((Serde) this.workflowMethod.getHandlerSignature().getResponseSerde()) - .serialize(output); + ((Serde) this.workflowMethod.getSpec().getResponseSerde()).serialize(output); } catch (TerminalException e) { // Intercept TerminalException to record it context.send( @@ -144,8 +146,8 @@ private void internalStart(Context context, InvokeRequest invokeRequest) { private byte[] invokeSharedMethod(String handlerName, Context context, InvokeRequest request) { // Lookup the method @SuppressWarnings("unchecked") - Service.Handler method = - (Service.Handler) sharedHandlers.get(handlerName); + HandlerDefinition method = + (HandlerDefinition) sharedHandlers.get(handlerName); if (method == null) { throw new TerminalException(404, "Method " + handlerName + " not found"); } @@ -153,16 +155,17 @@ private byte[] invokeSharedMethod(String handlerName, Context context, InvokeReq // Convert input Object input = method - .getHandlerSignature() + .getSpec() .getRequestSerde() .deserialize(request.getPayload().toString().getBytes(StandardCharsets.UTF_8)); // Invoke method WorkflowContext ctx = new WorkflowContextImpl(context, name, request.getKey(), false); // We let the sdk core to manage the failures - Object output = method.getRunner().apply(ctx, input); + Object output = + ((Service.Handler) method.getHandler()).getRunner().apply(ctx, input); - return method.getHandlerSignature().getResponseSerde().serialize(output); + return method.getSpec().getResponseSerde().serialize(output); } // --- Workflow manager methods @@ -329,84 +332,98 @@ public Service.Options options() { @Override public List> definitions() { // Prepare workflow service - Service.ServiceBuilder workflowBuilder = - Service.service(name) - .with( - HandlerSignature.of("submit", INVOKE_REQUEST_SERDE, WORKFLOW_EXECUTION_STATE_SERDE), - this::submit) - .with( - HandlerSignature.of(START_HANDLER, INVOKE_REQUEST_SERDE, CoreSerdes.VOID), - (context, invokeRequest) -> { - this.internalStart(context, invokeRequest); - return null; - }); + List> workflowHandlers = new ArrayList<>(); + workflowHandlers.add( + HandlerDefinition.of( + HandlerSpecification.of( + "submit", HandlerType.SHARED, INVOKE_REQUEST_SERDE, WORKFLOW_EXECUTION_STATE_SERDE), + Service.Handler.of(this::submit))); + workflowHandlers.add( + HandlerDefinition.of( + HandlerSpecification.of( + START_HANDLER, HandlerType.SHARED, INVOKE_REQUEST_SERDE, CoreSerdes.VOID), + Service.Handler.of(this::internalStart))); // Append shared methods - for (var sharedMethod : sharedHandlers.values()) { - workflowBuilder.with( - HandlerSignature.of( - sharedMethod.getHandlerSignature().getName(), INVOKE_REQUEST_SERDE, CoreSerdes.RAW), - (context, invokeRequest) -> - this.invokeSharedMethod( - sharedMethod.getHandlerSignature().getName(), context, invokeRequest)); + for (HandlerDefinition sharedMethod : sharedHandlers.values()) { + workflowHandlers.add( + HandlerDefinition.of( + HandlerSpecification.of( + sharedMethod.getSpec().getName(), + HandlerType.SHARED, + INVOKE_REQUEST_SERDE, + CoreSerdes.RAW), + Service.Handler.of( + (BiFunction) + (context, invokeRequest) -> + this.invokeSharedMethod( + sharedMethod.getSpec().getName(), context, invokeRequest)))); } // Prepare workflow manager service Service workflowManager = Service.virtualObject(workflowManagerObjectName(name)) .withExclusive( - HandlerSignature.of("getState", CoreSerdes.JSON_STRING, GET_STATE_RESPONSE_SERDE), - this::getState) + "getState", CoreSerdes.JSON_STRING, GET_STATE_RESPONSE_SERDE, this::getState) .withExclusive( - HandlerSignature.of("setState", SET_STATE_REQUEST_SERDE, CoreSerdes.VOID), + "setState", + SET_STATE_REQUEST_SERDE, + CoreSerdes.VOID, (context, setStateRequest) -> { this.setState(context, setStateRequest); return null; }) .withExclusive( - HandlerSignature.of("clearState", CoreSerdes.JSON_STRING, CoreSerdes.VOID), + "clearState", + CoreSerdes.JSON_STRING, + CoreSerdes.VOID, (context, s) -> { this.clearState(context, s); return null; }) .withExclusive( - HandlerSignature.of( - "waitDurablePromiseCompletion", - WAIT_DURABLE_PROMISE_COMPLETION_REQUEST_SERDE, - CoreSerdes.VOID), + "waitDurablePromiseCompletion", + WAIT_DURABLE_PROMISE_COMPLETION_REQUEST_SERDE, + CoreSerdes.VOID, (context, waitDurablePromiseCompletionRequest) -> { this.waitDurablePromiseCompletion(context, waitDurablePromiseCompletionRequest); return null; }) .withExclusive( - HandlerSignature.of( - "getDurablePromiseCompletion", - CoreSerdes.JSON_STRING, - MAYBE_DURABLE_PROMISE_COMPLETION_SERDE), + "getDurablePromiseCompletion", + CoreSerdes.JSON_STRING, + MAYBE_DURABLE_PROMISE_COMPLETION_SERDE, this::getDurablePromiseCompletion) .withExclusive( - HandlerSignature.of( - "completeDurablePromise", - COMPLETE_DURABLE_PROMISE_REQUEST_SERDE, - CoreSerdes.VOID), + "completeDurablePromise", + COMPLETE_DURABLE_PROMISE_REQUEST_SERDE, + CoreSerdes.VOID, (context, completeDurablePromiseRequest) -> { this.completeDurablePromise(context, completeDurablePromiseRequest); return null; }) .withExclusive( - HandlerSignature.of("tryStart", CoreSerdes.VOID, WORKFLOW_EXECUTION_STATE_SERDE), + "tryStart", + CoreSerdes.VOID, + WORKFLOW_EXECUTION_STATE_SERDE, (context, unused) -> this.tryStart(context)) .withExclusive( - HandlerSignature.of("getOutput", CoreSerdes.VOID, GET_OUTPUT_RESPONSE_SERDE), + "getOutput", + CoreSerdes.VOID, + GET_OUTPUT_RESPONSE_SERDE, (context, unused) -> this.getOutput(context)) .withExclusive( - HandlerSignature.of("setOutput", SET_OUTPUT_REQUEST_SERDE, CoreSerdes.VOID), + "setOutput", + SET_OUTPUT_REQUEST_SERDE, + CoreSerdes.VOID, (context, setOutputRequest) -> { this.setOutput(context, setOutputRequest); return null; }) .withExclusive( - HandlerSignature.of("cleanup", CoreSerdes.VOID, CoreSerdes.VOID), + "cleanup", + CoreSerdes.VOID, + CoreSerdes.VOID, (context, unused) -> { this.cleanup(context); return null; @@ -414,6 +431,7 @@ public List> definitions() { .build(options); return List.of( - workflowBuilder.build(options).definitions().get(0), workflowManager.definitions().get(0)); + ServiceDefinition.of(name, ServiceType.SERVICE, workflowHandlers), + workflowManager.definitions().get(0)); } }