diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala index 2d123edf56..850b2915a7 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala @@ -10,7 +10,7 @@ import com.microsoft.azure.synapse.ml.fabric.FabricClient import com.microsoft.azure.synapse.ml.io.http._ import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails -import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.param.{GlobalParams, HasGlobalParams, ServiceParam} import com.microsoft.azure.synapse.ml.stages.{DropColumns, Lambda} import org.apache.http.NameValuePair import org.apache.http.client.methods.{HttpEntityEnclosingRequestBase, HttpPost, HttpRequestBase} @@ -493,7 +493,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform with HasURL with ComplexParamsWritable with HasSubscriptionKey with HasErrorCol with HasAADToken with HasCustomCogServiceDomain - with SynapseMLLogging { + with SynapseMLLogging with HasGlobalParams { setDefault( outputCol -> (this.uid + "_output"), @@ -547,6 +547,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform } override def transform(dataset: Dataset[_]): DataFrame = { + transferGlobalParamsToParamMap() logTransform[DataFrame](getInternalTransformer(dataset.schema).transform(dataset), dataset.columns.length ) } diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index 5435fcdbf3..58fe0ece79 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.services.openai import com.microsoft.azure.synapse.ml.codegen.GenerationUtils import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting} import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails -import com.microsoft.azure.synapse.ml.param.ServiceParam +import com.microsoft.azure.synapse.ml.param.{GlobalKey, GlobalParams, ServiceParam} import com.microsoft.azure.synapse.ml.services._ import org.apache.spark.ml.PipelineModel import org.apache.spark.ml.param.{Param, Params} @@ -51,11 +51,15 @@ trait HasMessagesInput extends Params { def setMessagesCol(v: String): this.type = set(messagesCol, v) } +case object OpenAIDeploymentNameKey extends GlobalKey[Either[String, String]] + trait HasOpenAISharedParams extends HasServiceParams with HasAPIVersion { val deploymentName = new ServiceParam[String]( this, "deploymentName", "The name of the deployment", isRequired = false) + GlobalParams.registerParam(deploymentName, OpenAIDeploymentNameKey) + def getDeploymentName: String = getScalarParam(deploymentName) def setDeploymentName(v: String): this.type = setScalarParam(deploymentName, v) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala new file mode 100644 index 0000000000..8b91625064 --- /dev/null +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -0,0 +1,12 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.param.GlobalParams + +object OpenAIDefaults { + def setDeploymentName(v: String): Unit = { + GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, Left(v)) + } +} diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index a43f3ffe3a..49db4911f2 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -7,7 +7,7 @@ import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol import com.microsoft.azure.synapse.ml.core.spark.Functions import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL} import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging} -import com.microsoft.azure.synapse.ml.param.StringStringMapParam +import com.microsoft.azure.synapse.ml.param.{HasGlobalParams, StringStringMapParam} import com.microsoft.azure.synapse.ml.services._ import org.apache.http.entity.AbstractHttpEntity import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} @@ -28,7 +28,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer with HasURL with HasCustomCogServiceDomain with ConcurrencyParams with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader with HasCognitiveServiceInput - with ComplexParamsWritable with SynapseMLLogging { + with ComplexParamsWritable with SynapseMLLogging with HasGlobalParams { logClass(FeatureNames.AiServices.OpenAI) @@ -124,7 +124,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer override def transform(dataset: Dataset[_]): DataFrame = { import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._ - + transferGlobalParamsToParamMap() logTransform[DataFrame]({ val df = dataset.toDF diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala new file mode 100644 index 0000000000..0f156d02f2 --- /dev/null +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala @@ -0,0 +1,59 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.services.openai + +import com.microsoft.azure.synapse.ml.core.test.base.Flaky +import org.apache.spark.sql.{DataFrame, Row} + +class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { + + import spark.implicits._ + + OpenAIDefaults.setDeploymentName(deploymentName) + + def promptCompletion: OpenAICompletion = new OpenAICompletion() + .setCustomServiceName(openAIServiceName) + .setMaxTokens(200) + .setOutputCol("out") + .setSubscriptionKey(openAIAPIKey) + .setPromptCol("prompt") + + lazy val promptDF: DataFrame = Seq( + "Once upon a time", + "Best programming language award goes to", + "SynapseML is " + ).toDF("prompt") + + test("Completion w Globals") { + val fromRow = CompletionResponse.makeFromRowConverter + promptCompletion.transform(promptDF).collect().foreach(r => + fromRow(r.getAs[Row]("out")).choices.foreach(c => + assert(c.text.length > 10))) + } + + lazy val prompt: OpenAIPrompt = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setTemperature(0) + + lazy val df: DataFrame = Seq( + ("apple", "fruits"), + ("mercedes", "cars"), + ("cake", "dishes"), + (null, "none") //scalastyle:ignore null + ).toDF("text", "category") + + test("OpenAIPrompt w Globals") { + val nonNullCount = prompt + .setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") + .setPostProcessing("csv") + .transform(df) + .select("outParsed") + .collect() + .count(r => Option(r.getSeq[String](0)).isDefined) + + assert(nonNullCount == 3) + } +} diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala new file mode 100644 index 0000000000..98f7eb33e6 --- /dev/null +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala @@ -0,0 +1,52 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import org.apache.spark.ml.param.{Param, Params} + +import scala.collection.mutable + +trait GlobalKey[T] + +object GlobalParams { + private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty + private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty + + + def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { + GlobalParams(key) = value + } + + private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { + GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) + } + + def getParam[T](p: Param[T]): Option[T] = { + ParamToKeyMap.get(p).flatMap { key => + key match { + case k: GlobalKey[T] => + getGlobalParam(k) + case _ => None + } + } + } + + def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { + ParamToKeyMap(p) = key + } +} + +trait HasGlobalParams extends Params{ + + private[ml] def transferGlobalParamsToParamMap(): Unit = { + // check for empty params. Fill em with GlobalParams + this.params + .filter(p => !this.isSet(p) && !this.hasDefault(p)) + .foreach { p => + GlobalParams.getParam(p).foreach { v => + set(p.asInstanceOf[Param[Any]], v) + } + } + } +} diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala new file mode 100644 index 0000000000..268f4b48f0 --- /dev/null +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala @@ -0,0 +1,49 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in project root for information. + +package com.microsoft.azure.synapse.ml.param + +import com.microsoft.azure.synapse.ml.core.test.base.Flaky +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.util.Identifiable + +case object TestParamKey extends GlobalKey[Double] + +class TestGlobalParams extends HasGlobalParams { + override val uid: String = Identifiable.randomUID("TestGlobalParams") + + val testParam: Param[Double] = new Param[Double]( + this, "TestParam", "Test Param for testing") + + println(testParam.parent) + println(hasParam(testParam.name)) + + GlobalParams.registerParam(testParam, TestParamKey) + + def getTestParam: Double = $(testParam) + + def setTestParam(v: Double): this.type = set(testParam, v) + + override def copy(extra: ParamMap): Transformer = defaultCopy(extra) + +} + +class GlobalParamSuite extends Flaky { + + val testGlobalParams = new TestGlobalParams() + + test("Basic Usage") { + GlobalParams.setGlobalParam(TestParamKey, 12.5) + testGlobalParams.transferGlobalParamsToParamMap() + assert(testGlobalParams.getTestParam == 12.5) + } + + test("Test Setting Directly Value") { + testGlobalParams.setTestParam(18.7334) + GlobalParams.setGlobalParam(TestParamKey, 19.853) + testGlobalParams.transferGlobalParamsToParamMap() + assert(testGlobalParams.getTestParam == 18.7334) + } +} +