threadLocalBuffer = new WeakHashMap<>();
- public static void setOutputStream(ByteArrayOutputStream stream) {
- threadLocalBuffer.put(Thread.currentThread(), stream);
- }
-
- public static ByteArrayOutputStream getOutputStream() {
- return threadLocalBuffer.getOrDefault(Thread.currentThread(), new ByteArrayOutputStream());
+ private static ByteArrayOutputStream getThreadOutputStream() {
+ Thread currentThread = Thread.currentThread();
+ ByteArrayOutputStream outputStream;
+ synchronized (threadLocalBuffer) {
+ if ((outputStream = threadLocalBuffer.get(currentThread)) != null) return outputStream;
+ outputStream = new ByteArrayOutputStream();
+ threadLocalBuffer.put(currentThread, outputStream);
+ }
+ return outputStream;
}
public static String getThreadOutput() {
- return getOutputStream().toString();
+ ByteArrayOutputStream outputStream = getThreadOutputStream();
+ try {
+ outputStream.flush();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ return outputStream.toString();
}
public static void clearThreadOutput() {
- getOutputStream().reset();
+ getThreadOutputStream().reset();
}
public static String getGlobalOutput() {
- return centralStream.toString();
+ synchronized (globalStreamLock) {
+ return globalStream.toString();
+ }
}
public static void clearGlobalOutput() {
- centralStream.reset();
+ synchronized (globalStreamLock) {
+ globalStream.reset();
+ }
}
- public static PrintStream createInterceptorStream(PrintStream originalStream) {
+ private static class OutputStreamRouter extends ByteArrayOutputStream {
+ private final PrintStream originalStream;
int maxGlobalBuffer = 8 * 1024 * 1024;
int maxThreadBuffer = 1024 * 1024;
- return new PrintStream(new ByteArrayOutputStream() {
- @Override
- public void write(int b) {
- originalStream.write(b);
- if(centralStream.size() > maxGlobalBuffer) {
- centralStream.reset();
- }
- centralStream.write(b);
- ByteArrayOutputStream stream = getOutputStream();
- if(stream.size() > maxThreadBuffer) {
- stream.reset();
- }
- stream.write(b);
- }
- @Override
- public void write(byte[] b, int off, int len) {
- originalStream.write(b, off, len);
- if(centralStream.size() > 1024 * 1024) {
- centralStream.reset();
+ public OutputStreamRouter(PrintStream originalStream) {
+ this.originalStream = originalStream;
+ }
+
+ @Override
+ public void write(int b) {
+ originalStream.write(b);
+ synchronized (globalStreamLock) {
+ if (globalStream.size() > maxGlobalBuffer) {
+ globalStream.reset();
}
- centralStream.write(b, off, len);
- ByteArrayOutputStream threadStream = getOutputStream();
- if (threadStream != null) {
- if (threadStream.size() > 1024 * 1024) {
- threadStream.reset();
- }
- threadStream.write(b, off, len);
+ globalStream.write(b);
+ }
+ ByteArrayOutputStream threadOutputStream = getThreadOutputStream();
+ if (threadOutputStream.size() > maxThreadBuffer) {
+ threadOutputStream.reset();
+ }
+ threadOutputStream.write(b);
+ }
+
+ @Override
+ public void write(byte[] b, int off, int len) {
+ originalStream.write(b, off, len);
+ synchronized (globalStreamLock) {
+ if (globalStream.size() > maxGlobalBuffer) {
+ globalStream.reset();
}
+ globalStream.write(b, off, len);
}
- });
+ ByteArrayOutputStream threadOutputStream = getThreadOutputStream();
+ if (threadOutputStream.size() > maxThreadBuffer) {
+ threadOutputStream.reset();
+ }
+ threadOutputStream.write(b, off, len);
+ }
}
-
}
+
+
diff --git a/core/src/test/java/com/simiacryptus/skyenet/OutputInterceptorThreadedTest.java b/core/src/test/java/com/simiacryptus/skyenet/OutputInterceptorThreadedTest.java
index 59c441e0..2ae5bd5d 100644
--- a/core/src/test/java/com/simiacryptus/skyenet/OutputInterceptorThreadedTest.java
+++ b/core/src/test/java/com/simiacryptus/skyenet/OutputInterceptorThreadedTest.java
@@ -22,15 +22,20 @@ public void testThreadedInterceptor() throws InterruptedException {
String threadName = Thread.currentThread().getName();
System.out.println("Thread: " + threadName + " output");
System.err.println("Thread: " + threadName + " error");
-
- String expectedOutput = "Thread: " + threadName + " output\nThread: " + threadName + " error\n";
- String threadOutput = OutputInterceptor.getThreadOutput().replace("\r", "");
+ try {
+ Thread.sleep(1);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ String expectedOutput = ("Thread: " + threadName + " output\nThread: " + threadName + " error\n").trim();
+ String threadOutput = OutputInterceptor.getThreadOutput().replace("\r", "").trim();
if (threadOutput.trim().equals(expectedOutput.trim())) {
successCounter.incrementAndGet();
} else {
synchronized (lock) {
System.out.println("Expected:\n " + expectedOutput.replaceAll("\n", "\n "));
System.out.println("Actual:\n " + threadOutput.replaceAll("\n", "\n "));
+ System.out.flush();
}
}
};
diff --git a/gradle.properties b/gradle.properties
index a3bccd4b..24a56afd 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -1,6 +1,6 @@
# Gradle Releases -> https://github.com/gradle/gradle/releases
libraryGroup = com.simiacryptus.skyenet
-libraryVersion = 1.0.25
+libraryVersion = 1.0.26
gradleVersion = 7.6.1
# Opt-out flag for bundling Kotlin standard library -> https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library
diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatApplicationBase.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatApplicationBase.kt
index 8091aebb..63fae4c0 100644
--- a/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatApplicationBase.kt
+++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatApplicationBase.kt
@@ -4,10 +4,12 @@ abstract class ChatApplicationBase(
applicationName: String,
oauthConfig: String? = null,
temperature: Double = 0.1,
+ resourceBase: String = "simpleSession",
) : ApplicationBase(
applicationName = applicationName,
oauthConfig = oauthConfig,
temperature = temperature,
+ resourceBase = resourceBase,
) {
override fun newSession(sessionId: String): SessionInterface {
diff --git a/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatSession.kt b/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatSession.kt
index e930b1c0..97d40212 100644
--- a/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatSession.kt
+++ b/webui/src/main/kotlin/com/simiacryptus/skyenet/sessions/ChatSession.kt
@@ -1,6 +1,7 @@
package com.simiacryptus.skyenet.sessions
import com.simiacryptus.openai.OpenAIClient
+import com.simiacryptus.skyenet.util.MarkdownUtil
open class ChatSession(
val parent: ApplicationBase,
@@ -48,7 +49,7 @@ open class ChatSession(
open fun getResponse() = api.chat(newChatRequest, model).choices.first().message?.content.orEmpty()
- open fun renderResponse(response: String) = """$response
"""
+ open fun renderResponse(response: String) = """${MarkdownUtil.renderMarkdown(response)}
"""
open fun onResponse(response: String, responseContents: String) {}
diff --git a/webui/src/test/kotlin/com/simiacryptus/skyenet/SessionDataStorageTest.kt b/webui/src/test/kotlin/com/simiacryptus/skyenet/SessionDataStorageTest.kt
index 27d22157..1aa713f0 100644
--- a/webui/src/test/kotlin/com/simiacryptus/skyenet/SessionDataStorageTest.kt
+++ b/webui/src/test/kotlin/com/simiacryptus/skyenet/SessionDataStorageTest.kt
@@ -26,7 +26,7 @@ class SessionDataStorageTest {
@Test
fun testUpdateAndLoadMessage() {
- val sessionId = "session1"
+ val sessionId = SessionDataStorage.newID()
val messageId = "message1"
val message = "This is a test message."