Skip to content

Commit

Permalink
Update database for cost tracking and trace_id functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjoham committed Oct 12, 2024
1 parent 65fb259 commit 188ff22
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package de.tum.cit.aet.artemis.core.domain;

public enum LLMServiceType {
ATHENA, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION, IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION,
IRIS_CITATION_PIPELINE, NOT_SET
ATHENA_PRELIMINARY_FEEDBACK, ATHENA_FEEDBACK_SUGGESTION, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION,
IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, IRIS_CITATION_PIPELINE, NOT_SET
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,18 @@ public class LLMTokenUsage extends DomainObject {
@Column(name = "model")
private String model;

@Column(name = "cost_per_token")
private double cost_per_token;

@Column(name = "num_input_tokens")
private int num_input_tokens;

@Column(name = "cost_per_input_token")
private float cost_per_input_token;

@Column(name = "num_output_tokens")
private int num_output_tokens;

@Column(name = "cost_per_output_token")
private float cost_per_output_token;

@Nullable
@ManyToOne
@JsonIgnore
Expand All @@ -66,6 +69,9 @@ public class LLMTokenUsage extends DomainObject {
@Column(name = "timestamp")
private ZonedDateTime timestamp = ZonedDateTime.now();

@Column(name = "trace_id")
private Long traceId;

@Nullable
@ManyToOne
@JsonIgnore
Expand All @@ -88,12 +94,20 @@ public void setModel(String model) {
this.model = model;
}

public double getCost_per_token() {
return cost_per_token;
public float getCost_per_input_token() {
return cost_per_input_token;
}

public void setCost_per_input_token(float cost_per_input_token) {
this.cost_per_input_token = cost_per_input_token;
}

public float getCost_per_output_token() {
return cost_per_output_token;
}

public void setCost_per_token(double cost_per_token) {
this.cost_per_token = cost_per_token;
public void setCost_per_output_token(float cost_per_output_token) {
this.cost_per_output_token = cost_per_output_token;
}

public int getNum_input_tokens() {
Expand Down Expand Up @@ -144,11 +158,26 @@ public void setTimestamp(ZonedDateTime timestamp) {
this.timestamp = timestamp;
}

public Long getTraceId() {
return traceId;
}

public void setTraceId(Long traceId) {
this.traceId = traceId;
}

public IrisMessage getIrisMessage() {
return irisMessage;
}

public void setIrisMessage(IrisMessage message) {
this.irisMessage = message;
}

@Override
public String toString() {
return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", num_input_tokens=" + num_input_tokens + ", cost_per_input_token=" + cost_per_input_token
+ ", num_output_tokens=" + num_output_tokens + ", cost_per_output_token=" + cost_per_output_token + ", course=" + course + ", exercise=" + exercise + ", userId="
+ userId + ", timestamp=" + timestamp + ", trace_id=" + traceId + ", irisMessage=" + irisMessage + '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;
Expand Down Expand Up @@ -44,6 +45,12 @@ public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) {

public List<LLMTokenUsage> saveTokenUsage(IrisMessage message, Exercise exercise, User user, Course course, List<PyrisLLMCostDTO> tokens) {
List<LLMTokenUsage> tokenUsages = new ArrayList<>();

// Combine current time and UUID to create a unique traceId
long timestamp = System.currentTimeMillis();
long uuidComponent = UUID.randomUUID().getLeastSignificantBits() & Long.MAX_VALUE;
Long traceId = timestamp + uuidComponent;

for (PyrisLLMCostDTO cost : tokens) {
LLMTokenUsage llmTokenUsage = new LLMTokenUsage();
if (message != null) {
Expand All @@ -57,8 +64,11 @@ public List<LLMTokenUsage> saveTokenUsage(IrisMessage message, Exercise exercise
}
llmTokenUsage.setCourse(course);
llmTokenUsage.setNum_input_tokens(cost.num_input_tokens());
llmTokenUsage.setCost_per_input_token(cost.cost_per_input_token());
llmTokenUsage.setNum_output_tokens(cost.num_output_tokens());
llmTokenUsage.setCost_per_output_token(cost.cost_per_output_token());
llmTokenUsage.setModel(cost.model_info());
llmTokenUsage.setTraceId(traceId);
tokenUsages.add(llmTokenUsageRepository.save(llmTokenUsage));
}
return tokenUsages;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

import de.tum.cit.aet.artemis.core.domain.LLMServiceType;

public record PyrisLLMCostDTO(String model_info, int num_input_tokens, int num_output_tokens, LLMServiceType pipeline) {
public record PyrisLLMCostDTO(String model_info, int num_input_tokens, float cost_per_input_token, int num_output_tokens, float cost_per_output_token, LLMServiceType pipeline) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@
xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog
http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-4.29.xsd"
objectQuotingStrategy="QUOTE_ONLY_RESERVED_WORDS">
<changeSet id="20241011140701" author="alex.joham">
<changeSet id="20241012125003" author="alex.joham">
<createTable tableName="llm_token_usage">
<column autoIncrement="true" name="id" type="BIGINT">
<constraints nullable="false" primaryKey="true" primaryKeyName="pk_llm_token_usage"/>
</column>
<column name="service" type="VARCHAR(255)"/>
<column name="model" type="VARCHAR(255)"/>
<column name="cost_per_token" type="DOUBLE"/>
<column name="num_input_tokens" type="INT"/>
<column name="cost_per_input_token" type="FLOAT"/>
<column name="num_output_tokens" type="INT"/>
<column name="cost_per_output_token" type="FLOAT"/>
<column name="course_id" type="BIGINT"/>
<column name="exercise_id" type="BIGINT"/>
<column name="user_id" type="BIGINT"/>
<column name="timestamp" type="DATETIME"/>
<column name="trace_id" type="BIGINT"/>
<column name="iris_message_id" type="BIGINT"/>
</createTable>
<addForeignKeyConstraint baseColumnNames="course_id" baseTableName="llm_token_usage"
Expand Down
2 changes: 1 addition & 1 deletion src/main/resources/config/liquibase/master.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<include file="classpath:config/liquibase/changelog/20240626200000_changelog.xml" relativeToChangelogFile="false"/>
<include file="classpath:config/liquibase/changelog/20240708144500_changelog.xml" relativeToChangelogFile="false"/>
<include file="classpath:config/liquibase/changelog/20240802091201_changelog.xml" relativeToChangelogFile="false"/>
<include file="classpath:config/liquibase/changelog/20241011140701_changelog.xml" relativeToChangelogFile="false"/>
<include file="classpath:config/liquibase/changelog/20241012125003_changelog.xml" relativeToChangelogFile="false"/>
<!-- NOTE: please use the format "YYYYMMDDhhmmss_changelog.xml", i.e. year month day hour minutes seconds and not something else! -->
<!-- we should also stay in a chronological order! -->
<!-- you can use the command 'date '+%Y%m%d%H%M%S'' to get the current date and time in the correct format -->
Expand Down

0 comments on commit 188ff22

Please sign in to comment.