From c3cf48e954990d23cf80b4c55261765cc03b5341 Mon Sep 17 00:00:00 2001 From: cruise Date: Mon, 31 Jul 2023 23:18:19 +0800 Subject: [PATCH] support langchain transformer on fabric --- .../cognitive/langchain/LangchainTransform.py | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/cognitive/src/main/python/synapse/ml/cognitive/langchain/LangchainTransform.py b/cognitive/src/main/python/synapse/ml/cognitive/langchain/LangchainTransform.py index cbf6b528b8..71bf727856 100644 --- a/cognitive/src/main/python/synapse/ml/cognitive/langchain/LangchainTransform.py +++ b/cognitive/src/main/python/synapse/ml/cognitive/langchain/LangchainTransform.py @@ -44,6 +44,7 @@ ) from pyspark.sql.functions import udf from typing import cast, Optional, TypeVar, Type +from synapse.ml.core.platform import running_on_synapse_internal OPENAI_API_VERSION = "2022-12-01" RL = TypeVar("RL", bound="MLReadable") @@ -115,6 +116,7 @@ def __init__( subscriptionKey=None, url=None, apiVersion=OPENAI_API_VERSION, + useFabricInternalEndpoints=True, ): super(LangchainTransformer, self).__init__() self.chain = Param( @@ -125,6 +127,11 @@ def __init__( self.subscriptionKey = Param(self, "subscriptionKey", "openai api key") self.url = Param(self, "url", "openai api base") self.apiVersion = Param(self, "apiVersion", "openai api version") + self.useFabricInternalEndpoints = Param(self, "useFabricInternalEndpoints", "use internal openai endpoints when on fabric") + self.running_on_synapse_internal = running_on_synapse_internal() + if running_on_synapse_internal() and useFabricInternalEndpoints: + from synapse.ml.fabric.service_discovery import get_fabric_env_config + self.setUrl(get_fabric_env_config().fabric_env_config.ml_workload_endpoint + "cognitive/openai") kwargs = self._input_kwargs if subscriptionKey: kwargs["subscriptionKey"] = subscriptionKey @@ -132,6 +139,8 @@ def __init__( kwargs["url"] = url if apiVersion: kwargs["apiVersion"] = apiVersion + if useFabricInternalEndpoints: + kwargs["useFabricInternalEndpoints"] = useFabricInternalEndpoints self.setParams(**kwargs) @keyword_only @@ -143,6 +152,7 @@ def setParams( subscriptionKey=None, url=None, apiVersion=OPENAI_API_VERSION, + useFabricInternalEndpoints=True, ): kwargs = self._input_kwargs return self._set(**kwargs) @@ -153,6 +163,12 @@ def setChain(self, value): def getChain(self): return self.getOrDefault(self.chain) + def setUseFabricInternalEndpoints(self, value): + return self._set(useFabricInternalEndpoints=value) + + def getUseFabricInternalEndpoints(self): + return self.getOrDefault(self.useFabricInternalEndpoints) + def setSubscriptionKey(self, value: str): """ set the openAI api key @@ -196,10 +212,14 @@ def _transform(self, dataset): def udfFunction(x): import openai - openai.api_type = "azure" - openai.api_key = self.getSubscriptionKey() - openai.api_base = self.getUrl() - openai.api_version = self.getApiVersion() + if self.running_on_synapse_internal and self.getUseFabricInternalEndpoints(): + from synapse.ml.fabric.prerun.openai_prerun import OpenAIPrerun + OpenAIPrerun(api_base=self.getUrl()).init_personalized_session(None) + else: + openai.api_type = "azure" + openai.api_key = self.getSubscriptionKey() + openai.api_base = self.getUrl() + openai.api_version = self.getApiVersion() return self.getChain().run(x) outCol = self.getOutputCol()