diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java index e71589f300ca..4869ec522c3c 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMServiceType.java @@ -2,5 +2,5 @@ public enum LLMServiceType { ATHENA_PRELIMINARY_FEEDBACK, ATHENA_FEEDBACK_SUGGESTION, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION, - IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, IRIS_CITATION_PIPELINE, NOT_SET + IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, IRIS_CITATION_PIPELINE, IRIS_LECTURE_RETRIEVAL_PIPELINE, IRIS_LECTURE_INGESTION, NOT_SET } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java index cd26f899113a..3f51390ce397 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java @@ -7,8 +7,6 @@ import jakarta.persistence.Entity; import jakarta.persistence.EnumType; import jakarta.persistence.Enumerated; -import jakarta.persistence.Inheritance; -import jakarta.persistence.InheritanceType; import jakarta.persistence.JoinColumn; import jakarta.persistence.ManyToOne; import jakarta.persistence.Table; @@ -18,16 +16,13 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonTypeInfo; import de.tum.cit.aet.artemis.exercise.domain.Exercise; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; @Entity @Table(name = "llm_token_usage") -@Inheritance(strategy = InheritanceType.SINGLE_TABLE) @Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE) -@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") @JsonInclude(JsonInclude.Include.NON_EMPTY) public class LLMTokenUsage extends DomainObject { @@ -39,16 +34,16 @@ public class LLMTokenUsage extends DomainObject { private String model; @Column(name = "num_input_tokens") - private int num_input_tokens; + private int numInputTokens; - @Column(name = "cost_per_input_token") - private float cost_per_input_token; + @Column(name = "cost_per_million_input_tokens") + private float costPerMillionInputTokens; @Column(name = "num_output_tokens") - private int num_output_tokens; + private int numOutputTokens; - @Column(name = "cost_per_output_token") - private float cost_per_output_token; + @Column(name = "cost_per_million_output_tokens") + private float costPerMillionOutputTokens; @Nullable @ManyToOne @@ -66,11 +61,11 @@ public class LLMTokenUsage extends DomainObject { private long userId; @Nullable - @Column(name = "timestamp") - private ZonedDateTime timestamp = ZonedDateTime.now(); + @Column(name = "time") + private ZonedDateTime time = ZonedDateTime.now(); @Column(name = "trace_id") - private Long traceId; + private String traceId; @Nullable @ManyToOne @@ -94,36 +89,36 @@ public void setModel(String model) { this.model = model; } - public float getCost_per_input_token() { - return cost_per_input_token; + public float getCostPerMillionInputTokens() { + return costPerMillionInputTokens; } - public void setCost_per_input_token(float cost_per_input_token) { - this.cost_per_input_token = cost_per_input_token; + public void setCostPerMillionInputTokens(float costPerMillionInputToken) { + this.costPerMillionInputTokens = costPerMillionInputToken; } - public float getCost_per_output_token() { - return cost_per_output_token; + public float getCostPerMillionOutputTokens() { + return costPerMillionOutputTokens; } - public void setCost_per_output_token(float cost_per_output_token) { - this.cost_per_output_token = cost_per_output_token; + public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) { + this.costPerMillionOutputTokens = costPerMillionOutputToken; } - public int getNum_input_tokens() { - return num_input_tokens; + public int getNumInputTokens() { + return numInputTokens; } - public void setNum_input_tokens(int num_input_tokens) { - this.num_input_tokens = num_input_tokens; + public void setNumInputTokens(int numInputTokens) { + this.numInputTokens = numInputTokens; } - public int getNum_output_tokens() { - return num_output_tokens; + public int getNumOutputTokens() { + return numOutputTokens; } - public void setNum_output_tokens(int num_output_tokens) { - this.num_output_tokens = num_output_tokens; + public void setNumOutputTokens(int numOutputTokens) { + this.numOutputTokens = numOutputTokens; } public Course getCourse() { @@ -150,19 +145,19 @@ public void setUserId(long userId) { this.userId = userId; } - public ZonedDateTime getTimestamp() { - return timestamp; + public ZonedDateTime getTime() { + return time; } - public void setTimestamp(ZonedDateTime timestamp) { - this.timestamp = timestamp; + public void setTime(ZonedDateTime time) { + this.time = time; } - public Long getTraceId() { + public String getTraceId() { return traceId; } - public void setTraceId(Long traceId) { + public void setTraceId(String traceId) { this.traceId = traceId; } @@ -176,8 +171,8 @@ public void setIrisMessage(IrisMessage message) { @Override public String toString() { - return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", num_input_tokens=" + num_input_tokens + ", cost_per_input_token=" + cost_per_input_token - + ", num_output_tokens=" + num_output_tokens + ", cost_per_output_token=" + cost_per_output_token + ", course=" + course + ", exercise=" + exercise + ", userId=" - + userId + ", timestamp=" + timestamp + ", trace_id=" + traceId + ", irisMessage=" + irisMessage + '}'; + return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", num_input_tokens=" + numInputTokens + ", cost_per_input_token=" + costPerMillionInputTokens + + ", num_output_tokens=" + numOutputTokens + ", cost_per_output_token=" + costPerMillionOutputTokens + ", course=" + course + ", exercise=" + exercise + ", userId=" + + userId + ", timestamp=" + time + ", trace_id=" + traceId + ", irisMessage=" + irisMessage + '}'; } } diff --git a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java index 3446c61117e3..8dab3e056029 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/service/LLMTokenUsageService.java @@ -4,7 +4,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.UUID; import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Service; @@ -16,6 +15,7 @@ import de.tum.cit.aet.artemis.exercise.domain.Exercise; import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage; import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; +import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob; /** * Service for managing the LLMTokenUsage by all LLMs in Artemis @@ -31,31 +31,25 @@ public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) { } /** - * saves the tokens used for a specific IrisMessage or Athena call - * in case of an Athena call IrisMessage can be null and the - * LLMServiceType in tokens has to by Athena + * method saves the token usage to the database with a link to the IrisMessage + * messages of the same job are grouped together by saving the job id as a trace id * - * @param message IrisMessage related to the TokenUsage - * @param exercise Exercise in which the request was made - * @param user User that made the request - * @param course Course in which the request was made - * @param tokens List with Tokens of the PyrisLLMCostDTO Mdel - * @return List of the created LLMTokenUsage entries + * @param job used to create a unique traceId to group multiple LLM calls + * @param message IrisMessage to map the usage to an IrisMessage + * @param exercise to map the token cost to an exercise + * @param user to map the token cost to a user + * @param course to map the token to a course + * @param tokens token cost lsit of type PyrisLLMCostDTO + * @return list of the saved data */ - - public List saveTokenUsage(IrisMessage message, Exercise exercise, User user, Course course, List tokens) { + public List saveIrisTokenUsage(PyrisJob job, IrisMessage message, Exercise exercise, User user, Course course, List tokens) { List tokenUsages = new ArrayList<>(); - // Combine current time and UUID to create a unique traceId - long timestamp = System.currentTimeMillis(); - long uuidComponent = UUID.randomUUID().getLeastSignificantBits() & Long.MAX_VALUE; - Long traceId = timestamp + uuidComponent; - for (PyrisLLMCostDTO cost : tokens) { LLMTokenUsage llmTokenUsage = new LLMTokenUsage(); if (message != null) { llmTokenUsage.setIrisMessage(message); - llmTokenUsage.setTimestamp(message.getSentAt()); + llmTokenUsage.setTime(message.getSentAt()); } llmTokenUsage.setServiceType(cost.pipeline()); llmTokenUsage.setExercise(exercise); @@ -63,12 +57,12 @@ public List saveTokenUsage(IrisMessage message, Exercise exercise llmTokenUsage.setUserId(user.getId()); } llmTokenUsage.setCourse(course); - llmTokenUsage.setNum_input_tokens(cost.num_input_tokens()); - llmTokenUsage.setCost_per_input_token(cost.cost_per_input_token()); - llmTokenUsage.setNum_output_tokens(cost.num_output_tokens()); - llmTokenUsage.setCost_per_output_token(cost.cost_per_output_token()); - llmTokenUsage.setModel(cost.model_info()); - llmTokenUsage.setTraceId(traceId); + llmTokenUsage.setNumInputTokens(cost.numInputTokens()); + llmTokenUsage.setCostPerMillionInputTokens(cost.costPerInputToken()); + llmTokenUsage.setNumOutputTokens(cost.numOutputTokens()); + llmTokenUsage.setCostPerMillionOutputTokens(cost.costPerOutputToken()); + llmTokenUsage.setModel(cost.modelInfo()); + llmTokenUsage.setTraceId(job.jobId()); tokenUsages.add(llmTokenUsageRepository.save(llmTokenUsage)); } return tokenUsages; diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index 0ac0872a1767..93111ad2c234 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -8,6 +8,7 @@ import de.tum.cit.aet.artemis.atlas.domain.competency.CompetencyTaxonomy; import de.tum.cit.aet.artemis.core.domain.Course; import de.tum.cit.aet.artemis.core.domain.User; +import de.tum.cit.aet.artemis.core.repository.CourseRepository; import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisJobService; import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService; @@ -28,13 +29,16 @@ public class IrisCompetencyGenerationService { private final LLMTokenUsageService llmTokenUsageService; + private final CourseRepository courseRepository; + private final IrisWebsocketService websocketService; private final PyrisJobService pyrisJobService; - public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, IrisWebsocketService websocketService, - PyrisJobService pyrisJobService) { + public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, CourseRepository courseRepository, + IrisWebsocketService websocketService, PyrisJobService pyrisJobService) { this.pyrisPipelineService = pyrisPipelineService; + this.courseRepository = courseRepository; this.websocketService = websocketService; this.pyrisJobService = pyrisJobService; this.llmTokenUsageService = llmTokenUsageService; @@ -63,15 +67,15 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String /** * Takes a status update from Pyris containing a new competency extraction result and sends it to the client via websocket * - * @param userLogin the login of the user - * @param courseId the id of the course + * @param job Job related to the status update * @param statusUpdate the status update containing the new competency recommendations */ - public void handleStatusUpdate(String userLogin, long courseId, PyrisCompetencyStatusUpdateDTO statusUpdate) { + public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { + Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId()); if (statusUpdate.tokens() != null) { - var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens()); + llmTokenUsageService.saveIrisTokenUsage(job, null, null, null, course, statusUpdate.tokens()); } - websocketService.send(userLogin, websocketTopic(courseId), statusUpdate); + websocketService.send(job.userLogin(), websocketTopic(job.courseId()), statusUpdate); } private static String websocketTopic(long courseId) { diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java index aed62b6049c1..732b2b572458 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/PyrisStatusUpdateService.java @@ -71,13 +71,13 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu /** * Handles the status update of a competency extraction job and forwards it to - * {@link IrisCompetencyGenerationService#handleStatusUpdate(String, long, PyrisCompetencyStatusUpdateDTO)} + * {@link IrisCompetencyGenerationService#handleStatusUpdate(CompetencyExtractionJob, PyrisCompetencyStatusUpdateDTO)} * * @param job the job that is updated * @param statusUpdate the status update */ public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) { - competencyGenerationService.handleStatusUpdate(job.userLogin(), job.courseId(), statusUpdate); + competencyGenerationService.handleStatusUpdate(job, statusUpdate); removeJobIfTerminated(statusUpdate.stages(), job.jobId()); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java index 13fd40d84bf1..74f40cce6873 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/pyris/dto/data/PyrisLLMCostDTO.java @@ -2,5 +2,5 @@ import de.tum.cit.aet.artemis.core.domain.LLMServiceType; -public record PyrisLLMCostDTO(String model_info, int num_input_tokens, float cost_per_input_token, int num_output_tokens, float cost_per_output_token, LLMServiceType pipeline) { +public record PyrisLLMCostDTO(String modelInfo, int numInputTokens, float costPerInputToken, int numOutputTokens, float costPerOutputToken, LLMServiceType pipeline) { } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java index b87e64081c07..388a0539cf0b 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisCourseChatSessionService.java @@ -142,11 +142,11 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); + var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); + var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, null, null, session.getUser(), session.getCourse(), statusUpdate.tokens()); irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } updateLatestSuggestions(session, statusUpdate.suggestions()); diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index 12fd6fdf6905..613cbcc4a9eb 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -65,14 +65,14 @@ public class IrisExerciseChatSessionService extends AbstractIrisChatSessionServi private final ProgrammingExerciseRepository programmingExerciseRepository; - public IrisExerciseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService LLMTokenUsageService, IrisSettingsService irisSettingsService, + public IrisExerciseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService llmTokenUsageService, IrisSettingsService irisSettingsService, IrisChatWebsocketService irisChatWebsocketService, AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository, ProgrammingExerciseStudentParticipationRepository programmingExerciseStudentParticipationRepository, ProgrammingSubmissionRepository programmingSubmissionRepository, IrisRateLimitService rateLimitService, PyrisPipelineService pyrisPipelineService, ProgrammingExerciseRepository programmingExerciseRepository, ObjectMapper objectMapper) { super(irisSessionRepository, objectMapper); this.irisMessageService = irisMessageService; - this.llmTokenUsageService = LLMTokenUsageService; + this.llmTokenUsageService = llmTokenUsageService; this.irisSettingsService = irisSettingsService; this.irisChatWebsocketService = irisChatWebsocketService; this.authCheckService = authCheckService; @@ -177,14 +177,14 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); if (statusUpdate.tokens() != null) { - var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(), + var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); } irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { if (statusUpdate.tokens() != null) { - var tokenUsages = llmTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(), + var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, null, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); } irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); diff --git a/src/main/resources/config/liquibase/changelog/20241012125003_changelog.xml b/src/main/resources/config/liquibase/changelog/20241012080932_changelog.xml similarity index 86% rename from src/main/resources/config/liquibase/changelog/20241012125003_changelog.xml rename to src/main/resources/config/liquibase/changelog/20241012080932_changelog.xml index fcb5bb25ac22..e8f846219bb2 100644 --- a/src/main/resources/config/liquibase/changelog/20241012125003_changelog.xml +++ b/src/main/resources/config/liquibase/changelog/20241012080932_changelog.xml @@ -5,7 +5,7 @@ xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-4.29.xsd" objectQuotingStrategy="QUOTE_ONLY_RESERVED_WORDS"> - + @@ -13,14 +13,14 @@ - + - + - - + + - + diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java index adf594a5608f..282e4294eef2 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisCompetencyGenerationIntegrationTest.java @@ -22,6 +22,7 @@ import de.tum.cit.aet.artemis.iris.service.pyris.dto.competency.PyrisCompetencyStatusUpdateDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; +import de.tum.cit.aet.artemis.iris.service.pyris.job.CompetencyExtractionJob; class IrisCompetencyGenerationIntegrationTest extends AbstractIrisIntegrationTest { @@ -66,7 +67,8 @@ void generateCompetencies_asEditor_shouldSucceed() throws Exception { List stages = List.of(new PyrisStageDTO("Generating Competencies", 10, PyrisStageState.DONE, null)); // In the real system, this would be triggered by Pyris via a REST call to the Artemis server - irisCompetencyGenerationService.handleStatusUpdate(TEST_PREFIX + "editor1", course.getId(), new PyrisCompetencyStatusUpdateDTO(stages, recommendations, null)); + CompetencyExtractionJob job = new CompetencyExtractionJob("1", course.getId(), TEST_PREFIX + "editor1"); + irisCompetencyGenerationService.handleStatusUpdate(job, new PyrisCompetencyStatusUpdateDTO(stages, recommendations, null)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(PyrisCompetencyStatusUpdateDTO.class); verify(websocketMessagingService, timeout(200).times(3)).sendMessageToUser(eq(TEST_PREFIX + "editor1"), eq("/topic/iris/competencies/" + course.getId()),