Skip to content

Commit

Permalink
additional tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Jul 21, 2024
1 parent bffd018 commit e5d9aec
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 1 deletion.
29 changes: 29 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from ray_provider import __version__, get_provider_info


def test_get_provider_info():
expected_info = {
"package-name": "astro-provider-ray",
"name": "Ray",
"description": "An integration between airflow and ray",
"connection-types": [{"connection-type": "ray", "hook-class-name": "ray_provider.hooks.ray.RayHook"}],
"versions": [__version__],
}

result = get_provider_info()

assert result == expected_info
assert isinstance(result, dict)
assert "package-name" in result
assert "name" in result
assert "description" in result
assert "connection-types" in result
assert "versions" in result
assert isinstance(result["versions"], list)
assert result["versions"][0] == __version__


if __name__ == "__main__":
pytest.main()
181 changes: 180 additions & 1 deletion tests/hooks/test_ray_hooks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest.mock import MagicMock, mock_open, patch
import subprocess
from unittest.mock import MagicMock, Mock, mock_open, patch

import pytest
from airflow.exceptions import AirflowException
from kubernetes import client
from kubernetes.client.exceptions import ApiException
from ray.job_submission import JobStatus

Expand All @@ -10,6 +12,20 @@

class TestRayHook:

@pytest.fixture
def ray_hook(self):
with patch("ray_provider.hooks.ray.KubernetesHook.get_connection") as mock_get_connection:
mock_connection = Mock()
mock_connection.extra_dejson = {"kube_config_path": None, "kube_config": None, "cluster_context": None}
mock_get_connection.return_value = mock_connection

with patch("ray_provider.hooks.ray.KubernetesHook.__init__", return_value=None):
hook = RayHook(conn_id="test_conn")
# Manually set the necessary attributes
hook.namespace = "default"
hook.kubeconfig = "/path/to/kubeconfig"
return hook

@patch("ray_provider.hooks.ray.KubernetesHook.get_connection")
@patch("ray_provider.hooks.ray.KubernetesHook.__init__")
def test_init(self, mock_kubernetes_init, mock_get_connection):
Expand Down Expand Up @@ -189,6 +205,109 @@ def test_is_port_open(self, mock_socket, mock_get_connection):
result = hook._is_port_open("localhost", 8080)
assert result is True

@patch("ray_provider.hooks.ray.RayHook.core_v1_client")
def test_get_service_success(self, mock_core_v1_client, ray_hook):
mock_service = Mock(spec=client.V1Service)
mock_core_v1_client.read_namespaced_service.return_value = mock_service

service = ray_hook._get_service("test-service", "default")

assert service == mock_service
mock_core_v1_client.read_namespaced_service.assert_called_once_with("test-service", "default")

@patch("ray_provider.hooks.ray.RayHook.core_v1_client")
def test_get_service_not_found(self, mock_core_v1_client, ray_hook):
mock_core_v1_client.read_namespaced_service.side_effect = client.exceptions.ApiException(status=404)

with pytest.raises(AirflowException) as exc_info:
ray_hook._get_service("non-existent-service", "default")

assert "Service non-existent-service not found" in str(exc_info.value)

def test_get_load_balancer_details_with_ingress(self, ray_hook):
mock_service = Mock(spec=client.V1Service)
mock_ingress = Mock(spec=client.V1LoadBalancerIngress)
mock_ingress.ip = "192.168.1.1"
mock_ingress.hostname = None
mock_service.status.load_balancer.ingress = [mock_ingress]

mock_port = Mock()
mock_port.name = "http"
mock_port.port = 80
mock_service.spec.ports = [mock_port]

lb_details = ray_hook._get_load_balancer_details(mock_service)

assert lb_details == {"ip_or_hostname": "192.168.1.1", "ports": [{"name": "http", "port": 80}]}

def test_get_load_balancer_details_with_hostname(self, ray_hook):
mock_service = Mock(spec=client.V1Service)
mock_ingress = Mock(spec=client.V1LoadBalancerIngress)
mock_ingress.ip = None
mock_ingress.hostname = "example.com"
mock_service.status.load_balancer.ingress = [mock_ingress]

mock_port = Mock()
mock_port.name = "https"
mock_port.port = 443
mock_service.spec.ports = [mock_port]

lb_details = ray_hook._get_load_balancer_details(mock_service)

assert lb_details == {"ip_or_hostname": "example.com", "ports": [{"name": "https", "port": 443}]}

def test_get_load_balancer_details_no_ingress(self, ray_hook):
mock_service = Mock(spec=client.V1Service)
mock_service.status.load_balancer.ingress = None

lb_details = ray_hook._get_load_balancer_details(mock_service)

assert lb_details is None

