-
Notifications
You must be signed in to change notification settings - Fork 831
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* GlobalParamObject implementation with higher-order getters. DeploymentName added to GlobalParam set * Enable ServiceParams as part of GlobalParam * Only one GlobalKey type and only one GlobalParams dictionary. Added asserts to type check Params vs Service Params. Added more tests. Tests pass! * Use initialize objects with GlobalParams and get rid of boxedClass logic * Refactor GlobalParams. Move to core, and set up OpenAIDefaults as user interface. * fix comments - Works! * Remove unused imports * Add headers and fix style bugs ---------
- Loading branch information
Showing
7 changed files
with
183 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 12 additions & 0 deletions
12
cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.services.openai | ||
|
||
import com.microsoft.azure.synapse.ml.param.GlobalParams | ||
|
||
object OpenAIDefaults { | ||
def setDeploymentName(v: String): Unit = { | ||
GlobalParams.setGlobalParam(OpenAIDeploymentNameKey, Left(v)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
...e/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.services.openai | ||
|
||
import com.microsoft.azure.synapse.ml.core.test.base.Flaky | ||
import org.apache.spark.sql.{DataFrame, Row} | ||
|
||
class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { | ||
|
||
import spark.implicits._ | ||
|
||
OpenAIDefaults.setDeploymentName(deploymentName) | ||
|
||
def promptCompletion: OpenAICompletion = new OpenAICompletion() | ||
.setCustomServiceName(openAIServiceName) | ||
.setMaxTokens(200) | ||
.setOutputCol("out") | ||
.setSubscriptionKey(openAIAPIKey) | ||
.setPromptCol("prompt") | ||
|
||
lazy val promptDF: DataFrame = Seq( | ||
"Once upon a time", | ||
"Best programming language award goes to", | ||
"SynapseML is " | ||
).toDF("prompt") | ||
|
||
test("Completion w Globals") { | ||
val fromRow = CompletionResponse.makeFromRowConverter | ||
promptCompletion.transform(promptDF).collect().foreach(r => | ||
fromRow(r.getAs[Row]("out")).choices.foreach(c => | ||
assert(c.text.length > 10))) | ||
} | ||
|
||
lazy val prompt: OpenAIPrompt = new OpenAIPrompt() | ||
.setSubscriptionKey(openAIAPIKey) | ||
.setCustomServiceName(openAIServiceName) | ||
.setOutputCol("outParsed") | ||
.setTemperature(0) | ||
|
||
lazy val df: DataFrame = Seq( | ||
("apple", "fruits"), | ||
("mercedes", "cars"), | ||
("cake", "dishes"), | ||
(null, "none") //scalastyle:ignore null | ||
).toDF("text", "category") | ||
|
||
test("OpenAIPrompt w Globals") { | ||
val nonNullCount = prompt | ||
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ") | ||
.setPostProcessing("csv") | ||
.transform(df) | ||
.select("outParsed") | ||
.collect() | ||
.count(r => Option(r.getSeq[String](0)).isDefined) | ||
|
||
assert(nonNullCount == 3) | ||
} | ||
} |
52 changes: 52 additions & 0 deletions
52
core/src/main/scala/com/microsoft/azure/synapse/ml/param/GlobalParams.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.param | ||
|
||
import org.apache.spark.ml.param.{Param, Params} | ||
|
||
import scala.collection.mutable | ||
|
||
trait GlobalKey[T] | ||
|
||
object GlobalParams { | ||
private val ParamToKeyMap: mutable.Map[Any, GlobalKey[_]] = mutable.Map.empty | ||
private val GlobalParams: mutable.Map[GlobalKey[_], Any] = mutable.Map.empty | ||
|
||
|
||
def setGlobalParam[T](key: GlobalKey[T], value: T): Unit = { | ||
GlobalParams(key) = value | ||
} | ||
|
||
private def getGlobalParam[T](key: GlobalKey[T]): Option[T] = { | ||
GlobalParams.get(key.asInstanceOf[GlobalKey[Any]]).map(_.asInstanceOf[T]) | ||
} | ||
|
||
def getParam[T](p: Param[T]): Option[T] = { | ||
ParamToKeyMap.get(p).flatMap { key => | ||
key match { | ||
case k: GlobalKey[T] => | ||
getGlobalParam(k) | ||
case _ => None | ||
} | ||
} | ||
} | ||
|
||
def registerParam[T](p: Param[T], key: GlobalKey[T]): Unit = { | ||
ParamToKeyMap(p) = key | ||
} | ||
} | ||
|
||
trait HasGlobalParams extends Params{ | ||
|
||
private[ml] def transferGlobalParamsToParamMap(): Unit = { | ||
// check for empty params. Fill em with GlobalParams | ||
this.params | ||
.filter(p => !this.isSet(p) && !this.hasDefault(p)) | ||
.foreach { p => | ||
GlobalParams.getParam(p).foreach { v => | ||
set(p.asInstanceOf[Param[Any]], v) | ||
} | ||
} | ||
} | ||
} |
49 changes: 49 additions & 0 deletions
49
core/src/test/scala/com/microsoft/azure/synapse/ml/param/GlobalParamsSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.param | ||
|
||
import com.microsoft.azure.synapse.ml.core.test.base.Flaky | ||
import org.apache.spark.ml.Transformer | ||
import org.apache.spark.ml.param.{Param, ParamMap} | ||
import org.apache.spark.ml.util.Identifiable | ||
|
||
case object TestParamKey extends GlobalKey[Double] | ||
|
||
class TestGlobalParams extends HasGlobalParams { | ||
override val uid: String = Identifiable.randomUID("TestGlobalParams") | ||
|
||
val testParam: Param[Double] = new Param[Double]( | ||
this, "TestParam", "Test Param for testing") | ||
|
||
println(testParam.parent) | ||
println(hasParam(testParam.name)) | ||
|
||
GlobalParams.registerParam(testParam, TestParamKey) | ||
|
||
def getTestParam: Double = $(testParam) | ||
|
||
def setTestParam(v: Double): this.type = set(testParam, v) | ||
|
||
override def copy(extra: ParamMap): Transformer = defaultCopy(extra) | ||
|
||
} | ||
|
||
class GlobalParamSuite extends Flaky { | ||
|
||
val testGlobalParams = new TestGlobalParams() | ||
|
||
test("Basic Usage") { | ||
GlobalParams.setGlobalParam(TestParamKey, 12.5) | ||
testGlobalParams.transferGlobalParamsToParamMap() | ||
assert(testGlobalParams.getTestParam == 12.5) | ||
} | ||
|
||
test("Test Setting Directly Value") { | ||
testGlobalParams.setTestParam(18.7334) | ||
GlobalParams.setGlobalParam(TestParamKey, 19.853) | ||
testGlobalParams.transferGlobalParamsToParamMap() | ||
assert(testGlobalParams.getTestParam == 18.7334) | ||
} | ||
} | ||
|