diff --git a/examples/async/src/app/sse-infinite/StreamedTimeSSE.tsx b/examples/async/src/app/sse-infinite/StreamedTimeSSE.tsx index 3660860..c153642 100644 --- a/examples/async/src/app/sse-infinite/StreamedTimeSSE.tsx +++ b/examples/async/src/app/sse-infinite/StreamedTimeSSE.tsx @@ -11,6 +11,7 @@ export function StreamedTimeSSE() { useEffect(() => { const abortSignal = new AbortController(); createEventSource("/sse-infinite", { + reconnect: true, signal: abortSignal.signal, }) .then(async (shape) => { diff --git a/src/async/deserializeAsync.ts b/src/async/deserializeAsync.ts index 61d8ad9..f7cdfac 100644 --- a/src/async/deserializeAsync.ts +++ b/src/async/deserializeAsync.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-unsafe-assignment */ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import { TsonError } from "../errors.js"; @@ -31,10 +32,20 @@ type AnyTsonTransformerSerializeDeserialize = | TsonTransformerSerializeDeserialize; export interface TsonParseAsyncOptions { + /** + * Event handler for when the stream reconnects + * You can use this to do extra actions to ensure no messages were lost + */ + onReconnect?: () => void; /** * On stream error */ onStreamError?: (err: TsonStreamInterruptedError) => void; + /** + * Allow reconnecting to the stream if it's interrupted + * @default false + */ + reconnect?: boolean; } type TsonParseAsync = ( @@ -62,10 +73,11 @@ function createTsonDeserializer(opts: TsonAsyncOptions) { iterable: TsonDeserializeIterable, parseOptions: TsonParseAsyncOptions, ) => { - const cache = new Map< + const controllers = new Map< TsonAsyncIndex, ReadableStreamDefaultController >(); + const cache = new Map(); const iterator = iterable[Symbol.asyncIterator](); const walker: WalkerFactory = (nonce) => { @@ -83,6 +95,15 @@ function createTsonDeserializer(opts: TsonAsyncOptions) { const idx = serializedValue as TsonAsyncIndex; + if (cache.has(idx)) { + // We already have this async value in the cache - so this is probably a reconnect + assert( + parseOptions.reconnect, + "Duplicate index found but reconnect is off", + ); + return cache.get(idx); + } + const [readable, controller] = createReadableStream(); // the `start` method is called "immediately when the object is constructed" @@ -90,15 +111,18 @@ function createTsonDeserializer(opts: TsonAsyncOptions) { // so we're guaranteed that the controller is set in the cache assert(controller, "Controller not set - this is a bug"); - cache.set(idx, controller); + controllers.set(idx, controller); - return transformer.deserialize({ + const result = transformer.deserialize({ close() { controller.close(); - cache.delete(idx); + controllers.delete(idx); }, reader: readable.getReader(), }); + + cache.set(idx, result); + return result; } return mapOrReturn(value, walk); @@ -117,16 +141,33 @@ function createTsonDeserializer(opts: TsonAsyncOptions) { const { value } = nextValue; + if (!Array.isArray(value)) { + // we got the beginning of a new stream - probably because a reconnect + // we assume this new stream will have the same shape and restart the walker with the nonce + + parseOptions.onReconnect?.(); + + assert( + parseOptions.reconnect, + "Stream got beginning of results but reconnecting is not enabled", + ); + + await getStreamedValues(walker(value.nonce)); + return; + } + const [index, result] = value as TsonAsyncValueTuple; - const controller = cache.get(index); + const controller = controllers.get(index); const walkedResult = walk(result); - assert(controller, `No stream found for index ${index}`); + if (!parseOptions.reconnect) { + assert(controller, `No stream found for index ${index}`); + } // resolving deferred - controller.enqueue(walkedResult); + controller?.enqueue(walkedResult); } } @@ -152,7 +193,7 @@ function createTsonDeserializer(opts: TsonAsyncOptions) { const err = new TsonStreamInterruptedError(cause); // enqueue the error to all the streams - for (const controller of cache.values()) { + for (const controller of controllers.values()) { controller.enqueue(err); } diff --git a/src/async/sse.test.ts b/src/async/sse.test.ts index e1228d4..24a21ef 100644 --- a/src/async/sse.test.ts +++ b/src/async/sse.test.ts @@ -1,9 +1,14 @@ /* eslint-disable @typescript-eslint/no-unnecessary-condition */ import { EventSourcePolyfill, NativeEventSource } from "event-source-polyfill"; -import { expect, test } from "vitest"; +import { expect, test, vi } from "vitest"; (global as any).EventSource = NativeEventSource || EventSourcePolyfill; -import { TsonAsyncOptions, tsonAsyncIterable, tsonPromise } from "../index.js"; +import { + TsonAsyncOptions, + tsonAsyncIterable, + tsonBigint, + tsonPromise, +} from "../index.js"; import { createTestServer, sleep } from "../internals/testUtils.js"; import { createTsonAsync } from "./createTsonAsync.js"; @@ -13,15 +18,12 @@ test("SSE response test", async () => { let i = 0; while (true) { yield i++; - await sleep(100); + await sleep(10); } } return { - foo: "bar", iterable: generator(), - promise: Promise.resolve(42), - rejectedPromise: Promise.reject(new Error("rejected promise")), }; } @@ -73,14 +75,14 @@ test("SSE response test", async () => { }); expect(messages).toMatchInlineSnapshot(` - [ - "{\\"json\\":{\\"foo\\":\\"bar\\",\\"iterable\\":[\\"AsyncIterable\\",0,\\"__tson\\"],\\"promise\\":[\\"Promise\\",1,\\"__tson\\"],\\"rejectedPromise\\":[\\"Promise\\",2,\\"__tson\\"]},\\"nonce\\":\\"__tson\\"}", - "[0,[0,0]]", - "[1,[0,42]]", - "[2,[1,{}]]", - "[0,[0,1]]", - ] - `); + [ + "{\\"json\\":{\\"iterable\\":[\\"AsyncIterable\\",0,\\"__tson\\"]},\\"nonce\\":\\"__tson\\"}", + "[0,[0,0]]", + "[0,[0,1]]", + "[0,[0,2]]", + "[0,[0,3]]", + ] + `); } { @@ -110,3 +112,99 @@ test("SSE response test", async () => { `); } }); + +test("handle reconnects - iterator wrapped in Promise", async () => { + let i = 0; + + let kill = false; + function createMockObj() { + async function* generator() { + while (true) { + await sleep(10); + yield BigInt(i); + i++; + + if (i % 5 === 0) { + kill = true; + } + + if (i > 11) { + // done + return; + } + } + } + + return { + iterable: Promise.resolve(generator()), + }; + } + + type MockObj = ReturnType; + + // ------------- server ------------------- + const opts = { + nonce: () => "__tson" + i, // add index to nonce to make sure it's not cached + types: [tsonPromise, tsonAsyncIterable, tsonBigint], + } satisfies TsonAsyncOptions; + + const server = await createTestServer({ + handleRequest: async (_req, res) => { + const tson = createTsonAsync(opts); + + const obj = createMockObj(); + const response = tson.toSSEResponse(obj); + + for (const [key, value] of response.headers) { + res.setHeader(key, value); + } + + for await (const value of response.body as any) { + res.write(value); + if (kill) { + // interrupt the stream + res.end(); + kill = false; + return; + } + } + + res.end(); + }, + }); + + // ------------- client ------------------- + const tson = createTsonAsync(opts); + + // e2e + const ac = new AbortController(); + const onReconnect = vi.fn(); + const shape = await tson.createEventSource(server.url, { + onReconnect, + reconnect: true, + signal: ac.signal, + }); + + const messages: bigint[] = []; + + for await (const value of await shape.iterable) { + messages.push(value); + } + + expect(messages).toMatchInlineSnapshot(` + [ + 0n, + 1n, + 2n, + 3n, + 4n, + 6n, + 7n, + 8n, + 9n, + 11n, + ] + `); + + expect(onReconnect).toHaveBeenCalledTimes(2); +}); diff --git a/vitest.config.ts b/vitest.config.ts index 27b62ac..77279c4 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -10,6 +10,9 @@ export default defineConfig({ reporter: ["html", "lcov"], }, exclude: ["lib", "node_modules", "examples", "benchmark"], - setupFiles: ["console-fail-test/setup"], + setupFiles: [ + // this is useful to comment out sometimes + "console-fail-test/setup", + ], }, });