Skip to content

Commit

Permalink
feat: Global Params (#2318)
Browse files Browse the repository at this point in the history
* GlobalParamObject implementation with higher-order getters. DeploymentName added to GlobalParam set

* Enable ServiceParams as part of GlobalParam

* Only one GlobalKey type and only one GlobalParams dictionary. Added asserts to type check Params vs Service Params. Added more tests. Tests pass!

* Use initialize objects with GlobalParams and get rid of boxedClass logic

* Refactor GlobalParams. Move to core, and set up OpenAIDefaults as user interface.

* fix comments - Works!

* Remove unused imports

* Add headers and fix style bugs

---------
  • Loading branch information
sss04 authored Nov 22, 2024
1 parent 08aab6a commit 79d5b58
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient
import com.microsoft.azure.synapse.ml.io.http._
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.param.{GlobalParams, HasGlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda}
import org.apache.http.NameValuePair
import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase}
Expand Down Expand Up @@ -493,7 +493,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform
with HasURL with ComplexParamsWritable
with HasSubscriptionKey with HasErrorCol
with HasAADToken with HasCustomCogServiceDomain
with SynapseMLLogging {
with SynapseMLLogging with HasGlobalParams {

setDefault(
outputCol -> (this.uid + "_output"),
Expand Down Expand Up @@ -547,6 +547,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform
}

override def transform(dataset: Dataset[_]): DataFrame = {
transferGlobalParamsToParamMap()
logTransform[DataFrame](getInternalTransformer(dataset.schema).transform(dataset), dataset.columns.length
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.services.openai
import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting}
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, ServiceParam}
import com.microsoft.azure.synapse.ml.services._
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.param.{Param, Params}
Expand Down Expand Up @@ -51,11 +51,15 @@ trait HasMessagesInput extends Params {
def setMessagesCol(v: String): this.type = set(messagesCol, v)
}

case object OpenAIDeploymentNameKey extends GlobalKey[Either[String, String]]

trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion {

val deploymentName = new ServiceParam[String](
this, "deploymentName", "The name of the deployment", isRequired = false)

GlobalParams.registerParam(deploymentName, OpenAIDeploymentNameKey)

def getDeploymentName: String = getScalarParam(deploymentName)

def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.param.GlobalParams

object OpenAIDefaults {
def setDeploymentName(v: String): Unit = {
GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, Left(v))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
import com.microsoft.azure.synapse.ml.core.spark.Functions
import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL}
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.StringStringMapParam
import com.microsoft.azure.synapse.ml.param.{HasGlobalParams, StringStringMapParam}
import com.microsoft.azure.synapse.ml.services._
import org.apache.http.entity.AbstractHttpEntity
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
Expand All @@ -28,7 +28,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
with HasCognitiveServiceInput
with ComplexParamsWritable with SynapseMLLogging {
with ComplexParamsWritable with SynapseMLLogging with HasGlobalParams {

logClass(FeatureNames.AiServices.OpenAI)

Expand Down Expand Up @@ -124,7 +124,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer

override def transform(dataset: Dataset[_]): DataFrame = {
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._

transferGlobalParamsToParamMap()
logTransform[DataFrame]({
val df = dataset.toDF

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import org.apache.spark.sql.{DataFrame, Row}

class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey {

import spark.implicits._

OpenAIDefaults.setDeploymentName(deploymentName)

def promptCompletion: OpenAICompletion = new OpenAICompletion()
.setCustomServiceName(openAIServiceName)
.setMaxTokens(200)
.setOutputCol("out")
.setSubscriptionKey(openAIAPIKey)
.setPromptCol("prompt")

lazy val promptDF: DataFrame = Seq(
"Once upon a time",
"Best programming language award goes to",
"SynapseML is "
).toDF("prompt")

test("Completion w Globals") {
val fromRow = CompletionResponse.makeFromRowConverter
promptCompletion.transform(promptDF).collect().foreach(r =>
fromRow(r.getAs[Row]("out")).choices.foreach(c =>
assert(c.text.length > 10)))
}

lazy val prompt: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

lazy val df: DataFrame = Seq(
("apple", "fruits"),
("mercedes", "cars"),
("cake", "dishes"),
(null, "none") //scalastyle:ignore null
).toDF("text", "category")

test("OpenAIPrompt w Globals") {
val nonNullCount = prompt
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
.transform(df)
.select("outParsed")
.collect()
.count(r => Option(r.getSeq[String](0)).isDefined)

assert(nonNullCount == 3)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.param

import org.apache.spark.ml.param.{Param, Params}

import scala.collection.mutable

trait GlobalKey[T]

object GlobalParams {
private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty
private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty


def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = {
GlobalParams(key) = value
}

private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = {
GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T])
}

def getParam[T](p: Param[T]): Option[T] = {
ParamToKeyMap.get(p).flatMap { key =>
key match {
case k: GlobalKey[T] =>
getGlobalParam(k)
case _ => None
}
}
}

def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = {
ParamToKeyMap(p) = key
}
}

trait HasGlobalParams extends Params{

private[ml] def transferGlobalParamsToParamMap(): Unit = {
// check for empty params. Fill em with GlobalParams
this.params
.filter(p => !this.isSet(p) && !this.hasDefault(p))
.foreach { p =>
GlobalParams.getParam(p).foreach { v =>
set(p.asInstanceOf[Param[Any]], v)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.param

import com.microsoft.azure.synapse.ml.core.test.base.Flaky
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.Identifiable

case object TestParamKey extends GlobalKey[Double]

class TestGlobalParams extends HasGlobalParams {
override val uid: String = Identifiable.randomUID("TestGlobalParams")

val testParam: Param[Double] = new Param[Double](
this, "TestParam", "Test Param for testing")

println(testParam.parent)
println(hasParam(testParam.name))

GlobalParams.registerParam(testParam, TestParamKey)

def getTestParam: Double = $(testParam)

def setTestParam(v: Double): this.type = set(testParam, v)

override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

}

class GlobalParamSuite extends Flaky {

val testGlobalParams = new TestGlobalParams()

test("Basic Usage") {
GlobalParams.setGlobalParam(TestParamKey, 12.5)
testGlobalParams.transferGlobalParamsToParamMap()
assert(testGlobalParams.getTestParam == 12.5)
}

test("Test Setting Directly Value") {
testGlobalParams.setTestParam(18.7334)
GlobalParams.setGlobalParam(TestParamKey, 19.853)
testGlobalParams.transferGlobalParamsToParamMap()
assert(testGlobalParams.getTestParam == 18.7334)
}
}

0 comments on commit 79d5b58

Please sign in to comment.