Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AsyncStream #704

Merged
merged 3 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/tidy-planes-pay.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@xmtp/node-sdk": patch
---

Refactor AsyncStream
31 changes: 11 additions & 20 deletions sdks/node-sdk/src/AsyncStream.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
type ResolveValue<T> = {
value: T | undefined;
done: boolean;
error: Error | null;
};

type ResolveNext<T> = (resolveValue: ResolveValue<T>) => void;
Expand All @@ -13,7 +12,7 @@ export class AsyncStream<T> {
#resolveNext: ResolveNext<T> | null;
#queue: T[];

stopCallback: (() => void) | undefined = undefined;
onReturn: (() => void) | undefined = undefined;

constructor() {
this.#queue = [];
Expand All @@ -27,9 +26,7 @@ export class AsyncStream<T> {

callback: StreamCallback<T> = (error, value) => {
if (error) {
console.error("stream error", error);
this.stop(error);
return;
throw error;
}

if (this.#done) {
Expand All @@ -39,7 +36,6 @@ export class AsyncStream<T> {
if (this.#resolveNext) {
this.#resolveNext({
done: false,
error: null,
value,
});
this.#resolveNext = null;
Expand All @@ -48,29 +44,15 @@ export class AsyncStream<T> {
}
};

stop = (error?: Error) => {
this.#done = true;
if (this.#resolveNext) {
this.#resolveNext({
done: true,
error: error ?? null,
value: undefined,
});
}
this.stopCallback?.();
};

next = (): Promise<ResolveValue<T>> => {
if (this.#queue.length > 0) {
return Promise.resolve({
done: false,
error: null,
value: this.#queue.shift(),
});
} else if (this.#done) {
return Promise.resolve({
done: true,
error: null,
value: undefined,
});
} else {
Expand All @@ -80,6 +62,15 @@ export class AsyncStream<T> {
}
};

return = (value: T) => {
this.#done = true;
this.onReturn?.();
return Promise.resolve({
done: true,
value,
});
};

[Symbol.asyncIterator]() {
return this;
}
Expand Down
2 changes: 1 addition & 1 deletion sdks/node-sdk/src/Conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ export class Conversation {
callback?.(err, decodedMessage);
});

asyncStream.stopCallback = stream.end.bind(stream);
asyncStream.onReturn = stream.end.bind(stream);

return asyncStream;
}
Expand Down
12 changes: 6 additions & 6 deletions sdks/node-sdk/src/Conversations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ export class Conversations {
callback?.(err, conversation);
});

asyncStream.stopCallback = stream.end.bind(stream);
asyncStream.onReturn = stream.end.bind(stream);

return asyncStream;
}
Expand All @@ -120,7 +120,7 @@ export class Conversations {
callback?.(err, conversation);
});

asyncStream.stopCallback = stream.end.bind(stream);
asyncStream.onReturn = stream.end.bind(stream);

return asyncStream;
}
Expand All @@ -134,7 +134,7 @@ export class Conversations {
callback?.(err, conversation);
});

asyncStream.stopCallback = stream.end.bind(stream);
asyncStream.onReturn = stream.end.bind(stream);

return asyncStream;
}
Expand All @@ -151,7 +151,7 @@ export class Conversations {
callback?.(err, decodedMessage);
});

asyncStream.stopCallback = stream.end.bind(stream);
asyncStream.onReturn = stream.end.bind(stream);

return asyncStream;
}
Expand All @@ -170,7 +170,7 @@ export class Conversations {
},
);

asyncStream.stopCallback = stream.end.bind(stream);
asyncStream.onReturn = stream.end.bind(stream);

return asyncStream;
}
Expand All @@ -187,7 +187,7 @@ export class Conversations {
callback?.(err, decodedMessage);
});

asyncStream.stopCallback = stream.end.bind(stream);
asyncStream.onReturn = stream.end.bind(stream);

return asyncStream;
}
Expand Down
76 changes: 76 additions & 0 deletions sdks/node-sdk/test/AsyncStream.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import { describe, expect, it } from "vitest";
import { AsyncStream } from "@/AsyncStream";

const testError = new Error("test");

