Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
acharneski committed Nov 22, 2023
1 parent 282e116 commit 52019b2
Show file tree
Hide file tree
Showing 25 changed files with 134 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,10 @@ object ApplicationServices {
require(!isLocked) { "ApplicationServices is locked" }
field = value
}
var clientManager: ClientManager = ClientManager()
set(value) {
require(!isLocked) { "ApplicationServices is locked" }
field = value
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.simiacryptus.skyenet.platform

import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.openai.OpenAIClientBase
import com.simiacryptus.openai.models.OpenAIModel
import org.slf4j.event.Level
import java.io.File

open class ClientManager {

open fun createClient(
session: Session,
user: User?,
dataStorage: DataStorage,
): OpenAIClient {
if (user != null) {
val userSettings = ApplicationServices.userSettingsManager.getUserSettings(user)
val logfile = dataStorage.getSessionDir(user, session).resolve(".sys/openai.log")
logfile.parentFile?.mkdirs()
val userApi = createClient(session, user, logfile, userSettings.apiKey)
if (userApi != null) return userApi
}
val canUseGlobalKey = ApplicationServices.authorizationManager.isAuthorized(null, user,
AuthorizationManager.OperationType.GlobalKey
)
if (!canUseGlobalKey) throw RuntimeException("No API key")
val logfile = dataStorage.getSessionDir(user, session).resolve(".sys/openai.log")
logfile.parentFile?.mkdirs()
return createClient(session, user, logfile)!!
}

open protected fun createClient(
session: Session,
user: User?,
logfile: File,
key: String? = OpenAIClientBase.keyTxt
): OpenAIClient? = if (key.isNullOrBlank()) null else object : OpenAIClient(
key = key,
logLevel = Level.DEBUG,
logStreams = mutableListOf(
logfile?.outputStream()?.buffered()
).filterNotNull().toMutableList(),
) {
override fun incrementTokens(model: OpenAIModel?, tokens: Usage) {
ApplicationServices.usageManager.incrementUsage(session, user, model!!, tokens)
super.incrementTokens(model, tokens)
}
}

private val log = org.slf4j.LoggerFactory.getLogger(ClientManager::class.java)
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package com.simiacryptus.skyenet
package com.simiacryptus.skyenet.application


import com.simiacryptus.openai.OpenAIClientBase
import com.simiacryptus.skyenet.OutputInterceptor
import com.simiacryptus.skyenet.chat.ChatServer
import com.simiacryptus.skyenet.platform.ApplicationServices
import com.simiacryptus.skyenet.servlet.*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.simiacryptus.skyenet.session
package com.simiacryptus.skyenet.application

import com.simiacryptus.skyenet.session.SessionMessage
import java.util.function.Consumer

class ApplicationInterface(private val inner: ApplicationSocketManager) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.simiacryptus.skyenet
package com.simiacryptus.skyenet.application

import com.simiacryptus.openai.OpenAIAPI
import com.simiacryptus.skyenet.servlet.AppInfoServlet
Expand All @@ -13,8 +13,6 @@ import com.simiacryptus.skyenet.platform.AuthorizationManager
import com.simiacryptus.skyenet.platform.DataStorage
import com.simiacryptus.skyenet.platform.Session
import com.simiacryptus.skyenet.platform.User
import com.simiacryptus.skyenet.session.ApplicationInterface
import com.simiacryptus.skyenet.session.ApplicationSocketManager
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.eclipse.jetty.servlet.FilterHolder
Expand All @@ -23,7 +21,7 @@ import org.eclipse.jetty.webapp.WebAppContext
import org.slf4j.LoggerFactory
import java.io.File

abstract class ApplicationBase(
abstract class ApplicationServer(
final override val applicationName: String,
resourceBase: String = "simpleSession",
val temperature: Double = 0.1,
Expand All @@ -40,17 +38,17 @@ abstract class ApplicationBase(
override fun newSession(user: User?, session: Session): SocketManager {
return object : ApplicationSocketManager(
session = session,
userId = user,
user = user,
dataStorage = dataStorage,
applicationClass = this@ApplicationBase::class.java,
applicationClass = this@ApplicationServer::class.java,
) {
override fun newSession(
session: Session,
user: User?,
userMessage: String,
socketManager: ApplicationSocketManager,
api: OpenAIAPI
) = this@ApplicationBase.newSession(
) = this@ApplicationServer.newSession(
session = session,
user = user,
userMessage = userMessage,
Expand Down Expand Up @@ -93,7 +91,7 @@ abstract class ApplicationBase(
FilterHolder { request, response, chain ->
val user = authenticationManager.getUser((request as HttpServletRequest).getCookie())
val canRead = authorizationManager.isAuthorized(
applicationClass = this@ApplicationBase.javaClass,
applicationClass = this@ApplicationServer.javaClass,
user = user,
operationType = AuthorizationManager.OperationType.Read
)
Expand All @@ -116,7 +114,7 @@ abstract class ApplicationBase(
}

companion object {
private val log = LoggerFactory.getLogger(ApplicationBase::class.java)
private val log = LoggerFactory.getLogger(ApplicationServer::class.java)
val spinner =
"""<div class="spinner-border" role="status"><span class="sr-only">Loading...</span></div>"""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
package com.simiacryptus.skyenet.session
package com.simiacryptus.skyenet.application

import com.simiacryptus.openai.OpenAIAPI
import com.simiacryptus.skyenet.ApplicationBase
import com.simiacryptus.skyenet.chat.ChatSocket
import com.simiacryptus.skyenet.platform.DataStorage
import com.simiacryptus.skyenet.platform.Session
import com.simiacryptus.skyenet.platform.User
import com.simiacryptus.skyenet.platform.*
import com.simiacryptus.skyenet.session.SocketManagerBase
import java.util.function.Consumer

abstract class ApplicationSocketManager(
session: Session,
userId: User?,
user: User?,
dataStorage: DataStorage?,
applicationClass: Class<*>,
) : SocketManagerBase(
session = session,
dataStorage = dataStorage,
userId = userId,
user = user,
applicationClass = applicationClass,
) {
private val threads = mutableMapOf<String, Thread>()
Expand All @@ -26,7 +24,13 @@ abstract class ApplicationSocketManager(
override fun onRun(userMessage: String, socket: ChatSocket) {
val operationID = randomID()
threads[operationID] = Thread.currentThread()
newSession(session, user = userId, userMessage, this, socket.api)
newSession(
session, user = user, userMessage, this, ApplicationServices.clientManager.createClient(
session,
user,
dataStorage ?: throw IllegalStateException("No data storage")
)
)
}

val applicationInterface by lazy { ApplicationInterface(this) }
Expand Down Expand Up @@ -68,7 +72,7 @@ abstract class ApplicationSocketManager(
)

companion object {
val spinner: String get() = """<div>${ApplicationBase.spinner}</div>"""
val spinner: String get() = """<div>${ApplicationServer.spinner}</div>"""
// val playButton: String get() = """<button class="play-button" data-id="$operationID">▶</button>"""
// val cancelButton: String get() = """<button class="cancel-button" data-id="$operationID">&times;</button>"""
// val regenButton: String get() = """<button class="regen-button" data-id="$operationID">♲</button>"""
Expand Down
28 changes: 11 additions & 17 deletions webui/src/main/kotlin/com/simiacryptus/skyenet/chat/ChatServer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,26 @@ abstract class ChatServer(val resourceBase: String) {
override fun configure(factory: JettyWebSocketServletFactory) {
factory.setCreator { req, resp ->
try {
val authId = req.getCookie(AUTH_COOKIE)
return@setCreator if (!req.parameterMap.containsKey("sessionId")) {
null
throw IllegalArgumentException("sessionId is required")
} else {
val session = Session(req.parameterMap["sessionId"]?.first()!!)
val sessionState: SocketManager = getSession(session, req)
val user = authenticationManager.getUser(authId)
ChatSocket(session, sessionState, dataStorage, user)
ChatSocket(
if (stateCache.containsKey(session)) {
stateCache[session]!!
} else {
val user = authenticationManager.getUser(req.getCookie(AUTH_COOKIE))
val sessionState = newSession(user, session)
stateCache[session] = sessionState
sessionState
}
)
}
} catch (e: Exception) {
log.warn("Error configuring websocket", e)
}
}
}

private fun getSession(
session: Session,
req: JettyServerUpgradeRequest
) = if (stateCache.containsKey(session)) {
stateCache[session]!!
} else {
val user = authenticationManager.getUser(req.getCookie(AUTH_COOKIE))
val sessionState = newSession(user, session)
stateCache[session] = sessionState
sessionState
}
}

abstract fun newSession(user: User?, session: Session): SocketManager
Expand Down
57 changes: 2 additions & 55 deletions webui/src/main/kotlin/com/simiacryptus/skyenet/chat/ChatSocket.kt
Original file line number Diff line number Diff line change
@@ -1,68 +1,13 @@
package com.simiacryptus.skyenet.chat

import com.simiacryptus.openai.models.OpenAIModel
import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.skyenet.platform.*
import com.simiacryptus.skyenet.platform.ApplicationServices.authorizationManager
import com.simiacryptus.skyenet.session.SocketManager
import com.simiacryptus.skyenet.platform.AuthorizationManager.OperationType.GlobalKey
import org.eclipse.jetty.websocket.api.Session
import org.eclipse.jetty.websocket.api.WebSocketAdapter
import org.slf4j.event.Level

class ChatSocket(
private val session: com.simiacryptus.skyenet.platform.Session,
private val sessionState: SocketManager,
private val dataStorage: DataStorage?,
private val user: User?,
) : WebSocketAdapter() {

val logfile by lazy {
val file = dataStorage?.getSessionDir(user, session)?.resolve(".sys/openai.log")
file?.parentFile?.mkdirs()
file
}

val api: OpenAIClient
get() {
val user = user
val userApi = userApi
if (userApi != null) return userApi
val canUseGlobalKey = authorizationManager.isAuthorized(null, user, GlobalKey)
if (!canUseGlobalKey) throw RuntimeException("No API key")
return object : OpenAIClient(
logLevel = Level.DEBUG,
logStreams = mutableListOf(
logfile?.outputStream()?.buffered()
).filterNotNull().toMutableList()
) {
override fun incrementTokens(model: OpenAIModel?, tokens: Usage) {
if(null != model) ApplicationServices.usageManager.incrementUsage(session, user, model, tokens)
super.incrementTokens(model, tokens)
}
}
}

private val userApi: OpenAIClient?
get() {
val user = user
val userSettings = if (user == null) null else ApplicationServices.userSettingsManager.getUserSettings(user)
return if (userSettings == null) null else {
if (userSettings.apiKey.isBlank()) null else object : OpenAIClient(
key = userSettings.apiKey,
logLevel = Level.DEBUG,
logStreams = mutableListOf(
logfile?.outputStream()?.buffered()
).filterNotNull().toMutableList(),
) {
override fun incrementTokens(model: OpenAIModel?, tokens: Usage) {
ApplicationServices.usageManager.incrementUsage(session, user, model!!, tokens)
super.incrementTokens(model, tokens)
}
}
}
}


override fun onWebSocketConnect(session: Session) {
super.onWebSocketConnect(session)
Expand Down Expand Up @@ -91,3 +36,5 @@ class ChatSocket(
private val log = org.slf4j.LoggerFactory.getLogger(ChatSocket::class.java)
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.openai.OpenAIClientBase.Companion.toContentList
import com.simiacryptus.openai.models.ChatModels
import com.simiacryptus.openai.models.OpenAITextModel
import com.simiacryptus.skyenet.ApplicationBase
import com.simiacryptus.skyenet.application.ApplicationServer
import com.simiacryptus.skyenet.platform.Session
import com.simiacryptus.skyenet.session.SocketManagerBase
import com.simiacryptus.skyenet.util.MarkdownUtil
Expand All @@ -18,8 +18,8 @@ open class ChatSocketManager(
val systemPrompt: String,
val api: OpenAIClient,
val temperature: Double = 0.3,
applicationClass: Class<out ApplicationBase>,
) : SocketManagerBase(session, parent.dataStorage, userId = null, applicationClass = applicationClass) {
applicationClass: Class<out ApplicationServer>,
) : SocketManagerBase(session, parent.dataStorage, user = null, applicationClass = applicationClass) {

init {
if (userInterfacePrompt.isNotBlank()) {
Expand All @@ -39,7 +39,7 @@ open class ChatSocketManager(
override fun onRun(userMessage: String, socket: ChatSocket) {
var responseContents = divInitializer(cancelable = false)
responseContents += """<div class="user-message">${renderResponse(userMessage)}</div>"""
send("""$responseContents<div class="chat-response">${ApplicationBase.spinner}</div>""")
send("""$responseContents<div class="chat-response">${ApplicationServer.spinner}</div>""")
val response = handleMessage(userMessage, responseContents)
if(null != response) {
responseContents += """<div class="chat-response">${renderResponse(response)}</div>"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.simiacryptus.skyenet.chat
import com.simiacryptus.openai.OpenAIClient
import com.simiacryptus.openai.models.ChatModels
import com.simiacryptus.openai.models.OpenAITextModel
import com.simiacryptus.skyenet.ApplicationBase
import com.simiacryptus.skyenet.application.ApplicationServer
import com.simiacryptus.skyenet.platform.Session
import com.simiacryptus.skyenet.platform.User
import com.simiacryptus.skyenet.servlet.AppInfoServlet
Expand Down Expand Up @@ -47,7 +47,7 @@ class CodeChatServer(
|
|Responses may use markdown formatting.
""".trimMargin(),
applicationClass = ApplicationBase::class.java,
applicationClass = ApplicationServer::class.java,
) {
override fun canWrite(user: User?): Boolean = true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport
import com.google.api.client.json.gson.GsonFactory
import com.google.api.services.oauth2.Oauth2
import com.google.api.services.oauth2.model.Userinfo
import com.simiacryptus.skyenet.ApplicationBase.Companion.getCookie
import com.simiacryptus.skyenet.application.ApplicationServer.Companion.getCookie
import com.simiacryptus.skyenet.platform.ApplicationServices
import com.simiacryptus.skyenet.platform.AuthenticationManager.Companion.AUTH_COOKIE
import com.simiacryptus.skyenet.platform.User
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.simiacryptus.skyenet.servlet

import com.simiacryptus.skyenet.ApplicationBase
import com.simiacryptus.skyenet.ApplicationBase.Companion.getCookie
import com.simiacryptus.skyenet.application.ApplicationServer
import com.simiacryptus.skyenet.application.ApplicationServer.Companion.getCookie
import com.simiacryptus.skyenet.platform.ApplicationServices
import com.simiacryptus.skyenet.platform.DataStorage
import com.simiacryptus.skyenet.platform.Session
Expand All @@ -20,7 +20,7 @@ class FileServlet(val dataStorage: DataStorage) : HttpServlet() {
val file = File(sessionDir, filePath)
if (file.isFile) {
val filename = file.name
resp.contentType = ApplicationBase.getMimeType(filename)
resp.contentType = ApplicationServer.getMimeType(filename)
resp.status = HttpServletResponse.SC_OK
file.inputStream().use { inputStream ->
resp.outputStream.use { outputStream ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.simiacryptus.skyenet.servlet

import com.simiacryptus.skyenet.ApplicationBase.Companion.getCookie
import com.simiacryptus.skyenet.application.ApplicationServer.Companion.getCookie
import com.simiacryptus.skyenet.platform.ApplicationServices
import com.simiacryptus.skyenet.platform.DataStorage
import com.simiacryptus.skyenet.platform.Session
Expand Down
Loading

0 comments on commit 52019b2

Please sign in to comment.