Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Athena: Add LLM token usage tracking #9548

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package de.tum.cit.aet.artemis.athena.dto;

import java.util.List;

import com.fasterxml.jackson.annotation.JsonInclude;

import de.tum.cit.aet.artemis.core.domain.LLMRequest;

/**
* DTO representing the meta information in the Athena response.
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public record ResponseMetaDTO(TotalUsage totalUsage, List<LLMRequest> llmRequests) {

public record TotalUsage(Integer numInputTokens, Integer numOutputTokens, Integer numTotalTokens, Float cost) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@
import de.tum.cit.aet.artemis.athena.dto.ExerciseBaseDTO;
import de.tum.cit.aet.artemis.athena.dto.ModelingFeedbackDTO;
import de.tum.cit.aet.artemis.athena.dto.ProgrammingFeedbackDTO;
import de.tum.cit.aet.artemis.athena.dto.ResponseMetaDTO;
import de.tum.cit.aet.artemis.athena.dto.SubmissionBaseDTO;
import de.tum.cit.aet.artemis.athena.dto.TextFeedbackDTO;
import de.tum.cit.aet.artemis.core.domain.LLMRequest;
import de.tum.cit.aet.artemis.core.domain.LLMServiceType;
import de.tum.cit.aet.artemis.core.domain.User;
import de.tum.cit.aet.artemis.core.exception.ConflictException;
import de.tum.cit.aet.artemis.core.exception.NetworkingException;
import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService;
import de.tum.cit.aet.artemis.exercise.domain.Exercise;
import de.tum.cit.aet.artemis.exercise.domain.Submission;
import de.tum.cit.aet.artemis.exercise.domain.participation.StudentParticipation;
import de.tum.cit.aet.artemis.modeling.domain.ModelingExercise;
import de.tum.cit.aet.artemis.modeling.domain.ModelingSubmission;
import de.tum.cit.aet.artemis.programming.domain.ProgrammingExercise;
Expand Down Expand Up @@ -48,36 +56,40 @@ public class AthenaFeedbackSuggestionsService {

private final AthenaDTOConverterService athenaDTOConverterService;

private final LLMTokenUsageService llmTokenUsageService;

/**
* Create a new AthenaFeedbackSuggestionsService to receive feedback suggestions from the Athena service.
*
* @param athenaRestTemplate REST template used for the communication with Athena
* @param athenaModuleService Athena module serviced used to determine the urls for different modules
* @param athenaDTOConverterService Service to convert exr
* @param athenaDTOConverterService Service to convert exrcises and submissions to DTOs
* @param llmTokenUsageService Service to store the usage of LLM tokens
*/
public AthenaFeedbackSuggestionsService(@Qualifier("athenaRestTemplate") RestTemplate athenaRestTemplate, AthenaModuleService athenaModuleService,
AthenaDTOConverterService athenaDTOConverterService) {
AthenaDTOConverterService athenaDTOConverterService, LLMTokenUsageService llmTokenUsageService) {
textAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOText.class);
programmingAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOProgramming.class);
modelingAthenaConnector = new AthenaConnector<>(athenaRestTemplate, ResponseDTOModeling.class);
this.athenaDTOConverterService = athenaDTOConverterService;
this.athenaModuleService = athenaModuleService;
this.llmTokenUsageService = llmTokenUsageService;
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record RequestDTO(ExerciseBaseDTO exercise, SubmissionBaseDTO submission, boolean isGraded) {
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record ResponseDTOText(List<TextFeedbackDTO> data) {
private record ResponseDTOText(List<TextFeedbackDTO> data, ResponseMetaDTO meta) {
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record ResponseDTOProgramming(List<ProgrammingFeedbackDTO> data) {
private record ResponseDTOProgramming(List<ProgrammingFeedbackDTO> data, ResponseMetaDTO meta) {
}

@JsonInclude(JsonInclude.Include.NON_EMPTY)
private record ResponseDTOModeling(List<ModelingFeedbackDTO> data) {
private record ResponseDTOModeling(List<ModelingFeedbackDTO> data, ResponseMetaDTO meta) {
}

/**
Expand All @@ -100,6 +112,7 @@ public List<TextFeedbackDTO> getTextFeedbackSuggestions(TextExercise exercise, T
final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded);
ResponseDTOText response = textAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0);
log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data);
storeTokenUsage(exercise, submission, response.meta, !isGraded);
return response.data.stream().toList();
}

Expand All @@ -117,6 +130,7 @@ public List<ProgrammingFeedbackDTO> getProgrammingFeedbackSuggestions(Programmin
final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded);
ResponseDTOProgramming response = programmingAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0);
log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data);
storeTokenUsage(exercise, submission, response.meta, !isGraded);
return response.data.stream().toList();
}

Expand All @@ -139,6 +153,30 @@ public List<ModelingFeedbackDTO> getModelingFeedbackSuggestions(ModelingExercise
final RequestDTO request = new RequestDTO(athenaDTOConverterService.ofExercise(exercise), athenaDTOConverterService.ofSubmission(exercise.getId(), submission), isGraded);
ResponseDTOModeling response = modelingAthenaConnector.invokeWithRetry(athenaModuleService.getAthenaModuleUrl(exercise) + "/feedback_suggestions", request, 0);
log.info("Athena responded to '{}' feedback suggestions request: {}", isGraded ? "Graded" : "Non Graded", response.data);
storeTokenUsage(exercise, submission, response.meta, !isGraded);
return response.data;
}

/**
* Store the usage of LLM tokens for a given submission
*
* @param exercise the exercise the submission belongs to
* @param submission the submission for which the tokens were used
* @param meta the meta information of the response from Athena
* @param isPreliminaryFeedback whether the feedback is preliminary or not
*/
private void storeTokenUsage(Exercise exercise, Submission submission, ResponseMetaDTO meta, Boolean isPreliminaryFeedback) {
if (meta == null) {
return;
}
Long courseId = exercise.getCourseViaExerciseGroupOrCourseMember().getId();
Long userId = ((StudentParticipation) submission.getParticipation()).getStudent().map(User::getId).orElse(null);
List<LLMRequest> llmRequests = meta.llmRequests();
if (llmRequests == null) {
return;
}

llmTokenUsageService.saveLLMTokenUsage(llmRequests, LLMServiceType.ATHENA,
(llmTokenUsageBuilder -> llmTokenUsageBuilder.withCourse(courseId).withExercise(exercise.getId()).withUser(userId)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,20 @@

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE;

import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;

import de.tum.cit.aet.artemis.core.domain.Course;
import de.tum.cit.aet.artemis.core.domain.LLMRequest;
import de.tum.cit.aet.artemis.core.domain.LLMServiceType;
import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest;
import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace;
import de.tum.cit.aet.artemis.core.domain.User;
import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRequestRepository;
import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageTraceRepository;
import de.tum.cit.aet.artemis.exercise.domain.Exercise;

/**
* Service for managing the LLMTokenUsage by all LLMs in Artemis
Expand All @@ -38,12 +34,17 @@ public LLMTokenUsageService(LLMTokenUsageTraceRepository llmTokenUsageTraceRepos
}

/**
* method saves the token usage to the database
* Saves the token usage to the database.
* This method records the usage of tokens by various LLM services in the system.
*
* @param llmRequests List of LLM requests
* @param serviceType type of the LLM service
* @param builderFunction of type Function<IrisTokenUsageBuilder, IrisTokenUsageBuilder> using IrisTokenUsageBuilder
* @return saved LLMTokenUsage as a List
* @param llmRequests List of LLM requests containing details about the token usage.
* @param serviceType Type of the LLM service (e.g., IRIS, GPT-3).
* @param builderFunction A function that takes an LLMTokenUsageBuilder and returns a modified LLMTokenUsageBuilder.
* This function is used to set additional properties on the LLMTokenUsageTrace object, such as
* the course ID, user ID, exercise ID, and Iris message ID.
* Example usage:
* builder -> builder.withCourse(courseId).withUser(userId)
* @return The saved LLMTokenUsageTrace object, which includes the details of the token usage.
*/
// TODO: this should ideally be done Async
public LLMTokenUsageTrace saveLLMTokenUsage(List<LLMRequest> llmRequests, LLMServiceType serviceType, Function<LLMTokenUsageBuilder, LLMTokenUsageBuilder> builderFunction) {
Expand All @@ -52,51 +53,49 @@ public LLMTokenUsageTrace saveLLMTokenUsage(List<LLMRequest> llmRequests, LLMSer

LLMTokenUsageBuilder builder = builderFunction.apply(new LLMTokenUsageBuilder());
builder.getIrisMessageID().ifPresent(llmTokenUsageTrace::setIrisMessageId);
builder.getUser().ifPresent(user -> llmTokenUsageTrace.setUserId(user.getId()));
builder.getExercise().ifPresent(exercise -> llmTokenUsageTrace.setExerciseId(exercise.getId()));
builder.getCourse().ifPresent(course -> llmTokenUsageTrace.setCourseId(course.getId()));
builder.getCourseID().ifPresent(llmTokenUsageTrace::setCourseId);
builder.getExerciseID().ifPresent(llmTokenUsageTrace::setExerciseId);
builder.getUserID().ifPresent(llmTokenUsageTrace::setUserId);

llmTokenUsageTrace.setLlmRequests(llmRequests.stream().map(LLMTokenUsageService::convertLLMRequestToLLMTokenUsageRequest)
.peek(llmTokenUsageRequest -> llmTokenUsageRequest.setTrace(llmTokenUsageTrace)).collect(Collectors.toSet()));

Set<LLMTokenUsageRequest> llmRequestsSet = llmTokenUsageTrace.getLLMRequests();
setLLMTokenUsageRequests(llmRequests, llmTokenUsageTrace, llmRequestsSet);
return llmTokenUsageTraceRepository.save(llmTokenUsageTrace);
}

private void setLLMTokenUsageRequests(List<LLMRequest> llmRequests, LLMTokenUsageTrace llmTokenUsageTrace, Set<LLMTokenUsageRequest> llmRequestsSet) {
for (LLMRequest llmRequest : llmRequests) {
LLMTokenUsageRequest llmTokenUsageRequest = new LLMTokenUsageRequest();
llmTokenUsageRequest.setModel(llmRequest.model());
llmTokenUsageRequest.setNumInputTokens(llmRequest.numInputTokens());
llmTokenUsageRequest.setNumOutputTokens(llmRequest.numOutputTokens());
llmTokenUsageRequest.setCostPerMillionInputTokens(llmRequest.costPerMillionInputToken());
llmTokenUsageRequest.setCostPerMillionOutputTokens(llmRequest.costPerMillionOutputToken());
llmTokenUsageRequest.setServicePipelineId(llmRequest.pipelineId());
llmTokenUsageRequest.setTrace(llmTokenUsageTrace);
llmRequestsSet.add(llmTokenUsageRequest);
}
private static LLMTokenUsageRequest convertLLMRequestToLLMTokenUsageRequest(LLMRequest llmRequest) {
LLMTokenUsageRequest llmTokenUsageRequest = new LLMTokenUsageRequest();
llmTokenUsageRequest.setModel(llmRequest.model());
llmTokenUsageRequest.setNumInputTokens(llmRequest.numInputTokens());
llmTokenUsageRequest.setNumOutputTokens(llmRequest.numOutputTokens());
llmTokenUsageRequest.setCostPerMillionInputTokens(llmRequest.costPerMillionInputToken());
llmTokenUsageRequest.setCostPerMillionOutputTokens(llmRequest.costPerMillionOutputToken());
llmTokenUsageRequest.setServicePipelineId(llmRequest.pipelineId());
return llmTokenUsageRequest;
}

// TODO: this should ideally be done Async
public void appendRequestsToTrace(List<LLMRequest> requests, LLMTokenUsageTrace trace) {
Set<LLMTokenUsageRequest> llmRequestsSet = new HashSet<>();
setLLMTokenUsageRequests(requests, trace, llmRequestsSet);
llmTokenUsageRequestRepository.saveAll(llmRequestsSet);
var requestSet = requests.stream().map(LLMTokenUsageService::convertLLMRequestToLLMTokenUsageRequest).peek(llmTokenUsageRequest -> llmTokenUsageRequest.setTrace(trace))
.collect(Collectors.toSet());
llmTokenUsageRequestRepository.saveAll(requestSet);
}

/**
* Class LLMTokenUsageBuilder to be used for saveLLMTokenUsage()
*/
public static class LLMTokenUsageBuilder {

private Optional<Course> course = Optional.empty();
private Optional<Long> courseID = Optional.empty();

private Optional<Long> irisMessageID = Optional.empty();

private Optional<Exercise> exercise = Optional.empty();
private Optional<Long> exerciseID = Optional.empty();

private Optional<User> user = Optional.empty();
private Optional<Long> userID = Optional.empty();

public LLMTokenUsageBuilder withCourse(Course course) {
this.course = Optional.ofNullable(course);
public LLMTokenUsageBuilder withCourse(Long courseID) {
this.courseID = Optional.ofNullable(courseID);
return this;
}

Expand All @@ -105,30 +104,30 @@ public LLMTokenUsageBuilder withIrisMessageID(Long irisMessageID) {
return this;
}

public LLMTokenUsageBuilder withExercise(Exercise exercise) {
this.exercise = Optional.ofNullable(exercise);
public LLMTokenUsageBuilder withExercise(Long exerciseID) {
this.exerciseID = Optional.ofNullable(exerciseID);
return this;
}

public LLMTokenUsageBuilder withUser(User user) {
this.user = Optional.ofNullable(user);
public LLMTokenUsageBuilder withUser(Long userID) {
this.userID = Optional.ofNullable(userID);
return this;
}

public Optional<Course> getCourse() {
return course;
public Optional<Long> getCourseID() {
return courseID;
}

public Optional<Long> getIrisMessageID() {
return irisMessageID;
}

public Optional<Exercise> getExercise() {
return exercise;
public Optional<Long> getExerciseID() {
return exerciseID;
}

public Optional<User> getUser() {
return user;
public Optional<Long> getUserID() {
return userID;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import de.tum.cit.aet.artemis.core.domain.LLMServiceType;
import de.tum.cit.aet.artemis.core.domain.User;
import de.tum.cit.aet.artemis.core.repository.CourseRepository;
import de.tum.cit.aet.artemis.core.repository.UserRepository;
import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService;
import de.tum.cit.aet.artemis.iris.service.pyris.PyrisJobService;
import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService;
Expand All @@ -36,13 +37,16 @@ public class IrisCompetencyGenerationService {

private final PyrisJobService pyrisJobService;

private final UserRepository userRepository;

public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, CourseRepository courseRepository,
IrisWebsocketService websocketService, PyrisJobService pyrisJobService) {
IrisWebsocketService websocketService, PyrisJobService pyrisJobService, UserRepository userRepository) {
this.pyrisPipelineService = pyrisPipelineService;
this.llmTokenUsageService = llmTokenUsageService;
this.courseRepository = courseRepository;
this.websocketService = websocketService;
this.pyrisJobService = pyrisJobService;
this.userRepository = userRepository;
}

/**
Expand All @@ -58,7 +62,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String
pyrisPipelineService.executePipeline(
"competency-extraction",
"default",
pyrisJobService.createTokenForJob(token -> new CompetencyExtractionJob(token, course.getId(), user)),
pyrisJobService.createTokenForJob(token -> new CompetencyExtractionJob(token, course.getId(), user.getId())),
executionDto -> new PyrisCompetencyExtractionPipelineExecutionDTO(executionDto, courseDescription, currentCompetencies, CompetencyTaxonomy.values(), 5),
stages -> websocketService.send(user.getLogin(), websocketTopic(course.getId()), new PyrisCompetencyStatusUpdateDTO(stages, null, null))
);
Expand All @@ -74,9 +78,11 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String
public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) {
Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId());
if (statusUpdate.tokens() != null && !statusUpdate.tokens().isEmpty()) {
llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withCourse(course).withUser(job.user()));
llmTokenUsageService.saveLLMTokenUsage(statusUpdate.tokens(), LLMServiceType.IRIS, builder -> builder.withCourse(course.getId()).withUser(job.userId()));
}
websocketService.send(job.user().getLogin(), websocketTopic(job.courseId()), statusUpdate);

var user = userRepository.findById(job.userId()).orElseThrow();
websocketService.send(user.getLogin(), websocketTopic(job.courseId()), statusUpdate);
}

private static String websocketTopic(long courseId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
import com.fasterxml.jackson.annotation.JsonInclude;

import de.tum.cit.aet.artemis.core.domain.Course;
import de.tum.cit.aet.artemis.core.domain.User;

/**
* A pyris job that extracts competencies from a course description.
*
* @param jobId the job id
* @param courseId the course in which the competencies are being extracted
* @param user the user who started the job
* @param userId the user who started the job
*/
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public record CompetencyExtractionJob(String jobId, long courseId, User user) implements PyrisJob {
public record CompetencyExtractionJob(String jobId, long courseId, long userId) implements PyrisJob {

@Override
public boolean canAccess(Course course) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* This job is used to reference the details of a course chat session when Pyris sends a status update.
*/
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public record CourseChatJob(String jobId, long courseId, long sessionId) implements PyrisJob {
public record CourseChatJob(String jobId, long courseId, long sessionId) implements SessionBasedPyrisJob {

@Override
public boolean canAccess(Course course) {
Expand Down
Loading
Loading