Skip to content

Commit

Permalink
fixing the bug for the cerberus import
Browse files Browse the repository at this point in the history
  • Loading branch information
asingamaneni committed Sep 1, 2023
1 parent f5b3ea5 commit ecec44b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 29 deletions.
12 changes: 6 additions & 6 deletions spark_expectations/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def get_secret_type(self) -> Optional[str]:
return self._se_streaming_stats_dict[user_config.secret_type].lower()
raise SparkExpectationsMiscException(
"""The spark expectations context is not set completely, please assign
'user_config.secret_type' before
'UserConfig.secret_type' before
accessing it"""
)

Expand All @@ -879,7 +879,7 @@ def get_server_url_key(self) -> Optional[str]:
return _server_url_key
raise SparkExpectationsMiscException(
"""The spark expectations context is not set completely, please assign
'user_config.cbs_kafka_server_url' before
'UserConfig.cbs_kafka_server_url' before
accessing it"""
)

Expand All @@ -899,7 +899,7 @@ def get_token_endpoint_url(self) -> Optional[str]:
return _token_endpoint_url
raise SparkExpectationsMiscException(
"""The spark expectations context is not set completely, please assign
'user_config.cbs_secret_token_url' before
'UserConfig.cbs_secret_token_url' before
accessing it"""
)

Expand All @@ -919,7 +919,7 @@ def get_token(self) -> Optional[str]:
return _token
raise SparkExpectationsMiscException(
"""The spark expectations context is not set completely, please assign
'user_config.cbs_secret_token' before
'UserConfig.cbs_secret_token' before
accessing it"""
)

Expand All @@ -939,7 +939,7 @@ def get_client_id(self) -> Optional[str]:
return _client_id
raise SparkExpectationsMiscException(
"""The spark expectations context is not set completely, please assign
'user_config.cbs_secret_app_name' before
'UserConfig.cbs_secret_app_name' before
accessing it"""
)

Expand All @@ -959,7 +959,7 @@ def get_topic_name(self) -> Optional[str]:
return _topic_name
raise SparkExpectationsMiscException(
"""The spark expectations context is not set completely, please assign
'user_config.cbs_topic_name' before
'UserConfig.cbs_topic_name' before
accessing it"""
)

Expand Down
17 changes: 9 additions & 8 deletions spark_expectations/secrets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from dataclasses import dataclass
from typing import Optional, Dict
import pluggy
from cerberus.client import CerberusClient
from spark_expectations.config.user_config import Constants as user_config
from spark_expectations.config.user_config import Constants as UserConfig
from spark_expectations.core import get_spark_session
from spark_expectations import _log

