Skip to content

Commit

Permalink
Simplifying code and cleaning unit tests. In unit tests, prompt objec…
Browse files Browse the repository at this point in the history
…ts can be reused, in this iteration, I am removing this possibility. Now each test create a new OpenAIPrompt
  • Loading branch information
FMasudMsft committed Nov 21, 2024
1 parent 1d848ba commit 8a5f5d9
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
object OpenAIPrompt extends ComplexParamsReadable[OpenAIPrompt]

class OpenAIPrompt(override val uid: String) extends Transformer
with HasOpenAITextParams with HasMessagesInput
with HasOpenAITextParamsExtended with HasMessagesInput
with HasErrorCol with HasOutputCol
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
Expand Down Expand Up @@ -109,43 +109,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer

def setSystemPrompt(value: String): this.type = set(systemPrompt, value)

val responseFormat = new Param[String](
this, "responseFormat", "The response format from the OpenAI API.")

def getResponseFormat: String = $(responseFormat)

def setResponseFormat(value: String): this.type = {
if (value.isEmpty) {
this
} else {
val normalizedValue = value.toLowerCase match {
case "json" => "json_object"
case other => other
}

// Validate the normalized value using the OpenAIResponseFormat enum
if (!OpenAIResponseFormat.values
.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name)
.contains(normalizedValue)) {
throw new IllegalArgumentException("Response format must be valid for OpenAI API. " +
"Currently supported formats are " + OpenAIResponseFormat.values
.map(_.asInstanceOf[OpenAIResponseFormat.ResponseFormat].name)
.mkString(", "))
}

set(responseFormat, normalizedValue)
}
}

def setResponseFormat(value: OpenAIResponseFormat.ResponseFormat): this.type = {
this.setResponseFormat(value.name)
}

private val defaultSystemPrompt = "You are an AI chatbot who wants to answer user's questions and complete tasks. " +
"Follow their instructions carefully and be brief if they don't say otherwise."

setDefault(
responseFormat -> "",
postProcessing -> "",
postProcessingOptions -> Map.empty,
outputCol -> (this.uid + "_output"),
Expand All @@ -162,7 +129,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer

private val localParamNames = Seq(
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages",
"systemPrompt", "responseFormat")
"systemPrompt")

private def addRAIErrors(df: DataFrame, errorCol: String, outputCol: String): DataFrame = {
val openAIResultFromRow = ChatCompletionResponse.makeFromRowConverter
Expand Down Expand Up @@ -258,6 +225,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}
// apply all parameters
extractParamMap().toSeq
.filter(p => completion.hasParam(p.param.name))
.filter(p => !localParamNames.contains(p.param.name))
.foreach(p => completion.set(completion.getParam(p.param.name), p.value))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ trait OpenAIAPIKey {
lazy val deploymentName: String = "gpt-35-turbo"
lazy val modelName: String = "gpt-35-turbo"
lazy val deploymentNameGpt4: String = "gpt-4"
lazy val deploymentNameDavinci3: String = "text-davinci-003"
lazy val deploymentNameGpt4o: String = "gpt-4o"
lazy val modelNameGpt4: String = "gpt-4"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.col
import org.scalactic.Equality
import org.scalatest.matchers.must.Matchers.{be, include}
import org.scalatest.matchers.must.Matchers.be
import org.scalatest.matchers.should.Matchers.{an, convertToAnyShouldWrapper}

class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIKey with Flaky {
Expand All @@ -23,12 +23,6 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
super.beforeAll()
}

lazy val prompt: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentName)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

lazy val df: DataFrame = Seq(
("apple", "fruits"),
Expand All @@ -38,8 +32,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
).toDF("text", "category")

test("RAI Usage") {
val prompt = createPromptInstance(deploymentNameGpt4)
val result = prompt
.setDeploymentName(deploymentNameGpt4)
.setPromptTemplate("Tell me about a graphically disgusting movie in detail")
.transform(df)
.select(prompt.getErrorCol)
Expand All @@ -48,6 +42,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage") {
val prompt: OpenAIPrompt = createPromptInstance(deploymentName)
val nonNullCount = prompt
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
Expand All @@ -60,7 +55,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage with only post processing options") {
val nonNullCount = prompt
val nonNullCount = createPromptInstance(deploymentName)
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessingOptions(Map("delimiter" -> ","))
.transform(df)
Expand All @@ -72,6 +67,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage JSON") {
val prompt = createPromptInstance(deploymentName)
prompt.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
Expand All @@ -87,6 +83,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage JSON using text response format") {
val prompt = createPromptInstance(deploymentName)
prompt.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
Expand All @@ -103,6 +100,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage JSON using only post processing oiptions") {
val prompt = createPromptInstance(deploymentName)
prompt.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
Expand All @@ -116,16 +114,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}



lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

test("Basic Usage - Gpt 4") {
val promptGpt4 = createPromptInstance(deploymentNameGpt4)
val nonNullCount = promptGpt4
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
Expand All @@ -138,6 +128,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage JSON - Gpt 4") {
val promptGpt4 = createPromptInstance(deploymentNameGpt4)
promptGpt4.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
Expand All @@ -153,6 +144,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage JSON - Gpt 4 using responseFormat TEXT") {
val promptGpt4 = createPromptInstance(deploymentNameGpt4)
promptGpt4.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
Expand All @@ -168,14 +160,8 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

lazy val promptGpt4o: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4o)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

