diff --git a/.changeset/wild-rabbits-teach.md b/.changeset/wild-rabbits-teach.md new file mode 100644 index 00000000..e021c861 --- /dev/null +++ b/.changeset/wild-rabbits-teach.md @@ -0,0 +1,7 @@ +--- +"@livekit/agents": minor +"@livekit/agents-plugin-openai": minor +"livekit-agents-examples": minor +--- + +omniassistant overhaul diff --git a/agents/src/index.ts b/agents/src/index.ts index 1572c6cf..07e3525b 100644 --- a/agents/src/index.ts +++ b/agents/src/index.ts @@ -24,5 +24,6 @@ export * from './log.js'; export * from './generator.js'; export * from './tokenize.js'; export * from './audio.js'; +export * from './transcription.js'; export { cli, stt, tts, llm }; diff --git a/plugins/openai/src/omni_assistant/transcription_forwarder.ts b/agents/src/transcription.ts similarity index 98% rename from plugins/openai/src/omni_assistant/transcription_forwarder.ts rename to agents/src/transcription.ts index a7440730..c51ace44 100644 --- a/plugins/openai/src/omni_assistant/transcription_forwarder.ts +++ b/agents/src/transcription.ts @@ -1,8 +1,8 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { log } from '@livekit/agents'; import type { AudioFrame, Room } from '@livekit/rtc-node'; +import { log } from './log.js'; export interface TranscriptionForwarder { start(): void; diff --git a/examples/src/minimal_assistant.ts b/examples/src/minimal_assistant.ts index bcb44f67..69864089 100644 --- a/examples/src/minimal_assistant.ts +++ b/examples/src/minimal_assistant.ts @@ -2,9 +2,10 @@ // // SPDX-License-Identifier: Apache-2.0 import { type JobContext, WorkerOptions, cli, defineAgent } from '@livekit/agents'; -import { OmniAssistant, defaultSessionConfig } from '@livekit/agents-plugin-openai'; +import { OmniAssistant, RealtimeModel } from '@livekit/agents-plugin-openai'; import { fileURLToPath } from 'node:url'; -import { z } from 'zod'; + +// import { z } from 'zod'; export default defineAgent({ entry: async (ctx: JobContext) => { @@ -12,28 +13,30 @@ export default defineAgent({ console.log('starting assistant example agent'); + const model = new RealtimeModel({ + instructions: 'You are a helpful assistant.', + }); + // functions: { + // weather: { + // description: 'Get the weather in a location', + // parameters: z.object({ + // location: z.string().describe('The location to get the weather for'), + // }), + // execute: async ({ location }) => + // await fetch(`https://wttr.in/${location}?format=%C+%t`) + // .then((data) => data.text()) + // .then((data) => `The weather in ${location} right now is ${data}.`), + // }, + // }, + // }); + const assistant = new OmniAssistant({ - sessionConfig: { - ...defaultSessionConfig, - instructions: 'You are a helpful assistant.', - }, - functions: { - weather: { - description: 'Get the weather in a location', - parameters: z.object({ - location: z.string().describe('The location to get the weather for'), - }), - execute: async ({ location }) => - await fetch(`https://wttr.in/${location}?format=%C+%t`) - .then((data) => data.text()) - .then((data) => `The weather in ${location} right now is ${data}.`), - }, - }, + model, }); await assistant.start(ctx.room); - assistant.addUserMessage('Hello! Can you share a very short story?'); + // assistant.addUserMessage('Hello! Can you share a very short story?'); }, }); diff --git a/plugins/openai/src/agent_playout.ts b/plugins/openai/src/agent_playout.ts new file mode 100644 index 00000000..881f143f --- /dev/null +++ b/plugins/openai/src/agent_playout.ts @@ -0,0 +1,201 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { AudioByteStream } from '@livekit/agents'; +import type { TranscriptionForwarder } from '@livekit/agents'; +import type { Queue } from '@livekit/agents'; +import type { AudioFrame } from '@livekit/rtc-node'; +import { type AudioSource } from '@livekit/rtc-node'; +import { EventEmitter } from 'events'; +import { NUM_CHANNELS, OUTPUT_PCM_FRAME_SIZE, SAMPLE_RATE } from './realtime/api_proto.js'; + +export class AgentPlayout { + #audioSource: AudioSource; + #playoutPromise: Promise | null; + + constructor(audioSource: AudioSource) { + this.#audioSource = audioSource; + this.#playoutPromise = null; + } + + play( + itemId: string, + contentIndex: number, + transcriptionFwd: TranscriptionForwarder, + textStream: Queue, + audioStream: Queue, + ): PlayoutHandle { + const handle = new PlayoutHandle(this.#audioSource, itemId, contentIndex, transcriptionFwd); + this.#playoutPromise = this.playoutTask(this.#playoutPromise, handle, textStream, audioStream); + return handle; + } + + private async playoutTask( + oldPromise: Promise | null, + handle: PlayoutHandle, + textStream: Queue, + audioStream: Queue, + ): Promise { + if (oldPromise) { + // TODO: cancel old task + // oldPromise.cancel(); + } + + let firstFrame = true; + + const playTextStream = async () => { + while (true) { + const text = await textStream.get(); + if (text === null) break; + handle.transcriptionFwd.pushText(text); + } + handle.transcriptionFwd.markTextComplete(); + }; + + const captureTask = async () => { + const samplesPerChannel = OUTPUT_PCM_FRAME_SIZE; + const bstream = new AudioByteStream(SAMPLE_RATE, NUM_CHANNELS, samplesPerChannel); + + while (true) { + const frame = await audioStream.get(); + if (frame === null) break; + + if (firstFrame) { + handle.transcriptionFwd.start(); + firstFrame = false; + } + + handle.transcriptionFwd.pushAudio(frame); + + for (const f of bstream.write(frame.data.buffer)) { + handle.pushedDuration += f.samplesPerChannel / f.sampleRate; + await this.#audioSource.captureFrame(f); + } + } + + for (const f of bstream.flush()) { + handle.pushedDuration += f.samplesPerChannel / f.sampleRate; + await this.#audioSource.captureFrame(f); + } + + handle.transcriptionFwd.markAudioComplete(); + + await this.#audioSource.waitForPlayout(); + }; + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const readTextTaskPromise = playTextStream(); + const captureTaskPromise = captureTask(); + + try { + await Promise.race([captureTaskPromise, handle.intPromise]); + } finally { + // TODO: cancel tasks + // if (!captureTaskPromise.isCancelled) { + // captureTaskPromise.cancel(); + // } + + handle.totalPlayedTime = handle.pushedDuration - this.#audioSource.queuedDuration; + + // TODO: handle errors + // if (handle.interrupted || captureTaskPromise.error) { + // this.#audioSource.clearQueue(); // make sure to remove any queued frames + // } + + // TODO: cancel tasks + // if (!readTextTask.isCancelled) { + // readTextTask.cancel(); + // } + + if (!firstFrame && !handle.interrupted) { + handle.transcriptionFwd.markTextComplete(); + } + + handle.emit('done'); + await handle.transcriptionFwd.close(handle.interrupted); + } + } +} + +export class PlayoutHandle extends EventEmitter { + #audioSource: AudioSource; + #itemId: string; + #contentIndex: number; + /** @internal */ + transcriptionFwd: TranscriptionForwarder; + #donePromiseResolved: boolean; + /** @internal */ + donePromise: Promise; + #intPromiseResolved: boolean; + /** @internal */ + intPromise: Promise; + #interrupted: boolean; + /** @internal */ + pushedDuration: number; + /** @internal */ + totalPlayedTime: number | undefined; // Set when playout is done + + constructor( + audioSource: AudioSource, + itemId: string, + contentIndex: number, + transcriptionFwd: TranscriptionForwarder, + ) { + super(); + this.#audioSource = audioSource; + this.#itemId = itemId; + this.#contentIndex = contentIndex; + this.transcriptionFwd = transcriptionFwd; + this.#donePromiseResolved = false; + this.donePromise = new Promise((resolve) => { + this.once('done', () => { + this.#donePromiseResolved = true; + resolve(); + }); + }); + this.#intPromiseResolved = false; + this.intPromise = new Promise((resolve) => { + this.once('interrupt', () => { + this.#intPromiseResolved = true; + resolve(); + }); + }); + this.#interrupted = false; + this.pushedDuration = 0; + this.totalPlayedTime = undefined; + } + + get itemId(): string { + return this.#itemId; + } + + get audioSamples(): number { + if (this.totalPlayedTime !== undefined) { + return Math.floor(this.totalPlayedTime * SAMPLE_RATE); + } + + return Math.floor(this.pushedDuration - this.#audioSource.queuedDuration * SAMPLE_RATE); + } + + get textChars(): number { + return this.transcriptionFwd.currentCharacterIndex; // TODO: length of played text + } + + get contentIndex(): number { + return this.#contentIndex; + } + + get interrupted(): boolean { + return this.#interrupted; + } + + get done(): boolean { + return this.#donePromiseResolved || this.#interrupted; + } + + interrupt() { + if (this.#donePromiseResolved) return; + this.#interrupted = true; + this.emit('interrupt'); + } +} diff --git a/plugins/openai/src/index.ts b/plugins/openai/src/index.ts index 8de40314..407f5553 100644 --- a/plugins/openai/src/index.ts +++ b/plugins/openai/src/index.ts @@ -2,4 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 -export * from './omni_assistant/index.js'; +export * from './omni_assistant.js'; +export * from './agent_playout.js'; +export * from './realtime/index.js'; diff --git a/plugins/openai/src/omni_assistant.ts b/plugins/openai/src/omni_assistant.ts new file mode 100644 index 00000000..6e875ec3 --- /dev/null +++ b/plugins/openai/src/omni_assistant.ts @@ -0,0 +1,354 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { AudioByteStream } from '@livekit/agents'; +import { findMicroTrackId } from '@livekit/agents'; +import { type llm, log } from '@livekit/agents'; +import { BasicTranscriptionForwarder } from '@livekit/agents'; +import type { + LocalTrackPublication, + RemoteAudioTrack, + RemoteParticipant, + Room, +} from '@livekit/rtc-node'; +import { + AudioSource, + AudioStream, + LocalAudioTrack, + RoomEvent, + TrackPublishOptions, + TrackSource, +} from '@livekit/rtc-node'; +import { AgentPlayout, type PlayoutHandle } from './agent_playout.js'; +import * as api_proto from './realtime/api_proto.js'; +import type { + InputSpeechCommitted, + InputTranscriptionCompleted, + RealtimeContent, + RealtimeModel, + RealtimeSession, +} from './realtime/realtime_model.js'; +import { EventTypes } from './realtime/realtime_model.js'; + +type ImplOptions = { + // functions: llm.FunctionContext; +}; + +/** @alpha */ +export class OmniAssistant { + model: RealtimeModel; + options: ImplOptions; + room: Room | null = null; + linkedParticipant: RemoteParticipant | null = null; + subscribedTrack: RemoteAudioTrack | null = null; + readMicroTask: { promise: Promise; cancel: () => void } | null = null; + + constructor({ + model, + // functions = {}, + }: { + model: RealtimeModel; + functions?: llm.FunctionContext; + }) { + this.model = model; + + this.options = { + // functions, + }; + } + + private started: boolean = false; + private participant: RemoteParticipant | string | null = null; + private agentPublication: LocalTrackPublication | null = null; + private localTrackSid: string | null = null; + private localSource: AudioSource | null = null; + private agentPlayout: AgentPlayout | null = null; + private playingHandle: PlayoutHandle | undefined = undefined; + private logger = log(); + private session: RealtimeSession | null = null; + + // get funcCtx(): llm.FunctionContext { + // return this.options.functions; + // } + // set funcCtx(ctx: llm.FunctionContext) { + // this.options.functions = ctx; + // this.options.sessionConfig.tools = tools(ctx); + // this.sendClientCommand({ + // type: proto.ClientEventType.SessionUpdate, + // session: this.options.sessionConfig, + // }); + // } + + start(room: Room, participant: RemoteParticipant | string | null = null): Promise { + return new Promise(async (resolve, reject) => { + if (this.started) { + this.logger.warn('OmniAssistant already started'); + resolve(); // TODO: throw error? + return; + } + + room.on(RoomEvent.ParticipantConnected, (participant: RemoteParticipant) => { + if (!this.linkedParticipant) { + return; + } + + this.linkParticipant(participant.identity); + }); + this.room = room; + this.participant = participant; + + this.localSource = new AudioSource(api_proto.SAMPLE_RATE, api_proto.NUM_CHANNELS); + this.agentPlayout = new AgentPlayout(this.localSource); + const track = LocalAudioTrack.createAudioTrack('assistant_voice', this.localSource); + const options = new TrackPublishOptions(); + options.source = TrackSource.SOURCE_MICROPHONE; + this.agentPublication = (await room.localParticipant?.publishTrack(track, options)) || null; + if (!this.agentPublication) { + this.logger.error('Failed to publish track'); + reject(new Error('Failed to publish track')); + return; + } + + await this.agentPublication.waitForSubscription(); + + if (participant) { + if (typeof participant === 'string') { + this.linkParticipant(participant); + } else { + this.linkParticipant(participant.identity); + } + } else { + // No participant specified, try to find the first participant in the room + for (const participant of room.remoteParticipants.values()) { + this.linkParticipant(participant.identity); + break; + } + } + + this.session = this.model.session({}); + + this.session.on(EventTypes.ResponseContentAdded, (message: RealtimeContent) => { + const trFwd = new BasicTranscriptionForwarder( + this.room!, + this.room!.localParticipant!.identity, + this.getLocalTrackSid()!, + message.responseId, + ); + + this.playingHandle = this.agentPlayout?.play( + message.itemId, + message.contentIndex, + trFwd, + message.textStream, + message.audioStream, + ); + }); + + this.session.on(EventTypes.InputSpeechCommitted, (ev: InputSpeechCommitted) => { + const participantIdentity = this.linkedParticipant?.identity; + const trackSid = this.subscribedTrack?.sid; + if (participantIdentity && trackSid) { + this.publishTranscription(participantIdentity, trackSid, '', true, ev.itemId); + } else { + this.logger.error('Participant or track not set'); + } + }); + + this.session.on( + EventTypes.InputSpeechTranscriptionCompleted, + (ev: InputTranscriptionCompleted) => { + const transcription = ev.transcript; + const participantIdentity = this.linkedParticipant?.identity; + const trackSid = this.subscribedTrack?.sid; + if (participantIdentity && trackSid) { + this.publishTranscription( + participantIdentity, + trackSid, + transcription, + true, + ev.itemId, + ); + } else { + this.logger.error('Participant or track not set'); + } + }, + ); + + this.session.on(EventTypes.InputSpeechStarted, () => { + if (this.playingHandle && !this.playingHandle.done) { + this.playingHandle.interrupt(); + + this.session!.defaultConversation.item.truncate( + this.playingHandle.itemId, + this.playingHandle.contentIndex, + Math.floor((this.playingHandle.audioSamples / 24000) * 1000), + ); + + this.playingHandle = undefined; + } + }); + }); + } + + // close() { + // if (!this.connected || !this.ws) return; + // this.logger.debug('stopping assistant'); + // this.ws.close(); + // } + + // addUserMessage(text: string, generate: boolean = true): void { + // this.sendClientCommand({ + // type: proto.ClientEventType.ConversationItemCreate, + // item: { + // type: 'message', + // role: 'user', + // content: [ + // { + // type: 'text', + // text: text, + // }, + // ], + // }, + // }); + // if (generate) { + // this.sendClientCommand({ + // type: proto.ClientEventType.ResponseCreate, + // response: {}, + // }); + // } + // } + + // private setState(state: proto.State) { + // // don't override thinking until done + // if (this.thinking) return; + // if (this.room?.isConnected && this.room.localParticipant) { + // const currentState = this.room.localParticipant.attributes['lk.agent.state']; + // if (currentState !== state) { + // this.room.localParticipant!.setAttributes({ + // 'lk.agent.state': state, + // }); + // this.logger.debug(`lk.agent.state updated from ${currentState} to ${state}`); + // } + // } + // } + + private linkParticipant(participantIdentity: string): void { + if (!this.room) { + this.logger.error('Room is not set'); + return; + } + + this.linkedParticipant = this.room.remoteParticipants.get(participantIdentity) || null; + if (!this.linkedParticipant) { + this.logger.error(`Participant with identity ${participantIdentity} not found`); + return; + } + + if (this.linkedParticipant.trackPublications.size > 0) { + this.subscribeToMicrophone(); + } else { + this.room.on(RoomEvent.TrackPublished, () => { + this.subscribeToMicrophone(); + }); + } + } + + private subscribeToMicrophone(): void { + const readAudioStreamTask = async (audioStream: AudioStream) => { + const bstream = new AudioByteStream( + api_proto.SAMPLE_RATE, + api_proto.NUM_CHANNELS, + api_proto.INPUT_PCM_FRAME_SIZE, + ); + + for await (const frame of audioStream) { + const audioData = frame.data; + for (const frame of bstream.write(audioData.buffer)) { + this.model.sessions[0].queueMsg({ + type: 'input_audio_buffer.append', + audio: Buffer.from(frame.data.buffer).toString('base64'), + }); + } + } + }; + + if (!this.linkedParticipant) { + this.logger.error('Participant is not set'); + return; + } + + for (const publication of this.linkedParticipant.trackPublications.values()) { + if (publication.source !== TrackSource.SOURCE_MICROPHONE) { + continue; + } + + if (!publication.subscribed) { + publication.setSubscribed(true); + } + + const track = publication.track; + + if (track && track !== this.subscribedTrack) { + this.subscribedTrack = track!; + + if (this.readMicroTask) { + this.readMicroTask.cancel(); + } + + let cancel: () => void; + this.readMicroTask = { + promise: new Promise((resolve, reject) => { + cancel = () => { + // Cleanup logic here + reject(new Error('Task cancelled')); + }; + readAudioStreamTask( + new AudioStream(track, api_proto.SAMPLE_RATE, api_proto.NUM_CHANNELS), + ) + .then(resolve) + .catch(reject); + }), + cancel: () => cancel(), + }; + } + } + } + + private getLocalTrackSid(): string | null { + if (!this.localTrackSid && this.room && this.room.localParticipant) { + this.localTrackSid = findMicroTrackId(this.room, this.room.localParticipant?.identity); + } + return this.localTrackSid; + } + + private publishTranscription( + participantIdentity: string, + trackSid: string, + text: string, + isFinal: boolean, + id: string, + ): void { + this.logger.info( + `Publishing transcription ${participantIdentity} ${trackSid} ${text} ${isFinal} ${id}`, + ); + if (!this.room?.localParticipant) { + this.logger.error('Room or local participant not set'); + return; + } + + this.room.localParticipant.publishTranscription({ + participantIdentity, + trackSid, + segments: [ + { + text, + final: isFinal, + id, + startTime: BigInt(0), + endTime: BigInt(0), + language: '', + }, + ], + }); + } +} diff --git a/plugins/openai/src/omni_assistant/agent_playout.ts b/plugins/openai/src/omni_assistant/agent_playout.ts deleted file mode 100644 index 13261383..00000000 --- a/plugins/openai/src/omni_assistant/agent_playout.ts +++ /dev/null @@ -1,127 +0,0 @@ -// SPDX-FileCopyrightText: 2024 LiveKit, Inc. -// -// SPDX-License-Identifier: Apache-2.0 -import { AudioByteStream } from '@livekit/agents'; -import { Queue } from '@livekit/agents'; -import { AudioFrame, type AudioSource } from '@livekit/rtc-node'; -import { EventEmitter } from 'events'; -import * as proto from './proto.js'; -import type { TranscriptionForwarder } from './transcription_forwarder'; - -export class AgentPlayout { - #audioSource: AudioSource; - #currentPlayoutHandle: PlayoutHandle | null; - #currentPlayoutTask: Promise | null; - - constructor(audioSource: AudioSource) { - this.#audioSource = audioSource; - this.#currentPlayoutHandle = null; - this.#currentPlayoutTask = null; - } - - play(messageId: string, transcriptionFwd: TranscriptionForwarder): PlayoutHandle { - if (this.#currentPlayoutHandle) { - this.#currentPlayoutHandle.interrupt(); - } - this.#currentPlayoutHandle = new PlayoutHandle(messageId, transcriptionFwd); - this.#currentPlayoutTask = this.playoutTask( - this.#currentPlayoutTask, - this.#currentPlayoutHandle, - ); - return this.#currentPlayoutHandle; - } - - private async playoutTask(oldTask: Promise | null, handle: PlayoutHandle): Promise { - let firstFrame = true; - try { - const bstream = new AudioByteStream( - proto.SAMPLE_RATE, - proto.NUM_CHANNELS, - proto.OUTPUT_PCM_FRAME_SIZE, - ); - - while (!handle.interrupted) { - const frame = await handle.playoutQueue.get(); - if (frame === null) break; - if (firstFrame) { - handle.transcriptionFwd.start(); - firstFrame = false; - } - - for (const f of bstream.write(frame.data.buffer)) { - handle.playedAudioSamples += f.samplesPerChannel; - if (handle.interrupted) break; - - await this.#audioSource.captureFrame(f); - } - } - - if (!handle.interrupted) { - for (const f of bstream.flush()) { - await this.#audioSource.captureFrame(f); - } - } - } finally { - if (!firstFrame && !handle.interrupted) { - handle.transcriptionFwd.markTextComplete(); - } - await handle.transcriptionFwd.close(handle.interrupted); - handle.complete(); - } - } -} - -export class PlayoutHandle extends EventEmitter { - messageId: string; - transcriptionFwd: TranscriptionForwarder; - playedAudioSamples: number; - done: boolean; - interrupted: boolean; - playoutQueue: Queue; - - constructor(messageId: string, transcriptionFwd: TranscriptionForwarder) { - super(); - this.messageId = messageId; - this.transcriptionFwd = transcriptionFwd; - this.playedAudioSamples = 0; - this.done = false; - this.interrupted = false; - this.playoutQueue = new Queue(); - } - - pushAudio(data: Uint8Array) { - const frame = new AudioFrame( - new Int16Array(data.buffer), - proto.SAMPLE_RATE, - proto.NUM_CHANNELS, - data.length / 2, - ); - this.transcriptionFwd.pushAudio(frame); - this.playoutQueue.put(frame); - } - - pushText(text: string) { - this.transcriptionFwd.pushText(text); - } - - endInput() { - this.transcriptionFwd.markAudioComplete(); - this.transcriptionFwd.markTextComplete(); - this.playoutQueue.put(null); - } - - interrupt() { - if (this.done) return; - this.interrupted = true; - } - - publishedTextChars(): number { - return this.transcriptionFwd.currentCharacterIndex; - } - - complete() { - if (this.done) return; - this.done = true; - this.emit('complete', this.interrupted); - } -} diff --git a/plugins/openai/src/omni_assistant/index.ts b/plugins/openai/src/omni_assistant/index.ts deleted file mode 100644 index 4528fec5..00000000 --- a/plugins/openai/src/omni_assistant/index.ts +++ /dev/null @@ -1,542 +0,0 @@ -// SPDX-FileCopyrightText: 2024 LiveKit, Inc. -// -// SPDX-License-Identifier: Apache-2.0 -import { AudioByteStream } from '@livekit/agents'; -import { findMicroTrackId } from '@livekit/agents'; -import { llm, log } from '@livekit/agents'; -import type { - LocalTrackPublication, - RemoteAudioTrack, - RemoteParticipant, - Room, -} from '@livekit/rtc-node'; -import { - AudioSource, - AudioStream, - LocalAudioTrack, - RoomEvent, - TrackPublishOptions, - TrackSource, -} from '@livekit/rtc-node'; -import { WebSocket } from 'ws'; -import { AgentPlayout, type PlayoutHandle } from './agent_playout.js'; -import * as proto from './proto.js'; -import { BasicTranscriptionForwarder } from './transcription_forwarder.js'; - -/** @hidden */ -export const defaultSessionConfig: Partial = { - turn_detection: { - type: 'server_vad', - threshold: 0.5, - prefix_padding_ms: 300, - silence_duration_ms: 200, - }, - input_audio_format: proto.AudioFormat.PCM16, - input_audio_transcription: { - model: 'whisper-1', - }, - modalities: ['text', 'audio'], - instructions: 'You are a helpful assistant.', - voice: proto.Voice.ALLOY, - output_audio_format: proto.AudioFormat.PCM16, - tools: [], - tool_choice: proto.ToolChoice.AUTO, - temperature: 0.8, - // max_output_tokens: 2048, -}; - -type ImplOptions = { - apiKey: string; - sessionConfig: Partial; - functions: llm.FunctionContext; -}; - -/** @alpha */ -export class OmniAssistant { - options: ImplOptions; - room: Room | null = null; - linkedParticipant: RemoteParticipant | null = null; - subscribedTrack: RemoteAudioTrack | null = null; - readMicroTask: { promise: Promise; cancel: () => void } | null = null; - - constructor({ - sessionConfig = defaultSessionConfig, - functions = {}, - apiKey = process.env.OPENAI_API_KEY || '', - }: { - sessionConfig?: Partial; - functions?: llm.FunctionContext; - apiKey?: string; - }) { - if (!apiKey) { - throw new Error('OpenAI API key is required, whether as an argument or as $OPENAI_API_KEY'); - } - - sessionConfig.tools = tools(functions); - this.options = { - apiKey, - sessionConfig, - functions, - }; - } - - private ws: WebSocket | null = null; - private connected: boolean = false; - private thinking: boolean = false; - private participant: RemoteParticipant | string | null = null; - private agentPublication: LocalTrackPublication | null = null; - private localTrackSid: string | null = null; - private localSource: AudioSource | null = null; - private agentPlayout: AgentPlayout | null = null; - private playingHandle: PlayoutHandle | null = null; - private logger = log(); - - get funcCtx(): llm.FunctionContext { - return this.options.functions; - } - set funcCtx(ctx: llm.FunctionContext) { - this.options.functions = ctx; - this.options.sessionConfig.tools = tools(ctx); - this.sendClientCommand({ - type: proto.ClientEventType.SessionUpdate, - session: this.options.sessionConfig, - }); - } - - start(room: Room, participant: RemoteParticipant | string | null = null): Promise { - return new Promise(async (resolve, reject) => { - if (this.ws !== null) { - this.logger.warn('VoiceAssistant already started'); - resolve(); - return; - } - - room.on(RoomEvent.ParticipantConnected, (participant: RemoteParticipant) => { - if (!this.linkedParticipant) { - return; - } - - this.linkParticipant(participant.identity); - }); - this.room = room; - this.participant = participant; - this.setState(proto.State.INITIALIZING); - - this.localSource = new AudioSource(proto.SAMPLE_RATE, proto.NUM_CHANNELS); - this.agentPlayout = new AgentPlayout(this.localSource); - const track = LocalAudioTrack.createAudioTrack('assistant_voice', this.localSource); - const options = new TrackPublishOptions(); - options.source = TrackSource.SOURCE_MICROPHONE; - this.agentPublication = (await room.localParticipant?.publishTrack(track, options)) || null; - if (!this.agentPublication) { - this.logger.error('Failed to publish track'); - reject(new Error('Failed to publish track')); - return; - } - - await this.agentPublication.waitForSubscription(); - - if (participant) { - if (typeof participant === 'string') { - this.linkParticipant(participant); - } else { - this.linkParticipant(participant.identity); - } - } else { - // No participant specified, try to find the first participant in the room - for (const participant of room.remoteParticipants.values()) { - this.linkParticipant(participant.identity); - break; - } - } - - this.ws = new WebSocket(proto.API_URL, { - headers: { - Authorization: `Bearer ${this.options.apiKey}`, - 'OpenAI-Beta': 'realtime=v1', - }, - }); - - this.ws.onopen = () => { - this.connected = true; - }; - - this.ws.onerror = (error) => { - reject(error); - }; - - this.ws.onclose = () => { - this.connected = false; - this.ws = null; - }; - - this.ws.onmessage = (message) => { - const event = JSON.parse(message.data as string); - this.handleServerEvent(event); - - if (event.type === 'session.created') { - this.sendClientCommand({ - type: proto.ClientEventType.SessionUpdate, - session: this.options.sessionConfig, - }); - resolve(); - } - }; - }); - } - - close() { - if (!this.connected || !this.ws) return; - this.logger.debug('stopping assistant'); - this.ws.close(); - } - - addUserMessage(text: string, generate: boolean = true): void { - this.sendClientCommand({ - type: proto.ClientEventType.ConversationItemCreate, - item: { - type: 'message', - role: 'user', - content: [ - { - type: 'text', - text: text, - }, - ], - }, - }); - if (generate) { - this.sendClientCommand({ - type: proto.ClientEventType.ResponseCreate, - response: {}, - }); - } - } - - private setState(state: proto.State) { - // don't override thinking until done - if (this.thinking) return; - if (this.room?.isConnected && this.room.localParticipant) { - const currentState = this.room.localParticipant.attributes['lk.agent.state']; - if (currentState !== state) { - this.room.localParticipant!.setAttributes({ - 'lk.agent.state': state, - }); - this.logger.debug(`lk.agent.state updated from ${currentState} to ${state}`); - } - } - } - - /// Truncates the data field of the event to the specified maxLength to avoid overwhelming logs - /// with large amounts of base64 audio data. - private loggableEvent( - event: proto.ClientEvent | proto.ServerEvent, - maxLength: number = 30, - ): Record { - const untypedEvent: Record = {}; - for (const [key, value] of Object.entries(event)) { - if (value !== undefined) { - untypedEvent[key] = value; - } - } - - if (untypedEvent.audio && typeof untypedEvent.audio === 'string') { - const truncatedData = - untypedEvent.audio.slice(0, maxLength) + (untypedEvent.audio.length > maxLength ? '…' : ''); - return { ...untypedEvent, audio: truncatedData }; - } - if ( - untypedEvent.delta && - typeof untypedEvent.delta === 'string' && - event.type === 'response.audio.delta' - ) { - const truncatedDelta = - untypedEvent.delta.slice(0, maxLength) + (untypedEvent.delta.length > maxLength ? '…' : ''); - return { ...untypedEvent, delta: truncatedDelta }; - } - return untypedEvent; - } - - private sendClientCommand(command: proto.ClientEvent): void { - const isAudio = command.type === 'input_audio_buffer.append'; - - if (!this.connected || !this.ws) { - if (!isAudio) this.logger.error('WebSocket is not connected'); - return; - } - - if (!isAudio) { - this.logger.debug(`-> ${JSON.stringify(this.loggableEvent(command))}`); - } - this.ws.send(JSON.stringify(command)); - } - - private handleServerEvent(event: proto.ServerEvent): void { - this.logger.debug(`<- ${JSON.stringify(this.loggableEvent(event))}`); - - switch (event.type) { - case 'session.created': - this.setState(proto.State.LISTENING); - break; - case 'conversation.item.created': - break; - case 'response.audio_transcript.delta': - case 'response.audio.delta': - this.handleAddContent(event); - break; - case 'conversation.item.created': - this.handleMessageAdded(event); - break; - case 'input_audio_buffer.speech_started': - this.handleVadSpeechStarted(event); - break; - // case 'input_audio_transcription.stopped': - // break; - case 'conversation.item.input_audio_transcription.completed': - this.handleInputTranscribed(event); - break; - // case 'response.canceled': - // this.handleGenerationCanceled(); - // break; - case 'response.done': - this.handleGenerationFinished(event); - break; - default: - this.logger.warn(`Unknown server event: ${JSON.stringify(event)}`); - } - } - - private handleAddContent( - event: proto.ResponseAudioDeltaEvent | proto.ResponseAudioTranscriptDeltaEvent, - ): void { - const trackSid = this.getLocalTrackSid(); - if (!this.room || !this.room.localParticipant || !trackSid || !this.agentPlayout) { - log().error('Room or local participant not set'); - return; - } - - if (!this.playingHandle || this.playingHandle.done) { - const trFwd = new BasicTranscriptionForwarder( - this.room, - this.room?.localParticipant?.identity, - trackSid, - event.response_id, - ); - - this.setState(proto.State.SPEAKING); - this.playingHandle = this.agentPlayout.play(event.response_id, trFwd); - this.playingHandle.on('complete', () => { - this.setState(proto.State.LISTENING); - }); - } - if (event.type === 'response.audio.delta') { - this.playingHandle?.pushAudio(Buffer.from(event.delta, 'base64')); - } else if (event.type === 'response.audio_transcript.delta') { - this.playingHandle?.pushText(event.delta); - } - } - - private handleMessageAdded(event: proto.ConversationItemCreatedEvent): void { - if (event.item.type === 'function_call') { - const toolCall = event.item; - this.options.functions[toolCall.name].execute(toolCall.arguments).then((content) => { - this.thinking = false; - this.sendClientCommand({ - type: proto.ClientEventType.ConversationItemCreate, - item: { - type: 'function_call_output', - call_id: toolCall.call_id, - output: content, - }, - }); - this.sendClientCommand({ - type: proto.ClientEventType.ResponseCreate, - response: {}, - }); - }); - } - } - - private handleInputTranscribed( - event: proto.ConversationItemInputAudioTranscriptionCompletedEvent, - ): void { - const messageId = event.item_id; - const transcription = event.transcript; - if (!messageId || transcription === undefined) { - this.logger.error('Message ID or transcription not set'); - return; - } - const participantIdentity = this.linkedParticipant?.identity; - const trackSid = this.subscribedTrack?.sid; - if (participantIdentity && trackSid) { - this.publishTranscription(participantIdentity, trackSid, transcription, true, messageId); - } else { - this.logger.error('Participant or track not set'); - } - } - - private handleGenerationFinished(event: proto.ResponseDoneEvent): void { - if ( - event.response.status === 'cancelled' && - event.response.status_details?.type === 'cancelled' && - event.response.status_details?.reason === 'turn_detected' - ) { - if (this.playingHandle && !this.playingHandle.done) { - this.playingHandle.interrupt(); - this.sendClientCommand({ - type: proto.ClientEventType.ConversationItemTruncate, - item_id: this.playingHandle.messageId, - content_index: 0, // ignored for now (see OAI docs) - audio_end_ms: (this.playingHandle.playedAudioSamples * 1000) / proto.SAMPLE_RATE, - }); - } - } else if (event.response.status !== 'completed') { - log().warn(`assistant turn finished unexpectedly reason ${event.response.status}`); - } - - if (this.playingHandle && !this.playingHandle.interrupted) { - this.playingHandle.endInput(); - } - } - - private handleVadSpeechStarted(event: proto.InputAudioBufferSpeechStartedEvent): void { - const messageId = event.item_id; - const participantIdentity = this.linkedParticipant?.identity; - const trackSid = this.subscribedTrack?.sid; - if (participantIdentity && trackSid && messageId) { - this.publishTranscription(participantIdentity, trackSid, '', false, messageId); - } else { - this.logger.error('Participant or track or itemId not set'); - } - } - - private linkParticipant(participantIdentity: string): void { - if (!this.room) { - this.logger.error('Room is not set'); - return; - } - - this.linkedParticipant = this.room.remoteParticipants.get(participantIdentity) || null; - if (!this.linkedParticipant) { - this.logger.error(`Participant with identity ${participantIdentity} not found`); - return; - } - - if (this.linkedParticipant.trackPublications.size > 0) { - this.subscribeToMicrophone(); - } else { - this.room.on(RoomEvent.TrackPublished, () => { - this.subscribeToMicrophone(); - }); - } - } - - private subscribeToMicrophone(): void { - const readAudioStreamTask = async (audioStream: AudioStream) => { - const bstream = new AudioByteStream( - proto.SAMPLE_RATE, - proto.NUM_CHANNELS, - proto.INPUT_PCM_FRAME_SIZE, - ); - - for await (const frame of audioStream) { - const audioData = frame.data; - for (const frame of bstream.write(audioData.buffer)) { - this.sendClientCommand({ - type: proto.ClientEventType.InputAudioBufferAppend, - audio: Buffer.from(frame.data.buffer).toString('base64'), - }); - } - } - }; - - if (!this.linkedParticipant) { - this.logger.error('Participant is not set'); - return; - } - - for (const publication of this.linkedParticipant.trackPublications.values()) { - if (publication.source !== TrackSource.SOURCE_MICROPHONE) { - continue; - } - - if (!publication.subscribed) { - publication.setSubscribed(true); - } - - const track = publication.track; - - if (track && track !== this.subscribedTrack) { - this.subscribedTrack = track!; - if (this.readMicroTask) { - this.readMicroTask.cancel(); - } - - let cancel: () => void; - this.readMicroTask = { - promise: new Promise((resolve, reject) => { - cancel = () => { - // Cleanup logic here - reject(new Error('Task cancelled')); - }; - readAudioStreamTask(new AudioStream(track, proto.SAMPLE_RATE, proto.NUM_CHANNELS)) - .then(resolve) - .catch(reject); - }), - cancel: () => cancel(), - }; - } - } - } - - private getLocalTrackSid(): string | null { - if (!this.localTrackSid && this.room && this.room.localParticipant) { - this.localTrackSid = findMicroTrackId(this.room, this.room.localParticipant?.identity); - } - return this.localTrackSid; - } - - private publishTranscription( - participantIdentity: string, - trackSid: string, - text: string, - isFinal: boolean, - id: string, - ): void { - // Log all parameters - log().info('Publishing transcription', { - participantIdentity, - trackSid, - text, - isFinal, - id, - }); - if (!this.room?.localParticipant) { - log().error('Room or local participant not set'); - return; - } - - this.room.localParticipant.publishTranscription({ - participantIdentity, - trackSid, - segments: [ - { - text, - final: isFinal, - id, - startTime: BigInt(0), - endTime: BigInt(0), - language: '', - }, - ], - }); - } -} - -const tools = (ctx: llm.FunctionContext): proto.Tool[] => - Object.entries(ctx).map(([name, func]) => ({ - name, - description: func.description, - parameters: llm.oaiParams(func.parameters), - type: 'function', - })); diff --git a/plugins/openai/src/omni_assistant/proto.ts b/plugins/openai/src/realtime/api_proto.ts similarity index 64% rename from plugins/openai/src/omni_assistant/proto.ts rename to plugins/openai/src/realtime/api_proto.ts index d81e171e..9a02578d 100644 --- a/plugins/openai/src/omni_assistant/proto.ts +++ b/plugins/openai/src/realtime/api_proto.ts @@ -2,23 +2,70 @@ // // SPDX-License-Identifier: Apache-2.0 -export const API_URL = 'wss://api.openai.com/v1/realtime'; export const SAMPLE_RATE = 24000; export const NUM_CHANNELS = 1; export const INPUT_PCM_FRAME_SIZE = 2400; // 100ms export const OUTPUT_PCM_FRAME_SIZE = 1200; // 50ms -export enum Voice { - ALLOY = 'alloy', - SHIMMER = 'shimmer', - ECHO = 'echo', -} +export const API_URL = 'wss://api.openai.com/v1/realtime'; -export enum AudioFormat { - PCM16 = 'pcm16', - // G711_ULAW = 'g711-ulaw', - // G711_ALAW = 'g711-alaw', -} +export type Model = 'gpt-4o-realtime-preview-2024-10-01' | string; +export type Voice = 'alloy' | 'shimmer' | 'echo' | string; +export type AudioFormat = 'pcm16'; // TODO: 'g711-ulaw' | 'g711-alaw' +export type Role = 'system' | 'assistant' | 'user' | 'tool'; +export type GenerationFinishedReason = 'stop' | 'max_tokens' | 'content_filter' | 'interrupt'; +export type InputTranscriptionModel = 'whisper-1'; +export type Modality = 'text' | 'audio'; +export type ToolChoice = 'auto' | 'none' | 'required' | string; +export type State = 'initializing' | 'listening' | 'thinking' | 'speaking' | string; +export type ResponseStatus = + | 'in_progress' + | 'completed' + | 'incomplete' + | 'cancelled' + | 'failed' + | string; +export type ClientEventType = + | 'session.update' + | 'input_audio_buffer.append' + | 'input_audio_buffer.commit' + | 'input_audio_buffer.clear' + | 'conversation.item.create' + | 'conversation.item.truncate' + | 'conversation.item.delete' + | 'response.create' + | 'response.cancel'; +export type ServerEventType = + | 'error' + | 'session.created' + | 'session.updated' + | 'conversation.created' + | 'input_audio_buffer.committed' + | 'input_audio_buffer.cleared' + | 'input_audio_buffer.speech_started' + | 'input_audio_buffer.speech_stopped' + | 'conversation.item.created' + | 'conversation.item.input_audio_transcription.completed' + | 'conversation.item.input_audio_transcription.failed' + | 'conversation.item.truncated' + | 'conversation.item.deleted' + | 'response.created' + | 'response.done' + | 'response.output_item.added' + | 'response.output_item.done' + | 'response.content_part.added' + | 'response.content_part.done' + | 'response.text.delta' + | 'response.text.done' + | 'response.audio_transcript.delta' + | 'response.audio_transcript.done' + | 'response.audio.delta' + | 'response.audio.done' + | 'response.function_call_arguments.delta' + | 'response.function_call_arguments.done' + | 'rate_limits.updated'; + +export type AudioBase64Bytes = string; export interface Tool { type: 'function'; @@ -28,6 +75,7 @@ export interface Tool { type: 'object'; properties: { [prop: string]: { + // eslint-disable-next-line @typescript-eslint/no-explicit-any [prop: string]: any; }; }; @@ -35,20 +83,14 @@ export interface Tool { }; } -export enum ToolChoice { - AUTO = 'auto', - NONE = 'none', - REQUIRED = 'required', -} - -export enum State { - INITIALIZING = 'initializing', - LISTENING = 'listening', - THINKING = 'thinking', - SPEAKING = 'speaking', -} - -export type AudioBase64Bytes = string; +export type TurnDetectionType = + | { + type: 'server_vad'; + threshold?: number; // 0.0 to 1.0, default: 0.5 + prefix_padding_ms?: number; // default: 300 + silence_duration_ms?: number; // default: 200 + } + | 'none'; // Content Part Types export interface InputTextContent { @@ -136,18 +178,11 @@ export interface SessionResource { input_audio_transcription?: { model: 'whisper-1'; }; // default: null - turn_detection: - | { - type: 'server_vad'; - threshold: number; // 0.0 to 1.0, default: 0.5 - prefix_padding_ms: number; // default: 300 - silence_duration_ms: number; // default: 200 - } - | 'none'; + turn_detection: TurnDetectionType; tools: Tool[]; tool_choice: ToolChoice; // default: "auto" temperature: number; // default: 0.8 - // max_output_tokens: number | null; // FIXME: currently rejected by OpenAI and fails the whole update + max_response_output_tokens: number | null; } // Conversation Resource @@ -156,30 +191,21 @@ export interface ConversationResource { object: 'realtime.conversation'; } -// Response Resource -export enum ResponseStatus { - IN_PROGRESS = 'in_progress', - COMPLETED = 'completed', - INCOMPLETE = 'incomplete', - CANCELLED = 'cancelled', - FAILED = 'failed', -} - export type ResponseStatusDetails = | { - type: ResponseStatus.INCOMPLETE; - reason: 'max_output_tokens' | 'content_filter'; + type: 'incomplete'; + reason: 'max_output_tokens' | 'content_filter' | string; } | { - type: ResponseStatus.FAILED; + type: 'failed'; error?: { code: 'server_error' | 'rate_limit_exceeded' | string; message: string; }; } | { - type: ResponseStatus.CANCELLED; - reason: 'turn_detected' | 'client_cancelled'; + type: 'cancelled'; + reason: 'turn_detected' | 'client_cancelled' | string; }; export interface ResponseResource { @@ -202,7 +228,7 @@ interface BaseClientEvent { } export interface SessionUpdateEvent extends BaseClientEvent { - type: ClientEventType.SessionUpdate; + type: 'session.update'; session: Partial<{ modalities: ['text', 'audio'] | ['text']; instructions: string; @@ -223,25 +249,25 @@ export interface SessionUpdateEvent extends BaseClientEvent { tools: Tool[]; tool_choice: ToolChoice; temperature: number; - max_output_tokens: number; + max_response_output_tokens: number; }>; } export interface InputAudioBufferAppendEvent extends BaseClientEvent { - type: ClientEventType.InputAudioBufferAppend; + type: 'input_audio_buffer.append'; audio: AudioBase64Bytes; } export interface InputAudioBufferCommitEvent extends BaseClientEvent { - type: ClientEventType.InputAudioBufferCommit; + type: 'input_audio_buffer.commit'; } export interface InputAudioBufferClearEvent extends BaseClientEvent { - type: ClientEventType.InputAudioBufferClear; + type: 'input_audio_buffer.clear'; } export interface ConversationItemCreateEvent extends BaseClientEvent { - type: ClientEventType.ConversationItemCreate; + type: 'conversation.item.create'; item: | { type: 'message'; @@ -266,20 +292,20 @@ export interface ConversationItemCreateEvent extends BaseClientEvent { } export interface ConversationItemTruncateEvent extends BaseClientEvent { - type: ClientEventType.ConversationItemTruncate; + type: 'conversation.item.truncate'; item_id: string; content_index: number; audio_end_ms: number; } export interface ConversationItemDeleteEvent extends BaseClientEvent { - type: ClientEventType.ConversationItemDelete; + type: 'conversation.item.delete'; item_id: string; } export interface ResponseCreateEvent extends BaseClientEvent { - type: ClientEventType.ResponseCreate; - response: Partial<{ + type: 'response.create'; + response?: Partial<{ modalities: ['text', 'audio'] | ['text']; instructions: string; voice: Voice; @@ -287,24 +313,12 @@ export interface ResponseCreateEvent extends BaseClientEvent { tools?: Tool[]; tool_choice: ToolChoice; temperature: number; - max_output_tokens: number; + max_response_output_tokens: number; }>; } export interface ResponseCancelEvent extends BaseClientEvent { - type: ClientEventType.ResponseCancel; -} - -export enum ClientEventType { - SessionUpdate = 'session.update', - InputAudioBufferAppend = 'input_audio_buffer.append', - InputAudioBufferCommit = 'input_audio_buffer.commit', - InputAudioBufferClear = 'input_audio_buffer.clear', - ConversationItemCreate = 'conversation.item.create', - ConversationItemTruncate = 'conversation.item.truncate', - ConversationItemDelete = 'conversation.item.delete', - ResponseCreate = 'response.create', - ResponseCancel = 'response.cancel', + type: 'response.cancel'; } export type ClientEvent = @@ -318,14 +332,13 @@ export type ClientEvent = | ResponseCreateEvent | ResponseCancelEvent; -// Server Events interface BaseServerEvent { event_id: string; type: ServerEventType; } export interface ErrorEvent extends BaseServerEvent { - type: ServerEventType.Error; + type: 'error'; error: { type: 'invalid_request_error' | 'server_error' | string; code?: string; @@ -336,55 +349,55 @@ export interface ErrorEvent extends BaseServerEvent { } export interface SessionCreatedEvent extends BaseServerEvent { - type: ServerEventType.SessionCreated; + type: 'session.created'; session: SessionResource; } export interface SessionUpdatedEvent extends BaseServerEvent { - type: ServerEventType.SessionUpdated; + type: 'session.updated'; session: SessionResource; } export interface ConversationCreatedEvent extends BaseServerEvent { - type: ServerEventType.ConversationCreated; + type: 'conversation.created'; conversation: ConversationResource; } export interface InputAudioBufferCommittedEvent extends BaseServerEvent { - type: ServerEventType.InputAudioBufferCommitted; + type: 'input_audio_buffer.committed'; item_id: string; } export interface InputAudioBufferClearedEvent extends BaseServerEvent { - type: ServerEventType.InputAudioBufferCleared; + type: 'input_audio_buffer.cleared'; } export interface InputAudioBufferSpeechStartedEvent extends BaseServerEvent { - type: ServerEventType.InputAudioBufferSpeechStarted; + type: 'input_audio_buffer.speech_started'; audio_start_ms: number; item_id: string; } export interface InputAudioBufferSpeechStoppedEvent extends BaseServerEvent { - type: ServerEventType.InputAudioBufferSpeechStopped; + type: 'input_audio_buffer.speech_stopped'; audio_end_ms: number; item_id: string; } export interface ConversationItemCreatedEvent extends BaseServerEvent { - type: ServerEventType.ConversationItemCreated; + type: 'conversation.item.created'; item: ItemResource; } export interface ConversationItemInputAudioTranscriptionCompletedEvent extends BaseServerEvent { - type: ServerEventType.ConversationItemInputAudioTranscriptionCompleted; + type: 'conversation.item.input_audio_transcription.completed'; item_id: string; content_index: number; transcript: string; } export interface ConversationItemInputAudioTranscriptionFailedEvent extends BaseServerEvent { - type: ServerEventType.ConversationItemInputAudioTranscriptionFailed; + type: 'conversation.item.input_audio_transcription.failed'; item_id: string; content_index: number; error: { @@ -396,51 +409,52 @@ export interface ConversationItemInputAudioTranscriptionFailedEvent extends Base } export interface ConversationItemTruncatedEvent extends BaseServerEvent { - type: ServerEventType.ConversationItemTruncated; + type: 'conversation.item.truncated'; item_id: string; content_index: number; audio_end_ms: number; } export interface ConversationItemDeletedEvent extends BaseServerEvent { - type: ServerEventType.ConversationItemDeleted; + type: 'conversation.item.deleted'; item_id: string; } export interface ResponseCreatedEvent extends BaseServerEvent { - type: ServerEventType.ResponseCreated; + type: 'response.created'; response: ResponseResource; } export interface ResponseDoneEvent extends BaseServerEvent { - type: ServerEventType.ResponseDone; + type: 'response.done'; response: ResponseResource; } -export interface ResponseOutputAddedEvent extends BaseServerEvent { - type: ServerEventType.ResponseOutputAdded; +export interface ResponseOutputItemAddedEvent extends BaseServerEvent { + type: 'response.output_item.added'; response_id: string; output_index: number; item: ItemResource; } -export interface ResponseOutputDoneEvent extends BaseServerEvent { - type: ServerEventType.ResponseOutputDone; +export interface ResponseOutputItemDoneEvent extends BaseServerEvent { + type: 'response.output_item.done'; response_id: string; output_index: number; item: ItemResource; } -export interface ResponseContentAddedEvent extends BaseServerEvent { - type: ServerEventType.ResponseContentAdded; +export interface ResponseContentPartAddedEvent extends BaseServerEvent { + type: 'response.content_part.added'; response_id: string; + item_id: string; output_index: number; content_index: number; part: ContentPart; } -export interface ResponseContentDoneEvent extends BaseServerEvent { - type: ServerEventType.ResponseContentDone; +export interface ResponseContentPartDoneEvent extends BaseServerEvent { + type: 'response.content_part.done'; response_id: string; output_index: number; content_index: number; @@ -448,7 +462,7 @@ export interface ResponseContentDoneEvent extends BaseServerEvent { } export interface ResponseTextDeltaEvent extends BaseServerEvent { - type: ServerEventType.ResponseTextDelta; + type: 'response.text.delta'; response_id: string; output_index: number; content_index: number; @@ -456,7 +470,7 @@ export interface ResponseTextDeltaEvent extends BaseServerEvent { } export interface ResponseTextDoneEvent extends BaseServerEvent { - type: ServerEventType.ResponseTextDone; + type: 'response.text.done'; response_id: string; output_index: number; content_index: number; @@ -464,7 +478,7 @@ export interface ResponseTextDoneEvent extends BaseServerEvent { } export interface ResponseAudioTranscriptDeltaEvent extends BaseServerEvent { - type: ServerEventType.ResponseAudioTranscriptDelta; + type: 'response.audio_transcript.delta'; response_id: string; output_index: number; content_index: number; @@ -472,7 +486,7 @@ export interface ResponseAudioTranscriptDeltaEvent extends BaseServerEvent { } export interface ResponseAudioTranscriptDoneEvent extends BaseServerEvent { - type: ServerEventType.ResponseAudioTranscriptDone; + type: 'response.audio_transcript.done'; response_id: string; output_index: number; content_index: number; @@ -480,7 +494,7 @@ export interface ResponseAudioTranscriptDoneEvent extends BaseServerEvent { } export interface ResponseAudioDeltaEvent extends BaseServerEvent { - type: ServerEventType.ResponseAudioDelta; + type: 'response.audio.delta'; response_id: string; output_index: number; content_index: number; @@ -488,68 +502,36 @@ export interface ResponseAudioDeltaEvent extends BaseServerEvent { } export interface ResponseAudioDoneEvent extends BaseServerEvent { - type: ServerEventType.ResponseAudioDone; + type: 'response.audio.done'; response_id: string; output_index: number; content_index: number; - // 'audio' field is excluded from rendering } export interface ResponseFunctionCallArgumentsDeltaEvent extends BaseServerEvent { - type: ServerEventType.ResponseFunctionCallArgumentsDelta; + type: 'response.function_call_arguments.delta'; response_id: string; output_index: number; delta: string; } export interface ResponseFunctionCallArgumentsDoneEvent extends BaseServerEvent { - type: ServerEventType.ResponseFunctionCallArgumentsDone; + type: 'response.function_call_arguments.done'; response_id: string; output_index: number; arguments: string; } export interface RateLimitsUpdatedEvent extends BaseServerEvent { - type: ServerEventType.RateLimitsUpdated; + type: 'rate_limits.updated'; rate_limits: { - name: 'requests' | 'tokens' | 'input_tokens' | 'output_tokens'; + name: 'requests' | 'tokens' | 'input_tokens' | 'output_tokens' | string; limit: number; remaining: number; reset_seconds: number; }[]; } -export enum ServerEventType { - Error = 'error', - SessionCreated = 'session.created', - SessionUpdated = 'session.updated', - ConversationCreated = 'conversation.created', - InputAudioBufferCommitted = 'input_audio_buffer.committed', - InputAudioBufferCleared = 'input_audio_buffer.cleared', - InputAudioBufferSpeechStarted = 'input_audio_buffer.speech_started', - InputAudioBufferSpeechStopped = 'input_audio_buffer.speech_stopped', - ConversationItemCreated = 'conversation.item.created', - ConversationItemInputAudioTranscriptionCompleted = 'conversation.item.input_audio_transcription.completed', - ConversationItemInputAudioTranscriptionFailed = 'conversation.item.input_audio_transcription.failed', - ConversationItemTruncated = 'conversation.item.truncated', - ConversationItemDeleted = 'conversation.item.deleted', - ResponseCreated = 'response.created', - ResponseDone = 'response.done', - ResponseOutputAdded = 'response.output.added', - ResponseOutputDone = 'response.output.done', - ResponseContentAdded = 'response.content.added', - ResponseContentDone = 'response.content.done', - ResponseTextDelta = 'response.text.delta', - ResponseTextDone = 'response.text.done', - ResponseAudioTranscriptDelta = 'response.audio_transcript.delta', - ResponseAudioTranscriptDone = 'response.audio_transcript.done', - ResponseAudioDelta = 'response.audio.delta', - ResponseAudioDone = 'response.audio.done', - ResponseFunctionCallArgumentsDelta = 'response.function_call_arguments.delta', - ResponseFunctionCallArgumentsDone = 'response.function_call_arguments.done', - RateLimitsUpdated = 'response.rate_limits.updated', -} - export type ServerEvent = | ErrorEvent | SessionCreatedEvent @@ -566,10 +548,10 @@ export type ServerEvent = | ConversationItemDeletedEvent | ResponseCreatedEvent | ResponseDoneEvent - | ResponseOutputAddedEvent - | ResponseOutputDoneEvent - | ResponseContentAddedEvent - | ResponseContentDoneEvent + | ResponseOutputItemAddedEvent + | ResponseOutputItemDoneEvent + | ResponseContentPartAddedEvent + | ResponseContentPartDoneEvent | ResponseTextDeltaEvent | ResponseTextDoneEvent | ResponseAudioTranscriptDeltaEvent diff --git a/plugins/openai/src/realtime/index.ts b/plugins/openai/src/realtime/index.ts new file mode 100644 index 00000000..29a91473 --- /dev/null +++ b/plugins/openai/src/realtime/index.ts @@ -0,0 +1,5 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +export * from './api_proto.js'; +export * from './realtime_model.js'; diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts new file mode 100644 index 00000000..7a5d8681 --- /dev/null +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -0,0 +1,884 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { Queue } from '@livekit/agents'; +import { llm, log } from '@livekit/agents'; +import { AudioFrame } from '@livekit/rtc-node'; +import { EventEmitter, once } from 'events'; +import { WebSocket } from 'ws'; +import * as api_proto from './api_proto.js'; + +export enum EventTypes { + Error = 'error', + InputSpeechCommitted = 'input_speech_committed', + InputSpeechStarted = 'input_speech_started', + InputSpeechStopped = 'input_speech_stopped', + InputSpeechTranscriptionCompleted = 'input_speech_transcription_completed', + InputSpeechTranscriptionFailed = 'input_speech_transcription_failed', + ResponseContentAdded = 'response_content_added', + ResponseContentDone = 'response_content_done', + ResponseCreated = 'response_created', + ResponseDone = 'response_done', + ResponseOutputAdded = 'response_output_added', + ResponseOutputDone = 'response_output_done', + StartSession = 'start_session', +} + +interface ModelOptions { + modalities: ['text', 'audio'] | ['text']; + instructions?: string; + voice: api_proto.Voice; + inputAudioFormat: api_proto.AudioFormat; + outputAudioFormat: api_proto.AudioFormat; + inputAudioTranscription?: { + model: 'whisper-1'; + }; + turnDetection: + | { + type: 'server_vad'; + threshold?: number; + prefix_padding_ms?: number; + silence_duration_ms?: number; + } + | 'none'; + temperature: number; + maxResponseOutputTokens: number; + model: api_proto.Model; + apiKey: string; + baseURL: string; +} + +export interface RealtimeResponse { + /** ID of the message */ + id: string; + /** Status of the response */ + status: api_proto.ResponseStatus; + /** List of outputs */ + output: RealtimeOutput[]; + /** Promise that will be executed when the response is completed */ + donePromise: () => Promise; +} + +export interface RealtimeOutput { + /** ID of the response */ + responseId: string; + /** ID of the item */ + itemId: string; + /** Index of the output */ + outputIndex: number; + /** Role of the message */ + role: api_proto.Role; + /** Type of the output */ + type: 'message' | 'function_call'; + /** List of content */ + content: RealtimeContent[]; + /** Promise that will be executed when the response is completed */ + donePromise: () => Promise; +} + +export interface RealtimeContent { + /** ID of the response */ + responseId: string; + /** ID of the item */ + itemId: string; + /** Index of the output */ + outputIndex: number; + /** Index of the content */ + contentIndex: number; + /** Accumulated text content */ + text: string; + /** Accumulated audio content */ + audio: AudioFrame[]; + /** Stream of text content */ + textStream: Queue; + /** Stream of audio content */ + audioStream: Queue; + /** Pending tool calls */ + toolCalls: RealtimeToolCall[]; +} + +export interface RealtimeToolCall { + /** Name of the function */ + name: string; + /** Accumulated arguments */ + arguments: string; + /** ID of the tool call */ + toolCallID: string; +} + +export interface InputTranscriptionCompleted { + itemId: string; + transcript: string; +} + +export interface InputTranscriptionFailed { + itemId: string; + message: string; +} + +export interface InputSpeechCommitted { + itemId: string; +} + +class InputAudioBuffer { + #session: RealtimeSession; + + constructor(session: RealtimeSession) { + this.#session = session; + } + + append(frame: AudioFrame) { + this.#session.queueMsg({ + type: 'input_audio_buffer.append', + audio: Buffer.from(frame.data.buffer).toString('base64'), + }); + } + + clear() { + this.#session.queueMsg({ + type: 'input_audio_buffer.clear', + }); + } + + commit() { + this.#session.queueMsg({ + type: 'input_audio_buffer.commit', + }); + } +} + +class ConversationItem { + #session: RealtimeSession; + + constructor(session: RealtimeSession) { + this.#session = session; + } + + // create(message: llm.ChatMessage, previousItemId?: string) { + // // TODO: Implement create method + // throw new Error('Not implemented'); + // } + + truncate(itemId: string, contentIndex: number, audioEnd: number) { + this.#session.queueMsg({ + type: 'conversation.item.truncate', + item_id: itemId, + content_index: contentIndex, + audio_end_ms: audioEnd, + }); + } + + delete(itemId: string) { + this.#session.queueMsg({ + type: 'conversation.item.delete', + item_id: itemId, + }); + } +} + +class Conversation { + #session: RealtimeSession; + + constructor(session: RealtimeSession) { + this.#session = session; + } + + get item(): ConversationItem { + return new ConversationItem(this.#session); + } +} + +class Response { + #session: RealtimeSession; + + constructor(session: RealtimeSession) { + this.#session = session; + } + + create() { + this.#session.queueMsg({ + type: 'response.create', + }); + } + + cancel() { + this.#session.queueMsg({ + type: 'response.cancel', + }); + } +} + +interface ContentPtr { + response_id: string; + output_index: number; + content_index: number; +} + +export class RealtimeModel { + #defaultOpts: ModelOptions; + #sessions: RealtimeSession[] = []; + + constructor({ + modalities = ['text', 'audio'], + instructions = undefined, + voice = 'alloy', + inputAudioFormat = 'pcm16', + outputAudioFormat = 'pcm16', + inputAudioTranscription = { model: 'whisper-1' }, + turnDetection = { type: 'server_vad' }, + temperature = 0.8, + maxResponseOutputTokens = 2048, + model = 'gpt-4o-realtime-preview-2024-10-01', + apiKey = process.env.OPENAI_API_KEY || '', + baseURL = api_proto.API_URL, + }: { + modalities?: ['text', 'audio'] | ['text']; + instructions?: string; + voice?: api_proto.Voice; + inputAudioFormat?: api_proto.AudioFormat; + outputAudioFormat?: api_proto.AudioFormat; + inputAudioTranscription?: { model: 'whisper-1' }; + turnDetection?: api_proto.TurnDetectionType; + temperature?: number; + maxResponseOutputTokens?: number; + model?: api_proto.Model; + apiKey?: string; + baseURL?: string; + }) { + if (apiKey === '') { + throw new Error( + 'OpenAI API key is required, either using the argument or by setting the OPENAI_API_KEY environmental variable', + ); + } + + this.#defaultOpts = { + modalities, + instructions, + voice, + inputAudioFormat, + outputAudioFormat, + inputAudioTranscription, + turnDetection, + temperature, + maxResponseOutputTokens, + model, + apiKey, + baseURL, + }; + } + + get sessions(): RealtimeSession[] { + return this.#sessions; + } + + session({ + funcCtx = {}, + modalities = this.#defaultOpts.modalities, + instructions = this.#defaultOpts.instructions, + voice = this.#defaultOpts.voice, + inputAudioFormat = this.#defaultOpts.inputAudioFormat, + outputAudioFormat = this.#defaultOpts.outputAudioFormat, + inputAudioTranscription = this.#defaultOpts.inputAudioTranscription, + turnDetection = this.#defaultOpts.turnDetection, + temperature = this.#defaultOpts.temperature, + maxResponseOutputTokens = this.#defaultOpts.maxResponseOutputTokens, + }: { + funcCtx?: llm.FunctionContext; + modalities?: ['text', 'audio'] | ['text']; + instructions?: string; + voice?: api_proto.Voice; + inputAudioFormat?: api_proto.AudioFormat; + outputAudioFormat?: api_proto.AudioFormat; + inputAudioTranscription?: { model: 'whisper-1' }; + turnDetection?: api_proto.TurnDetectionType; + temperature?: number; + maxResponseOutputTokens?: number; + }): RealtimeSession { + const opts: ModelOptions = { + modalities, + instructions, + voice, + inputAudioFormat, + outputAudioFormat, + inputAudioTranscription, + turnDetection, + temperature, + maxResponseOutputTokens, + model: this.#defaultOpts.model, + apiKey: this.#defaultOpts.apiKey, + baseURL: this.#defaultOpts.baseURL, + }; + + const newSession = new RealtimeSession(funcCtx, opts); + this.#sessions.push(newSession); + return newSession; + } + + async close(): Promise { + // TODO: Implement close method + throw new Error('Not implemented'); + } +} + +export class RealtimeSession extends EventEmitter { + #funcCtx: llm.FunctionContext; + #opts: ModelOptions; + #pendingResponses: { [id: string]: RealtimeResponse } = {}; + #sessionId = 'not-connected'; + #ws: WebSocket | null = null; + #logger = log(); + #task: Promise; + #closing = true; + #sendQueue = new Queue(); + + constructor(funcCtx: llm.FunctionContext, opts: ModelOptions) { + super(); + + this.#funcCtx = funcCtx; + this.#opts = opts; + + this.#task = this.#start(); + + this.sessionUpdate({ + modalities: this.#opts.modalities, + instructions: this.#opts.instructions, + voice: this.#opts.voice, + inputAudioFormat: this.#opts.inputAudioFormat, + outputAudioFormat: this.#opts.outputAudioFormat, + inputAudioTranscription: this.#opts.inputAudioTranscription, + turnDetection: this.#opts.turnDetection, + temperature: this.#opts.temperature, + maxResponseOutputTokens: this.#opts.maxResponseOutputTokens, + toolChoice: 'auto', + }); + } + + get funcCtx(): llm.FunctionContext { + return this.#funcCtx; + } + + set funcCtx(ctx: llm.FunctionContext) { + this.#funcCtx = ctx; + } + + get defaultConversation(): Conversation { + return new Conversation(this); + } + + get inputAudioBuffer(): InputAudioBuffer { + return new InputAudioBuffer(this); + } + + get response(): Response { + return new Response(this); + } + + queueMsg(command: api_proto.ClientEvent): void { + this.#sendQueue.put(command); + } + + /// Truncates the data field of the event to the specified maxLength to avoid overwhelming logs + /// with large amounts of base64 audio data. + #loggableEvent( + event: api_proto.ClientEvent | api_proto.ServerEvent, + maxLength: number = 30, + ): Record { + const untypedEvent: Record = {}; + for (const [key, value] of Object.entries(event)) { + if (value !== undefined) { + untypedEvent[key] = value; + } + } + + if (untypedEvent.audio && typeof untypedEvent.audio === 'string') { + const truncatedData = + untypedEvent.audio.slice(0, maxLength) + (untypedEvent.audio.length > maxLength ? '…' : ''); + return { ...untypedEvent, audio: truncatedData }; + } + if ( + untypedEvent.delta && + typeof untypedEvent.delta === 'string' && + event.type === 'response.audio.delta' + ) { + const truncatedDelta = + untypedEvent.delta.slice(0, maxLength) + (untypedEvent.delta.length > maxLength ? '…' : ''); + return { ...untypedEvent, delta: truncatedDelta }; + } + return untypedEvent; + } + + sessionUpdate({ + modalities = this.#opts.modalities, + instructions = this.#opts.instructions, + voice = this.#opts.voice, + inputAudioFormat = this.#opts.inputAudioFormat, + outputAudioFormat = this.#opts.outputAudioFormat, + inputAudioTranscription = this.#opts.inputAudioTranscription, + turnDetection = this.#opts.turnDetection, + temperature = this.#opts.temperature, + maxResponseOutputTokens = this.#opts.maxResponseOutputTokens, + toolChoice = 'auto', + }: { + modalities: ['text', 'audio'] | ['text']; + instructions?: string; + voice: api_proto.Voice; + inputAudioFormat: api_proto.AudioFormat; + outputAudioFormat: api_proto.AudioFormat; + inputAudioTranscription?: { model: 'whisper-1' }; + turnDetection: api_proto.TurnDetectionType; + temperature: number; + maxResponseOutputTokens: number; + toolChoice: api_proto.ToolChoice; + }) { + this.#opts = { + modalities, + instructions, + voice, + inputAudioFormat, + outputAudioFormat, + inputAudioTranscription, + turnDetection, + temperature, + maxResponseOutputTokens, + model: this.#opts.model, + apiKey: this.#opts.apiKey, + baseURL: this.#opts.baseURL, + }; + + const tools = Object.entries(this.#funcCtx).map(([name, func]) => ({ + type: 'function' as const, + name, + description: func.description, + parameters: llm.oaiParams(func.parameters), + })); + + this.queueMsg({ + type: 'session.update', + session: { + modalities: this.#opts.modalities, + instructions: this.#opts.instructions, + voice: this.#opts.voice, + input_audio_format: this.#opts.inputAudioFormat, + output_audio_format: this.#opts.outputAudioFormat, + input_audio_transcription: this.#opts.inputAudioTranscription, + turn_detection: this.#opts.turnDetection, + temperature: this.#opts.temperature, + max_response_output_tokens: this.#opts.maxResponseOutputTokens, + tools, + tool_choice: toolChoice, + }, + }); + } + + #start(): Promise { + return new Promise(async (resolve, reject) => { + this.#ws = new WebSocket(`${this.#opts.baseURL}?model=gpt-4-turbo-preview`, { + headers: { + Authorization: `Bearer ${this.#opts.apiKey}`, + 'OpenAI-Beta': 'realtime=v1', + }, + }); + + this.#ws.onerror = (error) => { + reject(error.message); + }; + + await once(this.#ws, 'open'); + this.#closing = false; + + this.#ws.onmessage = (message) => { + const event: api_proto.ServerEvent = JSON.parse(message.data as string); + this.#logger.debug(`<- ${JSON.stringify(this.#loggableEvent(event))}`); + switch (event.type) { + case 'error': + this.handleError(event); + break; + case 'session.created': + this.handleSessionCreated(event); + break; + case 'session.updated': + this.handleSessionUpdated(event); + break; + case 'conversation.created': + this.handleConversationCreated(event); + break; + case 'input_audio_buffer.committed': + this.handleInputAudioBufferCommitted(event); + break; + case 'input_audio_buffer.cleared': + this.handleInputAudioBufferCleared(event); + break; + case 'input_audio_buffer.speech_started': + this.handleInputAudioBufferSpeechStarted(event); + break; + case 'input_audio_buffer.speech_stopped': + this.handleInputAudioBufferSpeechStopped(event); + break; + case 'conversation.item.created': + this.handleConversationItemCreated(event); + break; + case 'conversation.item.input_audio_transcription.completed': + this.handleConversationItemInputAudioTranscriptionCompleted(event); + break; + case 'conversation.item.input_audio_transcription.failed': + this.handleConversationItemInputAudioTranscriptionFailed(event); + break; + case 'conversation.item.truncated': + this.handleConversationItemTruncated(event); + break; + case 'conversation.item.deleted': + this.handleConversationItemDeleted(event); + break; + case 'response.created': + this.handleResponseCreated(event); + break; + case 'response.done': + this.handleResponseDone(event); + break; + case 'response.output_item.added': + this.handleResponseOutputItemAdded(event); + break; + case 'response.output_item.done': + this.handleResponseOutputItemDone(event); + break; + case 'response.content_part.added': + this.handleResponseContentPartAdded(event); + break; + case 'response.content_part.done': + this.handleResponseContentPartDone(event); + break; + case 'response.text.delta': + this.handleResponseTextDelta(event); + break; + case 'response.text.done': + this.handleResponseTextDone(event); + break; + case 'response.audio_transcript.delta': + this.handleResponseAudioTranscriptDelta(event); + break; + case 'response.audio_transcript.done': + this.handleResponseAudioTranscriptDone(event); + break; + case 'response.audio.delta': + this.handleResponseAudioDelta(event); + break; + case 'response.audio.done': + this.handleResponseAudioDone(event); + break; + case 'response.function_call_arguments.delta': + this.handleResponseFunctionCallArgumentsDelta(event); + break; + case 'response.function_call_arguments.done': + this.handleResponseFunctionCallArgumentsDone(event); + break; + case 'rate_limits.updated': + this.handleRateLimitsUpdated(event); + break; + } + }; + + const sendTask = async () => { + while (this.#ws && !this.#closing && this.#ws.readyState === WebSocket.OPEN) { + try { + const event = await this.#sendQueue.get(); + if (event.type !== 'input_audio_buffer.append') { + this.#logger.debug(`-> ${JSON.stringify(this.#loggableEvent(event))}`); + } + this.#ws.send(JSON.stringify(event)); + } catch (error) { + this.#logger.error('Error sending event:', error); + } + } + }; + + sendTask(); + + this.#ws.onclose = () => { + if (!this.#closing) { + reject('OpenAI S2S connection closed unexpectedly'); + } + this.#ws = null; + resolve(); + }; + }); + } + + async close(): Promise { + // TODO: Implement close method + throw new Error('Not implemented'); + } + + private getContent(ptr: ContentPtr): RealtimeContent { + const response = this.#pendingResponses[ptr.response_id]; + const output = response.output[ptr.output_index]; + const content = output.content[ptr.content_index]; + return content; + } + + private handleError(event: api_proto.ErrorEvent): void { + this.#logger.error(`OpenAI S2S error ${event.error}`); + } + + private handleSessionCreated(event: api_proto.SessionCreatedEvent): void { + this.#sessionId = event.session.id; + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private handleSessionUpdated(event: api_proto.SessionUpdatedEvent): void {} + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private handleConversationCreated(event: api_proto.ConversationCreatedEvent): void {} + + private handleInputAudioBufferCommitted(event: api_proto.InputAudioBufferCommittedEvent): void { + this.emit(EventTypes.InputSpeechCommitted, { + itemId: event.item_id, + }); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private handleInputAudioBufferCleared(event: api_proto.InputAudioBufferClearedEvent): void {} + + private handleInputAudioBufferSpeechStarted( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + event: api_proto.InputAudioBufferSpeechStartedEvent, + ): void { + this.emit(EventTypes.InputSpeechStarted); + } + + private handleInputAudioBufferSpeechStopped( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + event: api_proto.InputAudioBufferSpeechStoppedEvent, + ): void { + this.emit(EventTypes.InputSpeechStopped); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private handleConversationItemCreated(event: api_proto.ConversationItemCreatedEvent): void {} + + private handleConversationItemInputAudioTranscriptionCompleted( + event: api_proto.ConversationItemInputAudioTranscriptionCompletedEvent, + ): void { + const transcript = event.transcript; + this.emit(EventTypes.InputSpeechTranscriptionCompleted, { + itemId: event.item_id, + transcript: transcript, + }); + } + + private handleConversationItemInputAudioTranscriptionFailed( + event: api_proto.ConversationItemInputAudioTranscriptionFailedEvent, + ): void { + const error = event.error; + this.#logger.error(`OAI S2S failed to transcribe input audio: ${error.message}`); + this.emit(EventTypes.InputSpeechTranscriptionFailed, { + itemId: event.item_id, + message: error.message, + }); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private handleConversationItemTruncated(event: api_proto.ConversationItemTruncatedEvent): void {} + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private handleConversationItemDeleted(event: api_proto.ConversationItemDeletedEvent): void {} + + private handleResponseCreated(responseCreated: api_proto.ResponseCreatedEvent): void { + const response = responseCreated.response; + const donePromise = new Promise((resolve) => { + this.once('response_done', () => resolve()); + }); + const newResponse: RealtimeResponse = { + id: response.id, + status: response.status, + output: [], + donePromise: () => donePromise, + }; + this.#pendingResponses[newResponse.id] = newResponse; + this.emit(EventTypes.ResponseCreated, newResponse); + } + + private handleResponseDone(event: api_proto.ResponseDoneEvent): void { + const responseData = event.response; + const responseId = responseData.id; + const response = this.#pendingResponses[responseId]; + response.donePromise(); + this.emit(EventTypes.ResponseDone, response); + } + + private handleResponseOutputItemAdded(event: api_proto.ResponseOutputItemAddedEvent): void { + const responseId = event.response_id; + const response = this.#pendingResponses[responseId]; + const itemData = event.item; + + if (itemData.type !== 'message' && itemData.type !== 'function_call') { + throw new Error(`Unexpected item type: ${itemData.type}`); + } + + let role: api_proto.Role; + if (itemData.type === 'function_call') { + role = 'assistant'; // function_call doesn't have a role field, defaulting it to assistant + } else { + role = itemData.role; + } + + const newOutput: RealtimeOutput = { + responseId: responseId, + itemId: itemData.id, + outputIndex: event.output_index, + type: itemData.type, + role: role, + content: [], + donePromise: () => + new Promise((resolve) => { + this.once('response_output_done', (output: RealtimeOutput) => { + if (output.itemId === itemData.id) { + resolve(); + } + }); + }), + }; + response.output.push(newOutput); + this.emit(EventTypes.ResponseOutputAdded, newOutput); + } + + private handleResponseOutputItemDone(event: api_proto.ResponseOutputItemDoneEvent): void { + const responseId = event.response_id; + const response = this.#pendingResponses[responseId]; + const outputIndex = event.output_index; + const output = response.output[outputIndex]; + + // TODO: finish implementing + // if (output.type === "function_call") { + // if (!this.#funcCtx) { + // this.#logger.error( + // "function call received but no funcCtx is available" + // ); + // return; + // } + + // // parse the arguments and call the function inside the fnc_ctx + // const item = event.item; + // if (item.type !== "function_call") { + // throw new Error("Expected function_call item"); + // } + + // const funcCallInfo = this.#oai_api.createAiFunctionInfo( + // this.#funcCtx, + // item.call_id, + // item.name, + // item.arguments + // ); + + // this.#fnc_tasks.createTask( + // this.#runFncTask(fnc_call_info, output.item_id) + // ); + // } + + output.donePromise(); + this.emit(EventTypes.ResponseOutputDone, output); + } + + private handleResponseContentPartAdded(event: api_proto.ResponseContentPartAddedEvent): void { + const responseId = event.response_id; + const response = this.#pendingResponses[responseId]; + const outputIndex = event.output_index; + const output = response.output[outputIndex]; + + const textStream = new Queue(); + const audioStream = new Queue(); + + const newContent: RealtimeContent = { + responseId: responseId, + itemId: event.item_id, + outputIndex: outputIndex, + contentIndex: event.content_index, + text: '', + audio: [], + textStream: textStream, + audioStream: audioStream, + toolCalls: [], + }; + output.content.push(newContent); + this.emit(EventTypes.ResponseContentAdded, newContent); + } + + private handleResponseContentPartDone(event: api_proto.ResponseContentPartDoneEvent): void { + const content = this.getContent(event); + this.emit(EventTypes.ResponseContentDone, content); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private handleResponseTextDelta(event: api_proto.ResponseTextDeltaEvent): void {} + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private handleResponseTextDone(event: api_proto.ResponseTextDoneEvent): void {} + + private handleResponseAudioTranscriptDelta( + event: api_proto.ResponseAudioTranscriptDeltaEvent, + ): void { + const content = this.getContent(event); + const transcript = event.delta; + content.text += transcript; + + content.textStream.put(transcript); + } + + private handleResponseAudioTranscriptDone( + event: api_proto.ResponseAudioTranscriptDoneEvent, + ): void { + const content = this.getContent(event); + content.textStream.put(null); + } + + private handleResponseAudioDelta(event: api_proto.ResponseAudioDeltaEvent): void { + const content = this.getContent(event); + const data = Buffer.from(event.delta, 'base64'); + const audio = new AudioFrame( + new Int16Array(data.buffer), + api_proto.SAMPLE_RATE, + api_proto.NUM_CHANNELS, + data.length / 2, + ); + content.audio.push(audio); + + content.audioStream.put(audio); + } + + private handleResponseAudioDone(event: api_proto.ResponseAudioDoneEvent): void { + const content = this.getContent(event); + content.audioStream.put(null); + } + + private handleResponseFunctionCallArgumentsDelta( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + event: api_proto.ResponseFunctionCallArgumentsDeltaEvent, + ): void {} + + private handleResponseFunctionCallArgumentsDone( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + event: api_proto.ResponseFunctionCallArgumentsDoneEvent, + ): void {} + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private handleRateLimitsUpdated(event: api_proto.RateLimitsUpdatedEvent): void {} +} + +// TODO function init +// if (event.item.type === 'function_call') { +// const toolCall = event.item; +// this.options.functions[toolCall.name].execute(toolCall.arguments).then((content) => { +// this.thinking = false; +// this.sendClientCommand({ +// type: proto.ClientEventType.ConversationItemCreate, +// item: { +// type: 'function_call_output', +// call_id: toolCall.call_id, +// output: content, +// }, +// }); +// this.sendClientCommand({ +// type: proto.ClientEventType.ResponseCreate, +// response: {}, +// }); +// }); +// }