diff --git a/cli/src/main/java/com/box/l10n/mojito/cli/command/CreateAIPromptCommand.java b/cli/src/main/java/com/box/l10n/mojito/cli/command/CreateAIPromptCommand.java index cabc1733c5..f2872dc0c5 100644 --- a/cli/src/main/java/com/box/l10n/mojito/cli/command/CreateAIPromptCommand.java +++ b/cli/src/main/java/com/box/l10n/mojito/cli/command/CreateAIPromptCommand.java @@ -20,7 +20,7 @@ public class CreateAIPromptCommand extends Command { static Logger logger = LoggerFactory.getLogger(CreateAIPromptCommand.class); - @Autowired AIServiceClient AIServiceClient; + @Autowired AIServiceClient aiServiceClient; @Parameter( names = {"--repository-name", "-r"}, @@ -58,6 +58,18 @@ public class CreateAIPromptCommand extends Command { description = "The temperature to use for the prompt") float promptTemperature = 0.0F; + @Parameter( + names = {"--is-json-response", "-ijr"}, + required = false, + description = "The prompt response is expected to be in JSON format from the LLM") + boolean isJsonResponse = false; + + @Parameter( + names = {"--json-response-key", "-jrk"}, + required = false, + description = "The key to use to extract the translation from the JSON response") + String jsonResponseKey; + @Autowired private ConsoleWriter consoleWriter; @Override @@ -67,14 +79,19 @@ protected void execute() throws CommandException { private void createPrompt() { logger.debug("Received request to create prompt"); - AIPromptCreateRequest AIPromptCreateRequest = new AIPromptCreateRequest(); - AIPromptCreateRequest.setRepositoryName(repository); - AIPromptCreateRequest.setSystemPrompt(systemPromptText); - AIPromptCreateRequest.setUserPrompt(userPromptText); - AIPromptCreateRequest.setModelName(modelName); - AIPromptCreateRequest.setPromptType(promptType); - AIPromptCreateRequest.setPromptTemperature(promptTemperature); - long promptId = AIServiceClient.createPrompt(AIPromptCreateRequest); + AIPromptCreateRequest aiPromptCreateRequest = new AIPromptCreateRequest(); + aiPromptCreateRequest.setRepositoryName(repository); + aiPromptCreateRequest.setSystemPrompt(systemPromptText); + aiPromptCreateRequest.setUserPrompt(userPromptText); + aiPromptCreateRequest.setModelName(modelName); + aiPromptCreateRequest.setPromptType(promptType); + aiPromptCreateRequest.setPromptTemperature(promptTemperature); + aiPromptCreateRequest.setJsonResponse(isJsonResponse); + if (isJsonResponse && jsonResponseKey == null) { + throw new CommandException("jsonResponseKey is required when isJsonResponse is true"); + } + aiPromptCreateRequest.setJsonResponseKey(jsonResponseKey); + long promptId = aiServiceClient.createPrompt(aiPromptCreateRequest); consoleWriter.newLine().a("Prompt created with id: " + promptId).println(); } } diff --git a/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java b/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java index dce7dd05a7..0cdaacc248 100644 --- a/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java +++ b/common/src/main/java/com/box/l10n/mojito/openai/OpenAIClient.java @@ -428,7 +428,11 @@ public record Choice( int index, Delta delta, @JsonProperty("finish_reason") String finishReason) { public enum FinishReasons { - STOP("stop"); + STOP("stop"), + LENGTH("length"), + FUNCTION_CALL("function_call"), + CONTENT_FILTER("content_filter"), + NULL("null"); String value; diff --git a/restclient/src/main/java/com/box/l10n/mojito/rest/entity/AIPromptCreateRequest.java b/restclient/src/main/java/com/box/l10n/mojito/rest/entity/AIPromptCreateRequest.java index 465ecb4559..95d8512ea5 100644 --- a/restclient/src/main/java/com/box/l10n/mojito/rest/entity/AIPromptCreateRequest.java +++ b/restclient/src/main/java/com/box/l10n/mojito/rest/entity/AIPromptCreateRequest.java @@ -12,6 +12,8 @@ public class AIPromptCreateRequest { private boolean deleted; private String repositoryName; private String promptType; + private boolean isJsonResponse; + private String jsonResponseKey; public boolean isDeleted() { return deleted; @@ -68,4 +70,20 @@ public String getRepositoryName() { public void setRepositoryName(String repositoryName) { this.repositoryName = repositoryName; } + + public boolean isJsonResponse() { + return isJsonResponse; + } + + public void setJsonResponse(boolean jsonResponse) { + isJsonResponse = jsonResponse; + } + + public String getJsonResponseKey() { + return jsonResponseKey; + } + + public void setJsonResponseKey(String jsonResponseKey) { + this.jsonResponseKey = jsonResponseKey; + } } diff --git a/webapp/src/main/java/com/box/l10n/mojito/entity/AIPrompt.java b/webapp/src/main/java/com/box/l10n/mojito/entity/AIPrompt.java index 42c4e19149..f1ec51f374 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/entity/AIPrompt.java +++ b/webapp/src/main/java/com/box/l10n/mojito/entity/AIPrompt.java @@ -3,6 +3,8 @@ import jakarta.persistence.Column; import jakarta.persistence.Entity; import jakarta.persistence.FetchType; +import jakarta.persistence.JoinColumn; +import jakarta.persistence.ManyToOne; import jakarta.persistence.OneToMany; import jakarta.persistence.OrderBy; import jakarta.persistence.Table; @@ -31,6 +33,10 @@ public class AIPrompt extends BaseEntity { @Column(name = "deleted") private boolean deleted; + @ManyToOne + @JoinColumn(name = "prompt_type_id") + private AIPromptType promptType; + @CreatedDate @Column(name = "created_date") private ZonedDateTime createdDate; @@ -44,6 +50,12 @@ public class AIPrompt extends BaseEntity { @OrderBy("orderIndex ASC") List contextMessages; + @Column(name = "json_response") + private boolean jsonResponse; + + @Column(name = "json_response_key") + private String jsonResponseKey; + public String getModelName() { return modelName; } @@ -107,4 +119,28 @@ public ZonedDateTime getLastModifiedDate() { public void setLastModifiedDate(ZonedDateTime lastModifiedDate) { this.lastModifiedDate = lastModifiedDate; } + + public AIPromptType getPromptType() { + return promptType; + } + + public void setPromptType(AIPromptType promptType) { + this.promptType = promptType; + } + + public boolean isJsonResponse() { + return jsonResponse; + } + + public void setJsonResponse(boolean jsonResponse) { + this.jsonResponse = jsonResponse; + } + + public String getJsonResponseKey() { + return jsonResponseKey; + } + + public void setJsonResponseKey(String jsonResponseKey) { + this.jsonResponseKey = jsonResponseKey; + } } diff --git a/webapp/src/main/java/com/box/l10n/mojito/entity/PromptType.java b/webapp/src/main/java/com/box/l10n/mojito/entity/PromptType.java index a45b851d1f..0cb40725c2 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/entity/PromptType.java +++ b/webapp/src/main/java/com/box/l10n/mojito/entity/PromptType.java @@ -1,5 +1,10 @@ package com.box.l10n.mojito.entity; public enum PromptType { - SOURCE_STRING_CHECKER; + SOURCE_STRING_CHECKER, + TRANSLATION; + + public String toString() { + return name(); + } } diff --git a/webapp/src/main/java/com/box/l10n/mojito/entity/RepositoryAIPrompt.java b/webapp/src/main/java/com/box/l10n/mojito/entity/RepositoryAIPrompt.java deleted file mode 100644 index 7b079ab068..0000000000 --- a/webapp/src/main/java/com/box/l10n/mojito/entity/RepositoryAIPrompt.java +++ /dev/null @@ -1,50 +0,0 @@ -package com.box.l10n.mojito.entity; - -import jakarta.persistence.Column; -import jakarta.persistence.Entity; -import jakarta.persistence.Index; -import jakarta.persistence.Table; - -@Entity -@Table( - name = "repository_ai_prompt", - indexes = { - @Index( - name = "I__REPOSITORY_AI_PROMPT__REPO_ID__TYPE_ID__PROMPT_ID", - columnList = "repository_id, prompt_type_id, ai_prompt_id") - }) -public class RepositoryAIPrompt extends BaseEntity { - - @Column(name = "repository_id") - private Long repositoryId; - - @Column(name = "ai_prompt_id") - private Long aiPromptId; - - @Column(name = "prompt_type_id") - private Long promptTypeId; - - public Long getAiPromptId() { - return aiPromptId; - } - - public void setAiPromptId(Long aiPromptId) { - this.aiPromptId = aiPromptId; - } - - public long getPromptTypeId() { - return promptTypeId; - } - - public void setPromptTypeId(Long promptTypeId) { - this.promptTypeId = promptTypeId; - } - - public long getRepositoryId() { - return repositoryId; - } - - public void setRepositoryId(long repositoryId) { - this.repositoryId = repositoryId; - } -} diff --git a/webapp/src/main/java/com/box/l10n/mojito/entity/RepositoryLocaleAIPrompt.java b/webapp/src/main/java/com/box/l10n/mojito/entity/RepositoryLocaleAIPrompt.java new file mode 100644 index 0000000000..0d78a52fa5 --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/entity/RepositoryLocaleAIPrompt.java @@ -0,0 +1,66 @@ +package com.box.l10n.mojito.entity; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.JoinColumn; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.Table; +import jakarta.persistence.UniqueConstraint; + +@Entity +@Table( + name = "repository_locale_ai_prompt", + uniqueConstraints = { + @UniqueConstraint( + name = "UK__REPOSITORY_LOCALE_AI_PROMPT__REPO_ID__LOCALE_ID__AI_PROMPT", + columnNames = {"repository_id", "locale_id", "ai_prompt_id"}) + }) +public class RepositoryLocaleAIPrompt extends BaseEntity { + + @ManyToOne + @JoinColumn(name = "repository_id", nullable = false) + private Repository repository; + + @ManyToOne + @JoinColumn(name = "locale_id", nullable = true) + private Locale locale; + + @ManyToOne + @JoinColumn(name = "ai_prompt_id", nullable = false) + private AIPrompt aiPrompt; + + @Column(name = "disabled") + private boolean disabled; + + public AIPrompt getAiPrompt() { + return aiPrompt; + } + + public void setAiPrompt(AIPrompt aiPrompt) { + this.aiPrompt = aiPrompt; + } + + public Repository getRepository() { + return repository; + } + + public void setRepository(Repository repository) { + this.repository = repository; + } + + public Locale getLocale() { + return locale; + } + + public void setLocale(Locale locale) { + this.locale = locale; + } + + public boolean isDisabled() { + return disabled; + } + + public void setDisabled(boolean disabled) { + this.disabled = disabled; + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/entity/TMTextUnitVariant.java b/webapp/src/main/java/com/box/l10n/mojito/entity/TMTextUnitVariant.java index 0089633c76..20ebb94406 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/entity/TMTextUnitVariant.java +++ b/webapp/src/main/java/com/box/l10n/mojito/entity/TMTextUnitVariant.java @@ -68,6 +68,10 @@ public enum Status { * TRANSLATION_NEEDED status along with a comment. */ REVIEW_NEEDED, + + MT_TRANSLATED, + + MT_REVIEW, /** A string that doesn't need any work to be performed on it. */ APPROVED; }; diff --git a/webapp/src/main/java/com/box/l10n/mojito/entity/TmTextUnitPendingMT.java b/webapp/src/main/java/com/box/l10n/mojito/entity/TmTextUnitPendingMT.java new file mode 100644 index 0000000000..b622fa2d75 --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/entity/TmTextUnitPendingMT.java @@ -0,0 +1,33 @@ +package com.box.l10n.mojito.entity; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Table; +import java.time.ZonedDateTime; + +@Entity +@Table(name = "tm_text_unit_pending_mt") +public class TmTextUnitPendingMT extends BaseEntity { + + @Column(name = "tm_text_unit_id") + private Long tmTextUnitId; + + @Column(name = "created_date") + private ZonedDateTime createdDate; + + public ZonedDateTime getCreatedDate() { + return createdDate; + } + + public void setCreatedDate(ZonedDateTime createdDate) { + this.createdDate = createdDate; + } + + public Long getTmTextUnitId() { + return tmTextUnitId; + } + + public void setTmTextUnitId(Long tmTextUnitId) { + this.tmTextUnitId = tmTextUnitId; + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIException.java b/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIException.java index a968f7efd5..7e5bbecaec 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIException.java +++ b/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIException.java @@ -9,4 +9,8 @@ public AIException(String message) { public AIException(String message, Exception e) { super(message, e); } + + public AIException(String message, Throwable t) { + super(message, t); + } } diff --git a/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIPromptCreateRequest.java b/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIPromptCreateRequest.java index db77ec6175..00f3ce6078 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIPromptCreateRequest.java +++ b/webapp/src/main/java/com/box/l10n/mojito/rest/ai/AIPromptCreateRequest.java @@ -12,6 +12,8 @@ public class AIPromptCreateRequest { private boolean deleted; private String repositoryName; private String promptType; + private boolean isJsonResponse; + private String jsonResponseKey; public boolean isDeleted() { return deleted; @@ -68,4 +70,20 @@ public String getRepositoryName() { public void setRepositoryName(String repositoryName) { this.repositoryName = repositoryName; } + + public boolean isJsonResponse() { + return isJsonResponse; + } + + public void setJsonResponse(boolean jsonResponse) { + isJsonResponse = jsonResponse; + } + + public String getJsonResponseKey() { + return jsonResponseKey; + } + + public void setJsonResponseKey(String jsonResponseKey) { + this.jsonResponseKey = jsonResponseKey; + } } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/AIPromptRepository.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/AIPromptRepository.java index a987013aaf..ec96be6c79 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/ai/AIPromptRepository.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/AIPromptRepository.java @@ -12,9 +12,9 @@ public interface AIPromptRepository extends JpaRepository { @Query( "SELECT ap FROM AIPrompt ap " - + "JOIN RepositoryAIPrompt rap ON ap.id = rap.aiPromptId " - + "JOIN AIPromptType apt ON rap.promptTypeId = apt.id " - + "WHERE rap.repositoryId = :repositoryId AND apt.name = :promptTypeName AND ap.deleted = false") + + "JOIN RepositoryLocaleAIPrompt rlap ON ap.id = rlap.aiPrompt.id " + + "JOIN AIPromptType apt ON ap.promptType.id = apt.id " + + "WHERE rlap.repository.id = :repositoryId AND apt.name = :promptTypeName AND ap.deleted = false") List findByRepositoryIdAndPromptTypeName( @Param("repositoryId") Long repositoryId, @Param("promptTypeName") String promptTypeName); @@ -22,7 +22,7 @@ List findByRepositoryIdAndPromptTypeName( @Query( "SELECT ap FROM AIPrompt ap " - + "JOIN RepositoryAIPrompt rap ON ap.id = rap.aiPromptId " - + "WHERE rap.repositoryId = :repositoryId AND ap.deleted = false") + + "JOIN RepositoryLocaleAIPrompt rlap ON ap.id = rlap.aiPrompt.id " + + "WHERE rlap.repository.id = :repositoryId AND ap.deleted = false") List findByRepositoryIdAndDeletedFalse(Long repositoryId); } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/LLMPromptService.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMPromptService.java similarity index 87% rename from webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/LLMPromptService.java rename to webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMPromptService.java index 022a570e22..8c072cf0b9 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/LLMPromptService.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMPromptService.java @@ -1,4 +1,4 @@ -package com.box.l10n.mojito.service.ai.openai; +package com.box.l10n.mojito.service.ai; import com.box.l10n.mojito.JSR310Migration; import com.box.l10n.mojito.entity.AIPrompt; @@ -6,15 +6,11 @@ import com.box.l10n.mojito.entity.AIPromptType; import com.box.l10n.mojito.entity.PromptType; import com.box.l10n.mojito.entity.Repository; -import com.box.l10n.mojito.entity.RepositoryAIPrompt; +import com.box.l10n.mojito.entity.RepositoryLocaleAIPrompt; import com.box.l10n.mojito.rest.ai.AIException; import com.box.l10n.mojito.rest.ai.AIPromptContextMessageCreateRequest; import com.box.l10n.mojito.rest.ai.AIPromptCreateRequest; -import com.box.l10n.mojito.service.ai.AIPromptContextMessageRepository; -import com.box.l10n.mojito.service.ai.AIPromptRepository; -import com.box.l10n.mojito.service.ai.AIPromptTypeRepository; -import com.box.l10n.mojito.service.ai.PromptService; -import com.box.l10n.mojito.service.ai.RepositoryAIPromptRepository; +import com.box.l10n.mojito.service.ai.openai.OpenAIPromptContextMessageType; import com.box.l10n.mojito.service.repository.RepositoryRepository; import io.micrometer.core.annotation.Timed; import jakarta.transaction.Transactional; @@ -38,7 +34,7 @@ public class LLMPromptService implements PromptService { @Autowired AIPromptTypeRepository aiPromptTypeRepository; - @Autowired RepositoryAIPromptRepository repositoryAIPromptRepository; + @Autowired RepositoryLocaleAIPromptRepository repositoryLocaleAIPromptRepository; @Autowired AIPromptContextMessageRepository aiPromptContextMessageRepository; @@ -66,18 +62,20 @@ public Long createPrompt(AIPromptCreateRequest AIPromptCreateRequest) { aiPrompt.setUserPrompt(AIPromptCreateRequest.getUserPrompt()); aiPrompt.setPromptTemperature(AIPromptCreateRequest.getPromptTemperature()); aiPrompt.setModelName(AIPromptCreateRequest.getModelName()); + aiPrompt.setPromptType(aiPromptType); ZonedDateTime now = JSR310Migration.dateTimeNow(); aiPrompt.setCreatedDate(now); aiPrompt.setLastModifiedDate(now); + aiPrompt.setJsonResponse(AIPromptCreateRequest.isJsonResponse()); + aiPrompt.setJsonResponseKey(AIPromptCreateRequest.getJsonResponseKey()); aiPromptRepository.save(aiPrompt); logger.debug("Created prompt with id: {}", aiPrompt.getId()); - RepositoryAIPrompt repositoryAIPrompt = new RepositoryAIPrompt(); - repositoryAIPrompt.setRepositoryId(repository.getId()); - repositoryAIPrompt.setAiPromptId(aiPrompt.getId()); - repositoryAIPrompt.setPromptTypeId(aiPromptType.getId()); - repositoryAIPromptRepository.save(repositoryAIPrompt); - logger.debug("Created repository prompt with id: {}", repositoryAIPrompt.getId()); + RepositoryLocaleAIPrompt repositoryLocaleAIPrompt = new RepositoryLocaleAIPrompt(); + repositoryLocaleAIPrompt.setRepository(repository); + repositoryLocaleAIPrompt.setAiPrompt(aiPrompt); + repositoryLocaleAIPromptRepository.save(repositoryLocaleAIPrompt); + logger.debug("Created repository prompt with id: {}", repositoryLocaleAIPrompt.getId()); return aiPrompt.getId(); } @@ -102,12 +100,11 @@ public void addPromptToRepository(Long promptId, String repositoryName, String p .findById(promptId) .orElseThrow(() -> new AIException("Prompt not found: " + promptId)); - RepositoryAIPrompt repositoryAIPrompt = new RepositoryAIPrompt(); - repositoryAIPrompt.setRepositoryId(repository.getId()); - repositoryAIPrompt.setAiPromptId(aiPrompt.getId()); - repositoryAIPrompt.setPromptTypeId(aiPromptType.getId()); - repositoryAIPromptRepository.save(repositoryAIPrompt); - logger.debug("Created repository prompt with id: {}", repositoryAIPrompt.getId()); + RepositoryLocaleAIPrompt repositoryLocaleAIPrompt = new RepositoryLocaleAIPrompt(); + repositoryLocaleAIPrompt.setRepository(repository); + repositoryLocaleAIPrompt.setAiPrompt(aiPrompt); + repositoryLocaleAIPromptRepository.save(repositoryLocaleAIPrompt); + logger.debug("Created repository prompt with id: {}", repositoryLocaleAIPrompt.getId()); } @Timed("LLMPromptService.getPromptsByRepositoryAndPromptType") diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMService.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMService.java index ba47823985..5e63d90c82 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMService.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/LLMService.java @@ -4,6 +4,7 @@ import com.box.l10n.mojito.entity.AIPrompt; import com.box.l10n.mojito.entity.AIStringCheck; import com.box.l10n.mojito.entity.Repository; +import com.box.l10n.mojito.entity.TMTextUnit; import com.box.l10n.mojito.okapi.extractor.AssetExtractorTextUnit; import com.box.l10n.mojito.rest.ai.AICheckRequest; import com.box.l10n.mojito.rest.ai.AICheckResponse; @@ -14,6 +15,9 @@ public interface LLMService { String SOURCE_STRING_PLACEHOLDER = "[mojito_source_string]"; String COMMENT_STRING_PLACEHOLDER = "[mojito_comment_string]"; String CONTEXT_STRING_PLACEHOLDER = "[mojito_context_string]"; + String SOURCE_LOCALE_PLACEHOLDER = "[mojito_source_locale]"; + String TARGET_LOCALE_PLACEHOLDER = "[mojito_target_locale]"; + String PLURAL_FORM_PLACEHOLDER = "[mojito_plural_form]"; /** * Executes AI checks on the provided text units. @@ -51,4 +55,7 @@ default void persistCheckResult( aiStringCheck.setStringName(textUnit.getName()); aiStringCheckRepository.save(aiStringCheck); } + + String translate( + TMTextUnit tmTextUnit, String sourceBcp47Tag, String targetBcp47Tag, AIPrompt prompt); } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/RepositoryAIPromptRepository.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/RepositoryAIPromptRepository.java deleted file mode 100644 index 0f08b6faf5..0000000000 --- a/webapp/src/main/java/com/box/l10n/mojito/service/ai/RepositoryAIPromptRepository.java +++ /dev/null @@ -1,15 +0,0 @@ -package com.box.l10n.mojito.service.ai; - -import com.box.l10n.mojito.entity.RepositoryAIPrompt; -import java.util.List; -import org.springframework.data.jpa.repository.JpaRepository; -import org.springframework.data.jpa.repository.Query; -import org.springframework.data.repository.query.Param; - -public interface RepositoryAIPromptRepository extends JpaRepository { - - List findByRepositoryIdAndPromptTypeId(Long repositoryId, Long promptTypeId); - - @Query("SELECT rap.repositoryId FROM RepositoryAIPrompt rap WHERE rap.aiPromptId = :aiPromptId") - List findRepositoryIdsByAiPromptId(@Param("aiPromptId") Long aiPromptId); -} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/RepositoryLocaleAIPromptRepository.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/RepositoryLocaleAIPromptRepository.java new file mode 100644 index 0000000000..c09643ebd8 --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/RepositoryLocaleAIPromptRepository.java @@ -0,0 +1,27 @@ +package com.box.l10n.mojito.service.ai; + +import com.box.l10n.mojito.entity.RepositoryLocaleAIPrompt; +import java.util.List; +import org.springframework.data.jpa.repository.JpaRepository; +import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; + +public interface RepositoryLocaleAIPromptRepository + extends JpaRepository { + + @Query( + "SELECT count(rlap.id) FROM RepositoryLocaleAIPrompt rlap " + + "JOIN AIPrompt aip ON rlap.aiPrompt.id = aip.id " + + "JOIN AIPromptType aipt ON aip.promptType.id = aipt.id " + + "WHERE rlap.repository.id = :repositoryId AND rlap.disabled = false AND aip.deleted = false AND aipt.name = :promptType") + Long findCountOfActiveRepositoryPromptsByType( + @Param("repositoryId") Long repositoryId, @Param("promptType") String promptType); + + @Query( + "SELECT rlap FROM RepositoryLocaleAIPrompt rlap " + + "JOIN rlap.aiPrompt aip " + + "JOIN aip.promptType aipt " + + "WHERE rlap.repository.id = :repositoryId AND aip.deleted = false AND aipt.name = :promptType") + List getActivePromptsByRepositoryAndPromptType( + @Param("repositoryId") Long repositoryId, @Param("promptType") String promptType); +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMService.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMService.java index 9732b44ea0..db646dccbd 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMService.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMService.java @@ -9,6 +9,7 @@ import com.box.l10n.mojito.entity.AIPrompt; import com.box.l10n.mojito.entity.AIPromptContextMessage; import com.box.l10n.mojito.entity.Repository; +import com.box.l10n.mojito.entity.TMTextUnit; import com.box.l10n.mojito.json.ObjectMapper; import com.box.l10n.mojito.okapi.extractor.AssetExtractorTextUnit; import com.box.l10n.mojito.openai.OpenAIClient; @@ -17,16 +18,22 @@ import com.box.l10n.mojito.rest.ai.AICheckResult; import com.box.l10n.mojito.rest.ai.AIException; import com.box.l10n.mojito.service.ai.AIStringCheckRepository; +import com.box.l10n.mojito.service.ai.LLMPromptService; import com.box.l10n.mojito.service.ai.LLMService; import com.box.l10n.mojito.service.repository.RepositoryRepository; import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.base.Strings; import io.micrometer.core.annotation.Timed; import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tags; +import jakarta.annotation.PostConstruct; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,6 +41,9 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.stereotype.Service; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; +import reactor.util.retry.RetryBackoffSpec; @Service @ConditionalOnProperty(value = "l10n.ai.service.type", havingValue = "OpenAI") @@ -56,6 +66,19 @@ public class OpenAILLMService implements LLMService { @Value("${l10n.ai.checks.persistResults:true}") boolean persistResults; + @Value("${l10n.ai.translate.retry.maxAttempts:10}") + int retryMaxAttempts; + + @Value("${l10n.ai.translate.retry.minDurationSeconds:5}") + int retryMinDurationSeconds; + + @Value("${l10n.ai.translate.retry.maxBackoffDurationSeconds:60}") + int retryMaxBackoffDurationSeconds; + + RetryBackoffSpec llmTranslateRetryConfig; + + Map patternCache = new HashMap<>(); + @Timed("OpenAILLMService.executeAIChecks") public AICheckResponse executeAIChecks(AICheckRequest aiCheckRequest) { @@ -93,6 +116,105 @@ public AICheckResponse executeAIChecks(AICheckRequest aiCheckRequest) { return aiCheckResponse; } + @Override + @Timed("OpenAILLMService.translate") + public String translate( + TMTextUnit tmTextUnit, String sourceBcp47Tag, String targetBcp47Tag, AIPrompt prompt) { + logger.debug( + "Translating text unit {} from {} to {} using prompt {}", + tmTextUnit.getId(), + sourceBcp47Tag, + targetBcp47Tag, + prompt.getId()); + String systemPrompt = + getTranslationFormattedPrompt( + prompt.getSystemPrompt(), tmTextUnit, sourceBcp47Tag, targetBcp47Tag); + String userPrompt = + getTranslationFormattedPrompt( + prompt.getUserPrompt(), tmTextUnit, sourceBcp47Tag, targetBcp47Tag); + + OpenAIClient.ChatCompletionsRequest chatCompletionsRequest = + buildChatCompletionsRequest( + prompt, systemPrompt, userPrompt, prompt.getContextMessages(), prompt.isJsonResponse()); + + return Mono.fromCallable( + () -> { + OpenAIClient.ChatCompletionsResponse chatCompletionsResponse = + openAIClient.getChatCompletions(chatCompletionsRequest).join(); + if (chatCompletionsResponse.choices().size() > 1) { + logger.error( + "Multiple choices returned for text unit {}, expected only one", + tmTextUnit.getId()); + meterRegistry + .counter("OpenAILLMService.translate.error.multiChoiceResponse") + .increment(); + throw new AIException( + "Multiple response choices returned for text unit " + + tmTextUnit.getId() + + ", expected only one"); + } + if (chatCompletionsResponse + .choices() + .getFirst() + .finishReason() + .equals( + OpenAIClient.ChatCompletionsStreamResponse.Choice.FinishReasons.STOP + .getValue())) { + String response = chatCompletionsResponse.choices().getFirst().message().content(); + logger.debug( + "TmTextUnit id: {}, {} translation response: {}", + tmTextUnit.getId(), + targetBcp47Tag, + response); + if (prompt.isJsonResponse()) { + try { + logger.debug("Parsing JSON response for key: {}", prompt.getJsonResponseKey()); + response = + objectMapper.readTree(response).get(prompt.getJsonResponseKey()).asText(); + logger.debug("Parsed translation: {}", response); + } catch (JsonProcessingException e) { + logger.error("Error parsing JSON response: {}", response, e); + throw new AIException("Error parsing JSON response: " + response); + } + } + meterRegistry + .counter("OpenAILLMService.translate.result", "success", "true") + .increment(); + return response; + } + String message = + String.format( + "Error translating text unit %d from %s to %s, response finish_reason: %s", + tmTextUnit.getId(), + sourceBcp47Tag, + targetBcp47Tag, + chatCompletionsResponse.choices().getFirst().finishReason()); + logger.error(message); + throw new AIException(message); + }) + .doOnError( + e -> { + logger.error("Error translating text unit {}", tmTextUnit.getId(), e); + meterRegistry + .counter( + "OpenAILLMService.translate.result", + Tags.of("success", "false", "retryable", "true")) + .increment(); + }) + .retryWhen(llmTranslateRetryConfig) + .doOnError( + e -> { + logger.error("Error translating text unit {}", tmTextUnit.getId(), e); + meterRegistry + .counter( + "OpenAILLMService.translate.result", + Tags.of("success", "false", "retryable", "false")) + .increment(); + }) + .blockOptional() + .orElseThrow(() -> new AIException("Error translating text unit " + tmTextUnit.getId())); + } + @Timed("OpenAILLMService.checkString") private List checkString( AssetExtractorTextUnit textUnit, List prompts, Repository repository) { @@ -134,9 +256,11 @@ private void executePromptChecks( prompt.getId()); continue; } - String systemPrompt = getFormattedPrompt(prompt.getSystemPrompt(), sourceString, comment); + String systemPrompt = + getStringChecksFormattedPrompt(prompt.getSystemPrompt(), sourceString, comment); - String userPrompt = getFormattedPrompt(prompt.getUserPrompt(), sourceString, comment); + String userPrompt = + getStringChecksFormattedPrompt(prompt.getUserPrompt(), sourceString, comment); if (nameSplit.length > 1 && (systemPrompt.contains(CONTEXT_STRING_PLACEHOLDER) @@ -159,11 +283,11 @@ private void executePromptChecks( chatCompletionsResponse.choices().getFirst().message().content(), aiStringCheckRepository); } - results.add(parseResponse(chatCompletionsResponse, repository)); + results.add(parseAICheckPromptResponse(chatCompletionsResponse, repository)); } } - private AICheckResult parseResponse( + private AICheckResult parseAICheckPromptResponse( OpenAIClient.ChatCompletionsResponse chatCompletionsResponse, Repository repository) { AICheckResult result; String response = chatCompletionsResponse.choices().getFirst().message().content(); @@ -190,15 +314,68 @@ private AICheckResult parseResponse( return result; } - private static String getFormattedPrompt(String prompt, String sourceString, String comment) { - String systemPrompt = ""; + private static String getStringChecksFormattedPrompt( + String prompt, String sourceString, String comment) { + String formattedPrompt = ""; if (prompt != null) { - systemPrompt = + formattedPrompt = prompt .replace(SOURCE_STRING_PLACEHOLDER, sourceString) .replace(COMMENT_STRING_PLACEHOLDER, comment); } - return systemPrompt; + return formattedPrompt; + } + + protected String getTranslationFormattedPrompt( + String prompt, TMTextUnit tmTextUnit, String sourceBcp47Tag, String targetBcp47Tag) { + String formattedPrompt = ""; + if (prompt != null) { + formattedPrompt = + prompt + .replace(SOURCE_STRING_PLACEHOLDER, tmTextUnit.getContent()) + .replace(SOURCE_LOCALE_PLACEHOLDER, sourceBcp47Tag) + .replace(TARGET_LOCALE_PLACEHOLDER, targetBcp47Tag); + formattedPrompt = + processOptionalPlaceholderText( + formattedPrompt, COMMENT_STRING_PLACEHOLDER, tmTextUnit.getComment()); + formattedPrompt = + processOptionalPlaceholderText( + formattedPrompt, + PLURAL_FORM_PLACEHOLDER, + tmTextUnit.getPluralForm() != null ? tmTextUnit.getPluralForm().getName() : null); + formattedPrompt = + processOptionalPlaceholderText( + formattedPrompt, + CONTEXT_STRING_PLACEHOLDER, + tmTextUnit.getName() != null && tmTextUnit.getName().split(" --- ").length > 1 + ? tmTextUnit.getName().split(" --- ")[1] + : null); + } + return formattedPrompt.trim(); + } + + private String processOptionalPlaceholderText( + String promptText, String placeholder, String placeholderValue) { + if (placeholderValue != null && !placeholderValue.isEmpty()) { + Pattern pattern = patternCache.get(placeholder); + Matcher matcher = pattern.matcher(promptText); + if (matcher.find()) { + String optionalContent = matcher.group(1) + placeholderValue + matcher.group(2); + if (matcher.groupCount() > 2) { + optionalContent += matcher.group(3); + } + promptText = matcher.replaceFirst(optionalContent); + } + } else { + // Remove the entire template block from the prompt if we have no value for the placeholder, + // also removing new line characters if they exist immediately after the template ends + String regex = + "\\{\\{optional: [^\\{\\}]*" + + Pattern.quote(placeholder) + + "[^\\{\\}]*\\}\\}\\s*(?:\\r?\\n)?"; + promptText = promptText.replaceAll(regex, ""); + } + return promptText; } private static OpenAIClient.ChatCompletionsRequest buildChatCompletionsRequest( @@ -236,4 +413,41 @@ private static List buildPromptMess return messages; } + + @PostConstruct + public void init() { + String commentPattern = + "\\{\\{optional: ([^\\{\\}]*)" + + Pattern.quote(COMMENT_STRING_PLACEHOLDER) + + "([^\\{\\}]*)\\}\\}(\\s*(?:\\r?\\n)?)"; + String pluralFormPattern = + "\\{\\{optional: ([^\\{\\}]*)" + + Pattern.quote(PLURAL_FORM_PLACEHOLDER) + + "([^\\{\\}]*)\\}\\}(\\s*(?:\\r?\\n)?)"; + String contextPattern = + "\\{\\{optional: ([^\\{\\}]*)" + + Pattern.quote(CONTEXT_STRING_PLACEHOLDER) + + "([^\\{\\}]*)\\}\\}(\\s*(?:\\r?\\n)?)"; + patternCache.put(COMMENT_STRING_PLACEHOLDER, Pattern.compile(commentPattern)); + patternCache.put(PLURAL_FORM_PLACEHOLDER, Pattern.compile(pluralFormPattern)); + patternCache.put(CONTEXT_STRING_PLACEHOLDER, Pattern.compile(contextPattern)); + llmTranslateRetryConfig = + Retry.backoff(retryMaxAttempts, Duration.ofSeconds(retryMinDurationSeconds)) + .maxBackoff(Duration.ofSeconds(retryMaxBackoffDurationSeconds)) + .onRetryExhaustedThrow( + (retryBackoffSpec, retrySignal) -> { + Throwable error = retrySignal.failure(); + logger.error( + "Retry exhausted after {} attempts: {}", + retryMaxAttempts, + error.getMessage(), + error); + return new AIException( + "Retry exhausted after " + + retryMaxAttempts + + " attempts: " + + error.getMessage(), + error); + }); + } } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslateCronJob.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslateCronJob.java new file mode 100644 index 0000000000..5fcef410f9 --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslateCronJob.java @@ -0,0 +1,353 @@ +package com.box.l10n.mojito.service.ai.translation; + +import com.box.l10n.mojito.JSR310Migration; +import com.box.l10n.mojito.entity.Locale; +import com.box.l10n.mojito.entity.PromptType; +import com.box.l10n.mojito.entity.Repository; +import com.box.l10n.mojito.entity.RepositoryLocale; +import com.box.l10n.mojito.entity.RepositoryLocaleAIPrompt; +import com.box.l10n.mojito.entity.TMTextUnit; +import com.box.l10n.mojito.entity.TMTextUnitVariant; +import com.box.l10n.mojito.entity.TmTextUnitPendingMT; +import com.box.l10n.mojito.rest.ai.AIException; +import com.box.l10n.mojito.service.ai.LLMService; +import com.box.l10n.mojito.service.ai.RepositoryLocaleAIPromptRepository; +import com.box.l10n.mojito.service.tm.TMTextUnitRepository; +import com.box.l10n.mojito.service.tm.TMTextUnitVariantRepository; +import com.google.common.collect.Lists; +import io.micrometer.core.annotation.Timed; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tags; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.commons.codec.digest.DigestUtils; +import org.quartz.DisallowConcurrentExecution; +import org.quartz.Job; +import org.quartz.JobDetail; +import org.quartz.JobExecutionContext; +import org.quartz.JobExecutionException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.scheduling.quartz.CronTriggerFactoryBean; +import org.springframework.scheduling.quartz.JobDetailFactoryBean; +import org.springframework.stereotype.Component; + +/** + * Quartz job that translates text units in batches via AI. + * + * @author maallen + */ +@Component +@Configuration +@ConditionalOnProperty(value = "l10n.ai.translation.enabled", havingValue = "true") +@DisallowConcurrentExecution +public class AITranslateCronJob implements Job { + + static Logger logger = LoggerFactory.getLogger(AITranslateCronJob.class); + + private static final String REPOSITORY_DEFAULT_PROMPT = "repository_default_prompt"; + + @Autowired TMTextUnitRepository tmTextUnitRepository; + + @Autowired TMTextUnitVariantRepository tmTextUnitVariantRepository; + + @Autowired LLMService llmService; + + @Autowired MeterRegistry meterRegistry; + + @Autowired RepositoryLocaleAIPromptRepository repositoryLocaleAIPromptRepository; + + @Autowired AITranslationTextUnitFilterService aiTranslationTextUnitFilterService; + + @Autowired AITranslationConfiguration aiTranslationConfiguration; + + @Autowired AITranslationService aiTranslationService; + + @Autowired TmTextUnitPendingMTRepository tmTextUnitPendingMTRepository; + + @Autowired JdbcTemplate jdbcTemplate; + + @Value("${l10n.ai.translation.job.threads:1}") + int threads; + + @Timed("AITranslateCronJob.translate") + public void translate(Repository repository, TMTextUnit tmTextUnit, TmTextUnitPendingMT pendingMT) + throws AIException { + + try { + if (pendingMT != null) { + if (!isExpired(pendingMT)) { + if (aiTranslationTextUnitFilterService.isTranslatable(tmTextUnit, repository)) { + translateLocales(tmTextUnit, repository, getLocalesForMT(repository, tmTextUnit)); + meterRegistry + .timer("AITranslateCronJob.timeToMT", Tags.of("repository", repository.getName())) + .record( + Duration.between(JSR310Migration.dateTimeNow(), pendingMT.getCreatedDate())); + } else { + logger.debug( + "Text unit with name: {} should not be translated, skipping AI translation.", + tmTextUnit.getName()); + meterRegistry.counter( + "AITranslateCronJob.translate.notTranslatable", + Tags.of("repository", repository.getName())); + } + } else { + // If the pending MT is expired, log an error and delete it + logger.error("Pending MT for tmTextUnitId: {} is expired", tmTextUnit.getId()); + meterRegistry.counter( + "AITranslateCronJob.expired", Tags.of("repository", repository.getName())); + } + } + } catch (Exception e) { + logger.error("Error running job for text unit id {}", tmTextUnit.getId(), e); + meterRegistry.counter( + "AITranslateCronJob.error", Tags.of("repository", repository.getName())); + } + } + + private Set getLocalesForMT(Repository repository, TMTextUnit tmTextUnit) { + Set localesWithVariants = + tmTextUnitVariantRepository.findLocalesWithVariantByTmTextUnit_Id(tmTextUnit.getId()); + return repository.getRepositoryLocales().stream() + .map(RepositoryLocale::getLocale) + .filter( + locale -> + !localesWithVariants.contains(locale) + && !locale.equals(repository.getSourceLocale())) + .collect(Collectors.toSet()); + } + + private void translateLocales( + TMTextUnit tmTextUnit, Repository repository, Set localesForMT) { + + Map repositoryLocaleAIPrompts = + repositoryLocaleAIPromptRepository + .getActivePromptsByRepositoryAndPromptType( + repository.getId(), PromptType.TRANSLATION.toString()) + .stream() + .collect( + Collectors.toMap( + rlap -> + rlap.getLocale() != null + ? rlap.getLocale().getBcp47Tag() + : REPOSITORY_DEFAULT_PROMPT, + Function.identity())); + List aiTranslations = Lists.newArrayList(); + localesForMT.forEach( + targetLocale -> { + try { + String sourceLang = repository.getSourceLocale().getBcp47Tag().split("-")[0]; + if (aiTranslationConfiguration + .getRepositorySettings(repository.getName()) + .isReuseSourceOnLanguageMatch() + && targetLocale.getBcp47Tag().startsWith(sourceLang)) { + aiTranslations.add( + reuseSourceStringAsTranslation(tmTextUnit, repository, targetLocale, sourceLang)); + return; + } + // Get the prompt override for this locale if it exists, otherwise use the + // repository default + RepositoryLocaleAIPrompt repositoryLocaleAIPrompt = + repositoryLocaleAIPrompts.get(targetLocale.getBcp47Tag()) != null + ? repositoryLocaleAIPrompts.get(targetLocale.getBcp47Tag()) + : repositoryLocaleAIPrompts.get(REPOSITORY_DEFAULT_PROMPT); + if (repositoryLocaleAIPrompt != null && !repositoryLocaleAIPrompt.isDisabled()) { + logger.info( + "Translating text unit id {} for locale: {} using prompt: {}", + tmTextUnit.getId(), + targetLocale.getBcp47Tag(), + repositoryLocaleAIPrompt.getAiPrompt().getId()); + aiTranslations.add( + executeTranslationPrompt( + tmTextUnit, repository, targetLocale, repositoryLocaleAIPrompt)); + } else { + logger.debug( + "No active translation prompt found for locale: {}, skipping AI translation.", + targetLocale.getBcp47Tag()); + meterRegistry.counter( + "AITranslateCronJob.translate.noActivePrompt", + Tags.of( + "repository", repository.getName(), "locale", targetLocale.getBcp47Tag())); + } + } catch (Exception e) { + logger.error( + "Error translating text unit id {} for locale: {}", + tmTextUnit.getId(), + targetLocale.getBcp47Tag(), + e); + meterRegistry.counter( + "AITranslateCronJob.translate.error", + Tags.of("repository", repository.getName(), "locale", targetLocale.getBcp47Tag())); + } + }); + aiTranslationService.insertMultiRowAITranslationVariant(tmTextUnit.getId(), aiTranslations); + } + + private AITranslation reuseSourceStringAsTranslation( + TMTextUnit tmTextUnit, Repository repository, Locale targetLocale, String sourceLang) { + logger.debug( + "Target language {} matches source language {}, re-using source string as translation.", + targetLocale.getBcp47Tag(), + sourceLang); + meterRegistry.counter( + "AITranslateCronJob.translate.reuseSourceAsTranslation", + Tags.of("repository", repository.getName(), "locale", targetLocale.getBcp47Tag())); + + return createAITranslationDTO(tmTextUnit, targetLocale, tmTextUnit.getContent()); + } + + private AITranslation executeTranslationPrompt( + TMTextUnit tmTextUnit, + Repository repository, + Locale targetLocale, + RepositoryLocaleAIPrompt repositoryLocaleAIPrompt) { + String translation = + llmService.translate( + tmTextUnit, + repository.getSourceLocale().getBcp47Tag(), + targetLocale.getBcp47Tag(), + repositoryLocaleAIPrompt.getAiPrompt()); + meterRegistry.counter( + "AITranslateCronJob.translate.success", + Tags.of("repository", repository.getName(), "locale", targetLocale.getBcp47Tag())); + return createAITranslationDTO(tmTextUnit, targetLocale, translation); + } + + private AITranslation createAITranslationDTO( + TMTextUnit tmTextUnit, Locale locale, String translation) { + AITranslation aiTranslation = new AITranslation(); + aiTranslation.setTmTextUnit(tmTextUnit); + aiTranslation.setContentMd5(DigestUtils.md5Hex(translation)); + aiTranslation.setLocaleId(locale.getId()); + aiTranslation.setTranslation(translation); + aiTranslation.setIncludedInLocalizedFile(false); + aiTranslation.setStatus(TMTextUnitVariant.Status.MT_TRANSLATED); + aiTranslation.setCreatedDate(JSR310Migration.dateTimeNow()); + return aiTranslation; + } + + private boolean isExpired(TmTextUnitPendingMT pendingMT) { + return pendingMT + .getCreatedDate() + .isBefore( + JSR310Migration.newDateTimeEmptyCtor() + .minus(aiTranslationConfiguration.getExpiryDuration())); + } + + /** + * Iterates over all pending MTs and translates them. + * + *

As each individual {@link TMTextUnit} is translated into all locales, the associated {@link + * TmTextUnitPendingMT} is deleted. + * + * @param jobExecutionContext + * @throws JobExecutionException + */ + @Override + @Timed("AITranslateCronJob.execute") + public void execute(JobExecutionContext jobExecutionContext) throws JobExecutionException { + logger.info("Executing AITranslateCronJob"); + + ExecutorService executorService = Executors.newFixedThreadPool(threads); + + List pendingMTs; + try { + do { + pendingMTs = + tmTextUnitPendingMTRepository.findBatch(aiTranslationConfiguration.getBatchSize()); + logger.info("Processing {} pending MTs", pendingMTs.size()); + + List> futures = + pendingMTs.stream() + .map( + pendingMT -> + CompletableFuture.runAsync( + () -> { + try { + TMTextUnit tmTextUnit = getTmTextUnit(pendingMT); + Repository repository = tmTextUnit.getAsset().getRepository(); + translate(repository, tmTextUnit, pendingMT); + } catch (Exception e) { + logger.error( + "Error processing pending MT for text unit id: {}", + pendingMT.getTmTextUnitId(), + e); + meterRegistry + .counter("AITranslateCronJob.pendingMT.error") + .increment(); + } finally { + if (pendingMT != null) { + logger.debug( + "Sending pending MT for tmTextUnitId: {} for deletion", + pendingMT.getTmTextUnitId()); + aiTranslationService.sendForDeletion(pendingMT); + } + } + }, + executorService)) + .toList(); + + // Wait for all tasks in this batch to complete + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + } while (!pendingMTs.isEmpty()); + } finally { + shutdownExecutor(executorService); + } + + logger.info("Finished executing AITranslateCronJob"); + } + + private static void shutdownExecutor(ExecutorService executorService) { + try { + executorService.shutdown(); + if (!executorService.awaitTermination(1, TimeUnit.MINUTES)) { + logger.error("Thread pool tasks didn't finish in the expected time."); + executorService.shutdownNow(); + } + } catch (InterruptedException e) { + executorService.shutdownNow(); + } + } + + private TMTextUnit getTmTextUnit(TmTextUnitPendingMT pendingMT) { + return tmTextUnitRepository + .findByIdWithAssetAndRepositoryAndTMFetched(pendingMT.getTmTextUnitId()) + .orElseThrow( + () -> new AIException("TMTextUnit not found for id: " + pendingMT.getTmTextUnitId())); + } + + @Bean(name = "aiTranslateCron") + public JobDetailFactoryBean jobDetailAiTranslateCronJob() { + JobDetailFactoryBean jobDetailFactory = new JobDetailFactoryBean(); + jobDetailFactory.setJobClass(AITranslateCronJob.class); + jobDetailFactory.setDescription("Translate text units in batches via AI"); + jobDetailFactory.setDurability(true); + jobDetailFactory.setName("aiTranslateCron"); + return jobDetailFactory; + } + + @Bean + public CronTriggerFactoryBean triggerSlaCheckerCronJob( + @Qualifier("aiTranslateCron") JobDetail job, + AITranslationConfiguration aiTranslationConfiguration) { + CronTriggerFactoryBean trigger = new CronTriggerFactoryBean(); + trigger.setJobDetail(job); + trigger.setCronExpression(aiTranslationConfiguration.getCron()); + return trigger; + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslateJobException.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslateJobException.java new file mode 100644 index 0000000000..14bb36a26b --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslateJobException.java @@ -0,0 +1,12 @@ +package com.box.l10n.mojito.service.ai.translation; + +public class AITranslateJobException extends RuntimeException { + + public AITranslateJobException(String message) { + super(message); + } + + public AITranslateJobException(String message, Throwable t) { + super(message, t); + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslation.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslation.java new file mode 100644 index 0000000000..bb67f5cdf0 --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslation.java @@ -0,0 +1,72 @@ +package com.box.l10n.mojito.service.ai.translation; + +import com.box.l10n.mojito.entity.TMTextUnit; +import com.box.l10n.mojito.entity.TMTextUnitVariant; +import java.time.ZonedDateTime; + +public class AITranslation { + + TMTextUnit tmTextUnit; + Long localeId; + String translation; + String contentMd5; + TMTextUnitVariant.Status status; + boolean includedInLocalizedFile; + ZonedDateTime createdDate; + + public TMTextUnit getTmTextUnit() { + return tmTextUnit; + } + + public void setTmTextUnit(TMTextUnit tmTextUnit) { + this.tmTextUnit = tmTextUnit; + } + + public Long getLocaleId() { + return localeId; + } + + public void setLocaleId(Long localeId) { + this.localeId = localeId; + } + + public String getTranslation() { + return translation; + } + + public void setTranslation(String translation) { + this.translation = translation; + } + + public TMTextUnitVariant.Status getStatus() { + return status; + } + + public void setStatus(TMTextUnitVariant.Status status) { + this.status = status; + } + + public boolean isIncludedInLocalizedFile() { + return includedInLocalizedFile; + } + + public void setIncludedInLocalizedFile(boolean includedInLocalizedFile) { + this.includedInLocalizedFile = includedInLocalizedFile; + } + + public ZonedDateTime getCreatedDate() { + return createdDate; + } + + public void setCreatedDate(ZonedDateTime createdDate) { + this.createdDate = createdDate; + } + + public String getContentMd5() { + return contentMd5; + } + + public void setContentMd5(String content_md5) { + this.contentMd5 = content_md5; + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationConfiguration.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationConfiguration.java new file mode 100644 index 0000000000..472c72086a --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationConfiguration.java @@ -0,0 +1,95 @@ +package com.box.l10n.mojito.service.ai.translation; + +import com.google.common.collect.Maps; +import java.time.Duration; +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; + +@Component +@ConfigurationProperties("l10n.ai.translation") +public class AITranslationConfiguration { + + private boolean enabled = false; + private int batchSize = 10; + + /** + * Duration after which a pending MT is considered expired and will not be processed in AI + * translation (as it will be eligible for third party syncs once the entity is older than the + * expiry period). + * + *

If the pending MT is expired, it will be deleted which will remove it from AI translation + * flow. + */ + private Duration expiryDuration = Duration.ofHours(3); + + private String cron = "0 0/10 * * * ?"; + + private Map repositorySettings = Maps.newHashMap(); + + public static class RepositorySettings { + /** + * If true, reuse the source text if the language of the source text matches the target + * language. + * + *

Uses the language piece of the BCP47 tag to determine if the language matches. For + * example, en-US and en-GB would match. + * + *

If a match is found, the source text will be used as the translation for the target locale + * and no AI translation will be requested for the target locale. + */ + private boolean reuseSourceOnLanguageMatch = false; + + public Boolean isReuseSourceOnLanguageMatch() { + return reuseSourceOnLanguageMatch; + } + + public void setReuseSourceOnLanguageMatch(Boolean reuseSourceOnLanguageMatch) { + this.reuseSourceOnLanguageMatch = reuseSourceOnLanguageMatch; + } + } + + public int getBatchSize() { + return batchSize; + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + public Duration getExpiryDuration() { + return expiryDuration; + } + + public void setExpiryDuration(Duration expiryDuration) { + this.expiryDuration = expiryDuration; + } + + public String getCron() { + return cron; + } + + public void setCron(String cron) { + this.cron = cron; + } + + public Map getRepositorySettings() { + return repositorySettings; + } + + public void setRepositorySettings(Map repositorySettings) { + this.repositorySettings = repositorySettings; + } + + public boolean isEnabled() { + return enabled; + } + + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + public RepositorySettings getRepositorySettings(String repositoryName) { + return repositorySettings.get(repositoryName); + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationFilterConfiguration.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationFilterConfiguration.java new file mode 100644 index 0000000000..8570760178 --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationFilterConfiguration.java @@ -0,0 +1,59 @@ +package com.box.l10n.mojito.service.ai.translation; + +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.context.annotation.Configuration; + +@Configuration +@ConfigurationProperties(prefix = "l10n.ai.translation.filter") +public class AITranslationFilterConfiguration { + + private Map repositoryConfig; + + public Map getRepositoryConfig() { + return repositoryConfig; + } + + public void setRepositoryConfig(Map repositoryConfig) { + this.repositoryConfig = repositoryConfig; + } + + public static class RepositoryConfig { + private boolean excludePlurals; + private boolean excludePlaceholders; + private boolean excludeHtmlTags; + private String excludePlaceholdersRegex = "\\{[^\\}]*\\}"; + + public boolean shouldExcludePlurals() { + return excludePlurals; + } + + public void setExcludePlurals(boolean excludePlurals) { + this.excludePlurals = excludePlurals; + } + + public boolean shouldExcludePlaceholders() { + return excludePlaceholders; + } + + public void setExcludePlaceholders(boolean excludePlaceholders) { + this.excludePlaceholders = excludePlaceholders; + } + + public boolean shouldExcludeHtmlTags() { + return excludeHtmlTags; + } + + public void setExcludeHtmlTags(boolean excludeHtmlTags) { + this.excludeHtmlTags = excludeHtmlTags; + } + + public String getExcludePlaceholdersRegex() { + return excludePlaceholdersRegex; + } + + public void setExcludePlaceholdersRegex(String excludePlaceholdersRegex) { + this.excludePlaceholdersRegex = excludePlaceholdersRegex; + } + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationService.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationService.java new file mode 100644 index 0000000000..e509861be2 --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationService.java @@ -0,0 +1,218 @@ +package com.box.l10n.mojito.service.ai.translation; + +import com.box.l10n.mojito.JSR310Migration; +import com.box.l10n.mojito.entity.PromptType; +import com.box.l10n.mojito.entity.TmTextUnitPendingMT; +import com.box.l10n.mojito.service.ai.RepositoryLocaleAIPromptRepository; +import com.box.l10n.mojito.service.repository.RepositoryRepository; +import com.box.l10n.mojito.service.tm.TMTextUnitVariantRepository; +import com.google.common.collect.Lists; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.transaction.Transactional; +import java.sql.Timestamp; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.stereotype.Component; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; + +@Component +@ConditionalOnProperty(value = "l10n.ai.translation.enabled", havingValue = "true") +public class AITranslationService { + + private static final Logger logger = LoggerFactory.getLogger(AITranslationService.class); + + @Autowired RepositoryLocaleAIPromptRepository repositoryLocaleAIPromptRepository; + + @Autowired TmTextUnitPendingMTRepository tmTextUnitPendingMTRepository; + + @Autowired TMTextUnitVariantRepository tmTextUnitVariantRepository; + + @Autowired RepositoryRepository repositoryRepository; + + @Autowired JdbcTemplate jdbcTemplate; + + @Value("${l10n.ai.translation.pendingMT.batchSize:1000}") + int batchSize; + + @Value("${l10n.ai.translation.pendingMT.timeout:PT10S}") + Duration timeout; + + @Value("${l10n.ai.translation.maxTextUnitsAIRequest:1000}") + int maxTextUnitsAIRequest; + + private final Sinks.Many pendingMTDeletionSink = + Sinks.many().multicast().onBackpressureBuffer(); + + @Transactional + public void createPendingMTEntitiesInBatches(Long repositoryId, Set tmTextUnitIds) { + if (tmTextUnitIds.size() > maxTextUnitsAIRequest) { + logger.warn( + "Number of text units ({}) exceeds the maximum number of text units that can be sent for AI translation per request ({}). AI translation will be skipped.", + tmTextUnitIds.size(), + maxTextUnitsAIRequest); + return; + } + if (repositoryLocaleAIPromptRepository.findCountOfActiveRepositoryPromptsByType( + repositoryId, PromptType.TRANSLATION.toString()) + > 0) { + createPendingMTEntitiesInBatches(tmTextUnitIds); + } else { + logger.debug("No active prompts for repository: {}, no job scheduled", repositoryId); + } + } + + protected void sendForDeletion(TmTextUnitPendingMT pendingMT) { + logger.debug("Sending pending MT for deletion: {}", pendingMT); + pendingMTDeletionSink.tryEmitNext(pendingMT); + } + + private void createPendingMTEntitiesInBatches(Set tmTextUnitIds) { + List pendingMTs = + tmTextUnitIds.stream() + .map(AITranslationService::createTmTextUnitPendingMT) + .collect(Collectors.toList()); + logger.debug("Persisting {} pending MTs", pendingMTs.size()); + Lists.partition(pendingMTs, batchSize).forEach(this::savePendingMTsMultiRowBatch); + } + + private void savePendingMTsMultiRowBatch(List pendingMTs) { + String sql = + "INSERT INTO tm_text_unit_pending_mt(tm_text_unit_id, created_date) VALUES" + + pendingMTs.stream() + .map( + tmTextUnitPendingMT -> + String.format( + "(%d, '%s')", + tmTextUnitPendingMT.getTmTextUnitId(), + Timestamp.from(tmTextUnitPendingMT.getCreatedDate().toInstant()))) + .collect(Collectors.joining(",")); + logger.debug("Executing batch insert for {} pending MTs", pendingMTs.size()); + jdbcTemplate.update(sql); + } + + /** + * Inserts AI translations into the database. + * + * @param translationDTOs + */ + @Transactional + protected void insertMultiRowAITranslationVariant( + Long tmTextUnitId, List translationDTOs) { + insertMultiRowTextUnitVariants(tmTextUnitId, translationDTOs); + insertMultiRowTextUnitCurrentVariants(tmTextUnitId, translationDTOs); + } + + private void insertMultiRowTextUnitVariants( + Long textUnitId, List translationDTOs) { + logger.debug( + "Inserting {} translation variants for text unit ID: {}", + translationDTOs.size(), + textUnitId); + + String sql = + "INSERT INTO tm_text_unit_variant (tm_text_unit_id, locale_id, content, content_md5, status, included_in_localized_file, created_date) VALUES (?, ?, ?, ?, ?, ?, ?)"; + List batchArgs = + translationDTOs.stream() + .map( + translationDTO -> + new Object[] { + translationDTO.getTmTextUnit().getId(), + translationDTO.getLocaleId(), + translationDTO.getTranslation(), + translationDTO.getContentMd5(), + translationDTO.getStatus().toString(), + translationDTO.isIncludedInLocalizedFile(), + Timestamp.from(translationDTO.getCreatedDate().toInstant()) + }) + .collect(Collectors.toList()); + + logger.debug("Executing batch insert for {} translation variants", translationDTOs.size()); + + jdbcTemplate.batchUpdate(sql, batchArgs); + } + + private void insertMultiRowTextUnitCurrentVariants( + Long textUnitId, List translationDTOs) { + logger.debug( + "Inserting {} current variants for text unit ID: {}", translationDTOs.size(), textUnitId); + + String sql = + "INSERT INTO tm_text_unit_current_variant (tm_id, asset_id, tm_text_unit_id, tm_text_unit_variant_id, locale_id, created_date, last_modified_date) VALUES (?, ?, ?, ?, ?, ?, ?)"; + + Map localeToVariantIds = + tmTextUnitVariantRepository.findLocaleVariantDTOsByTmTextUnitId(textUnitId).stream() + .collect( + Collectors.toMap(LocaleVariantDTO::getLocaleId, LocaleVariantDTO::getVariantId)); + + List batchArgs = + translationDTOs.stream() + .map( + translationDTO -> + new Object[] { + translationDTO.getTmTextUnit().getTm().getId(), + translationDTO.getTmTextUnit().getAsset().getId(), + translationDTO.getTmTextUnit().getId(), + localeToVariantIds.get(translationDTO.getLocaleId()), + translationDTO.getLocaleId(), + Timestamp.from(translationDTO.getCreatedDate().toInstant()), + Timestamp.from(translationDTO.getCreatedDate().toInstant()) + }) + .toList(); + + logger.debug( + "Executing batch insert for {} translation current variants", translationDTOs.size()); + + jdbcTemplate.batchUpdate(sql, batchArgs); + } + + @Transactional + private void deleteBatch(List batch) { + if (batch.isEmpty()) { + logger.debug("No pending MTs to delete"); + return; + } + + String sql = + "DELETE FROM tm_text_unit_pending_mt WHERE id IN (" + + batch.stream() + .map(tmTextUnitPendingMT -> tmTextUnitPendingMT.getId().toString()) + .collect(Collectors.joining(",")) + + ")"; + logger.debug( + "Executing batch delete for IDs: {}", + batch.stream().map(TmTextUnitPendingMT::getId).collect(Collectors.toList())); + jdbcTemplate.update(sql); + } + + private static TmTextUnitPendingMT createTmTextUnitPendingMT(Long tmTextUnitId) { + TmTextUnitPendingMT tmTextUnitPendingMT = new TmTextUnitPendingMT(); + tmTextUnitPendingMT.setTmTextUnitId(tmTextUnitId); + tmTextUnitPendingMT.setCreatedDate(JSR310Migration.newDateTimeEmptyCtor()); + return tmTextUnitPendingMT; + } + + @PostConstruct + public void init() { + Flux flux = pendingMTDeletionSink.asFlux(); + + flux.bufferTimeout(batchSize, timeout) + .filter(batch -> !batch.isEmpty()) + .subscribe(this::deleteBatch); + } + + @PreDestroy + public void destroy() { + pendingMTDeletionSink.tryEmitComplete(); + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationTextUnitFilterService.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationTextUnitFilterService.java new file mode 100644 index 0000000000..754fdd3e87 --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/AITranslationTextUnitFilterService.java @@ -0,0 +1,103 @@ +package com.box.l10n.mojito.service.ai.translation; + +import com.box.l10n.mojito.entity.Repository; +import com.box.l10n.mojito.entity.TMTextUnit; +import io.micrometer.core.instrument.MeterRegistry; +import jakarta.annotation.PostConstruct; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +@Component +public class AITranslationTextUnitFilterService { + + private static final Logger logger = + LoggerFactory.getLogger(AITranslationTextUnitFilterService.class); + private static final String HTML_TAG_REGEX = "<[^>]*>"; + private static final Pattern HTML_TAG_PATTERN = Pattern.compile(HTML_TAG_REGEX); + + protected Map excludePlaceholdersPatternMap; + + @Autowired AITranslationFilterConfiguration aiTranslationFilterConfiguration; + + @Autowired MeterRegistry meterRegistry; + + public boolean isTranslatable(TMTextUnit tmTextUnit, Repository repository) { + boolean isTranslatable = true; + + if (repository == null) { + logger.warn( + "Repository is null for text unit with id: {}, filtering will be skipped", + tmTextUnit.getId()); + return isTranslatable; + } + + if (aiTranslationFilterConfiguration.getRepositoryConfig() == null + || aiTranslationFilterConfiguration.getRepositoryConfig().get(repository.getName()) + == null) { + logger.debug( + "No configuration found for repository: {}, filtering will be skipped", + repository.getName()); + return isTranslatable; + } + + AITranslationFilterConfiguration.RepositoryConfig repositoryConfig = + aiTranslationFilterConfiguration.getRepositoryConfig().get(repository.getName()); + + if (repositoryConfig.shouldExcludePlurals()) { + isTranslatable = !isPlural(tmTextUnit); + } + + if (repositoryConfig.shouldExcludePlaceholders()) { + isTranslatable = isTranslatable && !containsPlaceholder(repository.getName(), tmTextUnit); + } + + if (repositoryConfig.shouldExcludeHtmlTags()) { + isTranslatable = isTranslatable && !containsHtmlTag(tmTextUnit); + } + + logger.debug( + "Text unit with name: {}, should be translated: {}", tmTextUnit.getName(), isTranslatable); + return isTranslatable; + } + + private boolean containsPlaceholder(String repositoryName, TMTextUnit tmTextUnit) { + Pattern pattern = excludePlaceholdersPatternMap.get(repositoryName); + if (pattern != null) { + Matcher matcher = + excludePlaceholdersPatternMap.get(repositoryName).matcher(tmTextUnit.getContent()); + return matcher.find(); + } else { + logger.debug("No exclude placeholders pattern found for repository: {}", repositoryName); + return false; + } + } + + private boolean isPlural(TMTextUnit tmTextUnit) { + return tmTextUnit.getPluralForm() != null; + } + + private boolean containsHtmlTag(TMTextUnit tmTextUnit) { + Matcher matcher = HTML_TAG_PATTERN.matcher(tmTextUnit.getContent()); + return matcher.find(); + } + + @PostConstruct + public void init() { + excludePlaceholdersPatternMap = Map.of(); + if (aiTranslationFilterConfiguration.getRepositoryConfig() != null) { + for (Map.Entry entry : + aiTranslationFilterConfiguration.getRepositoryConfig().entrySet()) { + AITranslationFilterConfiguration.RepositoryConfig repositoryConfig = entry.getValue(); + if (repositoryConfig.shouldExcludePlaceholders()) { + excludePlaceholdersPatternMap.put( + entry.getKey(), Pattern.compile(repositoryConfig.getExcludePlaceholdersRegex())); + } + } + } + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/LocaleVariantDTO.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/LocaleVariantDTO.java new file mode 100644 index 0000000000..3cd7bc6136 --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/LocaleVariantDTO.java @@ -0,0 +1,36 @@ +package com.box.l10n.mojito.service.ai.translation; + +/** + * DTO class that is used as part of the AI Translation flow, specifically for the mapping of locale + * and variant IDs. + * + *

It is used when setting the current variants for a given locale and text unit. + * + * @author maallen + */ +public class LocaleVariantDTO { + + private Long localeId; + private Long variantId; + + public LocaleVariantDTO(Long localeId, Long variantId) { + this.localeId = localeId; + this.variantId = variantId; + } + + public Long getLocaleId() { + return localeId; + } + + public void setLocaleId(Long localeId) { + this.localeId = localeId; + } + + public Long getVariantId() { + return variantId; + } + + public void setVariantId(Long variantId) { + this.variantId = variantId; + } +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/TmTextUnitPendingMTRepository.java b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/TmTextUnitPendingMTRepository.java new file mode 100644 index 0000000000..b9c45a1b77 --- /dev/null +++ b/webapp/src/main/java/com/box/l10n/mojito/service/ai/translation/TmTextUnitPendingMTRepository.java @@ -0,0 +1,17 @@ +package com.box.l10n.mojito.service.ai.translation; + +import com.box.l10n.mojito.entity.TmTextUnitPendingMT; +import java.util.List; +import org.springframework.data.jpa.repository.JpaRepository; +import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; + +public interface TmTextUnitPendingMTRepository extends JpaRepository { + + TmTextUnitPendingMT findByTmTextUnitId(Long tmTextUnitId); + + @Query( + value = "SELECT * FROM tm_text_unit_pending_mt ORDER BY id LIMIT :batchSize", + nativeQuery = true) + List findBatch(@Param("batchSize") int batchSize); +} diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/asset/VirtualTextUnitBatchUpdaterService.java b/webapp/src/main/java/com/box/l10n/mojito/service/asset/VirtualTextUnitBatchUpdaterService.java index c2bacb5669..28f173927c 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/asset/VirtualTextUnitBatchUpdaterService.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/asset/VirtualTextUnitBatchUpdaterService.java @@ -6,8 +6,10 @@ import com.box.l10n.mojito.entity.AssetTextUnit; import com.box.l10n.mojito.entity.AssetTextUnitToTMTextUnit; import com.box.l10n.mojito.entity.PluralForm; +import com.box.l10n.mojito.entity.Repository; import com.box.l10n.mojito.entity.TMTextUnit; import com.box.l10n.mojito.okapi.TextUnitUtils; +import com.box.l10n.mojito.service.ai.translation.AITranslationService; import com.box.l10n.mojito.service.assetExtraction.AssetExtractionRepository; import com.box.l10n.mojito.service.assetExtraction.AssetExtractionService; import com.box.l10n.mojito.service.assetExtraction.AssetTextUnitToTMTextUnitRepository; @@ -31,7 +33,10 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Component; import org.springframework.transaction.annotation.Transactional; @@ -69,6 +74,9 @@ public class VirtualTextUnitBatchUpdaterService { @Autowired LocaleService localeService; + @Autowired(required = false) + AITranslationService aiTranslationService; + @Transactional public void updateTextUnits( Asset asset, List virtualAssetTextUnits, boolean replace) @@ -115,6 +123,10 @@ public void updateTextUnits( performLeveraging(savedTextUnits, nameToUsedtextUnitDTOs, contentToTextUnitDTOs); + if (aiTranslationService != null) { + scheduleAITranslation(savedTextUnits.keySet(), asset.getRepository()); + } + if (replace) { deleteOldAssetTextUnits(md5ToTextUnitDTOs, md5ToVirtualTextUnits); } @@ -313,4 +325,10 @@ void createMappedAssetTextUnit( assetTextUnitToTMTextUnit.setTmTextUnit(tmTextUnitRepository.getOne(tmTextUnitId)); assetTextUnitToTMTextUnitRepository.save(assetTextUnitToTMTextUnit); } + + @Async + void scheduleAITranslation(Set textUnits, Repository repository) { + Set tmTextUnitIds = textUnits.stream().map(TMTextUnit::getId).collect(Collectors.toSet()); + aiTranslationService.createPendingMTEntitiesInBatches(repository.getId(), tmTextUnitIds); + } } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/assetExtraction/AssetExtractionService.java b/webapp/src/main/java/com/box/l10n/mojito/service/assetExtraction/AssetExtractionService.java index 920b04ab76..dc8fc316c6 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/assetExtraction/AssetExtractionService.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/assetExtraction/AssetExtractionService.java @@ -31,6 +31,7 @@ import com.box.l10n.mojito.okapi.extractor.AssetExtractorTextUnit; import com.box.l10n.mojito.quartz.QuartzJobInfo; import com.box.l10n.mojito.quartz.QuartzPollableTaskScheduler; +import com.box.l10n.mojito.service.ai.translation.AITranslationService; import com.box.l10n.mojito.service.asset.AssetRepository; import com.box.l10n.mojito.service.asset.FilterOptionsMd5Builder; import com.box.l10n.mojito.service.assetTextUnit.AssetTextUnitRepository; @@ -92,6 +93,7 @@ import org.springframework.retry.RetryContext; import org.springframework.retry.annotation.Retryable; import org.springframework.retry.support.RetryTemplate; +import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -167,6 +169,9 @@ public class AssetExtractionService { @Autowired LocalBranchToEntityBranchConverter localBranchToEntityBranchConverter; + @Autowired(required = false) + AITranslationService aiTranslationService; + private RepositoryStatisticsJobScheduler repositoryStatisticsJobScheduler; @Value("${l10n.assetExtraction.quartz.schedulerName:" + DEFAULT_SCHEDULER_NAME + "}") @@ -214,7 +219,9 @@ public PollableFuture processAsset( asset, createdTextUnitsResult.getUpdatedState(), currentTask); updatePushRun(asset, createdTextUnitsResult.getUpdatedState(), pushRunId, currentTask); performLeveraging(createdTextUnitsResult.getLeveragingMatches(), currentTask); - + if (aiTranslationService != null) { + scheduleAITranslation(asset.getRepository().getId(), createdTextUnitsResult); + } logger.info("Done processing asset content id: {}", assetContentId); return new PollableFutureTaskResult<>(asset); } @@ -268,6 +275,15 @@ void updatePushRun( pushRunService.associatePushRunToTextUnitIds(pushRun, asset, textUnitIds); } + @Async + void scheduleAITranslation(Long repositoryId, CreateTextUnitsResult createTextUnitsResult) { + Set tmTextUnitIds = + createTextUnitsResult.getCreatedTextUnits().stream() + .map(BranchStateTextUnit::getTmTextUnitId) + .collect(Collectors.toSet()); + aiTranslationService.createPendingMTEntitiesInBatches(repositoryId, tmTextUnitIds); + } + MultiBranchState updateAssetExtractionWithState( Long assetExtractionId, MultiBranchState currentState, AssetContentMd5s assetContentMd5s) { return retryTemplate.execute( diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitCurrentVariantRepository.java b/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitCurrentVariantRepository.java index 483b5546a0..8d6f9cd800 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitCurrentVariantRepository.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitCurrentVariantRepository.java @@ -1,11 +1,14 @@ package com.box.l10n.mojito.service.tm; +import com.box.l10n.mojito.entity.Locale; import com.box.l10n.mojito.entity.TMTextUnitCurrentVariant; import java.util.List; +import java.util.Set; import org.springframework.data.jpa.repository.EntityGraph; import org.springframework.data.jpa.repository.EntityGraph.EntityGraphType; import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; import org.springframework.data.rest.core.annotation.RepositoryRestResource; /** @@ -29,4 +32,12 @@ public interface TMTextUnitCurrentVariantRepository where ttucv.asset.id = ?1 and ttucv.locale.id = ?2 """) List findByAsset_idAndLocale_Id(Long assetId, Long localeId); + + @Query( + """ + select l from TMTextUnitCurrentVariant ttucv + join ttucv.locale l + where ttucv.tmTextUnit.id = :tmTextUnitId + """) + Set findLocalesWithVariantByTmTextUnit_Id(@Param("tmTextUnitId") Long tmTextUnitId); } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitRepository.java b/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitRepository.java index acebf27387..7780771e4b 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitRepository.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitRepository.java @@ -10,6 +10,7 @@ import org.springframework.data.jpa.repository.EntityGraph.EntityGraphType; import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; import org.springframework.data.rest.core.annotation.RepositoryRestResource; /** @@ -48,4 +49,8 @@ public interface TMTextUnitRepository extends JpaRepository { List getTextUnitIdsByAssetId(Long assetId); TMTextUnit findByMd5AndTmIdAndAssetId(String contentMd5, Long tmId, Long assetId); + + @Query( + "SELECT t FROM TMTextUnit t JOIN FETCH t.asset a JOIN FETCH a.repository JOIN FETCH t.tm WHERE t.id = :id") + Optional findByIdWithAssetAndRepositoryAndTMFetched(@Param("id") Long id); } diff --git a/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitVariantRepository.java b/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitVariantRepository.java index e7fea53da5..ebc04f56ae 100644 --- a/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitVariantRepository.java +++ b/webapp/src/main/java/com/box/l10n/mojito/service/tm/TMTextUnitVariantRepository.java @@ -5,9 +5,11 @@ import com.box.l10n.mojito.entity.PushRun; import com.box.l10n.mojito.entity.Repository; import com.box.l10n.mojito.entity.TMTextUnitVariant; +import com.box.l10n.mojito.service.ai.translation.LocaleVariantDTO; import com.google.common.annotations.VisibleForTesting; import java.time.ZonedDateTime; import java.util.List; +import java.util.Set; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.jpa.repository.EntityGraph; @@ -143,4 +145,21 @@ List findDeltasForRuns( @Param("pushRunIds") List pushRunIds, @Param("pullRunIds") List pullRunIds, @Param("translationsFromDate") ZonedDateTime translationsFromDate); + + @Query( + """ + select l from TMTextUnitVariant ttuv + join ttuv.locale l + where ttuv.tmTextUnit.id = :tmTextUnitId + """) + Set findLocalesWithVariantByTmTextUnit_Id(@Param("tmTextUnitId") Long tmTextUnitId); + + @Query( + """ + SELECT new com.box.l10n.mojito.service.ai.translation.LocaleVariantDTO(ttuv.locale.id, ttuv.id) + FROM TMTextUnitVariant ttuv + WHERE ttuv.tmTextUnit.id = :tmTextUnitId + """) + List findLocaleVariantDTOsByTmTextUnitId( + @Param("tmTextUnitId") Long tmTextUnitId); } diff --git a/webapp/src/main/resources/db/migration/V68__GPT_translation.sql b/webapp/src/main/resources/db/migration/V68__GPT_translation.sql new file mode 100644 index 0000000000..de458783c8 --- /dev/null +++ b/webapp/src/main/resources/db/migration/V68__GPT_translation.sql @@ -0,0 +1,54 @@ +CREATE TABLE repository_locale_ai_prompt ( + id bigint AUTO_INCREMENT PRIMARY KEY, + repository_id bigint NOT NULL, + locale_id bigint NULL, + ai_prompt_id bigint NOT NULL, + disabled boolean DEFAULT FALSE +); + +ALTER TABLE repository_locale_ai_prompt +ADD CONSTRAINT FK__REPOSITORY_LOCALE_AI_PROMPT__REPOSITORY_ID FOREIGN KEY (repository_id) REFERENCES repository(id); + +ALTER TABLE repository_locale_ai_prompt +ADD CONSTRAINT FK__REPOSITORY_LOCALE_AI_PROMPT__LOCALE_ID FOREIGN KEY (locale_id) REFERENCES locale(id); + +ALTER TABLE repository_locale_ai_prompt +ADD CONSTRAINT FK__REPOSITORY_LOCALE_AI_PROMPT__AI_PROMPT_ID FOREIGN KEY (ai_prompt_id) REFERENCES ai_prompt(id); + +START TRANSACTION; +# Migrate existing repository_ai_prompt mapping entries to new table +INSERT INTO repository_locale_ai_prompt (repository_id, ai_prompt_id) +SELECT repository_id, ai_prompt_id +FROM repository_ai_prompt; +COMMIT; + +ALTER TABLE ai_prompt +ADD COLUMN prompt_type_id bigint NULL; + +ALTER TABLE ai_prompt +ADD CONSTRAINT FK__AI_PROMPT__PROMPT_TYPE_ID FOREIGN KEY (prompt_type_id) REFERENCES ai_prompt_type(id); + +START TRANSACTION; +# Update existing ai_prompt rows with prompt_type_id +UPDATE ai_prompt ap +JOIN repository_ai_prompt rap ON ap.id = rap.ai_prompt_id +SET ap.prompt_type_id = rap.prompt_type_id; +COMMIT; + +# Remove old mapping table +DROP TABLE repository_ai_prompt; + +INSERT INTO ai_prompt_type (name) VALUES ('TRANSLATION'); + +CREATE TABLE tm_text_unit_pending_mt ( + id bigint AUTO_INCREMENT PRIMARY KEY, + tm_text_unit_id bigint NOT NULL, + created_date DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +ALTER TABLE tm_text_unit_pending_mt +ADD CONSTRAINT FK__TM_TEXT_UNIT_PENDING_MT__TM_TEXT_UNIT_ID FOREIGN KEY (tm_text_unit_id) REFERENCES tm_text_unit(id); + +ALTER TABLE third_party_text_unit +ADD COLUMN uploaded_file_uri varchar(255) NULL; + diff --git a/webapp/src/main/resources/db/migration/V69__Repo_locale_ai_unique_key.sql b/webapp/src/main/resources/db/migration/V69__Repo_locale_ai_unique_key.sql new file mode 100644 index 0000000000..7e1a553507 --- /dev/null +++ b/webapp/src/main/resources/db/migration/V69__Repo_locale_ai_unique_key.sql @@ -0,0 +1,2 @@ +ALTER TABLE repository_locale_ai_prompt +ADD UNIQUE KEY UK__REPOSITORY_LOCALE_AI_PROMPT__REPO_ID__LOCALE_ID__AI_PROMPT (repository_id, locale_id, ai_prompt_id); \ No newline at end of file diff --git a/webapp/src/main/resources/db/migration/V70__Add_json_fields_to_prompt_table.sql b/webapp/src/main/resources/db/migration/V70__Add_json_fields_to_prompt_table.sql new file mode 100644 index 0000000000..d59458c204 --- /dev/null +++ b/webapp/src/main/resources/db/migration/V70__Add_json_fields_to_prompt_table.sql @@ -0,0 +1,3 @@ +ALTER TABLE ai_prompt +ADD COLUMN json_response BOOLEAN DEFAULT FALSE, +ADD COLUMN json_response_key VARCHAR(255) DEFAULT NULL; \ No newline at end of file diff --git a/webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/LLMPromptServiceTest.java b/webapp/src/test/java/com/box/l10n/mojito/service/ai/LLMPromptServiceTest.java similarity index 85% rename from webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/LLMPromptServiceTest.java rename to webapp/src/test/java/com/box/l10n/mojito/service/ai/LLMPromptServiceTest.java index 462cc5800a..717c058e4d 100644 --- a/webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/LLMPromptServiceTest.java +++ b/webapp/src/test/java/com/box/l10n/mojito/service/ai/LLMPromptServiceTest.java @@ -1,4 +1,4 @@ -package com.box.l10n.mojito.service.ai.openai; +package com.box.l10n.mojito.service.ai; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -12,12 +12,9 @@ import com.box.l10n.mojito.entity.AIPrompt; import com.box.l10n.mojito.entity.AIPromptType; import com.box.l10n.mojito.entity.Repository; -import com.box.l10n.mojito.entity.RepositoryAIPrompt; +import com.box.l10n.mojito.entity.RepositoryLocaleAIPrompt; import com.box.l10n.mojito.rest.ai.AIException; import com.box.l10n.mojito.rest.ai.AIPromptCreateRequest; -import com.box.l10n.mojito.service.ai.AIPromptRepository; -import com.box.l10n.mojito.service.ai.AIPromptTypeRepository; -import com.box.l10n.mojito.service.ai.RepositoryAIPromptRepository; import com.box.l10n.mojito.service.repository.RepositoryRepository; import java.util.Optional; import org.junit.jupiter.api.BeforeEach; @@ -34,11 +31,11 @@ public class LLMPromptServiceTest { @Mock AIPromptTypeRepository aiPromptTypeRepository; - @Mock RepositoryAIPromptRepository repositoryAIPromptRepository; + @Mock RepositoryLocaleAIPromptRepository repositoryLocaleAIPromptRepository; @Mock RepositoryRepository repositoryRepository; - @Captor ArgumentCaptor repositoryAIPromptCaptor; + @Captor ArgumentCaptor repositoryAIPromptCaptor; @InjectMocks LLMPromptService LLMPromptService; @@ -51,8 +48,8 @@ void setUp() { void testPromptCreation() { Repository repository = new Repository(); repository.setId(1L); - RepositoryAIPrompt repositoryAIPrompt = new RepositoryAIPrompt(); - repositoryAIPrompt.setId(1L); + RepositoryLocaleAIPrompt repositoryLocaleAIPrompt = new RepositoryLocaleAIPrompt(); + repositoryLocaleAIPrompt.setId(1L); AIPromptType promptType = new AIPromptType(); promptType.setId(1L); AIPrompt prompt = new AIPrompt(); @@ -69,13 +66,13 @@ void testPromptCreation() { when(aiPromptRepository.save(any())).thenReturn(prompt); when(repositoryRepository.findByName("testRepo")).thenReturn(repository); when(aiPromptTypeRepository.findByName("SOURCE_STRING_CHECKER")).thenReturn(promptType); - when(repositoryAIPromptRepository.save(any())).thenReturn(repositoryAIPrompt); + when(repositoryLocaleAIPromptRepository.save(any())).thenReturn(repositoryLocaleAIPrompt); LLMPromptService.createPrompt(AIPromptCreateRequest); verify(aiPromptTypeRepository, times(1)).findByName("SOURCE_STRING_CHECKER"); verify(aiPromptRepository, times(1)).save(any()); - verify(repositoryAIPromptRepository, times(1)).save(any()); + verify(repositoryLocaleAIPromptRepository, times(1)).save(any()); } @Test @@ -95,7 +92,7 @@ void testPromptCreationNoPromptType() { when(aiPromptRepository.save(any())).thenReturn(prompt); when(repositoryRepository.findByName("testRepo")).thenReturn(new Repository()); when(aiPromptTypeRepository.findByName("SOURCE_STRING_CHECKER")).thenReturn(null); - when(repositoryAIPromptRepository.save(any())).thenReturn(1L); + when(repositoryLocaleAIPromptRepository.save(any())).thenReturn(1L); AIException exception = assertThrows(AIException.class, () -> LLMPromptService.createPrompt(AIPromptCreateRequest)); @@ -103,7 +100,7 @@ void testPromptCreationNoPromptType() { verify(aiPromptTypeRepository, times(1)).findByName("SOURCE_STRING_CHECKER"); verify(aiPromptTypeRepository, times(0)).save(any()); - verify(repositoryAIPromptRepository, times(0)).save(any()); + verify(repositoryLocaleAIPromptRepository, times(0)).save(any()); } @Test @@ -153,9 +150,8 @@ void testAddPromptToRepository() { when(aiPromptRepository.findById(1L)).thenReturn(Optional.of(aiPrompt)); LLMPromptService.addPromptToRepository(1L, "testRepo", "SOURCE_STRING_CHECKER"); verify(aiPromptTypeRepository, times(1)).findByName("SOURCE_STRING_CHECKER"); - verify(repositoryAIPromptRepository, times(1)).save(repositoryAIPromptCaptor.capture()); - assertEquals(1L, repositoryAIPromptCaptor.getValue().getAiPromptId()); - assertEquals(2L, repositoryAIPromptCaptor.getValue().getRepositoryId()); - assertEquals(3L, repositoryAIPromptCaptor.getValue().getPromptTypeId()); + verify(repositoryLocaleAIPromptRepository, times(1)).save(repositoryAIPromptCaptor.capture()); + assertEquals(1L, repositoryAIPromptCaptor.getValue().getAiPrompt().getId()); + assertEquals(2L, repositoryAIPromptCaptor.getValue().getRepository().getId()); } } diff --git a/webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMServiceTest.java b/webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMServiceTest.java index 8b97e647ba..529e9a99d3 100644 --- a/webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMServiceTest.java +++ b/webapp/src/test/java/com/box/l10n/mojito/service/ai/openai/OpenAILLMServiceTest.java @@ -5,14 +5,19 @@ import com.box.l10n.mojito.entity.AIPrompt; import com.box.l10n.mojito.entity.AIPromptContextMessage; +import com.box.l10n.mojito.entity.AIPromptType; +import com.box.l10n.mojito.entity.PluralForm; import com.box.l10n.mojito.entity.PromptType; import com.box.l10n.mojito.entity.Repository; +import com.box.l10n.mojito.entity.TMTextUnit; import com.box.l10n.mojito.json.ObjectMapper; import com.box.l10n.mojito.okapi.extractor.AssetExtractorTextUnit; import com.box.l10n.mojito.openai.OpenAIClient; import com.box.l10n.mojito.rest.ai.AICheckRequest; import com.box.l10n.mojito.rest.ai.AICheckResponse; +import com.box.l10n.mojito.rest.ai.AIException; import com.box.l10n.mojito.service.ai.AIStringCheckRepository; +import com.box.l10n.mojito.service.ai.LLMPromptService; import com.box.l10n.mojito.service.repository.RepositoryRepository; import io.micrometer.core.instrument.MeterRegistry; import java.util.ArrayList; @@ -49,8 +54,14 @@ class OpenAILLMServiceTest { void setUp() { MockitoAnnotations.openMocks(this); openAILLMService.persistResults = true; + openAILLMService.retryMaxAttempts = 1; + openAILLMService.retryMinDurationSeconds = 0; + openAILLMService.retryMaxBackoffDurationSeconds = 0; + openAILLMService.init(); when(meterRegistry.counter(anyString(), any(String[].class))) .thenReturn(mock(io.micrometer.core.instrument.Counter.class)); + when(meterRegistry.counter(anyString(), any(Iterable.class))) + .thenReturn(mock(io.micrometer.core.instrument.Counter.class)); } @Test @@ -299,6 +310,9 @@ public void testPromptContextMessagesIncluded() { prompt.setUserPrompt("Check strings for spelling"); prompt.setModelName("gtp-3.5-turbo"); prompt.setPromptTemperature(0.0F); + AIPromptType promptType = new AIPromptType(); + promptType.setName("SOURCE_STRING_CHECKER"); + prompt.setPromptType(promptType); AIPromptContextMessage testSystemContextMessage = new AIPromptContextMessage(); testSystemContextMessage.setContent("A test system context message"); @@ -415,4 +429,405 @@ public void testSourceOnlyCheckedOnce() { verify(meterRegistry, times(1)) .counter("OpenAILLMService.checks.result", "success", "true", "repository", "testRepo"); } + + @Test + void testTranslateSuccess() { + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("Greeting"); + + AIPrompt prompt = new AIPrompt(); + prompt.setId(1L); + prompt.setSystemPrompt("Translate the following text:"); + prompt.setUserPrompt("Translate this text to French:"); + prompt.setModelName("gtp-3.5-turbo"); + prompt.setPromptTemperature(0.0F); + prompt.setContextMessages(new ArrayList<>()); + + OpenAIClient.ChatCompletionsResponse.Choice choice = + new OpenAIClient.ChatCompletionsResponse.Choice( + 0, new OpenAIClient.ChatCompletionsResponse.Choice.Message("test", "Bonjour"), "stop"); + OpenAIClient.ChatCompletionsResponse chatCompletionsResponse = + new OpenAIClient.ChatCompletionsResponse( + null, null, null, null, List.of(choice), null, null); + CompletableFuture futureResponse = + CompletableFuture.completedFuture(chatCompletionsResponse); + when(openAIClient.getChatCompletions(any(OpenAIClient.ChatCompletionsRequest.class))) + .thenReturn(futureResponse); + + String translation = openAILLMService.translate(tmTextUnit, "en", "fr", prompt); + assertEquals("Bonjour", translation); + } + + @Test + void testTranslateResponseNonStopFinishReason() { + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("Greeting"); + + AIPrompt prompt = new AIPrompt(); + prompt.setId(1L); + prompt.setSystemPrompt("Translate the following text:"); + prompt.setUserPrompt("Translate this text to French:"); + prompt.setModelName("gtp-3.5-turbo"); + prompt.setPromptTemperature(0.0F); + prompt.setContextMessages(new ArrayList<>()); + + OpenAIClient.ChatCompletionsResponse.Choice choice = + new OpenAIClient.ChatCompletionsResponse.Choice( + 0, + new OpenAIClient.ChatCompletionsResponse.Choice.Message("test", "Bonjour"), + "length"); + OpenAIClient.ChatCompletionsResponse chatCompletionsResponse = + new OpenAIClient.ChatCompletionsResponse( + null, null, null, null, List.of(choice), null, null); + CompletableFuture futureResponse = + CompletableFuture.completedFuture(chatCompletionsResponse); + when(openAIClient.getChatCompletions(any(OpenAIClient.ChatCompletionsRequest.class))) + .thenReturn(futureResponse); + + assertThrows( + AIException.class, () -> openAILLMService.translate(tmTextUnit, "en", "fr", prompt)); + } + + @Test + void testTranslateStripTranslationFromJsonKey() { + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("Greeting"); + + AIPrompt prompt = new AIPrompt(); + prompt.setId(1L); + prompt.setSystemPrompt("Translate the following text:"); + prompt.setUserPrompt("Translate this text to French:"); + prompt.setModelName("gtp-3.5-turbo"); + prompt.setPromptTemperature(0.0F); + prompt.setContextMessages(new ArrayList<>()); + prompt.setJsonResponseKey("translation"); + prompt.setJsonResponse(true); + + OpenAIClient.ChatCompletionsResponse.Choice choice = + new OpenAIClient.ChatCompletionsResponse.Choice( + 0, + new OpenAIClient.ChatCompletionsResponse.Choice.Message( + "test", "{\"translation\": \"Bonjour\"}"), + "stop"); + OpenAIClient.ChatCompletionsResponse chatCompletionsResponse = + new OpenAIClient.ChatCompletionsResponse( + null, null, null, null, List.of(choice), null, null); + CompletableFuture futureResponse = + CompletableFuture.completedFuture(chatCompletionsResponse); + when(openAIClient.getChatCompletions(any(OpenAIClient.ChatCompletionsRequest.class))) + .thenReturn(futureResponse); + + String translation = openAILLMService.translate(tmTextUnit, "en", "fr", prompt); + assertEquals("Bonjour", translation); + } + + @Test + void testTranslateStripTranslationFromInvalidJson() { + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("Greeting"); + + AIPrompt prompt = new AIPrompt(); + prompt.setId(1L); + prompt.setSystemPrompt("Translate the following text:"); + prompt.setUserPrompt("Translate this text to French:"); + prompt.setModelName("gtp-3.5-turbo"); + prompt.setPromptTemperature(0.0F); + prompt.setContextMessages(new ArrayList<>()); + prompt.setJsonResponseKey("translation"); + prompt.setJsonResponse(true); + + OpenAIClient.ChatCompletionsResponse.Choice choice = + new OpenAIClient.ChatCompletionsResponse.Choice( + 0, + new OpenAIClient.ChatCompletionsResponse.Choice.Message( + "test", "invalid: {\"translation\": \"Bonjour\"}"), + null); + OpenAIClient.ChatCompletionsResponse chatCompletionsResponse = + new OpenAIClient.ChatCompletionsResponse( + null, null, null, null, List.of(choice), null, null); + CompletableFuture futureResponse = + CompletableFuture.completedFuture(chatCompletionsResponse); + when(openAIClient.getChatCompletions(any(OpenAIClient.ChatCompletionsRequest.class))) + .thenReturn(futureResponse); + + assertThrows(Exception.class, () -> openAILLMService.translate(tmTextUnit, "en", "fr", prompt)); + } + + @Test + void testTranslateError() { + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("Greeting"); + + AIPrompt prompt = new AIPrompt(); + prompt.setId(1L); + prompt.setSystemPrompt("Translate the following text:"); + prompt.setUserPrompt("Translate this text to French:"); + prompt.setModelName("gtp-3.5-turbo"); + prompt.setPromptTemperature(0.0F); + prompt.setContextMessages(new ArrayList<>()); + + when(openAIClient.getChatCompletions(any(OpenAIClient.ChatCompletionsRequest.class))) + .thenThrow(new RuntimeException("OpenAI service error")); + + assertThrows(Exception.class, () -> openAILLMService.translate(tmTextUnit, "en", "fr", prompt)); + } + + @Test + void testPromptTemplatingAllValuesInjected() { + String promptText = + """ + Translate the following source string from [mojito_source_locale] to [mojito_target_locale]: + Source string: [mojito_source_string] + {{optional: Comment: [mojito_comment_string]}} + {{optional: Context: [mojito_context_string]}} + {{optional: Plural form: [mojito_plural_form]}} + """; + + PluralForm one = new PluralForm(); + one.setName("one"); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("A friendly greeting"); + tmTextUnit.setName("Hello --- some.id"); + tmTextUnit.setPluralForm(one); + String prompt = + openAILLMService.getTranslationFormattedPrompt(promptText, tmTextUnit, "en", "fr"); + assertEquals( + """ + Translate the following source string from en to fr: + Source string: Hello + Comment: A friendly greeting + Context: some.id + Plural form: one""", + prompt); + } + + @Test + void testPromptTemplatingNoContextValue() { + String promptText = + """ + Translate the following source string from [mojito_source_locale] to [mojito_target_locale]: + Source string: [mojito_source_string] + {{optional: Comment: [mojito_comment_string]}} + {{optional: Context: [mojito_context_string]}} + {{optional: Plural form: [mojito_plural_form]}} + """; + + PluralForm one = new PluralForm(); + one.setName("one"); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("A friendly greeting"); + tmTextUnit.setName("Hello"); + tmTextUnit.setPluralForm(one); + String prompt = + openAILLMService.getTranslationFormattedPrompt(promptText, tmTextUnit, "en", "fr"); + assertEquals( + """ + Translate the following source string from en to fr: + Source string: Hello + Comment: A friendly greeting + Plural form: one""", + prompt); + } + + @Test + void testPromptTemplatingNoPluralValue() { + String promptText = + """ + Translate the following source string from [mojito_source_locale] to [mojito_target_locale]: + Source string: [mojito_source_string] + {{optional: Comment: [mojito_comment_string]}} + {{optional: Context: [mojito_context_string]}} + {{optional: Plural form: [mojito_plural_form]}} + """; + + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("A friendly greeting"); + tmTextUnit.setName("Hello --- some.id"); + String prompt = + openAILLMService.getTranslationFormattedPrompt(promptText, tmTextUnit, "en", "fr"); + assertEquals( + """ + Translate the following source string from en to fr: + Source string: Hello + Comment: A friendly greeting + Context: some.id""", + prompt); + } + + @Test + void testPromptTemplatingNoCommentValue() { + String promptText = + """ + Translate the following source string from [mojito_source_locale] to [mojito_target_locale]: + Source string: [mojito_source_string] + {{optional: Comment: [mojito_comment_string]}} + {{optional: Context: [mojito_context_string]}} + {{optional: Plural form: [mojito_plural_form]}} + """; + + PluralForm one = new PluralForm(); + one.setName("one"); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setName("Hello --- some.id"); + tmTextUnit.setPluralForm(one); + String prompt = + openAILLMService.getTranslationFormattedPrompt(promptText, tmTextUnit, "en", "fr"); + assertEquals( + """ + Translate the following source string from en to fr: + Source string: Hello + Context: some.id + Plural form: one""", + prompt); + } + + @Test + void testPromptTemplatingInlineNoCommentValue() { + String promptText = + """ + Translate the following source string from [mojito_source_locale] to [mojito_target_locale]: + Source string: [mojito_source_string] + {{optional: Comment: [mojito_comment_string]}} {{optional: Context: [mojito_context_string]}} {{optional: Plural form: [mojito_plural_form]}} + """; + + PluralForm one = new PluralForm(); + one.setName("one"); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setName("Hello --- some.id"); + tmTextUnit.setPluralForm(one); + String prompt = + openAILLMService.getTranslationFormattedPrompt(promptText, tmTextUnit, "en", "fr"); + assertEquals( + """ + Translate the following source string from en to fr: + Source string: Hello + Context: some.id Plural form: one""", + prompt); + } + + @Test + void testPromptTemplatingInlineNoContextValue() { + String promptText = + """ + Translate the following source string from [mojito_source_locale] to [mojito_target_locale]: + Source string: [mojito_source_string] + {{optional: Comment: [mojito_comment_string]}} {{optional: Context: [mojito_context_string]}} {{optional: Plural form: [mojito_plural_form]}} + """; + + PluralForm one = new PluralForm(); + one.setName("one"); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("A friendly greeting"); + tmTextUnit.setPluralForm(one); + String prompt = + openAILLMService.getTranslationFormattedPrompt(promptText, tmTextUnit, "en", "fr"); + assertEquals( + """ + Translate the following source string from en to fr: + Source string: Hello + Comment: A friendly greeting Plural form: one""", + prompt); + } + + @Test + void testPromptTemplatingJsonInPrompt() { + String promptText = + """ + Translate the following source string from [mojito_source_locale] to [mojito_target_locale]: + Source string: [mojito_source_string] + { {{optional: "comment": "[mojito_comment_string]",}} {{optional: "context": "[mojito_context_string]",}} {{optional: "plural_form": "[mojito_plural_form]"}} } + """; + + PluralForm one = new PluralForm(); + one.setName("one"); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("A friendly greeting"); + tmTextUnit.setName("Hello --- some.id"); + tmTextUnit.setPluralForm(one); + String prompt = + openAILLMService.getTranslationFormattedPrompt(promptText, tmTextUnit, "en", "fr"); + assertEquals( + """ + Translate the following source string from en to fr: + Source string: Hello + { "comment": "A friendly greeting", "context": "some.id", "plural_form": "one" }""", + prompt); + } + + @Test + void testPromptTemplatingJsonInPromptContextMissing() { + String promptText = + """ + Translate the following source string from [mojito_source_locale] to [mojito_target_locale]: + Source string: [mojito_source_string] + {{{optional: "comment": "[mojito_comment_string]",}} {{optional: "context": "[mojito_context_string]",}} {{optional: "plural_form": "[mojito_plural_form]"}}} + """; + + PluralForm one = new PluralForm(); + one.setName("one"); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("A friendly greeting"); + tmTextUnit.setPluralForm(one); + String prompt = + openAILLMService.getTranslationFormattedPrompt(promptText, tmTextUnit, "en", "fr"); + assertEquals( + """ + Translate the following source string from en to fr: + Source string: Hello + {"comment": "A friendly greeting", "plural_form": "one"}""", + prompt); + } + + @Test + void testPromptTemplatingInlineSentence() { + String promptText = + """ + Translate the following source string from [mojito_source_locale] to [mojito_target_locale]: + Source string: [mojito_source_string] + {{optional: The comment is: [mojito_comment_string]. }}{{optional: The context is: [mojito_context_string]. }}{{optional: The plural form is: [mojito_plural_form]. }} + """; + + PluralForm one = new PluralForm(); + one.setName("one"); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("Hello"); + tmTextUnit.setComment("A friendly greeting"); + tmTextUnit.setName("Hello"); + tmTextUnit.setPluralForm(one); + String prompt = + openAILLMService.getTranslationFormattedPrompt(promptText, tmTextUnit, "en", "fr"); + assertEquals( + """ + Translate the following source string from en to fr: + Source string: Hello + The comment is: A friendly greeting. The plural form is: one.""", + prompt); + } } diff --git a/webapp/src/test/java/com/box/l10n/mojito/service/ai/translation/AITranslateCronJobTest.java b/webapp/src/test/java/com/box/l10n/mojito/service/ai/translation/AITranslateCronJobTest.java new file mode 100644 index 0000000000..7f40786d40 --- /dev/null +++ b/webapp/src/test/java/com/box/l10n/mojito/service/ai/translation/AITranslateCronJobTest.java @@ -0,0 +1,442 @@ +package com.box.l10n.mojito.service.ai.translation; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.box.l10n.mojito.JSR310Migration; +import com.box.l10n.mojito.entity.AIPrompt; +import com.box.l10n.mojito.entity.Asset; +import com.box.l10n.mojito.entity.Locale; +import com.box.l10n.mojito.entity.PromptType; +import com.box.l10n.mojito.entity.Repository; +import com.box.l10n.mojito.entity.RepositoryLocale; +import com.box.l10n.mojito.entity.RepositoryLocaleAIPrompt; +import com.box.l10n.mojito.entity.TMTextUnit; +import com.box.l10n.mojito.entity.TMTextUnitCurrentVariant; +import com.box.l10n.mojito.entity.TMTextUnitVariant; +import com.box.l10n.mojito.entity.TmTextUnitPendingMT; +import com.box.l10n.mojito.service.ai.LLMService; +import com.box.l10n.mojito.service.ai.RepositoryLocaleAIPromptRepository; +import com.box.l10n.mojito.service.repository.RepositoryRepository; +import com.box.l10n.mojito.service.tm.TMService; +import com.box.l10n.mojito.service.tm.TMTextUnitRepository; +import com.box.l10n.mojito.service.tm.TMTextUnitVariantRepository; +import com.google.common.collect.Sets; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tags; +import java.time.Duration; +import java.time.ZonedDateTime; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.assertj.core.util.Lists; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.quartz.JobExecutionContext; +import org.quartz.JobExecutionException; + +public class AITranslateCronJobTest { + + @Mock TMService tmService; + + @Mock MeterRegistry meterRegistry; + + @Mock LLMService llmService; + + @Mock TMTextUnitRepository tmTextUnitRepository; + + @Mock TmTextUnitPendingMTRepository tmTextUnitPendingMTRepository; + + @Mock RepositoryRepository repositoryRepository; + + @Mock RepositoryLocaleAIPromptRepository repositoryLocaleAIPromptRepository; + + @Mock AITranslationTextUnitFilterService aiTranslationTextUnitFilterService; + + @Mock AIPrompt aiPrompt; + + @Mock TMTextUnitVariantRepository tmTextUnitVariantRepository; + + @Mock Repository repository; + + @Mock AITranslationService aiTranslationService; + + AITranslateCronJob aiTranslateCronJob; + + TMTextUnit tmTextUnit; + + TmTextUnitPendingMT tmTextUnitPendingMT; + + Locale german; + + AITranslationConfiguration aITranslationConfiguration; + + ArgumentCaptor> aiTranslationCaptor = ArgumentCaptor.forClass(List.class); + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + aiTranslateCronJob = new AITranslateCronJob(); + aiTranslateCronJob.meterRegistry = meterRegistry; + aiTranslateCronJob.llmService = llmService; + aiTranslateCronJob.tmTextUnitRepository = tmTextUnitRepository; + aiTranslateCronJob.tmTextUnitPendingMTRepository = tmTextUnitPendingMTRepository; + aiTranslateCronJob.repositoryLocaleAIPromptRepository = repositoryLocaleAIPromptRepository; + aiTranslateCronJob.aiTranslationTextUnitFilterService = aiTranslationTextUnitFilterService; + aiTranslateCronJob.tmTextUnitVariantRepository = tmTextUnitVariantRepository; + aiTranslateCronJob.aiTranslationService = aiTranslationService; + aITranslationConfiguration = new AITranslationConfiguration(); + aITranslationConfiguration.setEnabled(true); + aITranslationConfiguration.setCron("0 0/10 * * * ?"); + aITranslationConfiguration.setBatchSize(5); + aITranslationConfiguration.setExpiryDuration(Duration.ofHours(3)); + aiTranslateCronJob.threads = 1; + AITranslationConfiguration.RepositorySettings repositorySettings = + new AITranslationConfiguration.RepositorySettings(); + repositorySettings.setReuseSourceOnLanguageMatch(false); + Map repositorySettingsMap = + Collections.singletonMap("testRepo", repositorySettings); + aITranslationConfiguration.setRepositorySettings(repositorySettingsMap); + aiTranslateCronJob.aiTranslationConfiguration = aITranslationConfiguration; + Repository testRepo = new Repository(); + testRepo.setId(1L); + testRepo.setName("testRepo"); + Locale english = new Locale(); + english.setBcp47Tag("en-GB"); + english.setId(1L); + RepositoryLocale englishRepoLocale = new RepositoryLocale(testRepo, english, true, null); + Locale french = new Locale(); + french.setBcp47Tag("fr-FR"); + french.setId(2L); + german = new Locale(); + german.setBcp47Tag("de-DE"); + german.setId(3L); + Locale hibernoEnglish = new Locale(); + hibernoEnglish.setBcp47Tag("en-IE"); + hibernoEnglish.setId(4L); + RepositoryLocale frenchRepoLocale = + new RepositoryLocale(testRepo, french, true, englishRepoLocale); + frenchRepoLocale.setId(2L); + RepositoryLocale germanRepoLocale = + new RepositoryLocale(testRepo, german, true, englishRepoLocale); + germanRepoLocale.setId(3L); + RepositoryLocale hibernoEnglishRepoLocale = + new RepositoryLocale(testRepo, hibernoEnglish, true, englishRepoLocale); + hibernoEnglishRepoLocale.setId(4L); + testRepo.setRepositoryLocales( + Sets.newHashSet(frenchRepoLocale, germanRepoLocale, hibernoEnglishRepoLocale)); + testRepo.setSourceLocale(english); + + when(repository.getSourceLocale()).thenReturn(english); + when(repositoryRepository.findById(1L)).thenReturn(Optional.of(testRepo)); + tmTextUnitPendingMT = new TmTextUnitPendingMT(); + tmTextUnitPendingMT.setTmTextUnitId(1L); + tmTextUnitPendingMT.setId(1L); + tmTextUnitPendingMT.setCreatedDate(JSR310Migration.dateTimeNow()); + when(tmTextUnitPendingMTRepository.findByTmTextUnitId(1L)).thenReturn(tmTextUnitPendingMT); + when(aiTranslationTextUnitFilterService.isTranslatable( + isA(TMTextUnit.class), isA(Repository.class))) + .thenReturn(true); + + RepositoryLocaleAIPrompt testPrompt1 = new RepositoryLocaleAIPrompt(); + testPrompt1.setId(1L); + testPrompt1.setRepository(testRepo); + testPrompt1.setLocale(french); + testPrompt1.setAiPrompt(aiPrompt); + + RepositoryLocaleAIPrompt testPrompt2 = new RepositoryLocaleAIPrompt(); + testPrompt2.setId(2L); + testPrompt2.setRepository(testRepo); + testPrompt2.setLocale(german); + testPrompt2.setAiPrompt(aiPrompt); + + RepositoryLocaleAIPrompt testPrompt3 = new RepositoryLocaleAIPrompt(); + testPrompt3.setId(3L); + testPrompt3.setRepository(testRepo); + testPrompt3.setLocale(null); + testPrompt3.setAiPrompt(aiPrompt); + + tmTextUnit = new TMTextUnit(); + tmTextUnit.setId(1L); + tmTextUnit.setContent("content"); + tmTextUnit.setComment("comment"); + tmTextUnit.setName("name"); + Asset asset = new Asset(); + asset.setRepository(repository); + tmTextUnit.setAsset(asset); + when(tmTextUnitRepository.findById(1L)).thenReturn(Optional.of(tmTextUnit)); + when(repositoryLocaleAIPromptRepository.getActivePromptsByRepositoryAndPromptType( + testRepo.getId(), PromptType.TRANSLATION.toString())) + .thenReturn(Lists.list(testPrompt1, testPrompt2, testPrompt3)); + when(llmService.translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class))) + .thenReturn("translated"); + TMTextUnitCurrentVariant tmTextUnitCurrentVariant = new TMTextUnitCurrentVariant(); + tmTextUnitCurrentVariant.setLocale(english); + when(repository.getId()).thenReturn(1L); + when(repository.getName()).thenReturn("testRepo"); + when(repository.getRepositoryLocales()) + .thenReturn( + Sets.newHashSet( + englishRepoLocale, frenchRepoLocale, germanRepoLocale, hibernoEnglishRepoLocale)); + when(repository.getSourceLocale()).thenReturn(english); + when(meterRegistry.timer(anyString(), isA((Iterable.class)))) + .thenReturn(mock(io.micrometer.core.instrument.Timer.class)); + when(tmTextUnitRepository.findByIdWithAssetAndRepositoryAndTMFetched(1L)) + .thenReturn(Optional.of(tmTextUnit)); + when(meterRegistry.counter(anyString())) + .thenReturn(mock(io.micrometer.core.instrument.Counter.class)); + } + + @Test + public void testTranslateSuccess() throws Exception { + aiTranslateCronJob.translate(repository, tmTextUnit, tmTextUnitPendingMT); + verify(llmService, times(3)) + .translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class)); + verify(aiTranslationService, times(1)) + .insertMultiRowAITranslationVariant(anyLong(), aiTranslationCaptor.capture()); + List aiTranslations = aiTranslationCaptor.getValue(); + assertEquals(3, aiTranslations.size()); + assertThat(aiTranslations).extracting("localeId").containsExactlyInAnyOrder(2L, 3L, 4L); + } + + @Test + public void testTranslateFailure() throws Exception { + aiTranslateCronJob.translate(repository, tmTextUnit, tmTextUnitPendingMT); + when(llmService.translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class))) + .thenThrow(new RuntimeException("test")); + verify(llmService, times(3)) + .translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class)); + } + + @Test + public void testTranslateReuseSource() throws Exception { + AITranslationConfiguration.RepositorySettings repositorySettings = + new AITranslationConfiguration.RepositorySettings(); + repositorySettings.setReuseSourceOnLanguageMatch(true); + aITranslationConfiguration.setRepositorySettings( + Collections.singletonMap("testRepo", repositorySettings)); + aiTranslateCronJob.translate(repository, tmTextUnit, tmTextUnitPendingMT); + verify(llmService, times(2)) + .translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class)); + verify(aiTranslationService, times(1)) + .insertMultiRowAITranslationVariant(anyLong(), aiTranslationCaptor.capture()); + List aiTranslations = aiTranslationCaptor.getValue(); + assertEquals(3, aiTranslations.size()); + assertThat(aiTranslations).extracting("localeId").containsExactlyInAnyOrder(2L, 3L, 4L); + assertThat(aiTranslations) + .extracting("translation") + .containsExactlyInAnyOrder("content", "translated", "translated"); + } + + @Test + public void testNoTranslationIfLeveragedVariantExistsForLocale() { + TMTextUnitCurrentVariant tmTextUnitCurrentVariant = new TMTextUnitCurrentVariant(); + tmTextUnitCurrentVariant.setLocale(german); + Set variants = Sets.newHashSet(tmTextUnitCurrentVariant.getLocale()); + when(tmTextUnitVariantRepository.findLocalesWithVariantByTmTextUnit_Id(1L)) + .thenReturn(variants); + aiTranslateCronJob.translate(repository, tmTextUnit, tmTextUnitPendingMT); + verify(llmService, times(2)) + .translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class)); + verify(aiTranslationService, times(1)) + .insertMultiRowAITranslationVariant(anyLong(), aiTranslationCaptor.capture()); + List aiTranslations = aiTranslationCaptor.getValue(); + assertEquals(2, aiTranslations.size()); + assertThat(aiTranslations).extracting("localeId").containsExactlyInAnyOrder(2L, 4L); + } + + @Test + public void testPendingMTEntityIsExpired() throws Exception { + TmTextUnitPendingMT expiredPendingMT = new TmTextUnitPendingMT(); + expiredPendingMT.setCreatedDate(ZonedDateTime.now().minusHours(4)); + when(tmTextUnitPendingMTRepository.findByTmTextUnitId(1L)).thenReturn(expiredPendingMT); + aiTranslateCronJob.translate(repository, tmTextUnit, expiredPendingMT); + verify(llmService, never()) + .translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class)); + verify(tmService, never()) + .addTMTextUnitVariant( + eq(1L), + eq(2L), + eq("translated"), + eq("comment"), + eq(TMTextUnitVariant.Status.MT_TRANSLATED), + eq(false), + isA(ZonedDateTime.class)); + verify(tmService, never()) + .addTMTextUnitVariant( + eq(1L), + eq(3L), + eq("translated"), + eq("comment"), + eq(TMTextUnitVariant.Status.MT_TRANSLATED), + eq(false), + isA(ZonedDateTime.class)); + verify(tmService, never()) + .addTMTextUnitVariant( + eq(1L), + eq(4L), + eq("content"), + eq("comment"), + eq(TMTextUnitVariant.Status.MT_TRANSLATED), + eq(false), + isA(ZonedDateTime.class)); + verify(meterRegistry, times(1)).counter(eq("AITranslateCronJob.expired"), any(Tags.class)); + } + + @Test + public void testFilterMatchNoTranslation() throws Exception { + when(aiTranslationTextUnitFilterService.isTranslatable( + isA(TMTextUnit.class), isA(Repository.class))) + .thenReturn(false); + aiTranslateCronJob.translate(repository, tmTextUnit, tmTextUnitPendingMT); + verify(llmService, never()) + .translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class)); + verify(tmService, never()) + .addTMTextUnitVariant( + eq(1L), + eq(2L), + eq("translated"), + eq("comment"), + eq(TMTextUnitVariant.Status.MT_TRANSLATED), + eq(false), + isA(ZonedDateTime.class)); + verify(tmService, never()) + .addTMTextUnitVariant( + eq(1L), + eq(3L), + eq("translated"), + eq("comment"), + eq(TMTextUnitVariant.Status.MT_TRANSLATED), + eq(false), + isA(ZonedDateTime.class)); + verify(tmService, never()) + .addTMTextUnitVariant( + eq(1L), + eq(4L), + eq("content"), + eq("comment"), + eq(TMTextUnitVariant.Status.MT_TRANSLATED), + eq(false), + isA(ZonedDateTime.class)); + } + + @Test + public void testBatchLogic() throws JobExecutionException { + List pendingMTList = + Lists.list( + tmTextUnitPendingMT, + tmTextUnitPendingMT, + tmTextUnitPendingMT, + tmTextUnitPendingMT, + tmTextUnitPendingMT); + when(tmTextUnitPendingMTRepository.findBatch(5)) + .thenReturn(pendingMTList) + .thenReturn(pendingMTList) + .thenReturn(Collections.emptyList()); + aiTranslateCronJob.execute(mock(JobExecutionContext.class)); + verify(llmService, times(30)) + .translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class)); + verify(aiTranslationService, times(10)) + .insertMultiRowAITranslationVariant(anyLong(), aiTranslationCaptor.capture()); + List aiTranslations = aiTranslationCaptor.getValue(); + for (int i = 0; i < aiTranslationCaptor.getAllValues().size(); i++) { + aiTranslations = aiTranslationCaptor.getAllValues().get(i); + assertEquals(3, aiTranslations.size()); + assertThat(aiTranslations).extracting("localeId").containsExactlyInAnyOrder(2L, 3L, 4L); + } + verify(aiTranslationService, times(10)).sendForDeletion(isA(TmTextUnitPendingMT.class)); + } + + @Test + public void testBatchRequestFailLogic() throws JobExecutionException { + // Test verifies that if a single locale fails to translate, the rest of the batch is still + // processed + List pendingMTList = + Lists.list( + tmTextUnitPendingMT, + tmTextUnitPendingMT, + tmTextUnitPendingMT, + tmTextUnitPendingMT, + tmTextUnitPendingMT); + when(tmTextUnitPendingMTRepository.findBatch(5)) + .thenReturn(pendingMTList) + .thenReturn(pendingMTList) + .thenReturn(Collections.emptyList()); + when(llmService.translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class))) + .thenThrow(new RuntimeException("test")) + .thenReturn("translated"); + aiTranslateCronJob.execute(mock(JobExecutionContext.class)); + verify(llmService, times(30)) + .translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class)); + verify(aiTranslationService, times(10)) + .insertMultiRowAITranslationVariant(anyLong(), aiTranslationCaptor.capture()); + List aiTranslations = aiTranslationCaptor.getAllValues().getFirst(); + assertEquals(2, aiTranslations.size()); + for (int i = 1; i < aiTranslationCaptor.getAllValues().size(); i++) { + aiTranslations = aiTranslationCaptor.getAllValues().get(i); + assertEquals(3, aiTranslations.size()); + assertThat(aiTranslations).extracting("localeId").containsExactlyInAnyOrder(2L, 3L, 4L); + } + verify(aiTranslationService, times(10)).sendForDeletion(isA(TmTextUnitPendingMT.class)); + } + + @Test + public void testBatchLogicFailureToRetrieveTextUnit() throws JobExecutionException { + // Test verifies that if an exception is thrown for a single text unit, the rest of the batch is + // still processed + TmTextUnitPendingMT tmTextUnitPendingMT2 = new TmTextUnitPendingMT(); + tmTextUnitPendingMT2.setTmTextUnitId(2L); + List pendingMTList = + Lists.list( + tmTextUnitPendingMT2, + tmTextUnitPendingMT, + tmTextUnitPendingMT, + tmTextUnitPendingMT, + tmTextUnitPendingMT); + when(tmTextUnitPendingMTRepository.findBatch(5)) + .thenReturn(pendingMTList) + .thenReturn(Collections.emptyList()); + when(tmTextUnitRepository.findByIdWithAssetAndRepositoryAndTMFetched(2L)) + .thenReturn(Optional.empty()); + aiTranslateCronJob.execute(mock(JobExecutionContext.class)); + verify(llmService, times(12)) + .translate( + isA(TMTextUnit.class), isA(String.class), isA(String.class), isA(AIPrompt.class)); + verify(aiTranslationService, times(4)) + .insertMultiRowAITranslationVariant(anyLong(), aiTranslationCaptor.capture()); + List aiTranslations = aiTranslationCaptor.getAllValues().getFirst(); + assertEquals(3, aiTranslations.size()); + for (int i = 0; i < aiTranslationCaptor.getAllValues().size(); i++) { + aiTranslations = aiTranslationCaptor.getAllValues().get(i); + assertEquals(3, aiTranslations.size()); + assertThat(aiTranslations).extracting("localeId").containsExactlyInAnyOrder(2L, 3L, 4L); + } + verify(aiTranslationService, times(5)).sendForDeletion(isA(TmTextUnitPendingMT.class)); + } +} diff --git a/webapp/src/test/java/com/box/l10n/mojito/service/ai/translation/AITranslationServiceTest.java b/webapp/src/test/java/com/box/l10n/mojito/service/ai/translation/AITranslationServiceTest.java new file mode 100644 index 0000000000..7dc35ce2eb --- /dev/null +++ b/webapp/src/test/java/com/box/l10n/mojito/service/ai/translation/AITranslationServiceTest.java @@ -0,0 +1,144 @@ +package com.box.l10n.mojito.service.ai.translation; + +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.box.l10n.mojito.JSR310Migration; +import com.box.l10n.mojito.entity.Asset; +import com.box.l10n.mojito.entity.PromptType; +import com.box.l10n.mojito.entity.TM; +import com.box.l10n.mojito.entity.TMTextUnit; +import com.box.l10n.mojito.entity.TMTextUnitVariant; +import com.box.l10n.mojito.service.ai.RepositoryLocaleAIPromptRepository; +import com.box.l10n.mojito.service.repository.RepositoryRepository; +import com.box.l10n.mojito.service.tm.TMTextUnitVariantRepository; +import java.time.Duration; +import java.util.List; +import java.util.Set; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.springframework.jdbc.core.JdbcTemplate; + +public class AITranslationServiceTest { + + @Mock JdbcTemplate jdbcTemplate; + + @Mock TmTextUnitPendingMTRepository tmTextUnitPendingMTRepository; + + @Mock TMTextUnitVariantRepository tmTextUnitVariantRepository; + + @Mock RepositoryRepository repositoryRepository; + + @Mock RepositoryLocaleAIPromptRepository repositoryLocaleAIPromptRepository; + + AITranslationService aiTranslationService; + + ArgumentCaptor stringArgumentCaptor = ArgumentCaptor.forClass(String.class); + + @Before + public void before() { + MockitoAnnotations.openMocks(this); + aiTranslationService = new AITranslationService(); + aiTranslationService.jdbcTemplate = jdbcTemplate; + aiTranslationService.tmTextUnitPendingMTRepository = tmTextUnitPendingMTRepository; + aiTranslationService.repositoryRepository = repositoryRepository; + aiTranslationService.tmTextUnitVariantRepository = tmTextUnitVariantRepository; + aiTranslationService.repositoryLocaleAIPromptRepository = repositoryLocaleAIPromptRepository; + aiTranslationService.batchSize = 5; + aiTranslationService.timeout = Duration.ofSeconds(10); + aiTranslationService.maxTextUnitsAIRequest = 10; + + when(repositoryLocaleAIPromptRepository.findCountOfActiveRepositoryPromptsByType( + 1L, PromptType.TRANSLATION.toString())) + .thenReturn(1L); + } + + @Test + public void testSuccessfulCreatePendingMTEntitiesInBatches() { + aiTranslationService.createPendingMTEntitiesInBatches( + 1L, Set.of(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L)); + verify(jdbcTemplate, times(2)).update(stringArgumentCaptor.capture()); + } + + @Test + public void testCreatePendingMTEntitiesInBatchesSkipsWhenTooManyTextUnits() { + Set tmTextUnitIds = Set.of(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L); + aiTranslationService.maxTextUnitsAIRequest = 10; + + aiTranslationService.createPendingMTEntitiesInBatches(1L, tmTextUnitIds); + + verify(jdbcTemplate, times(0)).update(anyString()); + } + + @Test + public void testCreatePendingMTEntitiesInBatchesCreatesWhenActivePromptsExist() { + Set tmTextUnitIds = Set.of(1L, 2L, 3L); + when(repositoryLocaleAIPromptRepository.findCountOfActiveRepositoryPromptsByType( + 1L, PromptType.TRANSLATION.toString())) + .thenReturn(1L); + + aiTranslationService.createPendingMTEntitiesInBatches(1L, tmTextUnitIds); + + verify(jdbcTemplate, times(1)).update(anyString()); + } + + @Test + public void testCreatePendingMTEntitiesInBatchesSkipsWhenNoActivePrompts() { + Set tmTextUnitIds = Set.of(1L, 2L, 3L); + when(repositoryLocaleAIPromptRepository.findCountOfActiveRepositoryPromptsByType( + 1L, PromptType.TRANSLATION.toString())) + .thenReturn(0L); + + aiTranslationService.createPendingMTEntitiesInBatches(1L, tmTextUnitIds); + + verify(jdbcTemplate, times(0)).update(anyString()); + } + + @Test + public void testMultiRowAITranslationVariantInserts() { + TM tm = new TM(); + tm.setId(1L); + Asset asset = new Asset(); + asset.setId(1L); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setAsset(asset); + tmTextUnit.setTm(tm); + tmTextUnit.setId(1L); + tmTextUnit.setContent("content"); + tmTextUnit.setContentMd5("md5"); + AITranslation aiTranslation = new AITranslation(); + aiTranslation.setTmTextUnit(tmTextUnit); + aiTranslation.setLocaleId(1L); + aiTranslation.setTranslation("translation"); + aiTranslation.setStatus(TMTextUnitVariant.Status.MT_TRANSLATED); + aiTranslation.setCreatedDate(JSR310Migration.newDateTimeEmptyCtor()); + aiTranslation.setIncludedInLocalizedFile(false); + + AITranslation aiTranslation2 = new AITranslation(); + aiTranslation2.setTmTextUnit(tmTextUnit); + aiTranslation2.setLocaleId(2L); + aiTranslation2.setTranslation("translation2"); + aiTranslation2.setStatus(TMTextUnitVariant.Status.MT_TRANSLATED); + aiTranslation2.setCreatedDate(JSR310Migration.newDateTimeEmptyCtor()); + aiTranslation2.setIncludedInLocalizedFile(false); + + AITranslation aiTranslation3 = new AITranslation(); + aiTranslation3.setTmTextUnit(tmTextUnit); + aiTranslation3.setLocaleId(3L); + aiTranslation3.setTranslation("translation3"); + aiTranslation3.setStatus(TMTextUnitVariant.Status.MT_TRANSLATED); + aiTranslation3.setCreatedDate(JSR310Migration.newDateTimeEmptyCtor()); + aiTranslation3.setIncludedInLocalizedFile(false); + List translations = List.of(aiTranslation, aiTranslation2, aiTranslation3); + + aiTranslationService.insertMultiRowAITranslationVariant(1L, translations); + + verify(jdbcTemplate, times(2)).batchUpdate(anyString(), anyList()); + } +} diff --git a/webapp/src/test/java/com/box/l10n/mojito/service/ai/translation/AITranslationTextUnitFilterServiceTest.java b/webapp/src/test/java/com/box/l10n/mojito/service/ai/translation/AITranslationTextUnitFilterServiceTest.java new file mode 100644 index 0000000000..412af82663 --- /dev/null +++ b/webapp/src/test/java/com/box/l10n/mojito/service/ai/translation/AITranslationTextUnitFilterServiceTest.java @@ -0,0 +1,147 @@ +package com.box.l10n.mojito.service.ai.translation; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.when; + +import com.box.l10n.mojito.entity.Asset; +import com.box.l10n.mojito.entity.PluralForm; +import com.box.l10n.mojito.entity.Repository; +import com.box.l10n.mojito.entity.TMTextUnit; +import java.util.Map; +import java.util.regex.Pattern; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +public class AITranslationTextUnitFilterServiceTest { + + AITranslationTextUnitFilterService textUnitFilterService; + + Asset testAsset; + + Repository repository; + + @Before + public void setUp() { + testAsset = Mockito.mock(Asset.class); + repository = Mockito.mock(Repository.class); + when(testAsset.getRepository()).thenReturn(repository); + when(repository.getName()).thenReturn("test"); + textUnitFilterService = new AITranslationTextUnitFilterService(); + AITranslationFilterConfiguration translationFilterConfiguration = + new AITranslationFilterConfiguration(); + setTestParameters(true, true, true); + textUnitFilterService.excludePlaceholdersPatternMap = + Map.of("test", Pattern.compile("\\{[^\\}]*\\}")); + } + + @Test + public void testIsTranslatable() { + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setName("test"); + tmTextUnit.setContent("test content"); + tmTextUnit.setAsset(testAsset); + assertTrue(textUnitFilterService.isTranslatable(tmTextUnit, repository)); + } + + @Test + public void testIsTranslatableWithPlural() { + setTestParameters(true, false, false); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setName("test"); + tmTextUnit.setContent("test content"); + tmTextUnit.setAsset(testAsset); + + PluralForm otherPluralForm = new PluralForm(); + otherPluralForm.setName("other"); + tmTextUnit.setPluralForm(otherPluralForm); + assertFalse(textUnitFilterService.isTranslatable(tmTextUnit, repository)); + } + + @Test + public void testIsTranslatableWithPlaceholder() { + setTestParameters(false, true, false); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setName("test"); + tmTextUnit.setContent("test {content}"); + tmTextUnit.setAsset(testAsset); + assertFalse(textUnitFilterService.isTranslatable(tmTextUnit, repository)); + } + + @Test + public void testIsTranslatableWithHtmlTags() { + setTestParameters(false, false, true); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setName("test"); + tmTextUnit.setContent("test content"); + tmTextUnit.setAsset(testAsset); + assertFalse(textUnitFilterService.isTranslatable(tmTextUnit, repository)); + } + + @Test + public void testIsTranslatableWithHtmlTagsAndPlaceholders() { + setTestParameters(false, true, true); + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setName("test"); + tmTextUnit.setContent("test {content}"); + tmTextUnit.setAsset(testAsset); + assertFalse(textUnitFilterService.isTranslatable(tmTextUnit, repository)); + } + + @Test + public void isTranslatableShouldReturnTrueWhenAllExclusionsAreFalse() { + setTestParameters(false, false, false); + + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setContent("Text with html and {placeholder}, including plurals."); + tmTextUnit.setAsset(testAsset); + assertTrue(textUnitFilterService.isTranslatable(tmTextUnit, repository)); + } + + @Test + public void isTranslatableShouldReturnFalseForTextWithMultipleExclusions() { + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setContent("This text has html, {placeholder}, and could be a plural form."); + tmTextUnit.setAsset(testAsset); + tmTextUnit.setPluralForm(new PluralForm()); + assertFalse(textUnitFilterService.isTranslatable(tmTextUnit, repository)); + } + + @Test + public void isTranslatableShouldReturnTrueWhenNoExclusionEnabled() { + setTestParameters(false, false, false); + + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setPluralForm(new PluralForm()); + tmTextUnit.setContent("Text with html and {placeholder}"); + tmTextUnit.setAsset(testAsset); + assertTrue(textUnitFilterService.isTranslatable(tmTextUnit, repository)); + } + + @Test + public void isTranslatableTrueWhenNoRepositoryConfig() { + TMTextUnit tmTextUnit = new TMTextUnit(); + tmTextUnit.setContent("Text with html and {placeholder}"); + tmTextUnit.setAsset(testAsset); + textUnitFilterService.aiTranslationFilterConfiguration = new AITranslationFilterConfiguration(); + assertTrue(textUnitFilterService.isTranslatable(tmTextUnit, repository)); + } + + private void setTestParameters( + boolean excludePlurals, boolean excludePlaceholders, boolean excludeHtmlTags) { + AITranslationFilterConfiguration translationFilterConfiguration = + new AITranslationFilterConfiguration(); + AITranslationFilterConfiguration.RepositoryConfig repositoryConfig = + new AITranslationFilterConfiguration.RepositoryConfig(); + repositoryConfig.setExcludePlurals(excludePlurals); + repositoryConfig.setExcludePlaceholders(excludePlaceholders); + repositoryConfig.setExcludeHtmlTags(excludeHtmlTags); + repositoryConfig.setExcludePlaceholdersRegex("\\{[^\\}]*\\}"); + translationFilterConfiguration.setRepositoryConfig(Map.of("test", repositoryConfig)); + + textUnitFilterService.aiTranslationFilterConfiguration = translationFilterConfiguration; + textUnitFilterService.excludePlaceholdersPatternMap = + Map.of("test", Pattern.compile("\\{[^\\}]*\\}")); + } +}