Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aivanou committed Sep 28, 2023
1 parent 9ebd1cc commit 3566c08
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 81 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dependencies = [
"typer==0.9.0",
"urllib3<1.27,>=1.21.1",
"GitPython==3.1.37",
"pyyaml==6.0.1",

]
classifiers = [
Expand Down
7 changes: 0 additions & 7 deletions src/snowcli/cli/common/sql_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ def _execute_schema_query(self, query: str):
self.check_database_and_schema()
return self._execute_query(query)

def _execute_query(self, query: str):
*_, last_result = self._conn.execute_string(dedent(query))
return last_result

def _execute_queries(self, queries: str):
return self._conn.execute_string(dedent(queries))

def check_database_and_schema(self) -> None:
database = self._conn.database
schema = self._conn.schema
Expand Down
4 changes: 2 additions & 2 deletions src/snowcli/cli/snowpark/services/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from pathlib import Path

import yaml
import strictyaml
from snowcli.cli.common.sql_execution import SqlExecutionMixin
from snowflake.connector.cursor import SnowflakeCursor

Expand Down Expand Up @@ -32,7 +32,7 @@ def create(
def _read_yaml(self, path: Path) -> str:
# TODO(aivanou): Add validation towards schema
with open(path, "r") as content:
spec_obj = yaml.safe_load(content)
spec_obj = strictyaml.load(content)
return json.dumps(spec_obj)

def desc(self, service_name: str) -> SnowflakeCursor:
Expand Down
74 changes: 3 additions & 71 deletions tests/snowpark/test_services.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import yaml
import strictyaml
import unittest
from pathlib import Path
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -45,8 +45,8 @@ def test_create_service_with_invalid_spec(self, mock_read_yaml):
compute_pool = "test_pool"
spec_path = "/path/to/spec.yaml"
num_instances = 42
mock_read_yaml.side_effect = yaml.YAMLError("Invalid YAML")
with self.assertRaises(yaml.YAMLError):
mock_read_yaml.side_effect = strictyaml.YAMLError("Invalid YAML")
with self.assertRaises(strictyaml.YAMLError):
self.service_manager.create(
service_name, compute_pool, Path(spec_path), num_instances
)
Expand Down Expand Up @@ -104,73 +104,5 @@ def test_logs(self, mock_execute_schema_query):
self.assertEqual(result, cursor)


<<<<<<< HEAD
@mock.patch("snowflake.connector.connect")
def test_desc_service(mock_connector, runner, mock_ctx):
ctx = mock_ctx()
mock_connector.return_value = ctx
service_name = "test_service"

result = runner.invoke(["snowpark", "services", "desc", service_name])

assert result.exit_code == 0, result.output
assert ctx.get_query() == f"desc service {service_name}"


@mock.patch("snowflake.connector.connect")
def test_list_service(mock_connector, runner, mock_ctx):
ctx = mock_ctx()
mock_connector.return_value = ctx

result = runner.invoke(["snowpark", "services", "list"])

assert result.exit_code == 0, result.output
assert ctx.get_query() == "show services"


@mock.patch("snowflake.connector.connect")
def test_drop_service(mock_connector, runner, mock_ctx):
ctx = mock_ctx()
mock_connector.return_value = ctx

result = runner.invoke(["snowpark", "services", "drop", "serviceName"])

assert result.exit_code == 0, result.output
assert ctx.get_query() == "drop service serviceName"


@mock.patch("snowflake.connector.connect")
def test_service_status(mock_connector, runner, mock_ctx):
ctx = mock_ctx()
mock_connector.return_value = ctx

result = runner.invoke(["snowpark", "services", "status", "serviceName"])

assert result.exit_code == 0, result.output
assert ctx.get_query() == "CALL SYSTEM$GET_SERVICE_STATUS('serviceName')"


@mock.patch("snowflake.connector.connect")
def test_service_logs(mock_connector, runner, mock_ctx, snapshot):
ctx = mock_ctx()
mock_connector.return_value = ctx

result = runner.invoke(
[
"snowpark",
"services",
"logs",
"--container_name",
"containerName",
"serviceName",
]
)

assert result.exit_code == 0, result.output
assert (
ctx.get_query()
== "call SYSTEM$GET_SERVICE_LOGS('serviceName', '0', 'containerName');"
)

if __name__ == "__main__":
unittest.main()

0 comments on commit 3566c08

Please sign in to comment.