From 7b8728251970aec0acceb056d4049cc88df0f0d6 Mon Sep 17 00:00:00 2001 From: vga91 Date: Fri, 6 Sep 2024 17:46:06 +0200 Subject: [PATCH 1/2] Fixes #4182: The huggingface examples return strange results --- docs/asciidoc/modules/ROOT/pages/ml/openai.adoc | 6 ++++-- extended/src/main/java/apoc/ml/OpenAI.java | 3 +++ extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java | 14 +++++++------- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index b7bbf172b1..b9ae333bd7 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -161,10 +161,12 @@ For the https://huggingface.co/[HuggingFace API], we have to define the config ` For example: [source,cypher] ---- -CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $huggingFaceApiKey, -{endpoint: 'https://api-inference.huggingface.co/models/gpt2', apiType: 'HUGGINGFACE', model: 'gpt2', path: ''}) +CALL apoc.ml.openai.completion('The sky has a [MASK] color', $huggingFaceApiKey, +{endpoint: 'https://api-inference.huggingface.co/models/google-bert/bert-base-uncased', apiType: 'HUGGINGFACE'}) ---- +With gpt2 or other text completion models the answers are not valid. + Or also, by using the https://docs.cohere.com/docs[Cohere API], where we have to define `path: '''` not to add the `/completions` suffix to the URL: [source,cypher] ---- diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index 7ab0e61549..ed9389a1cd 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -25,6 +25,7 @@ import static apoc.ExtendedApocConfig.APOC_ML_OPENAI_TYPE; import static apoc.ExtendedApocConfig.APOC_OPENAI_KEY; import static apoc.ml.MLUtil.*; +import static apoc.ml.RestAPIConfig.METHOD_KEY; @Extended @@ -103,6 +104,8 @@ private static void handleAPIProvider(OpenAIRequestHandler.Type type, } case HUGGINGFACE -> { configForPayload.putIfAbsent("inputs", inputs); + configuration.putIfAbsent(PATH_CONF_KEY, ""); + headers.putIfAbsent(METHOD_KEY, "POST"); configuration.putIfAbsent(JSON_PATH_CONF_KEY, "$[0]"); } case ANTHROPIC -> { diff --git a/extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java b/extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java index ed7eba6a36..dd822385ec 100644 --- a/extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java +++ b/extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java @@ -1,6 +1,7 @@ package apoc.ml; import apoc.util.TestUtil; +import apoc.util.Util; import org.junit.Assume; import org.junit.Before; import org.junit.Rule; @@ -44,18 +45,17 @@ public void setUp() throws Exception { public void completionWithHuggingFace() { String huggingFaceApiKey = System.getenv("HF_API_TOKEN"); Assume.assumeNotNull("No HF_API_TOKEN environment configured", huggingFaceApiKey); - - String modelId = "gpt2"; + + String modelId = "google-bert/bert-base-uncased"; Map conf = Map.of(ENDPOINT_CONF_KEY, "https://api-inference.huggingface.co/models/" + modelId, - API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.HUGGINGFACE.name(), - PATH_CONF_KEY, "", - MODEL_CONF_KEY, modelId + API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.HUGGINGFACE.name() ); - testCall(db, COMPLETION_QUERY, + + testCall(db, "CALL apoc.ml.openai.completion('The sky has a [MASK] color', $apiKey, $conf)", Map.of("conf", conf, "apiKey", huggingFaceApiKey), (row) -> { var result = (Map) row.get("value"); - String generatedText = (String) result.get("generated_text"); + String generatedText = (String) result.get("sequence"); assertTrue(generatedText.toLowerCase().contains("blue"), "Actual generatedText is " + generatedText); }); From 710e93de87c1ab9b927921ba769a3ba0467b0f97 Mon Sep 17 00:00:00 2001 From: vga91 Date: Fri, 4 Oct 2024 12:15:04 +0200 Subject: [PATCH 2/2] changes review and removed unused imports --- docs/asciidoc/modules/ROOT/pages/ml/openai.adoc | 2 +- extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc index b9ae333bd7..d72688d1be 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -161,7 +161,7 @@ For the https://huggingface.co/[HuggingFace API], we have to define the config ` For example: [source,cypher] ---- -CALL apoc.ml.openai.completion('The sky has a [MASK] color', $huggingFaceApiKey, +CALL apoc.ml.openai.completion('[MASK] is the color of the sky', $huggingFaceApiKey, {endpoint: 'https://api-inference.huggingface.co/models/google-bert/bert-base-uncased', apiType: 'HUGGINGFACE'}) ---- diff --git a/extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java b/extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java index dd822385ec..3df7c69136 100644 --- a/extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java +++ b/extended/src/test/java/apoc/ml/OpenAIOpenLMIT.java @@ -1,7 +1,6 @@ package apoc.ml; import apoc.util.TestUtil; -import apoc.util.Util; import org.junit.Assume; import org.junit.Before; import org.junit.Rule; @@ -51,7 +50,7 @@ public void completionWithHuggingFace() { API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.HUGGINGFACE.name() ); - testCall(db, "CALL apoc.ml.openai.completion('The sky has a [MASK] color', $apiKey, $conf)", + testCall(db, "CALL apoc.ml.openai.completion('[MASK] is the color of the sky', $apiKey, $conf)", Map.of("conf", conf, "apiKey", huggingFaceApiKey), (row) -> { var result = (Map) row.get("value");