test("Basic Usage JSON - Gpt 4o using responseFormat JSON") {
val promptGpt4o = createPromptInstance(deploymentNameGpt4o)
promptGpt4o.setPromptTemplate(
"""Split a word into prefix and postfix
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
Expand All @@ -191,6 +177,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage - Gpt 4o with response format json") {
val promptGpt4o = createPromptInstance(deploymentNameGpt4o)
val nonNullCount = promptGpt4o
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setResponseFormat(OpenAIResponseFormat.JSON)
Expand All @@ -203,6 +190,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Basic Usage - Gpt 4o with response format text") {
val promptGpt4o = createPromptInstance(deploymentNameGpt4o)
val nonNullCount = promptGpt4o
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setResponseFormat(OpenAIResponseFormat.TEXT)
Expand All @@ -215,6 +203,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

test("Setting and Keeping Messages Col - Gpt 4") {
val promptGpt4 = createPromptInstance(deploymentNameGpt4)
promptGpt4.setMessagesCol("messages")
.setDropPrompt(false)
.setPromptTemplate(
Expand All @@ -230,6 +219,31 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.foreach(r => assert(r.get(0) != null))
}

test("Basic Usage - Davinci 3 with no response format") {
val promptDavinci3 = createPromptInstance(deploymentNameDavinci3)
val rowCount = promptDavinci3
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.transform(df)
.select("outParsed")
.collect()
.length
assert(rowCount == 4)
}

test("Basic Usage - Davinci 3 with response format json") {
val promptDavinci3 = createPromptInstance(deploymentNameDavinci3)
intercept[IllegalArgumentException] {
promptDavinci3
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setResponseFormat(OpenAIResponseFormat.JSON)
.transform(df)
.select("outParsed")
.collect()
.length
}
}


ignore("Custom EndPoint") {
lazy val accessToken: String = sys.env.getOrElse("CUSTOM_ACCESS_TOKEN", "")
lazy val customRootUrlValue: String = sys.env.getOrElse("CUSTOM_ROOT_URL", "")
Expand Down Expand Up @@ -257,60 +271,25 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
.count(r => Option(r.getSeq[String](0)).isDefined)
}

test("getResponseFormat should return the default response format") {
val prompt = new OpenAIPrompt()
prompt.getResponseFormat shouldEqual ""
}

test("setResponseFormat should set the response format correctly with String") {
val prompt = new OpenAIPrompt()
prompt.setResponseFormat("json")
prompt.getResponseFormat shouldEqual "json_object"
prompt.getResponseFormat shouldEqual Map("type" -> "json_object")

prompt.setResponseFormat("json_object")
prompt.getResponseFormat shouldEqual "json_object"
prompt.getResponseFormat shouldEqual Map("type" -> "json_object")

prompt.setResponseFormat("text")
prompt.getResponseFormat shouldEqual "text"
prompt.getResponseFormat shouldEqual Map("type" -> "text")
}


test("setResponseFormat should throw an exception for invalid response format") {
val prompt = new OpenAIPrompt()
an[IllegalArgumentException] should be thrownBy {
prompt.setResponseFormat("invalid_format")
}
}

test("setResponseFormat should set the response format correctly with ResponseFormat") {
val prompt = new OpenAIPrompt()
prompt.setResponseFormat(OpenAIResponseFormat.JSON)
prompt.getResponseFormat shouldEqual "json_object"

prompt.setResponseFormat(OpenAIResponseFormat.TEXT)
prompt.getResponseFormat shouldEqual "text"
}


test("setResponseFormat should set the response format correctly for valid values") {
val prompt = new OpenAIPrompt()
prompt.setResponseFormat("text")
prompt.getResponseFormat should be ("text")

prompt.setResponseFormat("json")
prompt.getResponseFormat should be ("json_object")

prompt.setResponseFormat("json_object")
prompt.getResponseFormat should be ("json_object")

prompt.setResponseFormat("jSoN")
prompt.getResponseFormat should be ("json_object")

prompt.setResponseFormat("TEXT")
prompt.getResponseFormat should be ("text")
}


test("setPostProcessingOptions should set postProcessing to 'csv' for delimiter option") {
val prompt = new OpenAIPrompt()
prompt.setPostProcessingOptions(Map("delimiter" -> ","))
Expand Down Expand Up @@ -356,6 +335,7 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK
}

override def testObjects(): Seq[TestObject[OpenAIPrompt]] = {
val prompt = createPromptInstance(deploymentName)
val testPrompt = prompt
.setPromptTemplate("{text} rhymes with ")

Expand All @@ -364,4 +344,12 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK

override def reader: MLReadable[_] = OpenAIPrompt

private def createPromptInstance(deploymentName: String): OpenAIPrompt = {
new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentName)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)
}
}

0 comments on commit 8a5f5d9

Please sign in to comment.