diff --git a/src/commands/chat.ts b/src/commands/chat.ts index ff234d8..898010f 100644 --- a/src/commands/chat.ts +++ b/src/commands/chat.ts @@ -94,5 +94,5 @@ async function getResponse({ const iterableStream = streamToIterable(stream); - return { readResponse: readData(iterableStream, () => true) }; + return { readResponse: readData(iterableStream) }; } diff --git a/src/helpers/completion.ts b/src/helpers/completion.ts index e726fee..fe98323 100644 --- a/src/helpers/completion.ts +++ b/src/helpers/completion.ts @@ -8,6 +8,7 @@ import type { AxiosError } from 'axios'; import { streamToString } from './stream-to-string'; import './replace-all-polyfill'; import i18n from './i18n'; +import { stripRegexPatterns } from './strip-regex-patterns'; const explainInSecondRequest = true; @@ -20,7 +21,7 @@ function getOpenAi(key: string, apiEndpoint: string) { // Openai outputs markdown format for code blocks. It oftne uses // a github style like: "```bash" -const shellCodeStartRegex = /```[^\n]*/gi; +const shellCodeExclusions = [/```[a-zA-Z]*\n/gi, /```[a-zA-Z]*/gi, '\n']; export async function getScriptAndInfo({ prompt, @@ -42,14 +43,9 @@ export async function getScriptAndInfo({ apiEndpoint, }); const iterableStream = streamToIterable(stream); - const codeBlock = '```'; return { - readScript: readData(iterableStream, () => true, shellCodeStartRegex), - readInfo: readData( - iterableStream, - (content) => content.endsWith(codeBlock), - shellCodeStartRegex - ), + readScript: readData(iterableStream, ...shellCodeExclusions), + readInfo: readData(iterableStream, ...shellCodeExclusions), }; } @@ -154,7 +150,7 @@ export async function getExplanation({ apiEndpoint, }); const iterableStream = streamToIterable(stream); - return { readExplanation: readData(iterableStream, () => true) }; + return { readExplanation: readData(iterableStream) }; } export async function getRevision({ @@ -180,22 +176,24 @@ export async function getRevision({ }); const iterableStream = streamToIterable(stream); return { - readScript: readData(iterableStream, () => true), + readScript: readData(iterableStream, ...shellCodeExclusions), }; } export const readData = ( iterableStream: AsyncGenerator, - startSignal: (content: string) => boolean, - excluded?: RegExp + ...excluded: (RegExp | string | undefined)[] ) => (writer: (data: string) => void): Promise => new Promise(async (resolve) => { let data = ''; let content = ''; let dataStart = false; - let waitUntilNewline = false; + // This buffer will temporarily hold incoming data only for detecting the start + let buffer = ''; + + const [excludedPrefix] = excluded; for await (const chunk of iterableStream) { const payloads = chunk.toString().split('\n\n'); @@ -209,25 +207,24 @@ export const readData = if (payload.startsWith('data:')) { content = parseContent(payload); - if (!dataStart && content.match(excluded ?? '')) { - dataStart = startSignal(content); - if (!content.includes('\n')) { - waitUntilNewline = true; + // Use buffer only for start detection + if (!dataStart) { + // Append content to the buffer + buffer += content; + if (buffer.match(excludedPrefix ?? '')) { + dataStart = true; + // Clear the buffer once it has served its purpose + buffer = ''; + if (excludedPrefix) break; } - if (excluded) break; - } - - if (content && waitUntilNewline) { - if (!content.includes('\n')) { - continue; - } - waitUntilNewline = false; } if (dataStart && content) { - const contentWithoutExcluded = excluded - ? content.replaceAll(excluded, '') - : content; + const contentWithoutExcluded = stripRegexPatterns( + content, + excluded + ); + data += contentWithoutExcluded; writer(contentWithoutExcluded); } diff --git a/src/helpers/strip-regex-patterns.ts b/src/helpers/strip-regex-patterns.ts new file mode 100644 index 0000000..3654186 --- /dev/null +++ b/src/helpers/strip-regex-patterns.ts @@ -0,0 +1,9 @@ +export const stripRegexPatterns = ( + inputString: string, + patternList: (RegExp | string | undefined)[] +) => + patternList.reduce( + (currentString: string, pattern) => + pattern ? currentString.replaceAll(pattern, '') : currentString, + inputString + );