diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..b3e11a0 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,29 @@ +import pytest + +from ray_provider import __version__, get_provider_info + + +def test_get_provider_info(): + expected_info = { + "package-name": "astro-provider-ray", + "name": "Ray", + "description": "An integration between airflow and ray", + "connection-types": [{"connection-type": "ray", "hook-class-name": "ray_provider.hooks.ray.RayHook"}], + "versions": [__version__], + } + + result = get_provider_info() + + assert result == expected_info + assert isinstance(result, dict) + assert "package-name" in result + assert "name" in result + assert "description" in result + assert "connection-types" in result + assert "versions" in result + assert isinstance(result["versions"], list) + assert result["versions"][0] == __version__ + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/hooks/test_ray_hooks.py b/tests/hooks/test_ray_hooks.py index 4928003..7f22886 100644 --- a/tests/hooks/test_ray_hooks.py +++ b/tests/hooks/test_ray_hooks.py @@ -1,7 +1,9 @@ -from unittest.mock import MagicMock, mock_open, patch +import subprocess +from unittest.mock import MagicMock, Mock, mock_open, patch import pytest from airflow.exceptions import AirflowException +from kubernetes import client from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus @@ -10,6 +12,20 @@ class TestRayHook: + @pytest.fixture + def ray_hook(self): + with patch("ray_provider.hooks.ray.KubernetesHook.get_connection") as mock_get_connection: + mock_connection = Mock() + mock_connection.extra_dejson = {"kube_config_path": None, "kube_config": None, "cluster_context": None} + mock_get_connection.return_value = mock_connection + + with patch("ray_provider.hooks.ray.KubernetesHook.__init__", return_value=None): + hook = RayHook(conn_id="test_conn") + # Manually set the necessary attributes + hook.namespace = "default" + hook.kubeconfig = "/path/to/kubeconfig" + return hook + @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") @patch("ray_provider.hooks.ray.KubernetesHook.__init__") def test_init(self, mock_kubernetes_init, mock_get_connection): @@ -189,6 +205,109 @@ def test_is_port_open(self, mock_socket, mock_get_connection): result = hook._is_port_open("localhost", 8080) assert result is True + @patch("ray_provider.hooks.ray.RayHook.core_v1_client") + def test_get_service_success(self, mock_core_v1_client, ray_hook): + mock_service = Mock(spec=client.V1Service) + mock_core_v1_client.read_namespaced_service.return_value = mock_service + + service = ray_hook._get_service("test-service", "default") + + assert service == mock_service + mock_core_v1_client.read_namespaced_service.assert_called_once_with("test-service", "default") + + @patch("ray_provider.hooks.ray.RayHook.core_v1_client") + def test_get_service_not_found(self, mock_core_v1_client, ray_hook): + mock_core_v1_client.read_namespaced_service.side_effect = client.exceptions.ApiException(status=404) + + with pytest.raises(AirflowException) as exc_info: + ray_hook._get_service("non-existent-service", "default") + + assert "Service non-existent-service not found" in str(exc_info.value) + + def test_get_load_balancer_details_with_ingress(self, ray_hook): + mock_service = Mock(spec=client.V1Service) + mock_ingress = Mock(spec=client.V1LoadBalancerIngress) + mock_ingress.ip = "192.168.1.1" + mock_ingress.hostname = None + mock_service.status.load_balancer.ingress = [mock_ingress] + + mock_port = Mock() + mock_port.name = "http" + mock_port.port = 80 + mock_service.spec.ports = [mock_port] + + lb_details = ray_hook._get_load_balancer_details(mock_service) + + assert lb_details == {"ip_or_hostname": "192.168.1.1", "ports": [{"name": "http", "port": 80}]} + + def test_get_load_balancer_details_with_hostname(self, ray_hook): + mock_service = Mock(spec=client.V1Service) + mock_ingress = Mock(spec=client.V1LoadBalancerIngress) + mock_ingress.ip = None + mock_ingress.hostname = "example.com" + mock_service.status.load_balancer.ingress = [mock_ingress] + + mock_port = Mock() + mock_port.name = "https" + mock_port.port = 443 + mock_service.spec.ports = [mock_port] + + lb_details = ray_hook._get_load_balancer_details(mock_service) + + assert lb_details == {"ip_or_hostname": "example.com", "ports": [{"name": "https", "port": 443}]} + + def test_get_load_balancer_details_no_ingress(self, ray_hook): + mock_service = Mock(spec=client.V1Service) + mock_service.status.load_balancer.ingress = None + + lb_details = ray_hook._get_load_balancer_details(mock_service) + + assert lb_details is None + + def test_get_load_balancer_details_no_ip_or_hostname(self, ray_hook): + mock_service = Mock(spec=client.V1Service) + mock_ingress = Mock(spec=client.V1LoadBalancerIngress) + mock_ingress.ip = None + mock_ingress.hostname = None + mock_service.status.load_balancer.ingress = [mock_ingress] + + lb_details = ray_hook._get_load_balancer_details(mock_service) + + assert lb_details is None + + @patch("ray_provider.hooks.ray.RayHook.log") + @patch("ray_provider.hooks.ray.subprocess.run") + def test_run_bash_command_exception(self, mock_subprocess_run, mock_log, ray_hook): + # Simulate a CalledProcessError + mock_subprocess_run.side_effect = subprocess.CalledProcessError( + returncode=1, cmd="test command", output="test output", stderr="test error" + ) + + # Call the method + stdout, stderr = ray_hook._run_bash_command("test command") + + # Assert that the method returned None for both stdout and stderr + assert stdout is None + assert stderr is None + + # Assert that the error was logged + mock_log.error.assert_any_call( + "An error occurred while executing the command: %s", mock_subprocess_run.side_effect + ) + mock_log.error.assert_any_call("Return code: %s", 1) + mock_log.error.assert_any_call("Standard Output: %s", "test output") + mock_log.error.assert_any_call("Standard Error: %s", "test error") + + # Verify that subprocess.run was called with the correct arguments + mock_subprocess_run.assert_called_once_with( + "test command", + shell=True, + check=True, + text=True, + capture_output=True, + env=ray_hook._run_bash_command.__globals__["os"].environ.copy(), + ) + @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") @patch("ray_provider.hooks.ray.KubernetesHook.__init__") @patch("ray_provider.hooks.ray.subprocess.run") @@ -217,6 +336,66 @@ def test_uninstall_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_i assert "uninstall output" in stdout assert stderr == "" + @patch("ray_provider.hooks.ray.RayHook._get_service") + @patch("ray_provider.hooks.ray.RayHook._get_load_balancer_details") + @patch("ray_provider.hooks.ray.RayHook._is_port_open") + def test_wait_for_load_balancer_success(self, mock_is_port_open, mock_get_lb_details, mock_get_service, ray_hook): + # Mock the service + mock_service = Mock(spec=client.V1Service) + mock_get_service.return_value = mock_service + + # Mock the load balancer details + mock_get_lb_details.return_value = { + "ip_or_hostname": "test-lb.example.com", + "ports": [{"name": "http", "port": 80}, {"name": "https", "port": 443}], + } + + # Mock the port check to return True (ports are open) + mock_is_port_open.return_value = True + + # Call the method + result = ray_hook.wait_for_load_balancer("test-service", namespace="default", max_retries=1, retry_interval=1) + + # Assertions + assert result == mock_get_lb_details.return_value + mock_get_service.assert_called_once_with("test-service", "default") + mock_get_lb_details.assert_called_once_with(mock_service) + assert mock_is_port_open.call_count == 2 # Called for both ports + + @patch("ray_provider.hooks.ray.RayHook._get_service") + @patch("ray_provider.hooks.ray.RayHook._get_load_balancer_details") + @patch("ray_provider.hooks.ray.RayHook._is_port_open") + def test_wait_for_load_balancer_timeout(self, mock_is_port_open, mock_get_lb_details, mock_get_service, ray_hook): + # Mock the service + mock_service = Mock(spec=client.V1Service) + mock_get_service.return_value = mock_service + + # Mock the load balancer details + mock_get_lb_details.return_value = { + "ip_or_hostname": "test-lb.example.com", + "ports": [{"name": "http", "port": 80}], + } + + # Mock the port check to return False (port is not open) + mock_is_port_open.return_value = False + + # Call the method and expect an AirflowException + with pytest.raises(AirflowException) as exc_info: + ray_hook.wait_for_load_balancer("test-service", namespace="default", max_retries=2, retry_interval=1) + + assert "LoadBalancer did not become ready after 2 attempts" in str(exc_info.value) + + @patch("ray_provider.hooks.ray.RayHook._get_service") + def test_wait_for_load_balancer_service_not_found(self, mock_get_service, ray_hook): + # Mock the service to raise an AirflowException (service not found) + mock_get_service.side_effect = AirflowException("Service test-service not found") + + # Call the method and expect an AirflowException + with pytest.raises(AirflowException) as exc_info: + ray_hook.wait_for_load_balancer("test-service", namespace="default", max_retries=1, retry_interval=1) + + assert "LoadBalancer did not become ready after 1 attempts" in str(exc_info.value) + @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") @patch("ray_provider.hooks.ray.KubernetesHook.__init__") @patch("ray_provider.hooks.ray.client.AppsV1Api.read_namespaced_daemon_set")