diff --git a/koala/grammar.ts b/koala/grammar.ts new file mode 100644 index 0000000..08637c3 --- /dev/null +++ b/koala/grammar.ts @@ -0,0 +1,100 @@ +import OpenAI from "openai"; +import { ChatCompletionCreateParamsNonStreaming } from "openai/resources"; +import { errorReport } from "./error-report"; +import { YesNo } from "./shared-types"; +import { z } from "zod"; +import { zodResponseFormat } from "openai/helpers/zod"; + +const apiKey = process.env.OPENAI_API_KEY; + +if (!apiKey) { + errorReport("Missing ENV Var: OPENAI_API_KEY"); +} + +const configuration = { apiKey }; + +export const openai = new OpenAI(configuration); + +export async function gptCall(opts: ChatCompletionCreateParamsNonStreaming) { + return await openai.chat.completions.create(opts); +} + +const zodYesOrNo = z.object({ + response: z.union([ + z.object({ + userWasCorrect: z.literal(true), + }), + z.object({ + userWasCorrect: z.literal(false), + correctedSentence: z.string(), + }), + ]), +}); + +export type Explanation = { response: YesNo; whyNot?: string }; + +type GrammarCorrrectionProps = { + /** The Korean phrase. */ + term: string; + /** An English translation */ + definition: string; + /** Language code like KO */ + langCode: string; + /** What the user said. */ + userInput: string; +}; + +const getLangcode = (lang: string) => { + const names: Record = { + EN: "English", + IT: "Italian", + FR: "French", + ES: "Spanish", + KO: "Korean", + }; + const key = lang.slice(0, 2).toUpperCase(); + return names[key] || lang; +}; + +export const grammarCorrection = async ( + props: GrammarCorrrectionProps, +): Promise => { + // Latest snapshot that supports Structured Outputs + // TODO: Get on mainline 4o when it supports Structured Outputs + const model = "gpt-4o-2024-08-06"; + const { userInput } = props; + const lang = getLangcode(props.langCode); + const prompt = [ + `You are a language assistant helping users improve their ${lang} sentences.`, + `The user wants to say: '${props.definition}' in ${lang}.`, + `They provided: '${userInput}'.`, + `Your task is to determine if the user's input is an acceptable way to express the intended meaning in ${lang}.`, + `If the response is acceptable by ${lang} native speakers, respond with:`, + `{ "response": { "userWasCorrect": true } }`, + `If it is not, respond with:`, + `{ "response": { "userWasCorrect": false, "correctedSentence": "corrected sentence here" } }`, + `Do not include any additional commentary or explanations.`, + `Ensure your response is in valid JSON format.`, + ].join("\n"); + + const resp = await openai.beta.chat.completions.parse({ + messages: [ + { + role: "user", + content: prompt, + }, + ], + model, + max_tokens: 125, + // top_p: 1, + // frequency_penalty: 0, + temperature: 0.1, + response_format: zodResponseFormat(zodYesOrNo, "correct_sentence"), + }); + const correct_sentence = resp.choices[0].message.parsed; + if (correct_sentence) { + if (!correct_sentence.response.userWasCorrect) { + return correct_sentence.response.correctedSentence; + } + } +}; diff --git a/koala/openai.ts b/koala/openai.ts index b6a587b..ff99b49 100644 --- a/koala/openai.ts +++ b/koala/openai.ts @@ -2,8 +2,6 @@ import OpenAI from "openai"; import { ChatCompletionCreateParamsNonStreaming } from "openai/resources"; import { errorReport } from "./error-report"; import { YesNo } from "./shared-types"; -import { z } from "zod"; -import { zodResponseFormat } from "openai/helpers/zod"; const apiKey = process.env.OPENAI_API_KEY; @@ -49,18 +47,6 @@ const SIMPLE_YES_OR_NO = { description: "Answer a yes or no question.", }; -const zodYesOrNo = z.object({ - response: z.union([ - z.object({ - userWasCorrect: z.literal(true), - }), - z.object({ - userWasCorrect: z.literal(false), - correctedSentence: z.string(), - }), - ]), -}); - export type Explanation = { response: YesNo; whyNot?: string }; export const testEquivalence = async ( @@ -97,49 +83,6 @@ export const testEquivalence = async ( return raw.response as YesNo; }; -type GrammarCorrrectionProps = { - term: string; - definition: string; - langCode: string; - userInput: string; -}; - -export const grammarCorrection = async ( - props: GrammarCorrrectionProps, -): Promise => { - // Latest snapshot that supports Structured Outputs - // TODO: Get on mainline 4o when it supports Structured Outputs - const model = "gpt-4o-2024-08-06"; - const { userInput } = props; - const prompt = [ - `I want to say '${props.definition}' in language: ${props.langCode}.`, - `Is '${userInput}' OK?`, - `Correct awkwardness or major grammatical issues, if any.`, - ].join("\n"); - - const resp = await openai.beta.chat.completions.parse({ - messages: [ - { - role: "user", - content: prompt, - }, - ], - model, - max_tokens: 150, - top_p: 1, - frequency_penalty: 0, - stop: ["\n"], - temperature: 0.2, - response_format: zodResponseFormat(zodYesOrNo, "correct_sentence"), - }); - const correct_sentence = resp.choices[0].message.parsed; - if (correct_sentence) { - if (!correct_sentence.response.userWasCorrect) { - return correct_sentence.response.correctedSentence; - } - } -}; - export const translateToEnglish = async (content: string, langCode: string) => { const prompt = `You will be provided with a foreign language sentence (lang code: ${langCode}), and your task is to translate it into English.`; const hm = await gptCall({ diff --git a/koala/play-audio.tsx b/koala/play-audio.tsx index 7d9e9fb..b1f1513 100644 --- a/koala/play-audio.tsx +++ b/koala/play-audio.tsx @@ -1,59 +1,28 @@ let currentlyPlaying = false; -const playAudioBuffer = ( - buffer: AudioBuffer, - context: AudioContext, -): Promise => { - return new Promise((resolve) => { - const source = context.createBufferSource(); - source.buffer = buffer; - source.connect(context.destination); - source.onended = () => { - currentlyPlaying = false; - resolve(); - }; - source.start(0); - }); -}; - -export const playAudio = async (urlOrDataURI: string): Promise => { - if (!urlOrDataURI) { - return; - } - - if (currentlyPlaying) { - return; - } +export const playAudio = (urlOrDataURI: string) => { + return new Promise((resolve, reject) => { + if (!urlOrDataURI) { + return; + } - currentlyPlaying = true; - const audioContext = new AudioContext(); + if (currentlyPlaying) { + return; + } - try { - let arrayBuffer: ArrayBuffer; + currentlyPlaying = true; - if (urlOrDataURI.startsWith("data:")) { - // Handle Base64 Data URI - const base64Data = urlOrDataURI.split(",")[1]; - const binaryString = atob(base64Data); - const len = binaryString.length; - const bytes = new Uint8Array(len); - for (let i = 0; i < len; i++) { - bytes[i] = binaryString.charCodeAt(i); - } - arrayBuffer = bytes.buffer; - } else { - // Handle external URL - const response = await fetch(urlOrDataURI); - if (!response.ok) { - throw new Error("Network response was not ok"); - } - arrayBuffer = await response.arrayBuffer(); - } + const ok = () => { + currentlyPlaying = false; + resolve(""); + }; - const audioBuffer = await audioContext.decodeAudioData(arrayBuffer); - await playAudioBuffer(audioBuffer, audioContext); - } catch (e) { - currentlyPlaying = false; - throw e; - } + const audio = new Audio(urlOrDataURI); + audio.onended = ok; + audio.onerror = ok; + audio.play().catch((e) => { + reject(e); + console.error("Audio playback failed:", e); + }); + }); }; diff --git a/koala/quiz-evaluators/speaking.ts b/koala/quiz-evaluators/speaking.ts index 4fe786d..8937cdc 100644 --- a/koala/quiz-evaluators/speaking.ts +++ b/koala/quiz-evaluators/speaking.ts @@ -1,6 +1,5 @@ import { Explanation, - grammarCorrection, testEquivalence, translateToEnglish, } from "@/koala/openai"; @@ -8,6 +7,7 @@ import { QuizEvaluator, QuizEvaluatorOutput } from "./types"; import { strip } from "./evaluator-utils"; import { captureTrainingData } from "./capture-training-data"; import { prismaClient } from "../prisma-client"; +import { grammarCorrection } from "../grammar"; const doGrade = async ( userInput: string, @@ -63,7 +63,7 @@ function gradeWithGrammarCorrection(i: X, what: 1 | 2): QuizEvaluatorOutput { } else { return { result: "fail", - userMessage: `Correct, but say "${i.correction}" instead of "${i.userInput}" (${what}).`, + userMessage: `Say "${i.correction}" instead of "${i.userInput}" (${what}).`, }; } } diff --git a/koala/routes/get-mirroring-cards.ts b/koala/routes/get-mirroring-cards.ts new file mode 100644 index 0000000..1893522 --- /dev/null +++ b/koala/routes/get-mirroring-cards.ts @@ -0,0 +1,46 @@ +import { z } from "zod"; +import { prismaClient } from "../prisma-client"; +import { procedure } from "../trpc-procedure"; +import { generateLessonAudio } from "../speech"; +import { map, shuffle } from "radash"; + +export const getMirrorCards = procedure + .input(z.object({})) + .output( + z.array( + z.object({ + id: z.number(), + term: z.string(), + definition: z.string(), + audioUrl: z.string(), + translationAudioUrl: z.string(), + langCode: z.string(), + }), + ), + ) + .mutation(async ({ ctx }) => { + const cards = await prismaClient.card.findMany({ + where: { userId: ctx.user?.id || "000", flagged: false }, + orderBy: [{ mirrorRepetitionCount: "asc" }], + take: 200, + }); + // Order by length of 'term' field: + cards.sort((a, b) => b.term.length - a.term.length); + const shortList = shuffle(cards.slice(0, 100)).slice(0, 5); + return await map(shortList, async (card) => { + return { + id: card.id, + term: card.term, + definition: card.definition, + langCode: card.langCode, + translationAudioUrl: await generateLessonAudio({ + card, + lessonType: "speaking", + }), + audioUrl: await generateLessonAudio({ + card, + lessonType: "listening", + }), + }; + }); + }); diff --git a/koala/routes/main.ts b/koala/routes/main.ts index 34ddbf6..abf5cee 100644 --- a/koala/routes/main.ts +++ b/koala/routes/main.ts @@ -8,6 +8,7 @@ import { exportCards } from "./export-cards"; import { faucet } from "./faucet"; import { flagCard } from "./flag-card"; import { getAllCards } from "./get-all-cards"; +import { getMirrorCards } from "./get-mirroring-cards"; import { getNextQuizzes } from "./get-next-quizzes"; import { getOneCard } from "./get-one-card"; import { getPlaybackAudio } from "./get-playback-audio"; @@ -47,6 +48,7 @@ export const appRouter = router({ transcribeAudio, translateText, viewTrainingData, + getMirrorCards, }); export type AppRouter = typeof appRouter; diff --git a/pages/mirror.tsx b/pages/mirror.tsx index 042e869..2aa813f 100644 --- a/pages/mirror.tsx +++ b/pages/mirror.tsx @@ -1,97 +1,204 @@ -import { useState } from "react"; +import React, { useState, useEffect } from "react"; +import { useVoiceRecorder } from "@/koala/use-recorder"; import { playAudio } from "@/koala/play-audio"; import { blobToBase64, convertBlobToWav } from "@/koala/record-button"; -import { useVoiceRecorder } from "@/koala/use-recorder"; import { trpc } from "@/koala/trpc-config"; +import { useHotkeys } from "@mantine/hooks"; -export default function Mirror() { - const [transcription, setTranscription] = useState(null); - const [translation, setTranslation] = useState(null); - const [isProcessing, setIsProcessing] = useState(false); +type Card = { + term: string; + definition: string; + translationAudioUrl: string; + audioUrl: string; +}; + +const RecordingControls = ({ + isRecording, + successfulAttempts, + failedAttempts, + isProcessingRecording, + handleClick, +}: { + isRecording: boolean; + successfulAttempts: number; + failedAttempts: number; + isProcessingRecording: boolean; + handleClick: () => void; +}) => { + let message: string; + if (isRecording) { + message = "Recording..."; + } else { + message = "Start recording"; + } + + return ( +
+ +

{successfulAttempts} repetitions correct.

+

{isProcessingRecording ? 1 : 0} repetitions awaiting grade.

+

{failedAttempts} repetitions failed.

+

{3 - successfulAttempts} repetitions left.

+
+ ); +}; + +// Component to handle quizzing for a single sentence +export const SentenceQuiz = ({ + card, + setErrorMessage, + onCardCompleted, +}: { + card: Card; + setErrorMessage: (error: string) => void; + onCardCompleted: () => void; +}) => { + // State variables + const [successfulAttempts, setSuccessfulAttempts] = useState(0); + const [failedAttempts, setFailedAttempts] = useState(0); + const [isRecording, setIsRecording] = useState(false); + const [isProcessingRecording, setIsProcessingRecording] = useState(false); + + // TRPC mutations const transcribeAudio = trpc.transcribeAudio.useMutation(); - const translateText = trpc.translateText.useMutation(); - const speakText = trpc.speakText.useMutation(); - const vr = useVoiceRecorder(async (result: Blob) => { - setIsProcessing(true); + // Voice recorder hook + const voiceRecorder = useVoiceRecorder(handleRecordingResult); + // Handle button click + const handleClick = async () => { + if (successfulAttempts >= 3) { + // Do nothing if already completed + return; + } + + if (isRecording) { + // Stop recording + voiceRecorder.stop(); + setIsRecording(false); + } else { + playAudio(card.audioUrl); + // Start recording + setIsRecording(true); + voiceRecorder.start(); + } + }; + + // Use hotkeys to trigger handleClick on space bar press + useHotkeys([["space", handleClick]]); + + // Reset state variables when the term changes + useEffect(() => { + setSuccessfulAttempts(0); + setFailedAttempts(0); + setIsRecording(false); + setIsProcessingRecording(false); + }, [card.term]); + + // Handle the result after recording is finished + async function handleRecordingResult(audioBlob: Blob) { + setIsProcessingRecording(true); try { - // Convert blob to base64 WAV format - const wavBlob = await convertBlobToWav(result); + // Convert the recorded audio blob to WAV and then to base64 + const wavBlob = await convertBlobToWav(audioBlob); const base64Audio = await blobToBase64(wavBlob); // Transcribe the audio - const { result: korean } = await transcribeAudio.mutateAsync({ + const { result: transcription } = await transcribeAudio.mutateAsync({ audio: base64Audio, lang: "ko", - targetText: "사람들이 한국말로 말해요.", - }); - setTranscription(korean); - - // Translate the transcription to Korean - const { result: translatedText } = await translateText.mutateAsync({ - text: korean, - lang: "ko", + targetText: card.term, }); - setTranslation(translatedText); + // Compare the transcription with the target sentence + if (transcription.trim() === card.term.trim()) { + setSuccessfulAttempts((prev) => prev + 1); + } else { + setFailedAttempts((prev) => prev + 1); + } + } catch (err) { + setErrorMessage("Error processing the recording."); + } finally { + setIsProcessingRecording(false); + } + } - // Speak the Korean translation out loud - const { url: audioKO } = await speakText.mutateAsync({ - lang: "ko", - text: korean, + // Effect to handle successful completion + useEffect(() => { + if (successfulAttempts >= 3) { + // Play the translation audio + playAudio(card.translationAudioUrl).then(() => { + // After audio finishes, proceed to next sentence + onCardCompleted(); }); + } + }, [successfulAttempts]); - // Speak the original English transcription via TTS - const { url: audioEN } = await speakText.mutateAsync({ - lang: "en", - text: translatedText, - }); + return ( +
+ + {successfulAttempts === 0 &&

{card.term}

} +
+ ); +}; - // Play back the original voice recording - await playAudio(base64Audio); - await playAudio(audioKO); - await playAudio(audioEN); - } catch (error) { - console.error("Error processing voice recording:", error); - } finally { - setIsProcessing(false); - } - }); +export default function Mirror() { + // State variables + const [terms, setTerms] = useState([]); + const [currentIndex, setCurrentIndex] = useState(0); + const [errorMessage, setErrorMessage] = useState(null); - const handleClick = () => { - if (vr.isRecording) { - vr.stop(); - } else { - vr.start(); + const getMirrorCards = trpc.getMirrorCards.useMutation({}); + // Fetch translations for all sentences + const fetchCards = async () => { + try { + const translatedTerms = await getMirrorCards.mutateAsync({}); + setTerms(translatedTerms); + setCurrentIndex(0); + } catch (err) { + setErrorMessage("Failed to fetch sentences."); } }; - if (vr.error) { - return
Recording error: {JSON.stringify(vr.error)}
; + useEffect(() => { + fetchCards(); + }, []); + + const handleCardCompleted = () => { + setCurrentIndex((prevIndex) => prevIndex + 1); + }; + + // When the list is emptied, re-fetch more cards from the server + useEffect(() => { + if (currentIndex >= terms.length) { + // Re-fetch more cards from the server + fetchCards(); + } + }, [currentIndex, terms.length]); + + if (errorMessage) { + return
Error: {errorMessage}
; } - return ( -
- + if (terms.length === 0) { + return
Loading sentences...
; + } - {isProcessing &&
Processing your voice...
} - - {transcription && ( -
-

Transcription:

-

{transcription}

-
- )} - - {translation && ( -
-

Translation (Korean):

-

{translation}

-
- )} -
+ const currentTerm = terms[currentIndex]; + + return ( + ); } diff --git a/prisma/migrations/20241001010920_add_mirror_repetition_count/migration.sql b/prisma/migrations/20241001010920_add_mirror_repetition_count/migration.sql new file mode 100644 index 0000000..0b4fa90 --- /dev/null +++ b/prisma/migrations/20241001010920_add_mirror_repetition_count/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "Card" ADD COLUMN "mirrorRepetitionCount" INTEGER DEFAULT 0; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 7cdf453..05a63b7 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -75,19 +75,20 @@ model VerificationToken { } model Card { - id Int @id @default(autoincrement()) - user User @relation(fields: [userId], references: [id], onDelete: Cascade) - userId String - flagged Boolean @default(false) - term String - definition String - langCode String + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String + flagged Boolean @default(false) + term String + definition String + langCode String // Gender can be "M"ale, "F"emale, or "N"eutral - gender String @default("N") - createdAt DateTime @default(now()) - Quiz Quiz[] - imageBlobId String? - SpeakingCorrection SpeakingCorrection[] + gender String @default("N") + createdAt DateTime @default(now()) + Quiz Quiz[] + imageBlobId String? + SpeakingCorrection SpeakingCorrection[] + mirrorRepetitionCount Int? @default(0) @@unique([userId, term]) } @@ -137,13 +138,13 @@ model TrainingData { } model SpeakingCorrection { - id Int @id @default(autoincrement()) - cardId Int - Card Card @relation(fields: [cardId], references: [id], onDelete: Cascade) - createdAt DateTime @default(now()) - isCorrect Boolean - term String - definition String - userInput String - correction String @default("") + id Int @id @default(autoincrement()) + cardId Int + Card Card @relation(fields: [cardId], references: [id], onDelete: Cascade) + createdAt DateTime @default(now()) + isCorrect Boolean + term String + definition String + userInput String + correction String @default("") }