From 8a5f5d92d0008dfb516e03b8ff304e30f2744d5a Mon Sep 17 00:00:00 2001 From: Farrukh Masud Date: Wed, 20 Nov 2024 20:23:32 -0800 Subject: [PATCH] Simplifying code and cleaning unit tests. In unit tests, prompt objects can be reused, in this iteration, I am removing this possibility. Now each test create a new OpenAIPrompt --- .../ml/services/openai/OpenAIPrompt.scala | 38 +----- .../openai/OpenAICompletionSuite.scala | 1 + .../services/openai/OpenAIPromptSuite.scala | 114 ++++++++---------- 3 files changed, 55 insertions(+), 98 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index f33a54d961..0c46a209ab 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt] class OpenAIPrompt(override val uid: String) extends Transformer - with HasOpenAITextParams with HasMessagesInput + with HasOpenAITextParamsExtended with HasMessagesInput with HasErrorCol with HasOutputCol with HasURL with HasCustomCogServiceDomain with ConcurrencyParams with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader @@ -109,43 +109,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer def setSystemPrompt(value: String): this.type = set(systemPrompt, value) - val responseFormat = new Param[String]( - this, "responseFormat", "The response format from the OpenAI API.") - - def getResponseFormat: String = $(responseFormat) - - def setResponseFormat(value: String): this.type = { - if (value.isEmpty) { - this - } else { - val normalizedValue = value.toLowerCase match { - case "json" => "json_object" - case other => other - } - - // Validate the normalized value using the OpenAIResponseFormat enum - if (!OpenAIResponseFormat.values - .map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name) - .contains(normalizedValue)) { - throw new IllegalArgumentException("Response format must be valid for OpenAI API. " + - "Currently supported formats are " + OpenAIResponseFormat.values - .map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name) - .mkString(", ")) - } - - set(responseFormat, normalizedValue) - } - } - - def setResponseFormat(value: OpenAIResponseFormat.ResponseFormat): this.type = { - this.setResponseFormat(value.name) - } - private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " + "Follow their instructions carefully and be brief if they don't say otherwise." setDefault( - responseFormat -> "", postProcessing -> "", postProcessingOptions -> Map.empty, outputCol -> (this.uid + "_output"), @@ -162,7 +129,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer private val localParamNames = Seq( "promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages", - "systemPrompt", "responseFormat") + "systemPrompt") private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = { val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter @@ -258,6 +225,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer } // apply all parameters extractParamMap().toSeq + .filter(p => completion.hasParam(p.param.name)) .filter(p => !localParamNames.contains(p.param.name)) .foreach(p => completion.set(completion.getParam(p.param.name), p.value)) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala index 56bcb506a9..209778d805 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletionSuite.scala @@ -17,6 +17,7 @@ trait OpenAIAPIKey { lazy val deploymentName: String = "gpt-35-turbo" lazy val modelName: String = "gpt-35-turbo" lazy val deploymentNameGpt4: String = "gpt-4" + lazy val deploymentNameDavinci3: String = "text-davinci-003" lazy val deploymentNameGpt4o: String = "gpt-4o" lazy val modelNameGpt4: String = "gpt-4" } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 6f03380f0d..86f64314a1 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -10,7 +10,7 @@ import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.col import org.scalactic.Equality -import org.scalatest.matchers.must.Matchers.{be, include} +import org.scalatest.matchers.must.Matchers.be import org.scalatest.matchers.should.Matchers.{an, convertToAnyShouldWrapper} class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKey with Flaky { @@ -23,12 +23,6 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK super.beforeAll() } - lazy val prompt: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName(deploymentName) - .setCustomServiceName(openAIServiceName) - .setOutputCol("outParsed") - .setTemperature(0) lazy val df: DataFrame = Seq( ("apple", "fruits"), @@ -38,8 +32,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK ).toDF("text", "category") test("RAI Usage") { + val prompt = createPromptInstance(deploymentNameGpt4) val result = prompt - .setDeploymentName(deploymentNameGpt4) .setPromptTemplate("Tell me about a graphically disgusting movie in detail") .transform(df) .select(prompt.getErrorCol) @@ -48,6 +42,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage") { + val prompt: OpenAIPrompt = createPromptInstance(deploymentName) val nonNullCount = prompt .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -60,7 +55,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage with only post processing options") { - val nonNullCount = prompt + val nonNullCount = createPromptInstance(deploymentName) .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessingOptions(Map("delimiter" -> ",")) .transform(df) @@ -72,6 +67,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage JSON") { + val prompt = createPromptInstance(deploymentName) prompt.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -87,6 +83,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage JSON using text response format") { + val prompt = createPromptInstance(deploymentName) prompt.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -103,6 +100,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage JSON using only post processing oiptions") { + val prompt = createPromptInstance(deploymentName) prompt.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -116,16 +114,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) } - - - lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName(deploymentNameGpt4) - .setCustomServiceName(openAIServiceName) - .setOutputCol("outParsed") - .setTemperature(0) - test("Basic Usage - Gpt 4") { + val promptGpt4 = createPromptInstance(deploymentNameGpt4) val nonNullCount = promptGpt4 .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setPostProcessing("csv") @@ -138,6 +128,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage JSON - Gpt 4") { + val promptGpt4 = createPromptInstance(deploymentNameGpt4) promptGpt4.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -153,6 +144,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage JSON - Gpt 4 using responseFormat TEXT") { + val promptGpt4 = createPromptInstance(deploymentNameGpt4) promptGpt4.setPromptTemplate( """Split a word into prefix and postfix a respond in JSON |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -168,14 +160,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.getStruct(0).getString(0).nonEmpty)) } - lazy val promptGpt4o: OpenAIPrompt = new OpenAIPrompt() - .setSubscriptionKey(openAIAPIKey) - .setDeploymentName(deploymentNameGpt4o) - .setCustomServiceName(openAIServiceName) - .setOutputCol("outParsed") - .setTemperature(0) - test("Basic Usage JSON - Gpt 4o using responseFormat JSON") { + val promptGpt4o = createPromptInstance(deploymentNameGpt4o) promptGpt4o.setPromptTemplate( """Split a word into prefix and postfix |Cherry: {{"prefix": "Che", "suffix": "rry"}} @@ -191,6 +177,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage - Gpt 4o with response format json") { + val promptGpt4o = createPromptInstance(deploymentNameGpt4o) val nonNullCount = promptGpt4o .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setResponseFormat(OpenAIResponseFormat.JSON) @@ -203,6 +190,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Basic Usage - Gpt 4o with response format text") { + val promptGpt4o = createPromptInstance(deploymentNameGpt4o) val nonNullCount = promptGpt4o .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") .setResponseFormat(OpenAIResponseFormat.TEXT) @@ -215,6 +203,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } test("Setting and Keeping Messages Col - Gpt 4") { + val promptGpt4 = createPromptInstance(deploymentNameGpt4) promptGpt4.setMessagesCol("messages") .setDropPrompt(false) .setPromptTemplate( @@ -230,6 +219,31 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .foreach(r => assert(r.get(0) != null)) } + test("Basic Usage - Davinci 3 with no response format") { + val promptDavinci3 = createPromptInstance(deploymentNameDavinci3) + val rowCount = promptDavinci3 + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .transform(df) + .select("outParsed") + .collect() + .length + assert(rowCount == 4) + } + + test("Basic Usage - Davinci 3 with response format json") { + val promptDavinci3 = createPromptInstance(deploymentNameDavinci3) + intercept[IllegalArgumentException] { + promptDavinci3 + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setResponseFormat(OpenAIResponseFormat.JSON) + .transform(df) + .select("outParsed") + .collect() + .length + } + } + + ignore("Custom EndPoint") { lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "") lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "") @@ -257,24 +271,18 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .count(r => Option(r.getSeq[String](0)).isDefined) } - test("getResponseFormat should return the default response format") { - val prompt = new OpenAIPrompt() - prompt.getResponseFormat shouldEqual "" - } - test("setResponseFormat should set the response format correctly with String") { val prompt = new OpenAIPrompt() prompt.setResponseFormat("json") - prompt.getResponseFormat shouldEqual "json_object" + prompt.getResponseFormat shouldEqual Map("type" -> "json_object") prompt.setResponseFormat("json_object") - prompt.getResponseFormat shouldEqual "json_object" + prompt.getResponseFormat shouldEqual Map("type" -> "json_object") prompt.setResponseFormat("text") - prompt.getResponseFormat shouldEqual "text" + prompt.getResponseFormat shouldEqual Map("type" -> "text") } - test("setResponseFormat should throw an exception for invalid response format") { val prompt = new OpenAIPrompt() an[IllegalArgumentException] should be thrownBy { @@ -282,35 +290,6 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } } - test("setResponseFormat should set the response format correctly with ResponseFormat") { - val prompt = new OpenAIPrompt() - prompt.setResponseFormat(OpenAIResponseFormat.JSON) - prompt.getResponseFormat shouldEqual "json_object" - - prompt.setResponseFormat(OpenAIResponseFormat.TEXT) - prompt.getResponseFormat shouldEqual "text" - } - - - test("setResponseFormat should set the response format correctly for valid values") { - val prompt = new OpenAIPrompt() - prompt.setResponseFormat("text") - prompt.getResponseFormat should be ("text") - - prompt.setResponseFormat("json") - prompt.getResponseFormat should be ("json_object") - - prompt.setResponseFormat("json_object") - prompt.getResponseFormat should be ("json_object") - - prompt.setResponseFormat("jSoN") - prompt.getResponseFormat should be ("json_object") - - prompt.setResponseFormat("TEXT") - prompt.getResponseFormat should be ("text") - } - - test("setPostProcessingOptions should set postProcessing to 'csv' for delimiter option") { val prompt = new OpenAIPrompt() prompt.setPostProcessingOptions(Map("delimiter" -> ",")) @@ -356,6 +335,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK } override def testObjects(): Seq[TestObject[OpenAIPrompt]] = { + val prompt = createPromptInstance(deploymentName) val testPrompt = prompt .setPromptTemplate("{text} rhymes with ") @@ -364,4 +344,12 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK override def reader: MLReadable[_] = OpenAIPrompt + private def createPromptInstance(deploymentName: String): OpenAIPrompt = { + new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentName) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setTemperature(0) + } }