Skip to content

Commit

Permalink
Multimodal Agent with complete API (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
bcherry authored Sep 28, 2024
1 parent eb7e731 commit d703265
Show file tree
Hide file tree
Showing 21 changed files with 1,183 additions and 927 deletions.
7 changes: 7 additions & 0 deletions .changeset/gold-actors-rest.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@livekit/agents": minor
"@livekit/agents-plugin-openai": minor
"livekit-agents-examples": patch
---

Rename to MultimodalAgent, move to main package
6 changes: 4 additions & 2 deletions agents/package.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
{
"name": "@livekit/agents",
"version": "0.2.0",
"description": "LiveKit Node Agents",
"description": "LiveKit Agents - Node.js",
"main": "dist/index.js",
"types": "dist/index.d.ts",
"author": "aoife cassidy <[email protected]>",
"author": "LiveKit",
"type": "module",
"scripts": {
"build": "tsc",
"clean": "rm -rf dist",
"clean:build": "pnpm clean && pnpm build",
"lint": "eslint -f unix \"src/**/*.ts\"",
"api:check": "api-extractor run --typescript-compiler-folder ../node_modules/typescript",
"api:update": "api-extractor run --local --typescript-compiler-folder ../node_modules/typescript --verbose"
Expand Down
38 changes: 19 additions & 19 deletions agents/src/audio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,36 @@ import { log } from './log.js';

/** AudioByteStream translates between LiveKit AudioFrame packets and raw byte data. */
export class AudioByteStream {
private sampleRate: number;
private numChannels: number;
private bytesPerFrame: number;
private buf: Int8Array;
#sampleRate: number;
#numChannels: number;
#bytesPerFrame: number;
#buf: Int8Array;

constructor(sampleRate: number, numChannels: number, samplesPerChannel: number | null = null) {
this.sampleRate = sampleRate;
this.numChannels = numChannels;
this.#sampleRate = sampleRate;
this.#numChannels = numChannels;

if (samplesPerChannel === null) {
samplesPerChannel = Math.floor(sampleRate / 50); // 20ms by default
}

this.bytesPerFrame = numChannels * samplesPerChannel * 2; // 2 bytes per sample (Int16)
this.buf = new Int8Array();
this.#bytesPerFrame = numChannels * samplesPerChannel * 2; // 2 bytes per sample (Int16)
this.#buf = new Int8Array();
}

write(data: ArrayBuffer): AudioFrame[] {
this.buf = new Int8Array([...this.buf, ...new Int8Array(data)]);
this.#buf = new Int8Array([...this.#buf, ...new Int8Array(data)]);

const frames: AudioFrame[] = [];
while (this.buf.length >= this.bytesPerFrame) {
const frameData = this.buf.slice(0, this.bytesPerFrame);
this.buf = this.buf.slice(this.bytesPerFrame);
while (this.#buf.length >= this.#bytesPerFrame) {
const frameData = this.#buf.slice(0, this.#bytesPerFrame);
this.#buf = this.#buf.slice(this.#bytesPerFrame);

frames.push(
new AudioFrame(
new Int16Array(frameData.buffer),
this.sampleRate,
this.numChannels,
this.#sampleRate,
this.#numChannels,
frameData.length / 2,
),
);
Expand All @@ -45,17 +45,17 @@ export class AudioByteStream {
}

flush(): AudioFrame[] {
if (this.buf.length % (2 * this.numChannels) !== 0) {
if (this.#buf.length % (2 * this.#numChannels) !== 0) {
log().warn('AudioByteStream: incomplete frame during flush, dropping');
return [];
}

return [
new AudioFrame(
new Int16Array(this.buf.buffer),
this.sampleRate,
this.numChannels,
this.buf.length / 2,
new Int16Array(this.#buf.buffer),
this.#sampleRate,
this.#numChannels,
this.#buf.length / 2,
),
];
}
Expand Down
3 changes: 2 additions & 1 deletion agents/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
*/
import * as cli from './cli.js';
import * as llm from './llm/index.js';
import * as multimodal from './multimodal/index.js';
import * as stt from './stt/index.js';
import * as tts from './tts/index.js';

Expand All @@ -26,4 +27,4 @@ export * from './tokenize.js';
export * from './audio.js';
export * from './transcription.js';

export { cli, stt, tts, llm };
export { cli, stt, tts, llm, multimodal };
8 changes: 6 additions & 2 deletions agents/src/job.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,18 @@ export class JobContext {
throw new Error('room is not connected');
}

console.log(this.#room.remoteParticipants.values());

for (const p of this.#room.remoteParticipants.values()) {
if (p.identity === identity && p.info.kind != ParticipantKind.AGENT) {
if ((!identity || p.identity === identity) && p.info.kind != ParticipantKind.AGENT) {
return p;
}
}

return new Promise((resolve) => {
this.#room.once(RoomEvent.ParticipantConnected, resolve);
this.#room.once(RoomEvent.ParticipantConnected, () => {
resolve(this.#room.remoteParticipants.values().next().value);
});
});
}

Expand Down
254 changes: 254 additions & 0 deletions agents/src/multimodal/agent_playout.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
import type { AudioFrame } from '@livekit/rtc-node';
import { type AudioSource } from '@livekit/rtc-node';
import { EventEmitter } from 'events';
import { AudioByteStream } from '../audio.js';
import type { TranscriptionForwarder } from '../transcription.js';
import { type AsyncIterableQueue, CancellablePromise, Future, gracefullyCancel } from '../utils.js';

export const proto = {};

export class PlayoutHandle extends EventEmitter {
#audioSource: AudioSource;
#sampleRate: number;
#itemId: string;
#contentIndex: number;
/** @internal */
transcriptionFwd: TranscriptionForwarder;
/** @internal */
doneFut: Future;
/** @internal */
intFut: Future;
/** @internal */
#interrupted: boolean;
/** @internal */
pushedDuration: number;
/** @internal */
totalPlayedTime: number | undefined; // Set when playout is done

constructor(
audioSource: AudioSource,
sampleRate: number,
itemId: string,
contentIndex: number,
transcriptionFwd: TranscriptionForwarder,
) {
super();
this.#audioSource = audioSource;
this.#sampleRate = sampleRate;
this.#itemId = itemId;
this.#contentIndex = contentIndex;
this.transcriptionFwd = transcriptionFwd;
this.doneFut = new Future();
this.intFut = new Future();
this.#interrupted = false;
this.pushedDuration = 0;
this.totalPlayedTime = undefined;
}

get itemId(): string {
return this.#itemId;
}

get audioSamples(): number {
if (this.totalPlayedTime !== undefined) {
return Math.floor(this.totalPlayedTime * this.#sampleRate);
}

return Math.floor(this.pushedDuration - this.#audioSource.queuedDuration * this.#sampleRate);
}

get textChars(): number {
return this.transcriptionFwd.currentCharacterIndex;
}

get contentIndex(): number {
return this.#contentIndex;
}

get interrupted(): boolean {
return this.#interrupted;
}

get done(): boolean {
return this.doneFut.done || this.#interrupted;
}

interrupt() {
if (this.doneFut.done) return;
this.intFut.resolve();
this.#interrupted = true;
}
}

export class AgentPlayout {
#audioSource: AudioSource;
#playoutTask: CancellablePromise<void> | null;
#sampleRate: number;
#numChannels: number;
#inFrameSize: number;
#outFrameSize: number;
constructor(
audioSource: AudioSource,
sampleRate: number,
numChannels: number,
inFrameSize: number,
outFrameSize: number,
) {
this.#audioSource = audioSource;
this.#playoutTask = null;
this.#sampleRate = sampleRate;
this.#numChannels = numChannels;
this.#inFrameSize = inFrameSize;
this.#outFrameSize = outFrameSize;
}

play(
itemId: string,
contentIndex: number,
transcriptionFwd: TranscriptionForwarder,
textStream: AsyncIterableQueue<string>,
audioStream: AsyncIterableQueue<AudioFrame>,
): PlayoutHandle {
const handle = new PlayoutHandle(
this.#audioSource,
this.#sampleRate,
itemId,
contentIndex,
transcriptionFwd,
);
this.#playoutTask = this.#makePlayoutTask(this.#playoutTask, handle, textStream, audioStream);
return handle;
}

#makePlayoutTask(
oldTask: CancellablePromise<void> | null,
handle: PlayoutHandle,
textStream: AsyncIterableQueue<string>,
audioStream: AsyncIterableQueue<AudioFrame>,
): CancellablePromise<void> {
return new CancellablePromise<void>((resolve, reject, onCancel) => {
let cancelled = false;
onCancel(() => {
cancelled = true;
});

(async () => {
try {
if (oldTask) {
await gracefullyCancel(oldTask);
}

let firstFrame = true;

const readText = () =>
new CancellablePromise<void>((resolveText, rejectText, onCancelText) => {
let cancelledText = false;
onCancelText(() => {
cancelledText = true;
});

(async () => {
try {
for await (const text of textStream) {
if (cancelledText || cancelled) {
break;
}
handle.transcriptionFwd.pushText(text);
}
resolveText();
} catch (error) {
rejectText(error);
}
})();
});

const capture = () =>
new CancellablePromise<void>((resolveCapture, rejectCapture, onCancelCapture) => {
let cancelledCapture = false;
onCancelCapture(() => {
cancelledCapture = true;
});

(async () => {
try {
const samplesPerChannel = this.#outFrameSize;
const bstream = new AudioByteStream(
this.#sampleRate,
this.#numChannels,
samplesPerChannel,
);

for await (const frame of audioStream) {
if (cancelledCapture || cancelled) {
break;
}
if (firstFrame) {
handle.transcriptionFwd.start();
firstFrame = false;
}

handle.transcriptionFwd.pushAudio(frame);

for (const f of bstream.write(frame.data.buffer)) {
handle.pushedDuration += f.samplesPerChannel / f.sampleRate;
await this.#audioSource.captureFrame(f);
}
}

if (!cancelledCapture && !cancelled) {
for (const f of bstream.flush()) {
handle.pushedDuration += f.samplesPerChannel / f.sampleRate;
await this.#audioSource.captureFrame(f);
}

handle.transcriptionFwd.markAudioComplete();

await this.#audioSource.waitForPlayout();
}

resolveCapture();
} catch (error) {
rejectCapture(error);
}
})();
});

const readTextTask = readText();
const captureTask = capture();

try {
await Promise.race([captureTask, handle.intFut.await]);
} finally {
if (!captureTask.isCancelled) {
await gracefullyCancel(captureTask);
}

handle.totalPlayedTime = handle.pushedDuration - this.#audioSource.queuedDuration;

if (handle.interrupted || captureTask.error) {
this.#audioSource.clearQueue(); // make sure to remove any queued frames
}

if (!readTextTask.isCancelled) {
await gracefullyCancel(readTextTask);
}

if (!firstFrame && !handle.interrupted) {
handle.transcriptionFwd.markTextComplete();
}

handle.doneFut.resolve();
await handle.transcriptionFwd.close(handle.interrupted);
}

resolve();
} catch (error) {
reject(error);
}
})();
});
}
}
5 changes: 5 additions & 0 deletions agents/src/multimodal/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// SPDX-FileCopyrightText: 2024 LiveKit, Inc.
//
// SPDX-License-Identifier: Apache-2.0
export * from './multimodal_agent.js';
export * from './agent_playout.js';
Loading

0 comments on commit d703265

Please sign in to comment.