Skip to content

Commit

Permalink
Add AI translation functionality (#147)
Browse files Browse the repository at this point in the history
* DB updates

* Initial GPT translate flow compiling

* Fix package names

* Added unit test for AITranslationTextUnitFilterService

* Fix test failures

* Add @ConditionalOnProperty to AITranslateJob to get tests running

* Updated filter configuration to handle different config across repositories

* Null Handling

* Fix return type for prompt query lookup

* Added implementation to extract translation from json key response

* Cleaned up default prompt handling, added unit tests

* Add metric to record time taken for text unit to be machine translated

* Parse language from BCP-47 tag correctly

* Remove unused join

* No AI translation if current variant exists for locale

* Code review updates

* Update log and add metric

* Added AI translation handling for virtual text units

* Updating code readability

* Tests passing

* Only translate supplied locales, added support for optional placeholder templating

* Further review updates

* Update AITranslateJob to a Quartz Cron job, other code review updates

* Add cron setup

* Added unit test for batch processing failure logic

* Update configuration to use ConfigProperties and added repository level config for re-using source on lang match

* Added CLI support for isJsonResponse boolean for ai prompts

* Added jsonResponseKey to ai prompt table and added findLocalesWithVariantByTmTextUnitId query

* Fix test failures

* Update exception handling logic for batch processing

* Added error logging

* Added parellization, updated database inserts to be done in batches of server side inserts

* Fix log condition

* Remove dead code, fix docs

* Doc fix
  • Loading branch information
maallen authored Oct 11, 2024
1 parent 80031ba commit a9eec85
Show file tree
Hide file tree
Showing 39 changed files with 2,749 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class CreateAIPromptCommand extends Command {

static Logger logger = LoggerFactory.getLogger(CreateAIPromptCommand.class);

@Autowired AIServiceClient AIServiceClient;
@Autowired AIServiceClient aiServiceClient;

@Parameter(
names = {"--repository-name", "-r"},
Expand Down Expand Up @@ -58,6 +58,18 @@ public class CreateAIPromptCommand extends Command {
description = "The temperature to use for the prompt")
float promptTemperature = 0.0F;

@Parameter(
names = {"--is-json-response", "-ijr"},
required = false,
description = "The prompt response is expected to be in JSON format from the LLM")
boolean isJsonResponse = false;

@Parameter(
names = {"--json-response-key", "-jrk"},
required = false,
description = "The key to use to extract the translation from the JSON response")
String jsonResponseKey;

@Autowired private ConsoleWriter consoleWriter;

@Override
Expand All @@ -67,14 +79,19 @@ protected void execute() throws CommandException {

private void createPrompt() {
logger.debug("Received request to create prompt");
AIPromptCreateRequest AIPromptCreateRequest = new AIPromptCreateRequest();
AIPromptCreateRequest.setRepositoryName(repository);
AIPromptCreateRequest.setSystemPrompt(systemPromptText);
AIPromptCreateRequest.setUserPrompt(userPromptText);
AIPromptCreateRequest.setModelName(modelName);
AIPromptCreateRequest.setPromptType(promptType);
AIPromptCreateRequest.setPromptTemperature(promptTemperature);
long promptId = AIServiceClient.createPrompt(AIPromptCreateRequest);
AIPromptCreateRequest aiPromptCreateRequest = new AIPromptCreateRequest();
aiPromptCreateRequest.setRepositoryName(repository);
aiPromptCreateRequest.setSystemPrompt(systemPromptText);
aiPromptCreateRequest.setUserPrompt(userPromptText);
aiPromptCreateRequest.setModelName(modelName);
aiPromptCreateRequest.setPromptType(promptType);
aiPromptCreateRequest.setPromptTemperature(promptTemperature);
aiPromptCreateRequest.setJsonResponse(isJsonResponse);
if (isJsonResponse && jsonResponseKey == null) {
throw new CommandException("jsonResponseKey is required when isJsonResponse is true");
}
aiPromptCreateRequest.setJsonResponseKey(jsonResponseKey);
long promptId = aiServiceClient.createPrompt(aiPromptCreateRequest);
consoleWriter.newLine().a("Prompt created with id: " + promptId).println();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,11 @@ public record Choice(
int index, Delta delta, @JsonProperty("finish_reason") String finishReason) {

public enum FinishReasons {
STOP("stop");
STOP("stop"),
LENGTH("length"),
FUNCTION_CALL("function_call"),
CONTENT_FILTER("content_filter"),
NULL("null");

String value;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ public class AIPromptCreateRequest {
private boolean deleted;
private String repositoryName;
private String promptType;
private boolean isJsonResponse;
private String jsonResponseKey;

public boolean isDeleted() {
return deleted;
Expand Down Expand Up @@ -68,4 +70,20 @@ public String getRepositoryName() {
public void setRepositoryName(String repositoryName) {
this.repositoryName = repositoryName;
}

public boolean isJsonResponse() {
return isJsonResponse;
}

public void setJsonResponse(boolean jsonResponse) {
isJsonResponse = jsonResponse;
}

public String getJsonResponseKey() {
return jsonResponseKey;
}

public void setJsonResponseKey(String jsonResponseKey) {
this.jsonResponseKey = jsonResponseKey;
}
}
36 changes: 36 additions & 0 deletions webapp/src/main/java/com/box/l10n/mojito/entity/AIPrompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.FetchType;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.OneToMany;
import jakarta.persistence.OrderBy;
import jakarta.persistence.Table;
Expand Down Expand Up @@ -31,6 +33,10 @@ public class AIPrompt extends BaseEntity {
@Column(name = "deleted")
private boolean deleted;

@ManyToOne
@JoinColumn(name = "prompt_type_id")
private AIPromptType promptType;

@CreatedDate
@Column(name = "created_date")
private ZonedDateTime createdDate;
Expand All @@ -44,6 +50,12 @@ public class AIPrompt extends BaseEntity {
@OrderBy("orderIndex ASC")
List<AIPromptContextMessage> contextMessages;

@Column(name = "json_response")
private boolean jsonResponse;

@Column(name = "json_response_key")
private String jsonResponseKey;

public String getModelName() {
return modelName;
}
Expand Down Expand Up @@ -107,4 +119,28 @@ public ZonedDateTime getLastModifiedDate() {
public void setLastModifiedDate(ZonedDateTime lastModifiedDate) {
this.lastModifiedDate = lastModifiedDate;
}

public AIPromptType getPromptType() {
return promptType;
}

public void setPromptType(AIPromptType promptType) {
this.promptType = promptType;
}

public boolean isJsonResponse() {
return jsonResponse;
}

public void setJsonResponse(boolean jsonResponse) {
this.jsonResponse = jsonResponse;
}

public String getJsonResponseKey() {
return jsonResponseKey;
}

public void setJsonResponseKey(String jsonResponseKey) {
this.jsonResponseKey = jsonResponseKey;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package com.box.l10n.mojito.entity;

public enum PromptType {
SOURCE_STRING_CHECKER;
SOURCE_STRING_CHECKER,
TRANSLATION;

public String toString() {
return name();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package com.box.l10n.mojito.entity;

import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.Table;
import jakarta.persistence.UniqueConstraint;

@Entity
@Table(
name = "repository_locale_ai_prompt",
uniqueConstraints = {
@UniqueConstraint(
name = "UK__REPOSITORY_LOCALE_AI_PROMPT__REPO_ID__LOCALE_ID__AI_PROMPT",
columnNames = {"repository_id", "locale_id", "ai_prompt_id"})
})
public class RepositoryLocaleAIPrompt extends BaseEntity {

@ManyToOne
@JoinColumn(name = "repository_id", nullable = false)
private Repository repository;

@ManyToOne
@JoinColumn(name = "locale_id", nullable = true)
private Locale locale;

@ManyToOne
@JoinColumn(name = "ai_prompt_id", nullable = false)
private AIPrompt aiPrompt;

@Column(name = "disabled")
private boolean disabled;

public AIPrompt getAiPrompt() {
return aiPrompt;
}

public void setAiPrompt(AIPrompt aiPrompt) {
this.aiPrompt = aiPrompt;
}

public Repository getRepository() {
return repository;
}

public void setRepository(Repository repository) {
this.repository = repository;
}

public Locale getLocale() {
return locale;
}

public void setLocale(Locale locale) {
this.locale = locale;
}

public boolean isDisabled() {
return disabled;
}

public void setDisabled(boolean disabled) {
this.disabled = disabled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ public enum Status {
* TRANSLATION_NEEDED status along with a comment.
*/
REVIEW_NEEDED,

MT_TRANSLATED,

MT_REVIEW,
/** A string that doesn't need any work to be performed on it. */
APPROVED;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.box.l10n.mojito.entity;

import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.Table;
import java.time.ZonedDateTime;

@Entity
@Table(name = "tm_text_unit_pending_mt")
public class TmTextUnitPendingMT extends BaseEntity {

@Column(name = "tm_text_unit_id")
private Long tmTextUnitId;

@Column(name = "created_date")
private ZonedDateTime createdDate;

public ZonedDateTime getCreatedDate() {
return createdDate;
}

public void setCreatedDate(ZonedDateTime createdDate) {
this.createdDate = createdDate;
}

public Long getTmTextUnitId() {
return tmTextUnitId;
}

public void setTmTextUnitId(Long tmTextUnitId) {
this.tmTextUnitId = tmTextUnitId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ public AIException(String message) {
public AIException(String message, Exception e) {
super(message, e);
}

public AIException(String message, Throwable t) {
super(message, t);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ public class AIPromptCreateRequest {
private boolean deleted;
private String repositoryName;
private String promptType;
private boolean isJsonResponse;
private String jsonResponseKey;

public boolean isDeleted() {
return deleted;
Expand Down Expand Up @@ -68,4 +70,20 @@ public String getRepositoryName() {
public void setRepositoryName(String repositoryName) {
this.repositoryName = repositoryName;
}

public boolean isJsonResponse() {
return isJsonResponse;
}

public void setJsonResponse(boolean jsonResponse) {
isJsonResponse = jsonResponse;
}

public String getJsonResponseKey() {
return jsonResponseKey;
}

public void setJsonResponseKey(String jsonResponseKey) {
this.jsonResponseKey = jsonResponseKey;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ public interface AIPromptRepository extends JpaRepository<AIPrompt, Long> {

@Query(
"SELECT ap FROM AIPrompt ap "
+ "JOIN RepositoryAIPrompt rap ON ap.id = rap.aiPromptId "
+ "JOIN AIPromptType apt ON rap.promptTypeId = apt.id "
+ "WHERE rap.repositoryId = :repositoryId AND apt.name = :promptTypeName AND ap.deleted = false")
+ "JOIN RepositoryLocaleAIPrompt rlap ON ap.id = rlap.aiPrompt.id "
+ "JOIN AIPromptType apt ON ap.promptType.id = apt.id "
+ "WHERE rlap.repository.id = :repositoryId AND apt.name = :promptTypeName AND ap.deleted = false")
List<AIPrompt> findByRepositoryIdAndPromptTypeName(
@Param("repositoryId") Long repositoryId, @Param("promptTypeName") String promptTypeName);

List<AIPrompt> findByDeletedFalse();

@Query(
"SELECT ap FROM AIPrompt ap "
+ "JOIN RepositoryAIPrompt rap ON ap.id = rap.aiPromptId "
+ "WHERE rap.repositoryId = :repositoryId AND ap.deleted = false")
+ "JOIN RepositoryLocaleAIPrompt rlap ON ap.id = rlap.aiPrompt.id "
+ "WHERE rlap.repository.id = :repositoryId AND ap.deleted = false")
List<AIPrompt> findByRepositoryIdAndDeletedFalse(Long repositoryId);
}
Loading

0 comments on commit a9eec85

Please sign in to comment.