Skip to content

Commit

Permalink
Fix serialization of Java objects in step results/errors (#41)
Browse files Browse the repository at this point in the history
Step/Function result classes are no longer required to be public
with strictly public properties to serialize correctly when
sending response payloads to Inngest.
  • Loading branch information
KiKoS0 authored Feb 28, 2024
1 parent 5492d1d commit 3392c2d
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 15 deletions.
14 changes: 10 additions & 4 deletions inngest-core/src/main/kotlin/com/inngest/Comm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.inngest
import com.beust.klaxon.Json
import com.beust.klaxon.Klaxon
import com.inngest.signingkey.getAuthorizationHeader
import com.fasterxml.jackson.databind.ObjectMapper
import java.io.IOException

data class ExecutionRequestPayload(
Expand Down Expand Up @@ -83,7 +84,7 @@ class CommHandler(
body = result.data
}
return CommResponse(
body = Klaxon().toJsonString(body),
body = parseRequestBody(body),
statusCode = result.statusCode,
headers = headers,
)
Expand All @@ -98,13 +99,18 @@ class CommHandler(
stack = e.stackTrace.joinToString(separator = "\n"),
)
return CommResponse(
body = Klaxon().toJsonString(err),
body = parseRequestBody(err),
statusCode = statusCode,
headers = headers.plus(retryDecision.headers),
)
}
}

private fun parseRequestBody(requestBody: Any?): String {
val mapper = ObjectMapper()
return mapper.writeValueAsString(requestBody)
}

