Skip to content

Commit

Permalink
Merge pull request #7 from astronomer/fix_ci
Browse files Browse the repository at this point in the history
Use hatch to run tests in CI

- Remove the legacy CI
- Runs tests with hatch
- Enable static checks
- Enable codecov
- Fix test
CI: https://github.com/astronomer/astro-provider-ray/actions/runs/10012466204/
  • Loading branch information
pankajastro authored Jul 19, 2024
2 parents e5526b4 + 99a38c2 commit b4d1a7b
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 112 deletions.
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[run]
omit =
tests/*
41 changes: 0 additions & 41 deletions .github/workflows/python-package.yml

This file was deleted.

74 changes: 74 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
name: test

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
Authorize:
environment: ${{ github.event_name == 'pull_request_target' && github.event.pull_request.head.repo.full_name != github.repository && 'external' || 'internal' }}
runs-on: ubuntu-latest
steps:
- run: true

Static-Check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha || github.ref }}
- uses: actions/setup-python@v4
with:
python-version: "3.11"
architecture: "x64"
- run: pip3 install hatch
- run: hatch run tests.py3.11-2.9:static-check

Run-Unit-Tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
airflow-version: ["2.7", "2.8", "2.9"]
exclude:
- python-version: "3.12"
airflow-version: "2.7"
- python-version: "3.12"
airflow-version: "2.8"
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha || github.ref }}

- uses: actions/cache@v4
with:
path: |
~/.cache/pip
.nox
key: unit-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.airflow-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('ray_provider/__init__.py') }}

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install packages and dependencies
run: |
python -m pip install hatch
hatch -e tests.py${{ matrix.python-version }}-${{ matrix.airflow-version }} run pip freeze
- name: Test Ray against Airflow ${{ matrix.airflow-version }} and Python ${{ matrix.python-version }}
run: |
hatch run tests.py${{ matrix.python-version }}-${{ matrix.airflow-version }}:test-cov
- name: Upload coverage to Github
uses: actions/upload-artifact@v4
with:
name: coverage-unit-test-${{ matrix.python-version }}-${{ matrix.airflow-version }}
path: .coverage
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ repos:
name: Run codespell to check for common misspellings in files
language: python
types: [text]
args: ["--ignore-words", codespell-ignore-words.txt]
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
Expand Down
1 change: 1 addition & 0 deletions codespell-ignore-words.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
asend
5 changes: 4 additions & 1 deletion ray_provider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

__version__ = "1.0.0"

from typing import Any


## This is needed to allow Airflow to pick up specific metadata fields it needs for certain features.
# This is needed to allow Airflow to pick up specific metadata fields it needs for certain features.
def get_provider_info() -> dict[str, Any]:
return {
"package-name": "astro-provider-ray", # Required
Expand Down
14 changes: 14 additions & 0 deletions scripts/test/pre-install-airflow.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

AIRFLOW_VERSION="$1"
PYTHON_VERSION="$2"

CONSTRAINT_URL="https://raw.githubusercontent.com/apache/airflow/constraints-$AIRFLOW_VERSION.0/constraints-$PYTHON_VERSION.txt"
curl -sSL $CONSTRAINT_URL -o /tmp/constraint.txt
# Workaround to remove PyYAML constraint that will work on both Linux and MacOS
sed '/PyYAML==/d' /tmp/constraint.txt > /tmp/constraint.txt.tmp
mv /tmp/constraint.txt.tmp /tmp/constraint.txt
# Install Airflow with constraints
pip install apache-airflow==$AIRFLOW_VERSION --constraint /tmp/constraint.txt
pip install pydantic --constraint /tmp/constraint.txt
rm /tmp/constraint.txt
6 changes: 6 additions & 0 deletions scripts/test/unit_cov.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pytest \
-vv \
--cov=ray_provider \
--cov-report=term-missing \
--cov-report=xml \
--durations=0
4 changes: 4 additions & 0 deletions scripts/test/unit_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pytest \
-vv \
--durations=0 \
-m "not (integration or perf)"
5 changes: 1 addition & 4 deletions tests/decorators/test_ray_decorators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -29,13 +28,11 @@ def dummy_callable():

operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)

assert operator.host == "http://localhost:8265"
assert operator.entrypoint == "python my_script.py"
assert operator.runtime_env == {"pip": ["ray"]}
assert operator.num_cpus == 2
assert operator.num_gpus == 1
assert operator.memory == "1G"
assert operator.node_group is None

@patch.object(_RayDecoratedOperator, "get_python_source")
@patch.object(SubmitRayJob, "execute")
Expand Down Expand Up @@ -67,7 +64,7 @@ def dummy_callable():
pass

operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable)
assert operator.host == os.getenv("RAY_DASHBOARD_URL")
assert operator.entrypoint == "python my_script.py"

def test_invalid_config_raises_exception(self):
config = {
Expand Down
50 changes: 16 additions & 34 deletions tests/operators/test_ray_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from ray.job_submission import JobStatus

from ray_provider.operators.ray import SubmitRayJob

Expand All @@ -21,7 +20,7 @@
@pytest.fixture
def operator():
return SubmitRayJob(
host=host,
conn_id="test_conn",
entrypoint=entrypoint,
runtime_env=runtime_env,
num_cpus=num_cpus,
Expand All @@ -36,51 +35,34 @@ def operator():
class TestSubmitRayJob:

def test_init(self, operator):
assert operator.host == host
assert operator.conn_id == "test_conn"
assert operator.entrypoint == entrypoint
assert operator.runtime_env == runtime_env
assert operator.num_cpus == num_cpus
assert operator.num_gpus == num_gpus
assert operator.memory == memory
assert operator.resources == resources
# assert operator.resources == resources
assert operator.timeout == timeout
assert operator.client is None
assert operator.job_id is None
assert operator.status_to_wait_for == {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}

@patch("ray_provider.operators.kuberay.JobSubmissionClient")
def test_execute(self, mock_client_class, operator):
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.submit_job.return_value = "job_12345"
mock_client.get_job_status.return_value = JobStatus.RUNNING

try:
@patch("ray_provider.operators.ray.SubmitRayJob.hook")
def test_execute(self, mock_hook, operator):
with pytest.raises(TaskDeferred):
operator.execute(context)
except TaskDeferred:
pass

mock_client_class.assert_called_once_with(host)
mock_client.submit_job.assert_called_once_with(
entrypoint=entrypoint,
runtime_env=runtime_env,
entrypoint_num_cpus=num_cpus,
entrypoint_num_gpus=num_gpus,
entrypoint_memory=memory,
entrypoint_resources=resources,
mock_hook.submit_ray_job.assert_called_once_with(
entrypoint="python script.py",
runtime_env={"pip": ["requests"]},
entrypoint_num_cpus=2,
entrypoint_num_gpus=1,
entrypoint_memory=1024,
entrypoint_resources={"CPU": 2},
)
assert operator.job_id == "job_12345"

@patch("ray_provider.operators.kuberay.JobSubmissionClient")
def test_on_kill(self, mock_client_class, operator):
mock_client = MagicMock()
mock_client_class.return_value = mock_client
operator.client = mock_client
@patch("ray_provider.operators.ray.SubmitRayJob.hook")
def test_on_kill(self, mock_hook, operator):
operator.job_id = "job_12345"

operator.on_kill()

mock_client.delete_job.assert_called_once_with("job_12345")
mock_hook.delete_ray_job.assert_called_once_with("job_12345")

def test_execute_complete_success(self, operator):
event = {"status": "success", "message": "Job completed successfully"}
Expand Down
57 changes: 25 additions & 32 deletions tests/triggers/test_ray_triggers.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,39 @@
import time
from unittest import mock
from unittest.mock import patch

import pytest
from airflow.triggers.base import TriggerEvent
from ray.dashboard.modules.job.sdk import JobStatus, JobSubmissionClient
from ray.dashboard.modules.job.sdk import JobStatus

from ray_provider.triggers.ray import RayJobTrigger


class TestRayJobTrigger:

@pytest.mark.asyncio
async def test_run_no_job_id(self):
trigger = RayJobTrigger(job_id="", host="localhost", end_time=time.time() + 60, poll_interval=1)
@patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.ray.RayJobTrigger.hook")
async def test_run_no_job_id(self, mock_hook, mock_is_terminal):
mock_is_terminal.return_value = True
trigger = RayJobTrigger(job_id="", poll_interval=1, conn_id="test", xcom_dashboard_url="test")

generator = trigger.run()
event = await generator.send(None)
assert event == TriggerEvent(
{"status": "error", "message": "No job_id provided to async trigger", "job_id": ""}
)
event = await generator.asend(None)
assert event == TriggerEvent({"status": "error", "message": "Job run has failed.", "job_id": ""})

@pytest.mark.asyncio
async def test_run_job_succeeded(self):
trigger = RayJobTrigger(job_id="test_job_id", host="localhost", end_time=time.time() + 60, poll_interval=1)

client_mock = mock.MagicMock(spec=JobSubmissionClient)
client_mock.get_job_status.return_value = JobStatus.SUCCEEDED

async def async_generator():
yield "log line 1"
yield "log line 2"

client_mock.tail_job_logs.return_value = async_generator()

with mock.patch("ray_provider.triggers.kuberay.JobSubmissionClient", return_value=client_mock):
generator = trigger.run()
async for event in generator:
assert event == TriggerEvent(
{
"status": "success",
"message": "Job run test_job_id has completed successfully.",
"job_id": "test_job_id",
}
)
break # Stop after the first event for testing purposes
@patch("ray_provider.triggers.ray.RayJobTrigger.hook")
async def test_run_job_succeeded(self, mock_hook):
trigger = RayJobTrigger(job_id="test_job_id", poll_interval=1, conn_id="test", xcom_dashboard_url="test")

mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED

generator = trigger.run()
async for event in generator:
assert event == TriggerEvent(
{
"status": "success",
"message": "Job run test_job_id has completed successfully.",
"job_id": "test_job_id",
}
)
break # Stop after the first event for testing purposes

0 comments on commit b4d1a7b

Please sign in to comment.