Skip to content

Commit

Permalink
support langchain transformer on fabric
Browse files Browse the repository at this point in the history
  • Loading branch information
mslhrotk committed Aug 7, 2023
1 parent df2712a commit c3cf48e
Showing 1 changed file with 24 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from pyspark.sql.functions import udf
from typing import cast, Optional, TypeVar, Type
from synapse.ml.core.platform import running_on_synapse_internal

OPENAI_API_VERSION = "2022-12-01"
RL = TypeVar("RL", bound="MLReadable")
Expand Down Expand Up @@ -115,6 +116,7 @@ def __init__(
subscriptionKey=None,
url=None,
apiVersion=OPENAI_API_VERSION,
useFabricInternalEndpoints=True,
):
super(LangchainTransformer, self).__init__()
self.chain = Param(
Expand All @@ -125,13 +127,20 @@ def __init__(
self.subscriptionKey = Param(self, "subscriptionKey", "openai api key")
self.url = Param(self, "url", "openai api base")
self.apiVersion = Param(self, "apiVersion", "openai api version")
self.useFabricInternalEndpoints = Param(self, "useFabricInternalEndpoints", "use internal openai endpoints when on fabric")
self.running_on_synapse_internal = running_on_synapse_internal()
if running_on_synapse_internal() and useFabricInternalEndpoints:
from synapse.ml.fabric.service_discovery import get_fabric_env_config
self.setUrl(get_fabric_env_config().fabric_env_config.ml_workload_endpoint + "cognitive/openai")
kwargs = self._input_kwargs
if subscriptionKey:
kwargs["subscriptionKey"] = subscriptionKey
if url:
kwargs["url"] = url
if apiVersion:
kwargs["apiVersion"] = apiVersion
if useFabricInternalEndpoints:
kwargs["useFabricInternalEndpoints"] = useFabricInternalEndpoints
self.setParams(**kwargs)

@keyword_only
Expand All @@ -143,6 +152,7 @@ def setParams(
subscriptionKey=None,
url=None,
apiVersion=OPENAI_API_VERSION,
useFabricInternalEndpoints=True,
):
kwargs = self._input_kwargs
return self._set(**kwargs)
Expand All @@ -153,6 +163,12 @@ def setChain(self, value):
def getChain(self):
return self.getOrDefault(self.chain)

def setUseFabricInternalEndpoints(self, value):
return self._set(useFabricInternalEndpoints=value)

def getUseFabricInternalEndpoints(self):
return self.getOrDefault(self.useFabricInternalEndpoints)

def setSubscriptionKey(self, value: str):
"""
set the openAI api key
Expand Down Expand Up @@ -196,10 +212,14 @@ def _transform(self, dataset):
def udfFunction(x):
import openai

openai.api_type = "azure"
openai.api_key = self.getSubscriptionKey()
openai.api_base = self.getUrl()
openai.api_version = self.getApiVersion()
if self.running_on_synapse_internal and self.getUseFabricInternalEndpoints():
from synapse.ml.fabric.prerun.openai_prerun import OpenAIPrerun
OpenAIPrerun(api_base=self.getUrl()).init_personalized_session(None)
else:
openai.api_type = "azure"
openai.api_key = self.getSubscriptionKey()
openai.api_base = self.getUrl()
openai.api_version = self.getApiVersion()
return self.getChain().run(x)

outCol = self.getOutputCol()
Expand Down

0 comments on commit c3cf48e

Please sign in to comment.