Expand Down Expand Up @@ -56,14 +55,16 @@ def get_secret_value(
"""
This function implemented to get secret value from cerberus
Args:
secret_key: str which accepts url with secret key
secret_key_path: str which accepts url with secret key
secret_dict: dict which contains params for fetch secrets
Returns:
str | none : returns secret value in string or none
"""

if secret_dict[user_config.secret_type].lower() == "cerberus":
_client = CerberusClient(secret_dict[user_config.cbs_url])
from cerberus.client import CerberusClient

if secret_dict[UserConfig.secret_type].lower() == "cerberus":
_client = CerberusClient(secret_dict[UserConfig.cbs_url])
data = _client.get_secrets_data(secret_key_path)
return data

Expand All @@ -82,12 +83,12 @@ def get_secret_value(
# pragma: no cover
This function implemented to get secret value from databricks scope
Args:
secret_key: str which accepts url with secret key
secret_key_path: str which accepts url with secret key
secret_dict: dict which contains params for fetch secrets
Returns:
str | none : returns secret value in string or none
"""
if secret_dict[user_config.secret_type].lower() == "databricks":
if secret_dict[UserConfig.secret_type].lower() == "databricks":
try:
from pyspark.dbutils import DBUtils

Expand All @@ -101,7 +102,7 @@ def get_secret_value(
)

data = dbutils.secrets.get(
scope=secret_dict[user_config.dbx_secret_scope], key=secret_key_path
scope=secret_dict[UserConfig.dbx_secret_scope], key=secret_key_path
) # pragma: no cover
return data # pragma: no cover
return None # pragma: no cover
Expand Down
12 changes: 6 additions & 6 deletions tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ def test_get_secret_type_exception():

with pytest.raises(SparkExpectationsMiscException,
match="""The spark expectations context is not set completely, please assign
'user_config.secret_type' before
'UserConfig.secret_type' before
accessing it"""):
context.get_secret_type

Expand All @@ -1183,7 +1183,7 @@ def test_get_server_url_key_exception():

with pytest.raises(SparkExpectationsMiscException,
match="""The spark expectations context is not set completely, please assign
'user_config.cbs_kafka_server_url' before
'UserConfig.cbs_kafka_server_url' before
accessing it"""):
context.get_server_url_key

Expand All @@ -1208,7 +1208,7 @@ def test_get_token_endpoint_url_exception():

with pytest.raises(SparkExpectationsMiscException,
match="""The spark expectations context is not set completely, please assign
'user_config.cbs_secret_token_url' before
'UserConfig.cbs_secret_token_url' before
accessing it"""):
context.get_token_endpoint_url

Expand All @@ -1233,7 +1233,7 @@ def test_get_token_exception():

with pytest.raises(SparkExpectationsMiscException,
match="""The spark expectations context is not set completely, please assign
'user_config.cbs_secret_token' before
'UserConfig.cbs_secret_token' before
accessing it"""):
context.get_token

Expand All @@ -1258,7 +1258,7 @@ def test_get_client_id_exception():

with pytest.raises(SparkExpectationsMiscException,
match="""The spark expectations context is not set completely, please assign
'user_config.cbs_secret_app_name' before
'UserConfig.cbs_secret_app_name' before
accessing it"""):
context.get_client_id

Expand All @@ -1283,7 +1283,7 @@ def test_get_topic_name_exception():

with pytest.raises(SparkExpectationsMiscException,
match="""The spark expectations context is not set completely, please assign
'user_config.cbs_topic_name' before
'UserConfig.cbs_topic_name' before
accessing it"""):
context.get_topic_name

Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2321,7 +2321,7 @@ def get_dataset() -> DataFrame:
# write_to_table,
# agg_dq=None,
# query_dq=None,
# spark_conf={user_config.se_notifications_on_fail: False},
# spark_conf={UserConfig.se_notifications_on_fail: False},
# options={'mode': 'overwrite', "format": "delta"},
# options_error_table={'mode': 'overwrite', "format": "delta"}
# )(mock_func)
Expand Down
16 changes: 8 additions & 8 deletions tests/secrets/test__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
SparkExpectationsSecretPluginSpec,
CerberusSparkExpectationsSecretPluginImpl,
SparkExpectationsSecretsBackend)
from spark_expectations.config.user_config import Constants as user_config
from spark_expectations.config.user_config import Constants as UserConfig


def test_get_secret_value(mocker):
Expand All @@ -33,7 +33,7 @@ def test_get_spark_expectations_tasks_hook(caplog):
assert isinstance(result, pluggy._hooks._HookRelay)


@patch("spark_expectations.secrets.CerberusClient", autospec=True, spec_set=True)
@patch("cerberus.client.CerberusClient", autospec=True, spec_set=True)
def test_get_secret_value_with_cerberus(mock_cerberus):
cerberus_se_handler = CerberusSparkExpectationsSecretPluginImpl()
# Simulate the return value of get_secrets_data
Expand All @@ -46,8 +46,8 @@ def test_get_secret_value_with_cerberus(mock_cerberus):
# Set up the input parameters
secret_key_path = "my_secret_key"
secret_dict = {
user_config.secret_type: "cerberus",
user_config.cbs_url: "https://example.com/cerberus",
UserConfig.secret_type: "cerberus",
UserConfig.cbs_url: "https://example.com/cerberus",
}

# Call the function under test
Expand All @@ -64,15 +64,15 @@ def test_get_secret_value_with_cerberus_none():
# Set up the input parameters
secret_key_path = "my_secret_key"
secret_dict = {
user_config.secret_type: "x",
user_config.cbs_url: "https://example.com/cerberus",
UserConfig.secret_type: "x",
UserConfig.cbs_url: "https://example.com/cerberus",
}

# Call the function under test
result = cerberus_se_handler.get_secret_value(secret_key_path, secret_dict)

# Assert the CerberusClient methods were called as expected
assert result == None
assert result is None


@mock.patch("spark_expectations.secrets.get_spark_expectations_tasks_hook")
Expand Down Expand Up @@ -121,6 +121,6 @@ def test_get_secret_with_invalid_key(mock_hook):
def test_get_secret_exception():
# Create an instance of the class under test
secret_manager = SparkExpectationsSecretsBackend(secret_dict={"my_secret_key": "my_secret",
user_config.secret_type: "databricks"})
UserConfig.secret_type: "databricks"})
with pytest.raises(Exception):
secret_manager.get_secret("my_secret_key")

0 comments on commit ecec44b

Please sign in to comment.