From 1d848ba6dbf4922bc095fda30d9648e004e9c0db Mon Sep 17 00:00:00 2001 From: Farrukh Masud Date: Wed, 20 Nov 2024 19:17:35 -0800 Subject: [PATCH] making mappedMessages immutable again --- .../openai/OpenAIChatCompletion.scala | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala index afece1c8e6..8ae2819cf6 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala @@ -132,21 +132,23 @@ class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase( override def responseDataType: DataType = ChatCompletionResponse.schema private[this] def getStringEntity(messages: Seq[Row], optionalParams: Map[String, Any]): StringEntity = { - var mappedMessages: Seq[Map[String, String]] = messages.map { m => - Seq("role", "content", "name").map(n => - n -> Option(m.getAs[String](n)) - ).toMap.filter(_._2.isDefined).mapValues(_.get) + // Convert each message row to a map of non-null values + val mappedMessages: Seq[Map[String, String]] = messages.map { messageRow => + Seq("role", "content", "name").map { fieldName => + fieldName -> Option(messageRow.getAs[String](fieldName)) + }.toMap.filter(_._2.isDefined).mapValues(_.get) } - // if the optionalParams contains "response_format" key, and it's value contains "json_object", - // then we need to add a message to instruct openAI to return the response in JSON format - if (optionalParams.get("response_format") - .exists(_.asInstanceOf[Map[String, String]]("type") - .contentEquals("json_object"))) { - mappedMessages :+= Map("role" -> "system", "content" -> OpenAIResponseFormat.JSON.prompt) + // Check if the response format is JSON and add a system message if needed + val updatedMessages = if (optionalParams.get("response_format") + .exists(_.asInstanceOf[Map[String, String]]("type") == "json_object")) { + mappedMessages :+ Map("role" -> "system", "content" -> OpenAIResponseFormat.JSON.prompt) + } else { + mappedMessages } - val fullPayload = optionalParams.updated("messages", mappedMessages) + // Update the optional parameters with the messages + val fullPayload = optionalParams.updated("messages", updatedMessages) new StringEntity(fullPayload.toJson.compactPrint, ContentType.APPLICATION_JSON) }