diff --git a/.changeset/tidy-planes-pay.md b/.changeset/tidy-planes-pay.md new file mode 100644 index 000000000..0986ade89 --- /dev/null +++ b/.changeset/tidy-planes-pay.md @@ -0,0 +1,5 @@ +--- +"@xmtp/node-sdk": patch +--- + +Refactor AsyncStream diff --git a/sdks/node-sdk/src/AsyncStream.ts b/sdks/node-sdk/src/AsyncStream.ts index 4065d8b49..e0f488299 100644 --- a/sdks/node-sdk/src/AsyncStream.ts +++ b/sdks/node-sdk/src/AsyncStream.ts @@ -1,7 +1,6 @@ type ResolveValue = { value: T | undefined; done: boolean; - error: Error | null; }; type ResolveNext = (resolveValue: ResolveValue) => void; @@ -13,7 +12,7 @@ export class AsyncStream { #resolveNext: ResolveNext | null; #queue: T[]; - stopCallback: (() => void) | undefined = undefined; + onReturn: (() => void) | undefined = undefined; constructor() { this.#queue = []; @@ -27,9 +26,7 @@ export class AsyncStream { callback: StreamCallback = (error, value) => { if (error) { - console.error("stream error", error); - this.stop(error); - return; + throw error; } if (this.#done) { @@ -39,7 +36,6 @@ export class AsyncStream { if (this.#resolveNext) { this.#resolveNext({ done: false, - error: null, value, }); this.#resolveNext = null; @@ -48,29 +44,15 @@ export class AsyncStream { } }; - stop = (error?: Error) => { - this.#done = true; - if (this.#resolveNext) { - this.#resolveNext({ - done: true, - error: error ?? null, - value: undefined, - }); - } - this.stopCallback?.(); - }; - next = (): Promise> => { if (this.#queue.length > 0) { return Promise.resolve({ done: false, - error: null, value: this.#queue.shift(), }); } else if (this.#done) { return Promise.resolve({ done: true, - error: null, value: undefined, }); } else { @@ -80,6 +62,15 @@ export class AsyncStream { } }; + return = (value: T) => { + this.#done = true; + this.onReturn?.(); + return Promise.resolve({ + done: true, + value, + }); + }; + [Symbol.asyncIterator]() { return this; } diff --git a/sdks/node-sdk/src/Conversation.ts b/sdks/node-sdk/src/Conversation.ts index f33da7d82..f5d8715bc 100644 --- a/sdks/node-sdk/src/Conversation.ts +++ b/sdks/node-sdk/src/Conversation.ts @@ -119,7 +119,7 @@ export class Conversation { callback?.(err, decodedMessage); }); - asyncStream.stopCallback = stream.end.bind(stream); + asyncStream.onReturn = stream.end.bind(stream); return asyncStream; } diff --git a/sdks/node-sdk/src/Conversations.ts b/sdks/node-sdk/src/Conversations.ts index 71b59ddfb..295e9a87b 100644 --- a/sdks/node-sdk/src/Conversations.ts +++ b/sdks/node-sdk/src/Conversations.ts @@ -106,7 +106,7 @@ export class Conversations { callback?.(err, conversation); }); - asyncStream.stopCallback = stream.end.bind(stream); + asyncStream.onReturn = stream.end.bind(stream); return asyncStream; } @@ -120,7 +120,7 @@ export class Conversations { callback?.(err, conversation); }); - asyncStream.stopCallback = stream.end.bind(stream); + asyncStream.onReturn = stream.end.bind(stream); return asyncStream; } @@ -134,7 +134,7 @@ export class Conversations { callback?.(err, conversation); }); - asyncStream.stopCallback = stream.end.bind(stream); + asyncStream.onReturn = stream.end.bind(stream); return asyncStream; } @@ -151,7 +151,7 @@ export class Conversations { callback?.(err, decodedMessage); }); - asyncStream.stopCallback = stream.end.bind(stream); + asyncStream.onReturn = stream.end.bind(stream); return asyncStream; } @@ -170,7 +170,7 @@ export class Conversations { }, ); - asyncStream.stopCallback = stream.end.bind(stream); + asyncStream.onReturn = stream.end.bind(stream); return asyncStream; } @@ -187,7 +187,7 @@ export class Conversations { callback?.(err, decodedMessage); }); - asyncStream.stopCallback = stream.end.bind(stream); + asyncStream.onReturn = stream.end.bind(stream); return asyncStream; } diff --git a/sdks/node-sdk/test/AsyncStream.test.ts b/sdks/node-sdk/test/AsyncStream.test.ts new file mode 100644 index 000000000..81939e0dc --- /dev/null +++ b/sdks/node-sdk/test/AsyncStream.test.ts @@ -0,0 +1,76 @@ +import { describe, expect, it } from "vitest"; +import { AsyncStream } from "@/AsyncStream"; + +const testError = new Error("test"); + +describe("AsyncStream", () => { + it("should return values from callbacks", async () => { + const stream = new AsyncStream(); + let onReturnCalled = false; + stream.onReturn = () => { + onReturnCalled = true; + }; + stream.callback(null, 1); + stream.callback(null, 2); + stream.callback(null, 3); + + let count = 0; + + for await (const value of stream) { + if (count === 0) { + expect(value).toBe(1); + } + if (count === 1) { + expect(value).toBe(2); + } + if (count === 2) { + expect(value).toBe(3); + break; + } + count++; + } + expect(onReturnCalled).toBe(true); + }); + + it("should catch an error thrown in the for..await loop", async () => { + const stream = new AsyncStream(); + let onReturnCalled = false; + stream.onReturn = () => { + onReturnCalled = true; + }; + stream.callback(null, 1); + + try { + for await (const value of stream) { + expect(value).toBe(1); + throw testError; + } + } catch (error) { + expect(error).toBe(testError); + expect((error as Error).message).toBe("test"); + } + expect(onReturnCalled).toBe(true); + }); + + it("should catch an error passed to callback", async () => { + const runTest = async () => { + const stream = new AsyncStream(); + let onReturnCalled = false; + stream.onReturn = () => { + onReturnCalled = true; + }; + stream.callback(testError, 1); + try { + for await (const _value of stream) { + // this block should never be reached + } + } catch (error) { + expect(error).toBe(testError); + expect((error as Error).message).toBe("test"); + } + expect(onReturnCalled).toBe(true); + }; + + await expect(runTest()).rejects.toThrow(testError); + }); +}); diff --git a/sdks/node-sdk/test/Conversation.test.ts b/sdks/node-sdk/test/Conversation.test.ts index 93843eeb1..b2aced232 100644 --- a/sdks/node-sdk/test/Conversation.test.ts +++ b/sdks/node-sdk/test/Conversation.test.ts @@ -337,7 +337,6 @@ describe("Conversation", () => { break; } } - stream.stop(); }); it("should add and remove admins", async () => { diff --git a/sdks/node-sdk/test/Conversations.test.ts b/sdks/node-sdk/test/Conversations.test.ts index 403049672..b041d8bd7 100644 --- a/sdks/node-sdk/test/Conversations.test.ts +++ b/sdks/node-sdk/test/Conversations.test.ts @@ -3,8 +3,6 @@ import { NapiGroupPermissionsOptions, } from "@xmtp/node-bindings"; import { describe, expect, it } from "vitest"; -import { AsyncStream } from "@/AsyncStream"; -import type { Conversation } from "@/Conversation"; import { createRegisteredClient, createUser } from "@test/helpers"; describe("Conversations", () => { @@ -287,7 +285,6 @@ describe("Conversations", () => { break; } } - stream.stop(); expect( client3.conversations.getConversationById(conversation1.id)?.id, ).toBe(conversation1.id); @@ -305,8 +302,7 @@ describe("Conversations", () => { const client2 = await createRegisteredClient(user2); const client3 = await createRegisteredClient(user3); const client4 = await createRegisteredClient(user4); - const asyncStream = new AsyncStream(); - const stream = client3.conversations.streamGroups(asyncStream.callback); + const stream = client3.conversations.streamGroups(); await client4.conversations.newDm(user3.account.address); const group1 = await client1.conversations.newConversation([ user3.account.address, @@ -315,7 +311,7 @@ describe("Conversations", () => { user3.account.address, ]); let count = 0; - for await (const convo of asyncStream) { + for await (const convo of stream) { count++; expect(convo).toBeDefined(); if (count === 1) { @@ -326,7 +322,6 @@ describe("Conversations", () => { break; } } - stream.stop(); }); it("should only stream dm conversations", async () => { @@ -338,13 +333,12 @@ describe("Conversations", () => { const client2 = await createRegisteredClient(user2); const client3 = await createRegisteredClient(user3); const client4 = await createRegisteredClient(user4); - const asyncStream = new AsyncStream(); - const stream = client3.conversations.streamDms(asyncStream.callback); + const stream = client3.conversations.streamDms(); await client1.conversations.newConversation([user3.account.address]); await client2.conversations.newConversation([user3.account.address]); const group3 = await client4.conversations.newDm(user3.account.address); let count = 0; - for await (const convo of asyncStream) { + for await (const convo of stream) { count++; expect(convo).toBeDefined(); if (count === 1) { @@ -353,7 +347,6 @@ describe("Conversations", () => { } } expect(count).toBe(1); - stream.stop(); }); it("should stream all messages", async () => { @@ -390,7 +383,6 @@ describe("Conversations", () => { break; } } - stream.stop(); }); it("should only stream group conversation messages", async () => { @@ -437,7 +429,6 @@ describe("Conversations", () => { break; } } - stream.stop(); }); it("should only stream dm messages", async () => { @@ -481,6 +472,5 @@ describe("Conversations", () => { break; } } - stream.stop(); }); });