Skip to content

Commit

Permalink
making mappedMessages immutable again
Browse files Browse the repository at this point in the history
  • Loading branch information
FMasudMsft committed Nov 21, 2024
1 parent 70eead5 commit 1d848ba
Showing 1 changed file with 13 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,23 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(
override def responseDataType: DataType = ChatCompletionResponse.schema

private[this] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = {
var mappedMessages: Seq[Map[String, String]] = messages.map { m =>
Seq("role", "content", "name").map(n =>
n -> Option(m.getAs[String](n))
).toMap.filter(_._2.isDefined).mapValues(_.get)
// Convert each message row to a map of non-null values
val mappedMessages: Seq[Map[String, String]] = messages.map { messageRow =>
Seq("role", "content", "name").map { fieldName =>
fieldName -> Option(messageRow.getAs[String](fieldName))
}.toMap.filter(_._2.isDefined).mapValues(_.get)
}

// if the optionalParams contains "response_format" key, and it's value contains "json_object",
// then we need to add a message to instruct openAI to return the response in JSON format
if (optionalParams.get("response_format")
.exists(_.asInstanceOf[Map[String, String]]("type")
.contentEquals("json_object"))) {
mappedMessages :+= Map("role" -> "system", "content" -> OpenAIResponseFormat.JSON.prompt)
// Check if the response format is JSON and add a system message if needed
val updatedMessages = if (optionalParams.get("response_format")
.exists(_.asInstanceOf[Map[String, String]]("type") == "json_object")) {
mappedMessages :+ Map("role" -> "system", "content" -> OpenAIResponseFormat.JSON.prompt)
} else {
mappedMessages
}

val fullPayload = optionalParams.updated("messages", mappedMessages)
// Update the optional parameters with the messages
val fullPayload = optionalParams.updated("messages", updatedMessages)
new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON)
}

Expand Down

0 comments on commit 1d848ba

Please sign in to comment.