Skip to content

Commit

Permalink
omniassistant overhaul (#65)
Browse files Browse the repository at this point in the history
Co-authored-by: Ben Cherry <[email protected]>
Co-authored-by: Ben Cherry <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent 9cb2313 commit eb7e731
Show file tree
Hide file tree
Showing 12 changed files with 1,602 additions and 832 deletions.
7 changes: 7 additions & 0 deletions .changeset/wild-rabbits-teach.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@livekit/agents": minor
"@livekit/agents-plugin-openai": minor
"livekit-agents-examples": minor
---

omniassistant overhaul
1 change: 1 addition & 0 deletions agents/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ export * from './log.js';
export * from './generator.js';
export * from './tokenize.js';
export * from './audio.js';
export * from './transcription.js';

export { cli, stt, tts, llm };
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import { log } from '@livekit/agents';
import type { AudioFrame, Room } from '@livekit/rtc-node';
import { log } from './log.js';

export interface TranscriptionForwarder {
start(): void;
Expand Down
41 changes: 22 additions & 19 deletions examples/src/minimal_assistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,41 @@
//
// SPDX-License-Identifier: Apache-2.0
import { type JobContext, WorkerOptions, cli, defineAgent } from '@livekit/agents';
import { OmniAssistant, defaultSessionConfig } from '@livekit/agents-plugin-openai';
import { OmniAssistant, RealtimeModel } from '@livekit/agents-plugin-openai';
import { fileURLToPath } from 'node:url';
import { z } from 'zod';

// import { z } from 'zod';

export default defineAgent({
entry: async (ctx: JobContext) => {
await ctx.connect();

console.log('starting assistant example agent');

const model = new RealtimeModel({
instructions: 'You are a helpful assistant.',
});
// functions: {
// weather: {
// description: 'Get the weather in a location',
// parameters: z.object({
// location: z.string().describe('The location to get the weather for'),
// }),
// execute: async ({ location }) =>
// await fetch(`https://wttr.in/${location}?format=%C+%t`)
// .then((data) => data.text())
// .then((data) => `The weather in ${location} right now is ${data}.`),
// },
// },
// });

const assistant = new OmniAssistant({
sessionConfig: {
...defaultSessionConfig,
instructions: 'You are a helpful assistant.',
},
functions: {
weather: {
description: 'Get the weather in a location',
parameters: z.object({
location: z.string().describe('The location to get the weather for'),
}),
execute: async ({ location }) =>
await fetch(`https://wttr.in/${location}?format=%C+%t`)
.then((data) => data.text())
.then((data) => `The weather in ${location} right now is ${data}.`),
},
},
model,
});

await assistant.start(ctx.room);

assistant.addUserMessage('Hello! Can you share a very short story?');
// assistant.addUserMessage('Hello! Can you share a very short story?');
},
});

Expand Down
201 changes: 201 additions & 0 deletions plugins/openai/src/agent_playout.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import { AudioByteStream } from '@livekit/agents';
import type { TranscriptionForwarder } from '@livekit/agents';
import type { Queue } from '@livekit/agents';
import type { AudioFrame } from '@livekit/rtc-node';
import { type AudioSource } from '@livekit/rtc-node';
import { EventEmitter } from 'events';
import { NUM_CHANNELS, OUTPUT_PCM_FRAME_SIZE, SAMPLE_RATE } from './realtime/api_proto.js';

export class AgentPlayout {
#audioSource: AudioSource;
#playoutPromise: Promise<void> | null;

constructor(audioSource: AudioSource) {
this.#audioSource = audioSource;
this.#playoutPromise = null;
}

play(
itemId: string,
contentIndex: number,
transcriptionFwd: TranscriptionForwarder,
textStream: Queue<string | null>,
audioStream: Queue<AudioFrame | null>,
): PlayoutHandle {
const handle = new PlayoutHandle(this.#audioSource, itemId, contentIndex, transcriptionFwd);
this.#playoutPromise = this.playoutTask(this.#playoutPromise, handle, textStream, audioStream);
return handle;
}

private async playoutTask(
oldPromise: Promise<void> | null,
handle: PlayoutHandle,
textStream: Queue<string | null>,
audioStream: Queue<AudioFrame | null>,
): Promise<void> {
if (oldPromise) {
// TODO: cancel old task
// oldPromise.cancel();
}

let firstFrame = true;

const playTextStream = async () => {
while (true) {
const text = await textStream.get();
if (text === null) break;
handle.transcriptionFwd.pushText(text);
}
handle.transcriptionFwd.markTextComplete();
};

const captureTask = async () => {
const samplesPerChannel = OUTPUT_PCM_FRAME_SIZE;
const bstream = new AudioByteStream(SAMPLE_RATE, NUM_CHANNELS, samplesPerChannel);

while (true) {
const frame = await audioStream.get();
if (frame === null) break;

if (firstFrame) {
handle.transcriptionFwd.start();
firstFrame = false;
}

handle.transcriptionFwd.pushAudio(frame);

for (const f of bstream.write(frame.data.buffer)) {
handle.pushedDuration += f.samplesPerChannel / f.sampleRate;
await this.#audioSource.captureFrame(f);
}
}

for (const f of bstream.flush()) {
handle.pushedDuration += f.samplesPerChannel / f.sampleRate;
await this.#audioSource.captureFrame(f);
}

handle.transcriptionFwd.markAudioComplete();

await this.#audioSource.waitForPlayout();
};

// eslint-disable-next-line @typescript-eslint/no-unused-vars
const readTextTaskPromise = playTextStream();
const captureTaskPromise = captureTask();

try {
await Promise.race([captureTaskPromise, handle.intPromise]);
} finally {
// TODO: cancel tasks
// if (!captureTaskPromise.isCancelled) {
// captureTaskPromise.cancel();
// }

handle.totalPlayedTime = handle.pushedDuration - this.#audioSource.queuedDuration;

// TODO: handle errors
// if (handle.interrupted || captureTaskPromise.error) {
// this.#audioSource.clearQueue(); // make sure to remove any queued frames
// }

// TODO: cancel tasks
// if (!readTextTask.isCancelled) {
// readTextTask.cancel();
// }

if (!firstFrame && !handle.interrupted) {
handle.transcriptionFwd.markTextComplete();
}

handle.emit('done');
await handle.transcriptionFwd.close(handle.interrupted);
}
}
}

export class PlayoutHandle extends EventEmitter {
#audioSource: AudioSource;
#itemId: string;
#contentIndex: number;
/** @internal */
transcriptionFwd: TranscriptionForwarder;
#donePromiseResolved: boolean;
/** @internal */
donePromise: Promise<void>;
#intPromiseResolved: boolean;
/** @internal */
intPromise: Promise<void>;
#interrupted: boolean;
/** @internal */
pushedDuration: number;
/** @internal */
totalPlayedTime: number | undefined; // Set when playout is done

constructor(
audioSource: AudioSource,
itemId: string,
contentIndex: number,
transcriptionFwd: TranscriptionForwarder,
) {
super();
this.#audioSource = audioSource;
this.#itemId = itemId;
this.#contentIndex = contentIndex;
this.transcriptionFwd = transcriptionFwd;
this.#donePromiseResolved = false;
this.donePromise = new Promise((resolve) => {
this.once('done', () => {
this.#donePromiseResolved = true;
resolve();
});
});
this.#intPromiseResolved = false;
this.intPromise = new Promise((resolve) => {
this.once('interrupt', () => {
this.#intPromiseResolved = true;
resolve();
});
});
this.#interrupted = false;
this.pushedDuration = 0;
this.totalPlayedTime = undefined;
}

get itemId(): string {
return this.#itemId;
}

get audioSamples(): number {
if (this.totalPlayedTime !== undefined) {
return Math.floor(this.totalPlayedTime * SAMPLE_RATE);
}

return Math.floor(this.pushedDuration - this.#audioSource.queuedDuration * SAMPLE_RATE);
}

get textChars(): number {
return this.transcriptionFwd.currentCharacterIndex; // TODO: length of played text
}

get contentIndex(): number {
return this.#contentIndex;
}

get interrupted(): boolean {
return this.#interrupted;
}

get done(): boolean {
return this.#donePromiseResolved || this.#interrupted;
}

interrupt() {
if (this.#donePromiseResolved) return;
this.#interrupted = true;
this.emit('interrupt');
}
}
4 changes: 3 additions & 1 deletion plugins/openai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
//
// SPDX-License-Identifier: Apache-2.0

export * from './omni_assistant/index.js';
export * from './omni_assistant.js';
export * from './agent_playout.js';
export * from './realtime/index.js';
Loading

0 comments on commit eb7e731

Please sign in to comment.