Skip to content

Commit

Permalink
avoid addtional param
Browse files Browse the repository at this point in the history
  • Loading branch information
mslhrotk committed Aug 7, 2023
1 parent c3cf48e commit 50d09ee
Showing 1 changed file with 3 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(
subscriptionKey=None,
url=None,
apiVersion=OPENAI_API_VERSION,
useFabricInternalEndpoints=True,
):
super(LangchainTransformer, self).__init__()
self.chain = Param(
Expand All @@ -127,20 +126,17 @@ def __init__(
self.subscriptionKey = Param(self, "subscriptionKey", "openai api key")
self.url = Param(self, "url", "openai api base")
self.apiVersion = Param(self, "apiVersion", "openai api version")
self.useFabricInternalEndpoints = Param(self, "useFabricInternalEndpoints", "use internal openai endpoints when on fabric")
self.running_on_synapse_internal = running_on_synapse_internal()
if running_on_synapse_internal() and useFabricInternalEndpoints:
if running_on_synapse_internal():
from synapse.ml.fabric.service_discovery import get_fabric_env_config
self.setUrl(get_fabric_env_config().fabric_env_config.ml_workload_endpoint + "cognitive/openai")
self._setDefault(url=get_fabric_env_config().fabric_env_config.ml_workload_endpoint + "cognitive/openai")
kwargs = self._input_kwargs
if subscriptionKey:
kwargs["subscriptionKey"] = subscriptionKey
if url:
kwargs["url"] = url
if apiVersion:
kwargs["apiVersion"] = apiVersion
if useFabricInternalEndpoints:
kwargs["useFabricInternalEndpoints"] = useFabricInternalEndpoints
self.setParams(**kwargs)

@keyword_only
Expand All @@ -152,7 +148,6 @@ def setParams(
subscriptionKey=None,
url=None,
apiVersion=OPENAI_API_VERSION,
useFabricInternalEndpoints=True,
):
kwargs = self._input_kwargs
return self._set(**kwargs)
Expand All @@ -163,12 +158,6 @@ def setChain(self, value):
def getChain(self):
return self.getOrDefault(self.chain)

def setUseFabricInternalEndpoints(self, value):
return self._set(useFabricInternalEndpoints=value)

def getUseFabricInternalEndpoints(self):
return self.getOrDefault(self.useFabricInternalEndpoints)

def setSubscriptionKey(self, value: str):
"""
set the openAI api key
Expand Down Expand Up @@ -212,7 +201,7 @@ def _transform(self, dataset):
def udfFunction(x):
import openai

if self.running_on_synapse_internal and self.getUseFabricInternalEndpoints():
if self.running_on_synapse_internal and not self.isSet(self.url):
from synapse.ml.fabric.prerun.openai_prerun import OpenAIPrerun
OpenAIPrerun(api_base=self.getUrl()).init_personalized_session(None)
else:
Expand Down

0 comments on commit 50d09ee

Please sign in to comment.