Skip to content

Commit

Permalink
Optimizing the method getOptionalParams
Browse files Browse the repository at this point in the history
  • Loading branch information
FMasudMsft committed Nov 15, 2024
1 parent a360886 commit 80315bc
Showing 1 changed file with 54 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ trait HasOpenAIEmbeddingParams extends HasOpenAISharedParams with HasAPIVersion
}

trait HasOpenAITextParams extends HasOpenAISharedParams {
private var cachedParams: Map[String, Any] = Map.empty

val maxTokens: ServiceParam[Int] = new ServiceParam[Int](
this, "maxTokens",
this, "max_tokens",
"The maximum number of tokens to generate. Has minimum of 0.",
isRequired = false)

Expand Down Expand Up @@ -143,7 +144,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
def setStopCol(v: String): this.type = setVectorParam(stop, v)

val topP: ServiceParam[Double] = new ServiceParam[Double](
this, "topP",
this, "top_p",
"An alternative to sampling with temperature, called nucleus sampling, where the model considers the" +
" results of the tokens with top_p probability mass." +
" So 0.1 means only the tokens comprising the top 10 percent probability mass are considered." +
Expand Down Expand Up @@ -173,7 +174,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
def setNCol(v: String): this.type = setVectorParam(n, v)

val logProbs: ServiceParam[Int] = new ServiceParam[Int](
this, "logProbs",
this, "logprobs",
"Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens." +
" So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens." +
" If `logprobs` is 0, only the chosen tokens will have logprobs returned." +
Expand Down Expand Up @@ -202,7 +203,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
def setEchoCol(v: String): this.type = setVectorParam(echo, v)

val cacheLevel: ServiceParam[Int] = new ServiceParam[Int](
this, "cacheLevel",
this, "cache_level",
"can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache",
isRequired = false)

Expand All @@ -215,7 +216,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
def setCacheLevelCol(v: String): this.type = setVectorParam(cacheLevel, v)

val presencePenalty: ServiceParam[Double] = new ServiceParam[Double](
this, "presencePenalty",
this, "presence_penalty",
"How much to penalize new tokens based on their existing frequency in the text so far." +
" Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2.",
isRequired = false)
Expand All @@ -229,7 +230,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
def setPresencePenaltyCol(v: String): this.type = setVectorParam(presencePenalty, v)

val frequencyPenalty: ServiceParam[Double] = new ServiceParam[Double](
this, "frequencyPenalty",
this, "frequency_penalty",
"How much to penalize new tokens based on whether they appear in the text so far." +
" Increases the likelihood of the model to talk about new topics.",
isRequired = false)
Expand All @@ -243,7 +244,7 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
def setFrequencyPenaltyCol(v: String): this.type = setVectorParam(frequencyPenalty, v)

val bestOf: ServiceParam[Int] = new ServiceParam[Int](
this, "bestOf",
this, "best_of",
"How many generations to create server side, and display only the best." +
" Will not stream intermediate progress if best_of > 1. Has maximum value of 128.",
isRequired = false)
Expand All @@ -256,24 +257,53 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {

def setBestOfCol(v: String): this.type = setVectorParam(bestOf, v)

// list of shared text parameters. In method getOptionalParams, we will iterate over these parameters
// to compute the optional parameters. Since this list never changes, we can create it once and reuse it.
private val sharedTextParams = Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf,
logProbs
)

// This flag is used to ensure that we compute the cacheability of parameters only once.
private var areParamCacheable = true

private[ml] def getOptionalParams(r: Row): Map[String, Any] = {
Seq(
maxTokens,
temperature,
topP,
user,
n,
echo,
stop,
cacheLevel,
presencePenalty,
frequencyPenalty,
bestOf
).flatMap(param =>
getValueOpt(r, param).map(v => (GenerationUtils.camelToSnake(param.name), v))
).++(Seq(
getValueOpt(r, logProbs).map(v => ("logprobs", v))
).flatten).toMap
// Return cached parameters if they are already computed
if (cachedParams.nonEmpty) {
cachedParams
} else {
// The parameters are not cacheable if any of the shared text parameters are vector parameters
// i.e. if they are not constant for all rows. Following code checks if the parameters are cacheable
// or not. Since the cachability of parameter does not change over time, we can compute this only once.
// If the parameters are not cacheable, the instance variable areParamCacheable will be set to false
// and for subsequent calls, we will not even need to iterate parameters to check if they are cacheable.
// If parameters are cacheable, we will cache the computed parameters and return them.
areParamCacheable = areParamCacheable && !sharedTextParams.exists(param => get(param).exists(_.isRight))

// Compute the optional parameters
val optionalParams = sharedTextParams.flatMap { param =>
getValueOpt(r, param).map { value =>
param.name -> value
}
}.toMap

// Cache the computed parameters if caching is enabled
if (areParamCacheable) {
cachedParams = optionalParams
}

optionalParams
}
}
}

Expand Down

0 comments on commit 80315bc

Please sign in to comment.