Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.0.23 #27

Merged
merged 2 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ Maven:
<dependency>
<groupId>com.simiacryptus</groupId>
<artifactId>skyenet-webui</artifactId>
<version>1.0.22</version>
<version>1.0.23</version>
</dependency>
```

Gradle:

```groovy
implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.22'
implementation group: 'com.simiacryptus', name: 'skyenet', version: '1.0.23'
```

```kotlin
implementation("com.simiacryptus:skyenet:1.0.22")
implementation("com.simiacryptus:skyenet:1.0.23")
```

### 🌟 To Use
Expand Down
25 changes: 21 additions & 4 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = properties("libraryVersion")
plugins {
java
`java-library`
id("org.jetbrains.kotlin.jvm") version "1.7.22"
id("org.jetbrains.kotlin.jvm") version "1.9.20"
`maven-publish`
id("signing")
}
Expand All @@ -26,23 +26,40 @@ kotlin {
jvmToolchain(11)
}

val kotlin_version = "1.7.22"
val kotlin_version = "1.9.20"
val junit_version = "5.9.2"
val jetty_version = "11.0.17"
val logback_version = "1.2.12"

dependencies {

implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.24")
implementation(group = "com.simiacryptus", name = "joe-penai", version = "1.0.25")

implementation(group = "org.slf4j", name = "slf4j-api", version = "2.0.9")

implementation(group = "commons-io", name = "commons-io", version = "2.11.0")

compileOnlyApi(kotlin("stdlib"))
implementation(group = "org.jetbrains.kotlinx", name = "kotlinx-coroutines-core", version = "1.7.3")
testImplementation(kotlin("stdlib"))
testImplementation(kotlin("script-runtime"))

testImplementation(group = "org.junit.jupiter", name = "junit-jupiter-api", version = junit_version)
testRuntimeOnly(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = junit_version)
compileOnlyApi(group = "org.junit.jupiter", name = "junit-jupiter-api", version = junit_version)
compileOnlyApi(group = "org.junit.jupiter", name = "junit-jupiter-engine", version = junit_version)

// compileOnlyApi(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version)
compileOnlyApi(group = "com.google.cloud", name = "google-cloud-texttospeech", version = "2.28.0")
compileOnlyApi(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.454")
compileOnlyApi(group = "ch.qos.logback", name = "logback-classic", version = logback_version)
compileOnlyApi(group = "ch.qos.logback", name = "logback-core", version = logback_version)

// testImplementation(group = "org.eclipse.jetty", name = "jetty-server", version = jetty_version)
testImplementation(group = "com.google.cloud", name = "google-cloud-texttospeech", version = "2.28.0")
testImplementation(group = "com.amazonaws", name = "aws-java-sdk", version = "1.12.454")
testImplementation(group = "ch.qos.logback", name = "logback-classic", version = logback_version)
testImplementation(group = "ch.qos.logback", name = "logback-core", version = logback_version)

}

tasks {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.io.PrintStream;
import java.util.HashMap;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

public class OutputInterceptor {
Expand Down Expand Up @@ -32,14 +33,14 @@ public static void resetThreadOutputStream() {
setOutputStream(centralStream);
}

private static final Map<Thread, ByteArrayOutputStream> threadLocalBuffer = new HashMap<>();
private static final Map<Thread, ByteArrayOutputStream> threadLocalBuffer = new WeakHashMap<>();

public static void setOutputStream(ByteArrayOutputStream stream) {
threadLocalBuffer.put(Thread.currentThread(), stream);
}

public static ByteArrayOutputStream getOutputStream() {
return threadLocalBuffer.get(Thread.currentThread());
return threadLocalBuffer.getOrDefault(Thread.currentThread(), new ByteArrayOutputStream());
}

public static String getThreadOutput() {
Expand Down Expand Up @@ -83,11 +84,13 @@ public void write(byte[] b, int off, int len) {
centralStream.reset();
}
centralStream.write(b, off, len);
ByteArrayOutputStream stream = getOutputStream();
if(stream.size() > 1024 * 1024) {
stream.reset();
ByteArrayOutputStream threadStream = getOutputStream();
if (threadStream != null) {
if (threadStream.size() > 1024 * 1024) {
threadStream.reset();
}
threadStream.write(b, off, len);
}
stream.write(b, off, len);
}
});
}
Expand Down
55 changes: 15 additions & 40 deletions core/src/main/kotlin/com/simiacryptus/skyenet/Brain.kt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
@file:Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")

package com.simiacryptus.skyenet

import com.simiacryptus.openai.OpenAIClient
Expand All @@ -11,39 +13,28 @@ import java.lang.reflect.Method
import java.lang.reflect.Modifier
import java.util.concurrent.atomic.AtomicInteger

@Suppress("MemberVisibilityCanBePrivate")
open class Brain(
val api: OpenAIClient,
val symbols: java.util.Map<String, Object> = java.util.HashMap<String, Object>() as java.util.Map<String, Object>,
var model: OpenAIClient.Model = OpenAIClient.Models.GPT35Turbo,
var verbose: Boolean = false,
var temperature: Double = 0.3,
var describer: TypeDescriber = YamlDescriber(),
val model: OpenAIClient.Model = OpenAIClient.Models.GPT35Turbo,
private val verbose: Boolean = false,
val temperature: Double = 0.3,
val describer: TypeDescriber = YamlDescriber(),
val language: String = "Kotlin",
private val moderated: Boolean = true,
val apiDescription: String = apiDescription(symbols, describer),
private val apiDescription: String = apiDescription(symbols, describer),
) {
val metrics: Map<String, Any>
get() = hashMapOf(
"totalInputLength" to totalInputLength.get(),
"totalOutputLength" to totalOutputLength.get(),
"totalApiDescriptionLength" to totalApiDescriptionLength.get(),
"totalExamplesLength" to totalExamplesLength.get(),
) + api.metrics
protected val totalInputLength = AtomicInteger(0)
protected val totalOutputLength = AtomicInteger(0)
protected val totalExamplesLength = AtomicInteger(0)
protected val totalApiDescriptionLength: AtomicInteger = AtomicInteger(0)

private val totalInputLength = AtomicInteger(0)
private val totalOutputLength = AtomicInteger(0)
private val totalApiDescriptionLength: AtomicInteger = AtomicInteger(0)

open fun implement(vararg prompt: String): String {
if (verbose) log.info("Prompt: \n\t" + prompt.joinToString("\n\t"))
return implement(*getChatMessages(*prompt).toTypedArray())
return implement(*(getChatSystemMessages(apiDescription) +
prompt.map { ChatMessage(ChatMessage.Role.user, it) }).toTypedArray()
)
}

fun getChatMessages(vararg prompt: String) = getChatSystemMessages(apiDescription) +
prompt.map { ChatMessage(ChatMessage.Role.user, it) }

fun implement(
vararg messages: ChatMessage
): String {
Expand Down Expand Up @@ -143,7 +134,7 @@ open class Brain(
if (moderated) api.moderate(json)
totalInputLength.addAndGet(json.length)
val chatResponse = api.chat(request, model)
var response = chatResponse.choices.first()?.message?.content.orEmpty()
var response = chatResponse.choices.first().message?.content.orEmpty()
if (verbose) log.info(response)
totalOutputLength.addAndGet(response.length)
response = response.trim()
Expand All @@ -163,7 +154,7 @@ open class Brain(
return superMethod?.superMethod() ?: superMethod
}

val <T> Class<T>.superclasses: List<Class<*>>
private val <T> Class<T>.superclasses: List<Class<*>>
get() {
val superclass = superclass
val supers = if (superclass == null) listOf()
Expand Down Expand Up @@ -210,12 +201,6 @@ open class Brain(
""".trimMargin()
}

/***
* The input stream is parsed based on ```language\n...\n``` blocks.
* A list of tuples is returned, where the first element is the language and the second is the code block
* For intermediate non-code blocks, the language is "text"
* For unlabeled code blocks, the language is "code"
*/
fun extractCodeBlocks(response: String): List<Pair<String, String>> {
val codeBlockRegex = Regex("(?s)```(.*?)\\n(.*?)```")
val languageRegex = Regex("([a-zA-Z0-9-_]+)")
Expand Down Expand Up @@ -249,16 +234,6 @@ open class Brain(
return result
}

fun extractCodeBlock(response: String): String {
var response1 = response
if (response1.contains("```")) {
val startIndex = response1.indexOf('\n', response1.indexOf("```"))
val endIndex = response1.lastIndexOf("```")
val trim = response1.substring(startIndex, endIndex).trim()
response1 = trim
}
return response1
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import java.util.concurrent.atomic.AtomicInteger
/**
* The ears are the interface to the audio input for the SkyeNet system
*/
@Suppress("MemberVisibilityCanBePrivate")
@Suppress("MemberVisibilityCanBePrivate", "unused")
open class Ears(
val api: OpenAIClient,
val secondsPerAudioPacket : Double = 0.25,
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/kotlin/com/simiacryptus/skyenet/Heart.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ interface Heart {

fun getLanguage(): String
fun run(code: String): Any?
fun validate(code: String): Exception?
fun validate(code: String): Throwable?

fun wrapCode(code: String): String = code
fun <T : Any> wrapExecution(fn: java.util.function.Supplier<T?>): T? = fn.get()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ abstract class HeartTestBase {
@Test
fun `test validate with invalid code`() {
val interpreter = newInterpreter(mapOf<String,Object>() as Map<String,Object>)
assertThrows<Exception> { interpreter.validate("2 +") }
assertThrows<Exception> { with(interpreter.validate("2 +")) { throw this!! } }
}

@Test
open fun `test run with variables`() {
val interpreter = newInterpreter(mapOf("x" to (2 as Object), "y" to (3 as Object)) as Map<String,Object>)
interpreter.run("x * y")
val result = interpreter.run("x * y")
Assertions.assertEquals(6, result)
}

@Test
Expand Down Expand Up @@ -76,13 +77,13 @@ abstract class HeartTestBase {
@Test
open fun `test validate with tool object and invalid code`() {
val interpreter = newInterpreter(mapOf("tool" to (FooBar() as Object)) as Map<String,Object>)
assertThrows<Exception> { interpreter.validate("tool.baz()") }
assertThrows<Exception> { with(interpreter.validate("tool.baz()")) { throw this!! } }
}

@Test
open fun `test validate with undefined variable`() {
val interpreter = newInterpreter(mapOf<String,Object>() as Map<String,Object>)
assertThrows<Exception> { interpreter.validate("x * y") }
assertThrows<Exception> { with(interpreter.validate("x * y")) { throw this!! } }
}

abstract fun newInterpreter(map: Map<String, Object>): Heart
Expand Down
Loading
Loading