Skip to content

Commit

Permalink
Added parellization, updated database inserts to be done in batches o…
Browse files Browse the repository at this point in the history
…f server side inserts
  • Loading branch information
maallen committed Oct 10, 2024
1 parent a8abaa0 commit 41edb05
Show file tree
Hide file tree
Showing 12 changed files with 685 additions and 297 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,23 @@
import com.box.l10n.mojito.rest.ai.AIException;
import com.box.l10n.mojito.service.ai.LLMService;
import com.box.l10n.mojito.service.ai.RepositoryLocaleAIPromptRepository;
import com.box.l10n.mojito.service.repository.RepositoryRepository;
import com.box.l10n.mojito.service.tm.TMService;
import com.box.l10n.mojito.service.tm.TMTextUnitCurrentVariantRepository;
import com.box.l10n.mojito.service.tm.TMTextUnitRepository;
import com.box.l10n.mojito.service.tm.TMTextUnitVariantRepository;
import com.google.common.collect.Lists;
import io.micrometer.core.annotation.Timed;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tags;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.codec.digest.DigestUtils;
import org.quartz.DisallowConcurrentExecution;
import org.quartz.Job;
import org.quartz.JobDetail;
Expand All @@ -34,10 +38,11 @@
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Lazy;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.scheduling.quartz.CronTriggerFactoryBean;
import org.springframework.scheduling.quartz.JobDetailFactoryBean;
import org.springframework.stereotype.Component;
Expand All @@ -57,26 +62,29 @@ public class AITranslateCronJob implements Job {

private static final String REPOSITORY_DEFAULT_PROMPT = "repository_default_prompt";

@Autowired TmTextUnitPendingMTRepository tmTextUnitPendingMTRepository;

@Autowired TMTextUnitRepository tmTextUnitRepository;

@Autowired TMTextUnitCurrentVariantRepository tmTextUnitCurrentVariantRepository;

@Lazy @Autowired TMService tmService;
@Autowired TMTextUnitVariantRepository tmTextUnitVariantRepository;

@Autowired LLMService llmService;

@Autowired MeterRegistry meterRegistry;

@Autowired RepositoryRepository repositoryRepository;

@Autowired RepositoryLocaleAIPromptRepository repositoryLocaleAIPromptRepository;

@Autowired AITranslationTextUnitFilterService aiTranslationTextUnitFilterService;

@Autowired AITranslationConfiguration aiTranslationConfiguration;

@Autowired AITranslationService aiTranslationService;

@Autowired TmTextUnitPendingMTRepository tmTextUnitPendingMTRepository;

@Autowired JdbcTemplate jdbcTemplate;

@Value("${l10n.ai.translation.job.threads:1}")
int threads;

@Timed("AITranslateCronJob.translate")
public void translate(Repository repository, TMTextUnit tmTextUnit, TmTextUnitPendingMT pendingMT)
throws AIException {
Expand Down Expand Up @@ -114,8 +122,7 @@ public void translate(Repository repository, TMTextUnit tmTextUnit, TmTextUnitPe

private Set<Locale> getLocalesForMT(Repository repository, TMTextUnit tmTextUnit) {
Set<Locale> localesWithVariants =
tmTextUnitCurrentVariantRepository.findLocalesWithVariantByTmTextUnit_Id(
tmTextUnit.getId());
tmTextUnitVariantRepository.findLocalesWithVariantByTmTextUnit_Id(tmTextUnit.getId());
return repository.getRepositoryLocales().stream()
.map(RepositoryLocale::getLocale)
.filter(
Expand All @@ -140,6 +147,7 @@ private void translateLocales(
? rlap.getLocale().getBcp47Tag()
: REPOSITORY_DEFAULT_PROMPT,
Function.identity()));
List<AITranslation> aiTranslations = Lists.newArrayList();
localesForMT.forEach(
targetLocale -> {
try {
Expand All @@ -148,7 +156,8 @@ private void translateLocales(
.getRepositorySettings(repository.getName())
.isReuseSourceOnLanguageMatch()
&& targetLocale.getBcp47Tag().startsWith(sourceLang)) {
reuseSourceStringAsTranslation(tmTextUnit, repository, targetLocale, sourceLang);
aiTranslations.add(
reuseSourceStringAsTranslation(tmTextUnit, repository, targetLocale, sourceLang));
return;
}
// Get the prompt override for this locale if it exists, otherwise use the
Expand All @@ -163,8 +172,9 @@ private void translateLocales(
tmTextUnit.getId(),
targetLocale.getBcp47Tag(),
repositoryLocaleAIPrompt.getAiPrompt().getId());
executeTranslationPrompt(
tmTextUnit, repository, targetLocale, repositoryLocaleAIPrompt);
aiTranslations.add(
executeTranslationPrompt(
tmTextUnit, repository, targetLocale, repositoryLocaleAIPrompt));
} else {
logger.debug(
"No active translation prompt found for locale: {}, skipping AI translation.",
Expand All @@ -185,9 +195,10 @@ private void translateLocales(
Tags.of("repository", repository.getName(), "locale", targetLocale.getBcp47Tag()));
}
});
aiTranslationService.insertMultiRowAITranslationVariant(tmTextUnit.getId(), aiTranslations);
}

private void reuseSourceStringAsTranslation(
private AITranslation reuseSourceStringAsTranslation(
TMTextUnit tmTextUnit, Repository repository, Locale targetLocale, String sourceLang) {
logger.debug(
"Target language {} matches source language {}, re-using source string as translation.",
Expand All @@ -197,17 +208,10 @@ private void reuseSourceStringAsTranslation(
"AITranslateCronJob.translate.reuseSourceAsTranslation",
Tags.of("repository", repository.getName(), "locale", targetLocale.getBcp47Tag()));

tmService.addTMTextUnitVariant(
tmTextUnit.getId(),
targetLocale.getId(),
tmTextUnit.getContent(),
tmTextUnit.getComment(),
TMTextUnitVariant.Status.MT_TRANSLATED,
false,
JSR310Migration.dateTimeNow());
return createAITranslationDTO(tmTextUnit, targetLocale, tmTextUnit.getContent());
}

private void executeTranslationPrompt(
private AITranslation executeTranslationPrompt(
TMTextUnit tmTextUnit,
Repository repository,
Locale targetLocale,
Expand All @@ -218,17 +222,23 @@ private void executeTranslationPrompt(
repository.getSourceLocale().getBcp47Tag(),
targetLocale.getBcp47Tag(),
repositoryLocaleAIPrompt.getAiPrompt());
tmService.addTMTextUnitVariant(
tmTextUnit.getId(),
targetLocale.getId(),
translation,
tmTextUnit.getComment(),
TMTextUnitVariant.Status.MT_TRANSLATED,
false,
JSR310Migration.dateTimeNow());
meterRegistry.counter(
"AITranslateCronJob.translate.success",
Tags.of("repository", repository.getName(), "locale", targetLocale.getBcp47Tag()));
return createAITranslationDTO(tmTextUnit, targetLocale, translation);
}

private AITranslation createAITranslationDTO(
TMTextUnit tmTextUnit, Locale locale, String translation) {
AITranslation aiTranslation = new AITranslation();
aiTranslation.setTmTextUnit(tmTextUnit);
aiTranslation.setContentMd5(DigestUtils.md5Hex(translation));
aiTranslation.setLocaleId(locale.getId());
aiTranslation.setTranslation(translation);
aiTranslation.setIncludedInLocalizedFile(false);
aiTranslation.setStatus(TMTextUnitVariant.Status.MT_TRANSLATED);
aiTranslation.setCreatedDate(JSR310Migration.dateTimeNow());
return aiTranslation;
}

private boolean isExpired(TmTextUnitPendingMT pendingMT) {
Expand All @@ -252,38 +262,71 @@ private boolean isExpired(TmTextUnitPendingMT pendingMT) {
@Timed("AITranslateCronJob.execute")
public void execute(JobExecutionContext jobExecutionContext) throws JobExecutionException {
logger.info("Executing AITranslateCronJob");

ExecutorService executorService = Executors.newFixedThreadPool(threads);

List<TmTextUnitPendingMT> pendingMTs;
do {
pendingMTs =
tmTextUnitPendingMTRepository.findBatch(aiTranslationConfiguration.getBatchSize());
logger.info("Processing {} pending MTs", pendingMTs.size());
pendingMTs.forEach(
pendingMT -> {
try {
TMTextUnit tmTextUnit = getTmTextUnit(pendingMT);
Repository repository = tmTextUnit.getAsset().getRepository();
translate(repository, tmTextUnit, pendingMT);
} catch (Exception e) {
logger.error(
"Error processing pending MT for text unit id: {}",
pendingMT.getTmTextUnitId(),
e);
meterRegistry.counter("AITranslateCronJob.pendingMT.error");
} finally {
if (pendingMT != null) {
logger.debug(
"Deleting pending MT for tmTextUnitId: {}", pendingMT.getTmTextUnitId());
tmTextUnitPendingMTRepository.delete(pendingMT);
}
}
});
} while (!pendingMTs.isEmpty());
try {
do {
pendingMTs =
tmTextUnitPendingMTRepository.findBatch(aiTranslationConfiguration.getBatchSize());
logger.info("Processing {} pending MTs", pendingMTs.size());

List<CompletableFuture<Void>> futures =
pendingMTs.stream()
.map(
pendingMT ->
CompletableFuture.runAsync(
() -> {
try {
TMTextUnit tmTextUnit = getTmTextUnit(pendingMT);
Repository repository = tmTextUnit.getAsset().getRepository();
translate(repository, tmTextUnit, pendingMT);
} catch (Exception e) {
logger.error(
"Error processing pending MT for text unit id: {}",
pendingMT.getTmTextUnitId(),
e);
meterRegistry
.counter("AITranslateCronJob.pendingMT.error")
.increment();
} finally {
if (pendingMT != null) {
logger.debug(
"Sending pending MT for tmTextUnitId: {} for deletion",
pendingMT.getTmTextUnitId());
aiTranslationService.sendForDeletion(pendingMT);
}
}
},
executorService))
.toList();

// Wait for all tasks in this batch to complete
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
} while (!pendingMTs.isEmpty());
} finally {
shutdownExecutor(executorService);
}

logger.info("Finished executing AITranslateCronJob");
}

private static void shutdownExecutor(ExecutorService executorService) {
try {
executorService.shutdown();
if (!executorService.awaitTermination(1, TimeUnit.MINUTES)) {
logger.error("Thread pool tasks didn't finish in the expected time.");
executorService.shutdownNow();
}
} catch (InterruptedException e) {
executorService.shutdownNow();
}
}

private TMTextUnit getTmTextUnit(TmTextUnitPendingMT pendingMT) {
return tmTextUnitRepository
.findByIdWithAssetAndRepositoryFetched(pendingMT.getTmTextUnitId())
.findByIdWithAssetAndRepositoryAndTMFetched(pendingMT.getTmTextUnitId())
.orElseThrow(
() -> new AIException("TMTextUnit not found for id: " + pendingMT.getTmTextUnitId()));
}
Expand All @@ -294,6 +337,7 @@ public JobDetailFactoryBean jobDetailAiTranslateCronJob() {
jobDetailFactory.setJobClass(AITranslateCronJob.class);
jobDetailFactory.setDescription("Translate text units in batches via AI");
jobDetailFactory.setDurability(true);
jobDetailFactory.setName("aiTranslateCron");
return jobDetailFactory;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package com.box.l10n.mojito.service.ai.translation;

import com.box.l10n.mojito.entity.TMTextUnit;
import com.box.l10n.mojito.entity.TMTextUnitVariant;
import java.time.ZonedDateTime;

public class AITranslation {

TMTextUnit tmTextUnit;
Long localeId;
String translation;
String contentMd5;
TMTextUnitVariant.Status status;
boolean includedInLocalizedFile;
ZonedDateTime createdDate;

public TMTextUnit getTmTextUnit() {
return tmTextUnit;
}

public void setTmTextUnit(TMTextUnit tmTextUnit) {
this.tmTextUnit = tmTextUnit;
}

public Long getLocaleId() {
return localeId;
}

public void setLocaleId(Long localeId) {
this.localeId = localeId;
}

public String getTranslation() {
return translation;
}

public void setTranslation(String translation) {
this.translation = translation;
}

public TMTextUnitVariant.Status getStatus() {
return status;
}

public void setStatus(TMTextUnitVariant.Status status) {
this.status = status;
}

public boolean isIncludedInLocalizedFile() {
return includedInLocalizedFile;
}

public void setIncludedInLocalizedFile(boolean includedInLocalizedFile) {
this.includedInLocalizedFile = includedInLocalizedFile;
}

public ZonedDateTime getCreatedDate() {
return createdDate;
}

public void setCreatedDate(ZonedDateTime createdDate) {
this.createdDate = createdDate;
}

public String getContentMd5() {
return contentMd5;
}

public void setContentMd5(String content_md5) {
this.contentMd5 = content_md5;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
public class AITranslationConfiguration {

private boolean enabled = false;
private int batchSize = 1000;
private int batchSize = 10;

/**
* Duration after which a pending MT is considered expired and will not be processed in AI
Expand All @@ -28,6 +28,16 @@ public class AITranslationConfiguration {
private Map<String, RepositorySettings> repositorySettings = Maps.newHashMap();

public static class RepositorySettings {
/**
* If true, reuse the source text if the language of the source text matches the target
* language.
*
* <p>Uses the language piece of the BCP47 tag to determine if the language matches. For
* example, en-US and en-GB would match.
*
* <p>If a match is found, the source text will be used as the translation for the target locale
* and no AI translation will be requested for en-GB.
*/
private boolean reuseSourceOnLanguageMatch = false;

public Boolean isReuseSourceOnLanguageMatch() {
Expand Down
Loading

0 comments on commit 41edb05

Please sign in to comment.