private fun getFunctionConfigs(): List<FunctionConfig> {
val configs: MutableList<FunctionConfig> = mutableListOf()
functions.forEach { entry -> configs.add(entry.value.getFunctionConfig(getServeUrl())) }
Expand Down Expand Up @@ -133,7 +139,7 @@ class CommHandler(

// TODO - Add headers to output
val body: Map<String, Any?> = mapOf()
return Klaxon().toJsonString(body)
return parseRequestBody(body)
}

fun sync(): Result<InngestSyncResult> {
Expand All @@ -142,7 +148,7 @@ class CommHandler(

fun introspect(): String {
val requestPayload = getRegistrationRequestPayload()
return Klaxon().toJsonString(requestPayload)
return parseRequestBody(requestPayload)
}

private fun getRegistrationRequestPayload(): RegistrationRequestPayload {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.inngest.springbootdemo;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import okhttp3.Request;

Expand Down Expand Up @@ -52,12 +53,28 @@ EventRunsResponse<Object> runsByEvent(String eventId) throws Exception {
});
}

<T> RunResponse<T> runById(String eventId) throws Exception {
<T> RunResponse<T> runById(String eventId, Class<T> outputType) throws Exception {
Request request = new Request.Builder()
.url(String.format("%s/v1/runs/%S", baseUrl, eventId))
.build();
return makeRequest(request, new TypeReference<RunResponse<T>>() {
});
try (Response response = httpClient.newCall(request).execute()) {
if (response.code() == 200) {
assert response.body() != null;

String strResponse = response.body().string();
ObjectMapper mapper = new ObjectMapper();

JsonNode node = mapper.readTree(strResponse);
JsonNode dataResult = node.path("data").path("output");

T output = mapper.treeToValue(dataResult, outputType);
RunResponse<T> result = mapper.readValue(strResponse, new TypeReference<RunResponse<T>>() {
});
result.getData().setOutput(output);
return result;
}
}
return null;
}

private <T> T makeRequest(Request request, TypeReference<T> typeReference) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package com.inngest.springbootdemo;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

public class Result {
@JsonProperty("sum")
public final int sum;
@Getter
@Setter
@JsonIgnoreProperties(ignoreUnknown = true)
@NoArgsConstructor
class Result {
int sum;

public Result(@JsonProperty("sum") int sum) {
Result(int sum) {
this.sum = sum;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.inngest.springbootdemo;

import com.inngest.CommHandler;
import com.inngest.Inngest;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
import org.springframework.beans.factory.annotation.Autowired;

import static org.junit.jupiter.api.Assertions.assertEquals;

@IntegrationTest
@Execution(ExecutionMode.CONCURRENT)
class CustomStepResultIntegrationTest {
@BeforeAll
static void setup(@Autowired CommHandler handler) {
handler.register();
}

@Autowired
private DevServerComponent devServer;

static int sleepTime = 5000;

@Autowired
private Inngest client;


@Test
void testMultiStepsFunctionWithClassResultStep() throws Exception {
String eventId = InngestFunctionTestHelpers.sendEvent(client, "test/custom.result.step").first();

Thread.sleep(sleepTime);

RunEntry<Object> run = devServer.runsByEvent(eventId).first();
RunEntry<Result> runWithOutput = devServer.runById(run.getRun_id(), Result.class).getData();

assertEquals(runWithOutput.getStatus(), "Completed");
assertEquals(runWithOutput.getOutput().getSum(), (new Result(5).getSum()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ protected HashMap<String, InngestFunction> functions() {
addInngestFunction(functions, InngestFunctionTestHelpers.emptyStepFunction());
addInngestFunction(functions, InngestFunctionTestHelpers.sleepStepFunction());
addInngestFunction(functions, InngestFunctionTestHelpers.twoStepsFunction());
addInngestFunction(functions, InngestFunctionTestHelpers.customStepResultFunction());
addInngestFunction(functions, InngestFunctionTestHelpers.waitForEventFunction());
addInngestFunction(functions, InngestFunctionTestHelpers.sendEventFunction());
addInngestFunction(functions, InngestFunctionTestHelpers.nonRetriableErrorFunction());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,29 @@ static InngestFunction twoStepsFunction() {
return new InngestFunction(fnConfig, handler);
}

static InngestFunction customStepResultFunction() {
FunctionTrigger fnTrigger = new FunctionTrigger("test/custom.result.step");
FunctionTrigger[] triggers = {fnTrigger};
FunctionOptions fnConfig = new FunctionOptions("custom-result-fn", "Custom Result Function", triggers);

int count = 0;

BiFunction<FunctionContext, Step, Result> handler = (ctx, step) -> {
int step1 = step.run("step1", () -> count + 1, Integer.class);
int tmp1 = step1 + 1;

int step2 = step.run("step2", () -> tmp1 + 1, Integer.class);
int tmp2 = step2 + 1;

return step.run("cast-to-type-add-one", () -> {
System.out.println("-> running step 1!! " + tmp2);
return new Result(tmp2 + 1);
}, Result.class);
};

return new InngestFunction(fnConfig, handler);
}

static InngestFunction waitForEventFunction() {
FunctionTrigger fnTrigger = new FunctionTrigger("test/wait-for-event");
FunctionTrigger[] triggers = {fnTrigger};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void testSleepFunctionRunningSuccessfully() throws Exception {

Thread.sleep(10000);

RunEntry<Integer> updatedRun = devServer.<Integer>runById(run.getRun_id()).getData();
RunEntry<Integer> updatedRun = devServer.runById(run.getRun_id(), Integer.class).getData();

assertEquals(updatedRun.getEvent_id(), eventId);
assertEquals(updatedRun.getStatus(), "Completed");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void testWaitForEventFunctionWhenFullFilled() throws Exception {

Thread.sleep(sleepTime);

RunEntry<Object> updatedRun = devServer.runById(run.getRun_id()).getData();
RunEntry<Object> updatedRun = devServer.runById(run.getRun_id(), Object.class).getData();

assertEquals(updatedRun.getEvent_id(), eventId);
assertEquals(updatedRun.getRun_id(), run.getRun_id());
Expand All @@ -65,7 +65,7 @@ void testWaitForEventFunctionWhenTimeOut() throws Exception {

Thread.sleep(sleepTime);

RunEntry<String> updatedRun = devServer.<String>runById(run.getRun_id()).getData();
RunEntry<String> updatedRun = devServer.runById(run.getRun_id(), String.class).getData();

assertEquals(updatedRun.getEvent_id(), eventId);
assertEquals(updatedRun.getRun_id(), run.getRun_id());
Expand Down

0 comments on commit 3392c2d

Please sign in to comment.