From f3396e1c146baec9dae057d7f9eeff8f5d4baff5 Mon Sep 17 00:00:00 2001 From: Phodal Huang Date: Sun, 7 Jul 2024 16:06:24 +0800 Subject: [PATCH] refactor(embedding): modify LocalEmbedding creation method Modify the LocalEmbedding companion object to accept a ClassLoader parameter for flexibility in model loading. Remove unnecessary comments and instantiation in test class. --- .../src/main/kotlin/cc/unitmesh/cf/LocalEmbedding.kt | 8 +------- server/src/test/kotlin/RagIntegrationTests.kt | 3 --- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/llm-modules/local-embedding/src/main/kotlin/cc/unitmesh/cf/LocalEmbedding.kt b/llm-modules/local-embedding/src/main/kotlin/cc/unitmesh/cf/LocalEmbedding.kt index b94d3fd..9dc3170 100644 --- a/llm-modules/local-embedding/src/main/kotlin/cc/unitmesh/cf/LocalEmbedding.kt +++ b/llm-modules/local-embedding/src/main/kotlin/cc/unitmesh/cf/LocalEmbedding.kt @@ -67,14 +67,8 @@ open class LocalEmbedding( companion object { - /** - * Create a new instance of [LocalEmbedding] with default model. - * We use official model: [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) - * We can use [optimum](https://github.com/huggingface/optimum) to transform the model to onnx. - */ - fun create(): LocalEmbedding { - val classLoader = Thread.currentThread().getContextClassLoader() + fun create(classLoader: ClassLoader = Thread.currentThread().getContextClassLoader()): LocalEmbedding { val tokenizer = loadTokenizer(classLoader)!! val ortEnv = OrtEnvironment.getEnvironment() diff --git a/server/src/test/kotlin/RagIntegrationTests.kt b/server/src/test/kotlin/RagIntegrationTests.kt index fbd0a4c..1773e4d 100644 --- a/server/src/test/kotlin/RagIntegrationTests.kt +++ b/server/src/test/kotlin/RagIntegrationTests.kt @@ -1,4 +1,3 @@ -import cc.unitmesh.cf.LocalEmbedding import cc.unitmesh.cf.infrastructure.llms.embedding.SentenceTransformersEmbedding import cc.unitmesh.nlp.embedding.Embedding import cc.unitmesh.rag.document.Document @@ -13,8 +12,6 @@ import io.kotest.matchers.shouldBe import org.junit.jupiter.api.Test class RagIntegrationTests { - val semantic = LocalEmbedding.create() - private val embeddingProvider = SentenceTransformersEmbedding() @Test