Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-nmoiseyev committed Sep 27, 2024
1 parent b2c3097 commit 4077470
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 27 deletions.
57 changes: 32 additions & 25 deletions libs/snowflake/langchain_snowflake/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,16 @@ def validate_environment(cls, values: Dict) -> Dict:
"""Validate the environment needed to establish a Snowflake session or obtain
an API root from a provided Snowflake session."""

values["database"] = get_from_dict_or_env(
values, "database", "SNOWFLAKE_DATABASE"
)
values["schema"] = get_from_dict_or_env(values, "schema", "SNOWFLAKE_SCHEMA")

if "sp_session" not in values or values["sp_session"] is None:
values["username"] = get_from_dict_or_env(
values, "username", "SNOWFLAKE_USERNAME"
)
values["account"] = get_from_dict_or_env(
values, "account", "SNOWFLAKE_ACCOUNT"
)
values["role"] = get_from_dict_or_env(values, "role", "SNOWFLAKE_ROLE")
for param, env_var in [
("username", "SNOWFLAKE_USERNAME"),
("account", "SNOWFLAKE_ACCOUNT"),
("role", "SNOWFLAKE_ROLE"),
("database", "SNOWFLAKE_DATABASE"),
("schema", "SNOWFLAKE_SCHEMA"),
]:
if param not in values and env_var_is_set(env_var):
values[param] = get_from_dict_or_env(values, param, env_var)

# check whether to authenticate with password or authenticator
if "password" in values or env_var_is_set("SNOWFLAKE_PASSWORD"):
Expand Down Expand Up @@ -263,20 +260,30 @@ def validate_environment(cls, values: Dict) -> Dict:
f"password, account, role, authenticator)."
)

# If overridable parameters are not provided, use the value from the session
for param, method in [
("database", "get_current_database"),
("schema", "get_current_schema"),
# Set the overridable session parameters.
for param, env_var, method in [
("database", "SNOWFLAKE_DATABASE", "get_current_database"),
("schema", "SNOWFLAKE_SCHEMA", "get_current_schema"),
]:
if param not in values:
session_value = getattr(values["sp_session"], method)()
if session_value is None:
raise CortexSearchRetrieverError(
f"Snowflake {param} not set on the provided session. Pass "
f"the {param} as an argument, set it as an environment "
f"variable, or provide it in your session configuration."
)
values[param] = session_value
# Try to get the param as an env var if it's not in the kwargs.
if param in values:
continue

if env_var_is_set(env_var):
values[param] = get_from_dict_or_env(values, param, env_var)
continue

# If we're still missing the param, try to get it from the
# session. Error out at this point since we couldn't find it
# anywhere.
session_value = getattr(values["sp_session"], method)()
if session_value is None:
raise CortexSearchRetrieverError(
f"Snowflake {param} not set on the provided session. Pass "
f"the {param} as an argument, set it as an environment "
f"variable, or provide it in your session configuration."
)
values[param] = session_value

return values

Expand Down
44 changes: 42 additions & 2 deletions libs/snowflake/tests/integration_tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,41 @@ def test_snowflake_cortex_search_session_auth_validation_error() -> None:
del kwargs[param]


@pytest.mark.requires("snowflake.core")
def test_snowflake_cortex_search_session_auth_no_database() -> None:
"""Test that a database param is not needed when the provided
`snowlfake.snowpark.Session object` has a database."""

db = os.environ["SNOWFLAKE_DATABASE"]

with mock.patch.dict(os.environ, {"SNOWFLAKE_DATABASE": ""}):
columns = ["name", "description", "era", "diet"]
search_column = "description"
kwargs = {
"search_service": "dinosaur_svc",
"columns": columns,
"search_column": search_column,
"limit": 10,
}

session_config = {
"account": os.environ["SNOWFLAKE_ACCOUNT"],
"user": os.environ["SNOWFLAKE_USERNAME"],
"password": os.environ["SNOWFLAKE_PASSWORD"],
"database": db,
"schema": os.environ["SNOWFLAKE_SCHEMA"],
"role": os.environ["SNOWFLAKE_ROLE"],
}

session = Session.builder.configs(session_config).create()

retriever = CortexSearchRetriever(sp_session=session, **kwargs)

documents = retriever.invoke("dinosaur with a large tail")
assert len(documents) > 0
check_documents(documents, columns, search_column)


@pytest.mark.requires("snowflake.core")
def test_snowflake_cortex_search_session_auth_overrides() -> None:
"""Test overrides to the provided `snowlfake.snowpark.Session object`."""
Expand All @@ -288,12 +323,17 @@ def test_snowflake_cortex_search_session_auth_overrides() -> None:
"role": os.environ["SNOWFLAKE_ROLE"],
}

for param in ["database", "schema"]:
for param, env_var in [
("database", "SNOWFLAKE_DATABASE"),
("schema", "SNOWFLAKE_SCHEMA"),
]:
session_config_copy = session_config.copy()
del session_config_copy[param]
session = Session.builder.configs(session_config_copy).create()

retriever = CortexSearchRetriever(sp_session=session, **kwargs)
kwargs_copy = kwargs.copy()
kwargs_copy[param] = os.environ[env_var]
retriever = CortexSearchRetriever(sp_session=session, **kwargs_copy)

documents = retriever.invoke("dinosaur with a large tail")
assert len(documents) > 0
Expand Down

0 comments on commit 4077470

Please sign in to comment.