Skip to content

Commit

Permalink
Fix server test failures by checking if tokens received
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjoham committed Oct 12, 2024
1 parent f85cf46 commit 65fb259
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package de.tum.cit.aet.artemis.core.repository;

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_IRIS;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Repository;

import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage;
import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository;

@Repository
@Profile(PROFILE_IRIS)
public interface LLMTokenUsageRepository extends ArtemisJpaRepository<LLMTokenUsage, Long> {
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String
* @param statusUpdate the status update containing the new competency recommendations
*/
public void handleStatusUpdate(String userLogin, long courseId, PyrisCompetencyStatusUpdateDTO statusUpdate) {
var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens());
if (statusUpdate.tokens() != null) {
var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens());
}
websocketService.send(userLogin, websocketTopic(courseId), statusUpdate);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,17 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta
var message = new IrisMessage();
message.addContent(new IrisTextMessageContent(statusUpdate.result()));
var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM);
var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(),
session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens());
if (statusUpdate.tokens() != null) {
var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(),
session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens());
}
irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages());
}
else {
var tokenUsages = llmTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(), session.getExercise().getCourseViaExerciseGroupOrCourseMember(),
statusUpdate.tokens());
if (statusUpdate.tokens() != null) {
var tokenUsages = llmTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(),
session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens());
}
irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import de.tum.cit.aet.artemis.iris.repository.IrisSessionRepository;
import de.tum.cit.aet.artemis.iris.service.IrisMessageService;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.chat.PyrisChatStatusUpdateDTO;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageState;
import de.tum.cit.aet.artemis.iris.service.session.IrisExerciseChatSessionService;
Expand Down Expand Up @@ -131,7 +130,7 @@ void sendOneMessage() throws Exception {
irisRequestMockProvider.mockRunResponse(dto -> {
assertThat(dto.settings().authenticationToken()).isNotNull();

assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null));
assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null));

pipelineDone.set(true);
});
Expand All @@ -156,7 +155,7 @@ void sendSuggestions() throws Exception {

List<String> suggestions = List.of("suggestion1", "suggestion2", "suggestion3");

assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, dto.initialStages(), suggestions, null));
assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), null, dto.initialStages(), suggestions));

pipelineDone.set(true);
});
Expand Down Expand Up @@ -195,15 +194,15 @@ void sendTwoMessages() throws Exception {
irisRequestMockProvider.mockRunResponse(dto -> {
assertThat(dto.settings().authenticationToken()).isNotNull();

assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 1", dto.initialStages(), null, null));
assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 1", dto.initialStages(), null));

pipelineDone.set(true);
});

irisRequestMockProvider.mockRunResponse(dto -> {
assertThat(dto.settings().authenticationToken()).isNotNull();

assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 2", dto.initialStages(), null, null));
assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World 2", dto.initialStages(), null));

pipelineDone.set(true);
});
Expand Down Expand Up @@ -299,7 +298,7 @@ void resendMessage() throws Exception {
irisRequestMockProvider.mockRunResponse(dto -> {
assertThat(dto.settings().authenticationToken()).isNotNull();

assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null));
assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null));

pipelineDone.set(true);
});
Expand All @@ -322,7 +321,7 @@ void sendMessageRateLimitReached() throws Exception {
irisRequestMockProvider.mockRunResponse(dto -> {
assertThat(dto.settings().authenticationToken()).isNotNull();

assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null, null));
assertThatNoException().isThrownBy(() -> sendStatus(dto.settings().authenticationToken(), "Hello World", dto.initialStages(), null));

pipelineDone.set(true);
});
Expand Down Expand Up @@ -445,9 +444,9 @@ public String toString() {
};
}

private void sendStatus(String jobId, String result, List<PyrisStageDTO> stages, List<String> suggestions, List<PyrisLLMCostDTO> tokens) throws Exception {
private void sendStatus(String jobId, String result, List<PyrisStageDTO> stages, List<String> suggestions) throws Exception {
var headers = new HttpHeaders(new LinkedMultiValueMap<>(Map.of("Authorization", List.of("Bearer " + jobId))));
request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, tokens),
request.postWithoutResponseBody("/api/public/pyris/pipelines/tutor-chat/runs/" + jobId + "/status", new PyrisChatStatusUpdateDTO(result, stages, suggestions, null),
HttpStatus.OK, headers);
}
}

0 comments on commit 65fb259

Please sign in to comment.