diff --git a/README.md b/README.md index 6d0189e9..23568ed2 100644 --- a/README.md +++ b/README.md @@ -76,18 +76,18 @@ Maven: com.simiacryptus skyenet-webui - 1.0.25 + 1.0.26 ``` Gradle: ```groovy -implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.25' +implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.26' ``` ```kotlin -implementation("com.simiacryptus:skyenet:1.0.25") +implementation("com.simiacryptus:skyenet:1.0.26") ``` ### 🌟 To Use diff --git a/core/src/main/java/com/simiacryptus/skyenet/OutputInterceptor.java b/core/src/main/java/com/simiacryptus/skyenet/OutputInterceptor.java index 63db8bc9..6fe7ef38 100644 --- a/core/src/main/java/com/simiacryptus/skyenet/OutputInterceptor.java +++ b/core/src/main/java/com/simiacryptus/skyenet/OutputInterceptor.java @@ -1,8 +1,8 @@ package com.simiacryptus.skyenet; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.io.PrintStream; -import java.util.HashMap; import java.util.Map; import java.util.WeakHashMap; import java.util.concurrent.atomic.AtomicBoolean; @@ -16,83 +16,96 @@ private OutputInterceptor() { private static final PrintStream originalOut = System.out; private static final PrintStream originalErr = System.err; private static final AtomicBoolean isSetup = new AtomicBoolean(false); + private static final Object globalStreamLock = new Object(); public static void setupInterceptor() { if (isSetup.getAndSet(true)) return; - System.setOut(createInterceptorStream(originalOut)); - System.setErr(createInterceptorStream(originalErr)); + System.setOut(new PrintStream(new OutputStreamRouter(originalOut))); + System.setErr(new PrintStream(new OutputStreamRouter(originalErr))); } - private static final ByteArrayOutputStream centralStream = new ByteArrayOutputStream(); - - public static void initThreadOutputStream() { - setOutputStream(new ByteArrayOutputStream()); - } - - public static void resetThreadOutputStream() { - setOutputStream(centralStream); - } + private static final ByteArrayOutputStream globalStream = new ByteArrayOutputStream(); private static final Map threadLocalBuffer = new WeakHashMap<>(); - public static void setOutputStream(ByteArrayOutputStream stream) { - threadLocalBuffer.put(Thread.currentThread(), stream); - } - - public static ByteArrayOutputStream getOutputStream() { - return threadLocalBuffer.getOrDefault(Thread.currentThread(), new ByteArrayOutputStream()); + private static ByteArrayOutputStream getThreadOutputStream() { + Thread currentThread = Thread.currentThread(); + ByteArrayOutputStream outputStream; + synchronized (threadLocalBuffer) { + if ((outputStream = threadLocalBuffer.get(currentThread)) != null) return outputStream; + outputStream = new ByteArrayOutputStream(); + threadLocalBuffer.put(currentThread, outputStream); + } + return outputStream; } public static String getThreadOutput() { - return getOutputStream().toString(); + ByteArrayOutputStream outputStream = getThreadOutputStream(); + try { + outputStream.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + return outputStream.toString(); } public static void clearThreadOutput() { - getOutputStream().reset(); + getThreadOutputStream().reset(); } public static String getGlobalOutput() { - return centralStream.toString(); + synchronized (globalStreamLock) { + return globalStream.toString(); + } } public static void clearGlobalOutput() { - centralStream.reset(); + synchronized (globalStreamLock) { + globalStream.reset(); + } } - public static PrintStream createInterceptorStream(PrintStream originalStream) { + private static class OutputStreamRouter extends ByteArrayOutputStream { + private final PrintStream originalStream; int maxGlobalBuffer = 8 * 1024 * 1024; int maxThreadBuffer = 1024 * 1024; - return new PrintStream(new ByteArrayOutputStream() { - @Override - public void write(int b) { - originalStream.write(b); - if(centralStream.size() > maxGlobalBuffer) { - centralStream.reset(); - } - centralStream.write(b); - ByteArrayOutputStream stream = getOutputStream(); - if(stream.size() > maxThreadBuffer) { - stream.reset(); - } - stream.write(b); - } - @Override - public void write(byte[] b, int off, int len) { - originalStream.write(b, off, len); - if(centralStream.size() > 1024 * 1024) { - centralStream.reset(); + public OutputStreamRouter(PrintStream originalStream) { + this.originalStream = originalStream; + } + + @Override + public void write(int b) { + originalStream.write(b); + synchronized (globalStreamLock) { + if (globalStream.size() > maxGlobalBuffer) { + globalStream.reset(); } - centralStream.write(b, off, len); - ByteArrayOutputStream threadStream = getOutputStream(); - if (threadStream != null) { - if (threadStream.size() > 1024 * 1024) { - threadStream.reset(); - } - threadStream.write(b, off, len); + globalStream.write(b); + } + ByteArrayOutputStream threadOutputStream = getThreadOutputStream(); + if (threadOutputStream.size() > maxThreadBuffer) { + threadOutputStream.reset(); + } + threadOutputStream.write(b); + } + + @Override + public void write(byte[] b, int off, int len) { + originalStream.write(b, off, len); + synchronized (globalStreamLock) { + if (globalStream.size() > maxGlobalBuffer) { + globalStream.reset(); } + globalStream.write(b, off, len); } - }); + ByteArrayOutputStream threadOutputStream = getThreadOutputStream(); + if (threadOutputStream.size() > maxThreadBuffer) { + threadOutputStream.reset(); + } + threadOutputStream.write(b, off, len); + } } - } + + diff --git a/core/src/test/java/com/simiacryptus/skyenet/OutputInterceptorThreadedTest.java b/core/src/test/java/com/simiacryptus/skyenet/OutputInterceptorThreadedTest.java index 59c441e0..2ae5bd5d 100644 --- a/core/src/test/java/com/simiacryptus/skyenet/OutputInterceptorThreadedTest.java +++ b/core/src/test/java/com/simiacryptus/skyenet/OutputInterceptorThreadedTest.java @@ -22,15 +22,20 @@ public void testThreadedInterceptor() throws InterruptedException { String threadName = Thread.currentThread().getName(); System.out.println("Thread: " + threadName + " output"); System.err.println("Thread: " + threadName + " error"); - - String expectedOutput = "Thread: " + threadName + " output\nThread: " + threadName + " error\n"; - String threadOutput = OutputInterceptor.getThreadOutput().replace("\r", ""); + try { + Thread.sleep(1); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + String expectedOutput = ("Thread: " + threadName + " output\nThread: " + threadName + " error\n").trim(); + String threadOutput = OutputInterceptor.getThreadOutput().replace("\r", "").trim(); if (threadOutput.trim().equals(expectedOutput.trim())) { successCounter.incrementAndGet(); } else { synchronized (lock) { System.out.println("Expected:\n " + expectedOutput.replaceAll("\n", "\n ")); System.out.println("Actual:\n " + threadOutput.replaceAll("\n", "\n ")); + System.out.flush(); } } }; diff --git a/gradle.properties b/gradle.properties index a3bccd4b..24a56afd 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,6 +1,6 @@ # Gradle Releases -> https://github.com/gradle/gradle/releases libraryGroup = com.simiacryptus.skyenet -libraryVersion = 1.0.25 +libraryVersion = 1.0.26 gradleVersion = 7.6.1 # Opt-out flag for bundling Kotlin standard library -> https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatApplicationBase.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatApplicationBase.kt index 8091aebb..63fae4c0 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatApplicationBase.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatApplicationBase.kt @@ -4,10 +4,12 @@ abstract class ChatApplicationBase( applicationName: String, oauthConfig: String? = null, temperature: Double = 0.1, + resourceBase: String = "simpleSession", ) : ApplicationBase( applicationName = applicationName, oauthConfig = oauthConfig, temperature = temperature, + resourceBase = resourceBase, ) { override fun newSession(sessionId: String): SessionInterface { diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatSession.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatSession.kt index e930b1c0..97d40212 100644 --- a/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatSession.kt +++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatSession.kt @@ -1,6 +1,7 @@ package com.simiacryptus.skyenet.sessions import com.simiacryptus.openai.OpenAIClient +import com.simiacryptus.skyenet.util.MarkdownUtil open class ChatSession( val parent: ApplicationBase, @@ -48,7 +49,7 @@ open class ChatSession( open fun getResponse() = api.chat(newChatRequest, model).choices.first().message?.content.orEmpty() - open fun renderResponse(response: String) = """
$response
""" + open fun renderResponse(response: String) = """
${MarkdownUtil.renderMarkdown(response)}
""" open fun onResponse(response: String, responseContents: String) {} diff --git a/webui/src/test/kotlin/com/simiacryptus/skyenet/SessionDataStorageTest.kt b/webui/src/test/kotlin/com/simiacryptus/skyenet/SessionDataStorageTest.kt index 27d22157..1aa713f0 100644 --- a/webui/src/test/kotlin/com/simiacryptus/skyenet/SessionDataStorageTest.kt +++ b/webui/src/test/kotlin/com/simiacryptus/skyenet/SessionDataStorageTest.kt @@ -26,7 +26,7 @@ class SessionDataStorageTest { @Test fun testUpdateAndLoadMessage() { - val sessionId = "session1" + val sessionId = SessionDataStorage.newID() val messageId = "message1" val message = "This is a test message."