diff --git a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java index 755ed4a1b014..2e6d9f1902e1 100644 --- a/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java +++ b/src/main/java/de/tum/cit/aet/artemis/core/repository/LLMTokenUsageRepository.java @@ -1,10 +1,14 @@ package de.tum.cit.aet.artemis.core.repository; +import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_IRIS; + +import org.springframework.context.annotation.Profile; import org.springframework.stereotype.Repository; import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage; import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository; @Repository +@Profile(PROFILE_IRIS) public interface LLMTokenUsageRepository extends ArtemisJpaRepository { } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java index c38293c8a150..0ac0872a1767 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/IrisCompetencyGenerationService.java @@ -68,7 +68,9 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String * @param statusUpdate the status update containing the new competency recommendations */ public void handleStatusUpdate(String userLogin, long courseId, PyrisCompetencyStatusUpdateDTO statusUpdate) { - var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens()); + if (statusUpdate.tokens() != null) { + var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens()); + } websocketService.send(userLogin, websocketTopic(courseId), statusUpdate); } diff --git a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java index bcdc912a99f2..12fd6fdf6905 100644 --- a/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java +++ b/src/main/java/de/tum/cit/aet/artemis/iris/service/session/IrisExerciseChatSessionService.java @@ -176,13 +176,17 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta var message = new IrisMessage(); message.addContent(new IrisTextMessageContent(statusUpdate.result())); var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM); - var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(), - session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); + if (statusUpdate.tokens() != null) { + var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(), + session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); + } irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages()); } else { - var tokenUsages = llmTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(), - statusUpdate.tokens()); + if (statusUpdate.tokens() != null) { + var tokenUsages = llmTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(), + session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens()); + } irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens()); } diff --git a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java index ad92db87287f..8e36d9063cd9 100644 --- a/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java +++ b/src/test/java/de/tum/cit/aet/artemis/iris/IrisChatMessageIntegrationTest.java @@ -41,7 +41,6 @@ import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository; import de.tum.cit.aet.artemis.iris.service.IrisMessageService; import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO; -import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO; import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState; import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService; @@ -131,7 +130,7 @@ void sendOneMessage() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null)); pipelineDone.set(true); }); @@ -156,7 +155,7 @@ void sendSuggestions() throws Exception { List suggestions = List.of("suggestion1", "suggestion2", "suggestion3"); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, dto.initialStages(), suggestions, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, dto.initialStages(), suggestions)); pipelineDone.set(true); }); @@ -195,7 +194,7 @@ void sendTwoMessages() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 1", dto.initialStages(), null, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 1", dto.initialStages(), null)); pipelineDone.set(true); }); @@ -203,7 +202,7 @@ void sendTwoMessages() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 2", dto.initialStages(), null, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 2", dto.initialStages(), null)); pipelineDone.set(true); }); @@ -299,7 +298,7 @@ void resendMessage() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null)); pipelineDone.set(true); }); @@ -322,7 +321,7 @@ void sendMessageRateLimitReached() throws Exception { irisRequestMockProvider.mockRunResponse(dto -> { assertThat(dto.settings().authenticationToken()).isNotNull(); - assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null)); + assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null)); pipelineDone.set(true); }); @@ -445,9 +444,9 @@ public String toString() { }; } - private void sendStatus(String jobId, String result, List stages, List suggestions, List tokens) throws Exception { + private void sendStatus(String jobId, String result, List stages, List suggestions) throws Exception { var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId)))); - request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, tokens), + request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, null), HttpStatus.OK, headers); } }