Skip to content

Commit

Permalink
Update database, add information to competency gen, change traceId calc
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjoham committed Oct 12, 2024
1 parent 188ff22 commit be85a3b
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

public enum LLMServiceType {
ATHENA_PRELIMINARY_FEEDBACK, ATHENA_FEEDBACK_SUGGESTION, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION,
IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, IRIS_CITATION_PIPELINE, NOT_SET
IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, IRIS_CITATION_PIPELINE, IRIS_LECTURE_RETRIEVAL_PIPELINE, IRIS_LECTURE_INGESTION, NOT_SET
}
73 changes: 34 additions & 39 deletions src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import jakarta.persistence.Entity;
import jakarta.persistence.EnumType;
import jakarta.persistence.Enumerated;
import jakarta.persistence.Inheritance;
import jakarta.persistence.InheritanceType;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.Table;
Expand All @@ -18,16 +16,13 @@

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonTypeInfo;

import de.tum.cit.aet.artemis.exercise.domain.Exercise;
import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage;

@Entity
@Table(name = "llm_token_usage")
@Inheritance(strategy = InheritanceType.SINGLE_TABLE)
@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE)
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public class LLMTokenUsage extends DomainObject {

Expand All @@ -39,16 +34,16 @@ public class LLMTokenUsage extends DomainObject {
private String model;

@Column(name = "num_input_tokens")
private int num_input_tokens;
private int numInputTokens;

@Column(name = "cost_per_input_token")
private float cost_per_input_token;
@Column(name = "cost_per_million_input_tokens")
private float costPerMillionInputTokens;

@Column(name = "num_output_tokens")
private int num_output_tokens;
private int numOutputTokens;

@Column(name = "cost_per_output_token")
private float cost_per_output_token;
@Column(name = "cost_per_million_output_tokens")
private float costPerMillionOutputTokens;

@Nullable
@ManyToOne
Expand All @@ -66,11 +61,11 @@ public class LLMTokenUsage extends DomainObject {
private long userId;

@Nullable
@Column(name = "timestamp")
private ZonedDateTime timestamp = ZonedDateTime.now();
@Column(name = "time")
private ZonedDateTime time = ZonedDateTime.now();

@Column(name = "trace_id")
private Long traceId;
private String traceId;

@Nullable
@ManyToOne
Expand All @@ -94,36 +89,36 @@ public void setModel(String model) {
this.model = model;
}

public float getCost_per_input_token() {
return cost_per_input_token;
public float getCostPerMillionInputTokens() {
return costPerMillionInputTokens;
}

public void setCost_per_input_token(float cost_per_input_token) {
this.cost_per_input_token = cost_per_input_token;
public void setCostPerMillionInputTokens(float costPerMillionInputToken) {
this.costPerMillionInputTokens = costPerMillionInputToken;
}

public float getCost_per_output_token() {
return cost_per_output_token;
public float getCostPerMillionOutputTokens() {
return costPerMillionOutputTokens;
}

public void setCost_per_output_token(float cost_per_output_token) {
this.cost_per_output_token = cost_per_output_token;
public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) {
this.costPerMillionOutputTokens = costPerMillionOutputToken;
}

public int getNum_input_tokens() {
return num_input_tokens;
public int getNumInputTokens() {
return numInputTokens;
}

public void setNum_input_tokens(int num_input_tokens) {
this.num_input_tokens = num_input_tokens;
public void setNumInputTokens(int numInputTokens) {
this.numInputTokens = numInputTokens;
}

public int getNum_output_tokens() {
return num_output_tokens;
public int getNumOutputTokens() {
return numOutputTokens;
}

public void setNum_output_tokens(int num_output_tokens) {
this.num_output_tokens = num_output_tokens;
public void setNumOutputTokens(int numOutputTokens) {
this.numOutputTokens = numOutputTokens;
}

public Course getCourse() {
Expand All @@ -150,19 +145,19 @@ public void setUserId(long userId) {
this.userId = userId;
}

public ZonedDateTime getTimestamp() {
return timestamp;
public ZonedDateTime getTime() {
return time;
}

public void setTimestamp(ZonedDateTime timestamp) {
this.timestamp = timestamp;
public void setTime(ZonedDateTime time) {
this.time = time;
}

public Long getTraceId() {
public String getTraceId() {
return traceId;
}

public void setTraceId(Long traceId) {
public void setTraceId(String traceId) {
this.traceId = traceId;
}

Expand All @@ -176,8 +171,8 @@ public void setIrisMessage(IrisMessage message) {

@Override
public String toString() {
return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", num_input_tokens=" + num_input_tokens + ", cost_per_input_token=" + cost_per_input_token
+ ", num_output_tokens=" + num_output_tokens + ", cost_per_output_token=" + cost_per_output_token + ", course=" + course + ", exercise=" + exercise + ", userId="
+ userId + ", timestamp=" + timestamp + ", trace_id=" + traceId + ", irisMessage=" + irisMessage + '}';
return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", num_input_tokens=" + numInputTokens + ", cost_per_input_token=" + costPerMillionInputTokens
+ ", num_output_tokens=" + numOutputTokens + ", cost_per_output_token=" + costPerMillionOutputTokens + ", course=" + course + ", exercise=" + exercise + ", userId="
+ userId + ", timestamp=" + time + ", trace_id=" + traceId + ", irisMessage=" + irisMessage + '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;
Expand All @@ -16,6 +15,7 @@
import de.tum.cit.aet.artemis.exercise.domain.Exercise;
import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO;
import de.tum.cit.aet.artemis.iris.service.pyris.job.PyrisJob;

/**
* Service for managing the LLMTokenUsage by all LLMs in Artemis
Expand All @@ -31,44 +31,38 @@ public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) {
}

/**
* saves the tokens used for a specific IrisMessage or Athena call
* in case of an Athena call IrisMessage can be null and the
* LLMServiceType in tokens has to by Athena
* method saves the token usage to the database with a link to the IrisMessage
* messages of the same job are grouped together by saving the job id as a trace id
*
* @param message IrisMessage related to the TokenUsage
* @param exercise Exercise in which the request was made
* @param user User that made the request
* @param course Course in which the request was made
* @param tokens List with Tokens of the PyrisLLMCostDTO Mdel
* @return List of the created LLMTokenUsage entries
* @param job used to create a unique traceId to group multiple LLM calls
* @param message IrisMessage to map the usage to an IrisMessage
* @param exercise to map the token cost to an exercise
* @param user to map the token cost to a user
* @param course to map the token to a course
* @param tokens token cost lsit of type PyrisLLMCostDTO
* @return list of the saved data
*/

public List<LLMTokenUsage> saveTokenUsage(IrisMessage message, Exercise exercise, User user, Course course, List<PyrisLLMCostDTO> tokens) {
public List<LLMTokenUsage> saveIrisTokenUsage(PyrisJob job, IrisMessage message, Exercise exercise, User user, Course course, List<PyrisLLMCostDTO> tokens) {
List<LLMTokenUsage> tokenUsages = new ArrayList<>();

// Combine current time and UUID to create a unique traceId
long timestamp = System.currentTimeMillis();
long uuidComponent = UUID.randomUUID().getLeastSignificantBits() & Long.MAX_VALUE;
Long traceId = timestamp + uuidComponent;

for (PyrisLLMCostDTO cost : tokens) {
LLMTokenUsage llmTokenUsage = new LLMTokenUsage();
if (message != null) {
llmTokenUsage.setIrisMessage(message);
llmTokenUsage.setTimestamp(message.getSentAt());
llmTokenUsage.setTime(message.getSentAt());
}
llmTokenUsage.setServiceType(cost.pipeline());
llmTokenUsage.setExercise(exercise);
if (user != null) {
llmTokenUsage.setUserId(user.getId());
}
llmTokenUsage.setCourse(course);
llmTokenUsage.setNum_input_tokens(cost.num_input_tokens());
llmTokenUsage.setCost_per_input_token(cost.cost_per_input_token());
llmTokenUsage.setNum_output_tokens(cost.num_output_tokens());
llmTokenUsage.setCost_per_output_token(cost.cost_per_output_token());
llmTokenUsage.setModel(cost.model_info());
llmTokenUsage.setTraceId(traceId);
llmTokenUsage.setNumInputTokens(cost.numInputTokens());
llmTokenUsage.setCostPerMillionInputTokens(cost.costPerInputToken());
llmTokenUsage.setNumOutputTokens(cost.numOutputTokens());
llmTokenUsage.setCostPerMillionOutputTokens(cost.costPerOutputToken());
llmTokenUsage.setModel(cost.modelInfo());
llmTokenUsage.setTraceId(job.jobId());
tokenUsages.add(llmTokenUsageRepository.save(llmTokenUsage));
}
return tokenUsages;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import de.tum.cit.aet.artemis.atlas.domain.competency.CompetencyTaxonomy;
import de.tum.cit.aet.artemis.core.domain.Course;
import de.tum.cit.aet.artemis.core.domain.User;
import de.tum.cit.aet.artemis.core.repository.CourseRepository;
import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService;
import de.tum.cit.aet.artemis.iris.service.pyris.PyrisJobService;
import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService;
Expand All @@ -28,13 +29,16 @@ public class IrisCompetencyGenerationService {

private final LLMTokenUsageService llmTokenUsageService;

private final CourseRepository courseRepository;

private final IrisWebsocketService websocketService;

private final PyrisJobService pyrisJobService;

public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, IrisWebsocketService websocketService,
PyrisJobService pyrisJobService) {
public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, CourseRepository courseRepository,
IrisWebsocketService websocketService, PyrisJobService pyrisJobService) {
this.pyrisPipelineService = pyrisPipelineService;
this.courseRepository = courseRepository;
this.websocketService = websocketService;
this.pyrisJobService = pyrisJobService;
this.llmTokenUsageService = llmTokenUsageService;
Expand Down Expand Up @@ -63,15 +67,15 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String
/**
* Takes a status update from Pyris containing a new competency extraction result and sends it to the client via websocket
*
* @param userLogin the login of the user
* @param courseId the id of the course
* @param job Job related to the status update
* @param statusUpdate the status update containing the new competency recommendations
*/
public void handleStatusUpdate(String userLogin, long courseId, PyrisCompetencyStatusUpdateDTO statusUpdate) {
public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) {
Course course = courseRepository.findByIdForUpdateElseThrow(job.courseId());
if (statusUpdate.tokens() != null) {
var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens());
llmTokenUsageService.saveIrisTokenUsage(job, null, null, null, course, statusUpdate.tokens());
}
websocketService.send(userLogin, websocketTopic(courseId), statusUpdate);
websocketService.send(job.userLogin(), websocketTopic(job.courseId()), statusUpdate);
}

private static String websocketTopic(long courseId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu

/**
* Handles the status update of a competency extraction job and forwards it to
* {@link IrisCompetencyGenerationService#handleStatusUpdate(String, long, PyrisCompetencyStatusUpdateDTO)}
* {@link IrisCompetencyGenerationService#handleStatusUpdate(CompetencyExtractionJob, PyrisCompetencyStatusUpdateDTO)}
*
* @param job the job that is updated
* @param statusUpdate the status update
*/
public void handleStatusUpdate(CompetencyExtractionJob job, PyrisCompetencyStatusUpdateDTO statusUpdate) {
competencyGenerationService.handleStatusUpdate(job.userLogin(), job.courseId(), statusUpdate);
competencyGenerationService.handleStatusUpdate(job, statusUpdate);

removeJobIfTerminated(statusUpdate.stages(), job.jobId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

import de.tum.cit.aet.artemis.core.domain.LLMServiceType;

public record PyrisLLMCostDTO(String model_info, int num_input_tokens, float cost_per_input_token, int num_output_tokens, float cost_per_output_token, LLMServiceType pipeline) {
public record PyrisLLMCostDTO(String modelInfo, int numInputTokens, float costPerInputToken, int numOutputTokens, float costPerOutputToken, LLMServiceType pipeline) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ public void handleStatusUpdate(CourseChatJob job, PyrisChatStatusUpdateDTO statu
var message = new IrisMessage();
message.addContent(new IrisTextMessageContent(statusUpdate.result()));
var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM);
var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, null, session.getUser(), session.getCourse(), statusUpdate.tokens());
var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, null, session.getUser(), session.getCourse(), statusUpdate.tokens());
irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages());
}
else {
var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, session.getUser(), session.getCourse(), statusUpdate.tokens());
var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, null, null, session.getUser(), session.getCourse(), statusUpdate.tokens());
irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens());
}
updateLatestSuggestions(session, statusUpdate.suggestions());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ public class IrisExerciseChatSessionService extends AbstractIrisChatSessionServi

private final ProgrammingExerciseRepository programmingExerciseRepository;

public IrisExerciseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService LLMTokenUsageService, IrisSettingsService irisSettingsService,
public IrisExerciseChatSessionService(IrisMessageService irisMessageService, LLMTokenUsageService llmTokenUsageService, IrisSettingsService irisSettingsService,
IrisChatWebsocketService irisChatWebsocketService, AuthorizationCheckService authCheckService, IrisSessionRepository irisSessionRepository,
ProgrammingExerciseStudentParticipationRepository programmingExerciseStudentParticipationRepository, ProgrammingSubmissionRepository programmingSubmissionRepository,
IrisRateLimitService rateLimitService, PyrisPipelineService pyrisPipelineService, ProgrammingExerciseRepository programmingExerciseRepository,
ObjectMapper objectMapper) {
super(irisSessionRepository, objectMapper);
this.irisMessageService = irisMessageService;
this.llmTokenUsageService = LLMTokenUsageService;
this.llmTokenUsageService = llmTokenUsageService;
this.irisSettingsService = irisSettingsService;
this.irisChatWebsocketService = irisChatWebsocketService;
this.authCheckService = authCheckService;
Expand Down Expand Up @@ -177,14 +177,14 @@ public void handleStatusUpdate(ExerciseChatJob job, PyrisChatStatusUpdateDTO sta
message.addContent(new IrisTextMessageContent(statusUpdate.result()));
var savedMessage = irisMessageService.saveMessage(message, session, IrisMessageSender.LLM);
if (statusUpdate.tokens() != null) {
var tokenUsages = llmTokenUsageService.saveTokenUsage(savedMessage, session.getExercise(), session.getUser(),
var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, savedMessage, session.getExercise(), session.getUser(),
session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens());
}
irisChatWebsocketService.sendMessage(session, savedMessage, statusUpdate.stages());
}
else {
if (statusUpdate.tokens() != null) {
var tokenUsages = llmTokenUsageService.saveTokenUsage(null, session.getExercise(), session.getUser(),
var tokenUsages = llmTokenUsageService.saveIrisTokenUsage(job, null, session.getExercise(), session.getUser(),
session.getExercise().getCourseViaExerciseGroupOrCourseMember(), statusUpdate.tokens());
}
irisChatWebsocketService.sendStatusUpdate(session, statusUpdate.stages(), statusUpdate.suggestions(), statusUpdate.tokens());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@
xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog
http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-4.29.xsd"
objectQuotingStrategy="QUOTE_ONLY_RESERVED_WORDS">
<changeSet id="20241012125003" author="alex.joham">
<changeSet id="20241012080932" author="alex.joham">
<createTable tableName="llm_token_usage">
<column autoIncrement="true" name="id" type="BIGINT">
<constraints nullable="false" primaryKey="true" primaryKeyName="pk_llm_token_usage"/>
</column>
<column name="service" type="VARCHAR(255)"/>
<column name="model" type="VARCHAR(255)"/>
<column name="num_input_tokens" type="INT"/>
<column name="cost_per_input_token" type="FLOAT"/>
<column name="cost_per_million_input_tokens" type="FLOAT"/>
<column name="num_output_tokens" type="INT"/>
<column name="cost_per_output_token" type="FLOAT"/>
<column name="cost_per_million_output_tokens" type="FLOAT"/>
<column name="course_id" type="BIGINT"/>
<column name="exercise_id" type="BIGINT"/>
<column name="user_id" type="BIGINT"/>
<column name="timestamp" type="DATETIME"/>
<column name="trace_id" type="BIGINT"/>
<column name="time" type="DATETIME"/>
<column name="trace_id" type="VARCHAR(255)"/>
<column name="iris_message_id" type="BIGINT"/>
</createTable>
<addForeignKeyConstraint baseColumnNames="course_id" baseTableName="llm_token_usage"
Expand Down
Loading

0 comments on commit be85a3b

Please sign in to comment.