diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 8d40f49ee6..3c361a27d7 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -100,9 +100,10 @@ trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion } trait HasOpenAITextParams extends HasOpenAISharedParams { + private var cachedParams: Map[String, Any] = Map.empty val maxTokens: ServiceParam[Int] = new ServiceParam[Int]( - this, "maxTokens", + this, "max_tokens", "The maximum number of tokens to generate. Has minimum of 0.", isRequired = false) @@ -143,7 +144,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { def setStopCol(v: String): this.type = setVectorParam(stop, v) val topP: ServiceParam[Double] = new ServiceParam[Double]( - this, "topP", + this, "top_p", "An alternative to sampling with temperature, called nucleus sampling, where the model considers the" + " results of the tokens with top_p probability mass." + " So 0.1 means only the tokens comprising the top 10 percent probability mass are considered." + @@ -173,7 +174,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { def setNCol(v: String): this.type = setVectorParam(n, v) val logProbs: ServiceParam[Int] = new ServiceParam[Int]( - this, "logProbs", + this, "logprobs", "Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens." + " So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens." + " If `logprobs` is 0, only the chosen tokens will have logprobs returned." + @@ -202,7 +203,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { def setEchoCol(v: String): this.type = setVectorParam(echo, v) val cacheLevel: ServiceParam[Int] = new ServiceParam[Int]( - this, "cacheLevel", + this, "cache_level", "can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache", isRequired = false) @@ -215,7 +216,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { def setCacheLevelCol(v: String): this.type = setVectorParam(cacheLevel, v) val presencePenalty: ServiceParam[Double] = new ServiceParam[Double]( - this, "presencePenalty", + this, "presence_penalty", "How much to penalize new tokens based on their existing frequency in the text so far." + " Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2.", isRequired = false) @@ -229,7 +230,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { def setPresencePenaltyCol(v: String): this.type = setVectorParam(presencePenalty, v) val frequencyPenalty: ServiceParam[Double] = new ServiceParam[Double]( - this, "frequencyPenalty", + this, "frequency_penalty", "How much to penalize new tokens based on whether they appear in the text so far." + " Increases the likelihood of the model to talk about new topics.", isRequired = false) @@ -243,7 +244,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { def setFrequencyPenaltyCol(v: String): this.type = setVectorParam(frequencyPenalty, v) val bestOf: ServiceParam[Int] = new ServiceParam[Int]( - this, "bestOf", + this, "best_of", "How many generations to create server side, and display only the best." + " Will not stream intermediate progress if best_of > 1. Has maximum value of 128.", isRequired = false) @@ -256,24 +257,53 @@ trait HasOpenAITextParams extends HasOpenAISharedParams { def setBestOfCol(v: String): this.type = setVectorParam(bestOf, v) + // list of shared text parameters. In method getOptionalParams, we will iterate over these parameters + // to compute the optional parameters. Since this list never changes, we can create it once and reuse it. + private val sharedTextParams = Seq( + maxTokens, + temperature, + topP, + user, + n, + echo, + stop, + cacheLevel, + presencePenalty, + frequencyPenalty, + bestOf, + logProbs + ) + + // This flag is used to ensure that we compute the cacheability of parameters only once. + private var areParamCacheable = true + private[ml] def getOptionalParams(r: Row): Map[String, Any] = { - Seq( - maxTokens, - temperature, - topP, - user, - n, - echo, - stop, - cacheLevel, - presencePenalty, - frequencyPenalty, - bestOf - ).flatMap(param => - getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v)) - ).++(Seq( - getValueOpt(r, logProbs).map(v => ("logprobs", v)) - ).flatten).toMap + // Return cached parameters if they are already computed + if (cachedParams.nonEmpty) { + cachedParams + } else { + // The parameters are not cacheable if any of the shared text parameters are vector parameters + // i.e. if they are not constant for all rows. Following code checks if the parameters are cacheable + // or not. Since the cachability of parameter does not change over time, we can compute this only once. + // If the parameters are not cacheable, the instance variable areParamCacheable will be set to false + // and for subsequent calls, we will not even need to iterate parameters to check if they are cacheable. + // If parameters are cacheable, we will cache the computed parameters and return them. + areParamCacheable = areParamCacheable && !sharedTextParams.exists(param => get(param).exists(_.isRight)) + + // Compute the optional parameters + val optionalParams = sharedTextParams.flatMap { param => + getValueOpt(r, param).map { value => + param.name -> value + } + }.toMap + + // Cache the computed parameters if caching is enabled + if (areParamCacheable) { + cachedParams = optionalParams + } + + optionalParams + } } }