From c7752660010fae6e0c3bb60f4e81f0cf973105c8 Mon Sep 17 00:00:00 2001 From: Darko Date: Wed, 6 Nov 2024 10:32:27 -0400 Subject: [PATCH] I18N-1336 - Improve AI checker perf Strings are checked concurrently --- .../service/ai/openai/OpenAILLMService.java | 39 ++++++++++- .../ai/openai/OpenAILLMServiceTest.java | 64 +++++++++++++++++++ 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMService.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMService.java index db646dccb..1d1adb00e 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMService.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMService.java @@ -32,6 +32,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -101,15 +103,46 @@ public AICheckResponse executeAIChecks(AICheckRequest aiCheckRequest) { textUnit -> textUnit, (existing, replacement) -> existing, HashMap::new)); - Map> results = new HashMap<>(); + + Map>> checkResultsBySource = new HashMap<>(); textUnitsUniqueSource .values() .forEach( textUnit -> { - List aiCheckResults = checkString(textUnit, prompts, repository); - results.put(textUnit.getSource(), aiCheckResults); + CompletableFuture> checkResultFuture = + CompletableFuture.supplyAsync(() -> checkString(textUnit, prompts, repository)); + checkResultsBySource.put(textUnit.getSource(), checkResultFuture); + }); + + CompletableFuture combinedCheckResults = + CompletableFuture.allOf(checkResultsBySource.values().toArray(new CompletableFuture[0])); + + CompletableFuture>> checkResults = + combinedCheckResults.thenApply( + v -> { + Map> allResultsList = new HashMap<>(); + for (Map.Entry>> + checkResultFutureBySource : checkResultsBySource.entrySet()) { + try { + allResultsList.put( + checkResultFutureBySource.getKey(), + checkResultFutureBySource.getValue().get()); + } catch (InterruptedException | ExecutionException e) { + logger.error("Error while running a completable future", e); + } + } + return allResultsList; }); + Map> results = new HashMap<>(); + checkResults.thenAccept(results::putAll); + + try { + checkResults.get(); + } catch (InterruptedException | ExecutionException e) { + logger.error("Error while running completable futures", e); + } + AICheckResponse aiCheckResponse = new AICheckResponse(); aiCheckResponse.setResults(results); diff --git a/webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMServiceTest.java b/webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMServiceTest.java index 529e9a99d..65a3c54c8 100644 --- a/webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMServiceTest.java +++ b/webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMServiceTest.java @@ -830,4 +830,68 @@ void testPromptTemplatingInlineSentence() { The comment is: A friendly greeting. The plural form is: one.""", prompt); } + + @Test + void testExecuteAIChecksWithSleepTime() { + AIPrompt prompt = new AIPrompt(); + prompt.setId(1L); + prompt.setUserPrompt("Check strings for spelling"); + prompt.setModelName("gtp-3.5-turbo"); + prompt.setPromptTemperature(0.0F); + prompt.setContextMessages(new ArrayList<>()); + AssetExtractorTextUnit assetExtractorTextUnit = new AssetExtractorTextUnit(); + assetExtractorTextUnit.setSource("A test string"); + assetExtractorTextUnit.setName("A test string --- A test context"); + assetExtractorTextUnit.setComments("A test comment"); + AssetExtractorTextUnit assetExtractorTextUnit2 = new AssetExtractorTextUnit(); + assetExtractorTextUnit2.setSource("A test string 2"); + assetExtractorTextUnit2.setName("A test string --- A test context 2"); + assetExtractorTextUnit2.setComments("A test comment 2"); + List textUnits = + List.of(assetExtractorTextUnit, assetExtractorTextUnit2); + AICheckRequest aiCheckRequest = new AICheckRequest(); + aiCheckRequest.setRepositoryName("testRepo"); + aiCheckRequest.setTextUnits(textUnits); + List choices = + List.of( + new OpenAIClient.ChatCompletionsResponse.Choice( + 0, + new OpenAIClient.ChatCompletionsResponse.Choice.Message( + "test", "{\"success\": true, \"suggestedFix\": \"\"}"), + null)); + Repository repository = new Repository(); + repository.setName("testRepo"); + repository.setId(1L); + OpenAIClient.ChatCompletionsResponse chatCompletionsResponse = + new OpenAIClient.ChatCompletionsResponse(null, null, null, null, choices, null, null); + CompletableFuture futureResponse = + CompletableFuture.completedFuture(chatCompletionsResponse); + List prompts = List.of(prompt); + + when(repositoryRepository.findByName("testRepo")).thenReturn(repository); + when(promptService.getPromptsByRepositoryAndPromptType( + repository, PromptType.SOURCE_STRING_CHECKER)) + .thenReturn(prompts); + doAnswer( + invocation -> { + Thread.sleep(100); + return futureResponse; + }) + .when(openAIClient) + .getChatCompletions(any(OpenAIClient.ChatCompletionsRequest.class)); + + AICheckResponse response = this.openAILLMService.executeAIChecks(aiCheckRequest); + + assertNotNull(response); + assertEquals(2, response.getResults().size()); + assertTrue(response.getResults().containsKey("A test string")); + assertTrue(response.getResults().get("A test string").getFirst().isSuccess()); + + assertTrue(response.getResults().containsKey("A test string 2")); + assertTrue(response.getResults().get("A test string 2").getFirst().isSuccess()); + + verify(aiStringCheckRepository, times(2)).save(any()); + verify(meterRegistry, times(2)) + .counter("OpenAILLMService.checks.result", "success", "true", "repository", "testRepo"); + } }