diff --git a/pyproject.toml b/pyproject.toml index 15f212dc90..c2fb78ec9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "typer==0.9.0", "urllib3<1.27,>=1.21.1", "GitPython==3.1.37", - "pyyaml==6.0.1", ] classifiers = [ diff --git a/src/snowcli/cli/common/sql_execution.py b/src/snowcli/cli/common/sql_execution.py index 72447a5dc7..145cdad8df 100644 --- a/src/snowcli/cli/common/sql_execution.py +++ b/src/snowcli/cli/common/sql_execution.py @@ -44,13 +44,6 @@ def _execute_schema_query(self, query: str): self.check_database_and_schema() return self._execute_query(query) - def _execute_query(self, query: str): - *_, last_result = self._conn.execute_string(dedent(query)) - return last_result - - def _execute_queries(self, queries: str): - return self._conn.execute_string(dedent(queries)) - def check_database_and_schema(self) -> None: database = self._conn.database schema = self._conn.schema diff --git a/src/snowcli/cli/snowpark/services/manager.py b/src/snowcli/cli/snowpark/services/manager.py index 5fb7a45fdb..7247c4b41b 100644 --- a/src/snowcli/cli/snowpark/services/manager.py +++ b/src/snowcli/cli/snowpark/services/manager.py @@ -2,7 +2,7 @@ import os from pathlib import Path -import yaml +import strictyaml from snowcli.cli.common.sql_execution import SqlExecutionMixin from snowflake.connector.cursor import SnowflakeCursor @@ -32,7 +32,7 @@ def create( def _read_yaml(self, path: Path) -> str: # TODO(aivanou): Add validation towards schema with open(path, "r") as content: - spec_obj = yaml.safe_load(content) + spec_obj = strictyaml.load(content) return json.dumps(spec_obj) def desc(self, service_name: str) -> SnowflakeCursor: diff --git a/tests/snowpark/test_services.py b/tests/snowpark/test_services.py index f0a31c93b8..1427751162 100644 --- a/tests/snowpark/test_services.py +++ b/tests/snowpark/test_services.py @@ -1,4 +1,4 @@ -import yaml +import strictyaml import unittest from pathlib import Path from unittest.mock import Mock, patch @@ -45,8 +45,8 @@ def test_create_service_with_invalid_spec(self, mock_read_yaml): compute_pool = "test_pool" spec_path = "/path/to/spec.yaml" num_instances = 42 - mock_read_yaml.side_effect = yaml.YAMLError("Invalid YAML") - with self.assertRaises(yaml.YAMLError): + mock_read_yaml.side_effect = strictyaml.YAMLError("Invalid YAML") + with self.assertRaises(strictyaml.YAMLError): self.service_manager.create( service_name, compute_pool, Path(spec_path), num_instances ) @@ -104,73 +104,5 @@ def test_logs(self, mock_execute_schema_query): self.assertEqual(result, cursor) -<<<<<<< HEAD -@mock.patch("snowflake.connector.connect") -def test_desc_service(mock_connector, runner, mock_ctx): - ctx = mock_ctx() - mock_connector.return_value = ctx - service_name = "test_service" - - result = runner.invoke(["snowpark", "services", "desc", service_name]) - - assert result.exit_code == 0, result.output - assert ctx.get_query() == f"desc service {service_name}" - - -@mock.patch("snowflake.connector.connect") -def test_list_service(mock_connector, runner, mock_ctx): - ctx = mock_ctx() - mock_connector.return_value = ctx - - result = runner.invoke(["snowpark", "services", "list"]) - - assert result.exit_code == 0, result.output - assert ctx.get_query() == "show services" - - -@mock.patch("snowflake.connector.connect") -def test_drop_service(mock_connector, runner, mock_ctx): - ctx = mock_ctx() - mock_connector.return_value = ctx - - result = runner.invoke(["snowpark", "services", "drop", "serviceName"]) - - assert result.exit_code == 0, result.output - assert ctx.get_query() == "drop service serviceName" - - -@mock.patch("snowflake.connector.connect") -def test_service_status(mock_connector, runner, mock_ctx): - ctx = mock_ctx() - mock_connector.return_value = ctx - - result = runner.invoke(["snowpark", "services", "status", "serviceName"]) - - assert result.exit_code == 0, result.output - assert ctx.get_query() == "CALL SYSTEM$GET_SERVICE_STATUS('serviceName')" - - -@mock.patch("snowflake.connector.connect") -def test_service_logs(mock_connector, runner, mock_ctx, snapshot): - ctx = mock_ctx() - mock_connector.return_value = ctx - - result = runner.invoke( - [ - "snowpark", - "services", - "logs", - "--container_name", - "containerName", - "serviceName", - ] - ) - - assert result.exit_code == 0, result.output - assert ( - ctx.get_query() - == "call SYSTEM$GET_SERVICE_LOGS('serviceName', '0', 'containerName');" - ) - if __name__ == "__main__": unittest.main()