describe("AsyncStream", () => {
it("should return values from callbacks", async () => {
const stream = new AsyncStream<number>();
let onReturnCalled = false;
stream.onReturn = () => {
onReturnCalled = true;
};
stream.callback(null, 1);
stream.callback(null, 2);
stream.callback(null, 3);

let count = 0;

for await (const value of stream) {
if (count === 0) {
expect(value).toBe(1);
}
if (count === 1) {
expect(value).toBe(2);
}
if (count === 2) {
expect(value).toBe(3);
break;
}
count++;
}
expect(onReturnCalled).toBe(true);
});

it("should catch an error thrown in the for..await loop", async () => {
const stream = new AsyncStream<number>();
let onReturnCalled = false;
stream.onReturn = () => {
onReturnCalled = true;
};
stream.callback(null, 1);

try {
for await (const value of stream) {
expect(value).toBe(1);
throw testError;
}
} catch (error) {
expect(error).toBe(testError);
expect((error as Error).message).toBe("test");
}
expect(onReturnCalled).toBe(true);
});

it("should catch an error passed to callback", async () => {
const runTest = async () => {
const stream = new AsyncStream<number>();
let onReturnCalled = false;
stream.onReturn = () => {
onReturnCalled = true;
};
stream.callback(testError, 1);
try {
for await (const _value of stream) {
// this block should never be reached
}
} catch (error) {
expect(error).toBe(testError);
expect((error as Error).message).toBe("test");
}
expect(onReturnCalled).toBe(true);
};

await expect(runTest()).rejects.toThrow(testError);
});
});
1 change: 0 additions & 1 deletion sdks/node-sdk/test/Conversation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ describe("Conversation", () => {
break;
}
}
stream.stop();
});

it("should add and remove admins", async () => {
Expand Down
18 changes: 4 additions & 14 deletions sdks/node-sdk/test/Conversations.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ import {
NapiGroupPermissionsOptions,
} from "@xmtp/node-bindings";
import { describe, expect, it } from "vitest";
import { AsyncStream } from "@/AsyncStream";
import type { Conversation } from "@/Conversation";
import { createRegisteredClient, createUser } from "@test/helpers";

describe("Conversations", () => {
Expand Down Expand Up @@ -287,7 +285,6 @@ describe("Conversations", () => {
break;
}
}
stream.stop();
expect(
client3.conversations.getConversationById(conversation1.id)?.id,
).toBe(conversation1.id);
Expand All @@ -305,8 +302,7 @@ describe("Conversations", () => {
const client2 = await createRegisteredClient(user2);
const client3 = await createRegisteredClient(user3);
const client4 = await createRegisteredClient(user4);
const asyncStream = new AsyncStream<Conversation>();
const stream = client3.conversations.streamGroups(asyncStream.callback);
const stream = client3.conversations.streamGroups();
await client4.conversations.newDm(user3.account.address);
const group1 = await client1.conversations.newConversation([
user3.account.address,
Expand All @@ -315,7 +311,7 @@ describe("Conversations", () => {
user3.account.address,
]);
let count = 0;
for await (const convo of asyncStream) {
for await (const convo of stream) {
count++;
expect(convo).toBeDefined();
if (count === 1) {
Expand All @@ -326,7 +322,6 @@ describe("Conversations", () => {
break;
}
}
stream.stop();
});

it("should only stream dm conversations", async () => {
Expand All @@ -338,13 +333,12 @@ describe("Conversations", () => {
const client2 = await createRegisteredClient(user2);
const client3 = await createRegisteredClient(user3);
const client4 = await createRegisteredClient(user4);
const asyncStream = new AsyncStream<Conversation>();
const stream = client3.conversations.streamDms(asyncStream.callback);
const stream = client3.conversations.streamDms();
await client1.conversations.newConversation([user3.account.address]);
await client2.conversations.newConversation([user3.account.address]);
const group3 = await client4.conversations.newDm(user3.account.address);
let count = 0;
for await (const convo of asyncStream) {
for await (const convo of stream) {
count++;
expect(convo).toBeDefined();
if (count === 1) {
Expand All @@ -353,7 +347,6 @@ describe("Conversations", () => {
}
}
expect(count).toBe(1);
stream.stop();
});

it("should stream all messages", async () => {
Expand Down Expand Up @@ -390,7 +383,6 @@ describe("Conversations", () => {
break;
}
}
stream.stop();
});

it("should only stream group conversation messages", async () => {
Expand Down Expand Up @@ -437,7 +429,6 @@ describe("Conversations", () => {
break;
}
}
stream.stop();
});

it("should only stream dm messages", async () => {
Expand Down Expand Up @@ -481,6 +472,5 @@ describe("Conversations", () => {
break;
}
}
stream.stop();
});
});
Loading