Skip to content

Commit

Permalink
Token counting
Browse files Browse the repository at this point in the history
  • Loading branch information
acharneski committed Nov 12, 2023
1 parent 48123ca commit caa5574
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.skyenet.OutputInterceptor
import com.simiacryptus.skyenet.util.AwsUtil.decryptResource
import com.simiacryptus.skyenet.servlet.AuthenticatedWebsite
import com.simiacryptus.skyenet.servlet.UsageServlet
import com.simiacryptus.skyenet.webui.ApplicationBase
import com.simiacryptus.skyenet.servlet.UserInfoServlet
import com.simiacryptus.skyenet.servlet.UserSettingsServlet
Expand Down Expand Up @@ -46,6 +47,7 @@ abstract class AppServerBase(
val welcomeResources = Resource.newResource(javaClass.classLoader.getResource("welcome"))
val userInfoServlet = UserInfoServlet()
val userSettingsServlet = UserSettingsServlet()
val usageServlet = UsageServlet()

protected fun _main(args: Array<String>) {
try {
Expand All @@ -65,6 +67,7 @@ abstract class AppServerBase(
*(arrayOf(
newWebAppContext("/userInfo", userInfoServlet),
newWebAppContext("/userSettings", userSettingsServlet),
newWebAppContext("/usage", usageServlet),
newWebAppContext("/proxy", ProxyHttpServlet()),
authentication.configure(
newWebAppContext(
Expand Down Expand Up @@ -110,6 +113,7 @@ abstract class AppServerBase(
requestURI == "/index.html" -> resp?.writer?.write(homepage().trimIndent())
requestURI.startsWith("/userInfo") -> userInfoServlet.doGet(req!!, resp!!)
requestURI.startsWith("/userSettings") -> userSettingsServlet.doGet(req!!, resp!!)
requestURI.startsWith("/usage") -> usageServlet.doGet(req!!, resp!!)
else -> try {
val inputStream = welcomeResources.addPath(requestURI)?.inputStream
inputStream?.copyTo(resp?.outputStream!!)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,31 @@ class SessionSettingsServlet(
resp.contentType = "text/html"
resp.status = HttpServletResponse.SC_OK
val sessionId = req.getParameter("sessionId")
if (null == sessionId) {
resp.status = HttpServletResponse.SC_BAD_REQUEST
resp.writer.write("Session ID is required")
} else {
if (null != sessionId) {
val settings = server.getSettings<Any>(sessionId)
val json = if(settings != null) JsonUtil.toJson(settings) else ""
//language=HTML
resp.writer.write(
"""
|<html>
|<head>
| <title>Settings</title>
| <link rel="icon" type="image/svg+xml" href="/favicon.svg"/>
|</head>
|<body>
|<form action="${req.contextPath}/settings" method="post">
| <input type="hidden" name="sessionId" value="$sessionId"/>
| <input type="hidden" name="action" value="save"/>
| <textarea name="settings" style="width: 100%; height: 100px;">$json</textarea>
| <input type="submit" value="Save"/>
|</form>
|</body>
|</html>
""".trimMargin()
|<html>
|<head>
| <title>Settings</title>
| <link rel="icon" type="image/svg+xml" href="/favicon.svg"/>
|</head>
|<body>
|<form action="${req.contextPath}/settings" method="post">
| <input type="hidden" name="sessionId" value="$sessionId"/>
| <input type="hidden" name="action" value="save"/>
| <textarea name="settings" style="width: 100%; height: 100px;">$json</textarea>
| <input type="submit" value="Save"/>
|</form>
|</body>
|</html>
""".trimMargin()
)
} else {
resp.status = HttpServletResponse.SC_BAD_REQUEST
resp.writer.write("Session ID is required")
}
}

Expand Down
107 changes: 107 additions & 0 deletions webui/src/main/kotlin/com/simiacryptus/skyenet/servlet/UsageServlet.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package com.simiacryptus.skyenet.servlet

import com.google.api.services.oauth2.model.Userinfo
import com.simiacryptus.openai.OpenAIClient
import jakarta.servlet.http.HttpServlet
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import java.util.concurrent.atomic.AtomicInteger

class UsageServlet : HttpServlet() {
public override fun doGet(req: HttpServletRequest, resp: HttpServletResponse) {
resp.contentType = "text/html"
resp.status = HttpServletResponse.SC_OK

val sessionId = req.getParameter("sessionId")
if (null != sessionId) {
serve(resp, getSessionUsageSummary(sessionId))
} else {
val userinfo = AuthenticatedWebsite.getUser(req)
if (null == userinfo) {
resp.status = HttpServletResponse.SC_BAD_REQUEST
} else {
val usage = getUserUsageSummary(userinfo)
serve(resp, usage)
}
}
}

private fun serve(
resp: HttpServletResponse,
usage: Map<OpenAIClient.Model, Int>
) {
resp.writer.write(
"""
|<html>
|<head>
| <title>Usage</title>
| <link rel="icon" type="image/svg+xml" href="/favicon.svg"/>
|</head>
|<body>
|<table>
| <tr>
| <th>Model</th>
| <th>Usage</th>
| </tr>
| ${
usage.entries.joinToString("\n") { (model, count) ->
"""
|<tr>
| <td>$model</td>
| <td>$count</td>
|</tr>
""".trimMargin()
}
}
|</table>
|</body>
|</html>
""".trimMargin()
)
}

data class UsageCounters(
val tokensPerModel: HashMap<OpenAIClient.Model, AtomicInteger> = HashMap(),
)

companion object {
val log = org.slf4j.LoggerFactory.getLogger(UsageServlet::class.java)

private val usagePerSession = HashMap<String, UsageCounters>()
val sessionsByUser = HashMap<String, ArrayList<String>>()

fun incrementUsage(sessionId: String, userinfo: Userinfo?, model: OpenAIClient.Model, tokens: Int) {
val usage = usagePerSession.getOrPut(sessionId) {
UsageCounters()
}
val tokensPerModel = usage.tokensPerModel.getOrPut(model) {
AtomicInteger()
}
tokensPerModel.addAndGet(tokens)
if (userinfo != null) {
val sessions = sessionsByUser.getOrPut(userinfo.id) {
ArrayList()
}
sessions.add(sessionId)
}
}

fun getUserUsageSummary(userinfo: Userinfo): Map<OpenAIClient.Model, Int> {
val sessions = sessionsByUser[userinfo.id]
return sessions?.flatMap { sessionId ->
val usage = usagePerSession[sessionId]
usage?.tokensPerModel?.entries?.map { (model, counter) ->
model to counter.get()
} ?: emptyList()
}?.groupBy { it.first }?.mapValues { it.value.map { it.second }.sum() } ?: emptyMap()
}

fun getSessionUsageSummary(sessionId: String): Map<OpenAIClient.Model, Int> {
val usage = usagePerSession[sessionId]
return usage?.tokensPerModel?.entries?.map { (model, counter) ->
model to counter.get()
}?.groupBy { it.first }?.mapValues { it.value.map { it.second }.sum() } ?: emptyMap()
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.simiacryptus.skyenet.webui

import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.skyenet.servlet.AuthenticatedWebsite
import com.simiacryptus.skyenet.servlet.UsageServlet.Companion.incrementUsage
import com.simiacryptus.skyenet.servlet.UserSettingsServlet.Companion.getUserSettings
import org.eclipse.jetty.websocket.api.Session
import org.eclipse.jetty.websocket.api.WebSocketAdapter
Expand All @@ -19,20 +20,30 @@ class MessageWebSocket(
val userinfo = AuthenticatedWebsite.users[authId]
val userApi: OpenAIClient? = if (userinfo == null) null else {
val userSettings = getUserSettings(userinfo)
if (userSettings.apiKey.isBlank()) null else OpenAIClient(
if (userSettings.apiKey.isBlank()) null else object : OpenAIClient(
key = userSettings.apiKey,
logLevel = Level.DEBUG,
logStreams = mutableListOf(
sessionDataStorage.getSessionDir(sessionId).resolve("openai.log").outputStream().buffered()
),
)
) {
override fun incrementTokens(model: Model?, tokens: Int) {
incrementUsage(sessionId, userinfo, model!!, tokens)
super.incrementTokens(model, tokens)
}
}
}
return userApi ?: OpenAIClient(
return userApi ?: object : OpenAIClient(
logLevel = Level.DEBUG,
logStreams = mutableListOf(
sessionDataStorage.getSessionDir(sessionId).resolve("openai.log").outputStream().buffered()
)
)
) {
override fun incrementTokens(model: Model?, tokens: Int) {
incrementUsage(sessionId, userinfo, model!!, tokens)
super.incrementTokens(model, tokens)
}
}
}

override fun onWebSocketConnect(session: Session) {
Expand Down
1 change: 1 addition & 0 deletions webui/src/main/resources/simpleSession/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
<a id="history">Session List</a>
<a id="settings">Session Settings</a>
<a id="files">Files</a>
<a href="/usage" id="usage">Usage</a>
</div>

<div id="namebar">
Expand Down
2 changes: 2 additions & 0 deletions webui/src/main/resources/simpleSession/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ document.addEventListener('DOMContentLoaded', () => {

const form = document.getElementById('form');
const messageInput = document.getElementById('message');
const usage = document.getElementById('usage');

form.addEventListener('submit', (event) => {
event.preventDefault();
Expand All @@ -90,6 +91,7 @@ document.addEventListener('DOMContentLoaded', () => {
const sessionId = getSessionId();
if (sessionId) {
connect(sessionId, onWebSocketText);
usage.href = '/usage/?sessionId=' + sessionId;
} else {
connect(undefined, onWebSocketText);
}
Expand Down

0 comments on commit caa5574

Please sign in to comment.