def test_get_load_balancer_details_no_ip_or_hostname(self, ray_hook):
mock_service = Mock(spec=client.V1Service)
mock_ingress = Mock(spec=client.V1LoadBalancerIngress)
mock_ingress.ip = None
mock_ingress.hostname = None
mock_service.status.load_balancer.ingress = [mock_ingress]

lb_details = ray_hook._get_load_balancer_details(mock_service)

assert lb_details is None

@patch("ray_provider.hooks.ray.RayHook.log")
@patch("ray_provider.hooks.ray.subprocess.run")
def test_run_bash_command_exception(self, mock_subprocess_run, mock_log, ray_hook):
# Simulate a CalledProcessError
mock_subprocess_run.side_effect = subprocess.CalledProcessError(
returncode=1, cmd="test command", output="test output", stderr="test error"
)

# Call the method
stdout, stderr = ray_hook._run_bash_command("test command")

# Assert that the method returned None for both stdout and stderr
assert stdout is None
assert stderr is None

# Assert that the error was logged
mock_log.error.assert_any_call(
"An error occurred while executing the command: %s", mock_subprocess_run.side_effect
)
mock_log.error.assert_any_call("Return code: %s", 1)
mock_log.error.assert_any_call("Standard Output: %s", "test output")
mock_log.error.assert_any_call("Standard Error: %s", "test error")

# Verify that subprocess.run was called with the correct arguments
mock_subprocess_run.assert_called_once_with(
"test command",
shell=True,
check=True,
text=True,
capture_output=True,
env=ray_hook._run_bash_command.__globals__["os"].environ.copy(),
)

@patch("ray_provider.hooks.ray.KubernetesHook.get_connection")
@patch("ray_provider.hooks.ray.KubernetesHook.__init__")
@patch("ray_provider.hooks.ray.subprocess.run")
Expand Down Expand Up @@ -217,6 +336,66 @@ def test_uninstall_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_i
assert "uninstall output" in stdout
assert stderr == ""

@patch("ray_provider.hooks.ray.RayHook._get_service")
@patch("ray_provider.hooks.ray.RayHook._get_load_balancer_details")
@patch("ray_provider.hooks.ray.RayHook._is_port_open")
def test_wait_for_load_balancer_success(self, mock_is_port_open, mock_get_lb_details, mock_get_service, ray_hook):
# Mock the service
mock_service = Mock(spec=client.V1Service)
mock_get_service.return_value = mock_service

# Mock the load balancer details
mock_get_lb_details.return_value = {
"ip_or_hostname": "test-lb.example.com",
"ports": [{"name": "http", "port": 80}, {"name": "https", "port": 443}],
}

# Mock the port check to return True (ports are open)
mock_is_port_open.return_value = True

# Call the method
result = ray_hook.wait_for_load_balancer("test-service", namespace="default", max_retries=1, retry_interval=1)

# Assertions
assert result == mock_get_lb_details.return_value
mock_get_service.assert_called_once_with("test-service", "default")
mock_get_lb_details.assert_called_once_with(mock_service)
assert mock_is_port_open.call_count == 2 # Called for both ports

@patch("ray_provider.hooks.ray.RayHook._get_service")
@patch("ray_provider.hooks.ray.RayHook._get_load_balancer_details")
@patch("ray_provider.hooks.ray.RayHook._is_port_open")
def test_wait_for_load_balancer_timeout(self, mock_is_port_open, mock_get_lb_details, mock_get_service, ray_hook):
# Mock the service
mock_service = Mock(spec=client.V1Service)
mock_get_service.return_value = mock_service

# Mock the load balancer details
mock_get_lb_details.return_value = {
"ip_or_hostname": "test-lb.example.com",
"ports": [{"name": "http", "port": 80}],
}

# Mock the port check to return False (port is not open)
mock_is_port_open.return_value = False

# Call the method and expect an AirflowException
with pytest.raises(AirflowException) as exc_info:
ray_hook.wait_for_load_balancer("test-service", namespace="default", max_retries=2, retry_interval=1)

assert "LoadBalancer did not become ready after 2 attempts" in str(exc_info.value)

@patch("ray_provider.hooks.ray.RayHook._get_service")
def test_wait_for_load_balancer_service_not_found(self, mock_get_service, ray_hook):
# Mock the service to raise an AirflowException (service not found)
mock_get_service.side_effect = AirflowException("Service test-service not found")

# Call the method and expect an AirflowException
with pytest.raises(AirflowException) as exc_info:
ray_hook.wait_for_load_balancer("test-service", namespace="default", max_retries=1, retry_interval=1)

assert "LoadBalancer did not become ready after 1 attempts" in str(exc_info.value)

@patch("ray_provider.hooks.ray.KubernetesHook.get_connection")
@patch("ray_provider.hooks.ray.KubernetesHook.__init__")
@patch("ray_provider.hooks.ray.client.AppsV1Api.read_namespaced_daemon_set")
Expand Down

0 comments on commit e5d9aec

Please sign in to comment.