Skip to content

Commit

Permalink
Merge pull request #249 from NASA-IMPACT/feature/use-pythonoperators
Browse files Browse the repository at this point in the history
Feature/use pythonoperators
  • Loading branch information
amarouane-ABDELHAK authored Oct 25, 2024
2 parents 61a797d + 686c4c8 commit c190564
Show file tree
Hide file tree
Showing 22 changed files with 1,064 additions and 388 deletions.
6 changes: 3 additions & 3 deletions dags/generate_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def generate_dags():

from pathlib import Path


mwaa_stac_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True)
bucket = mwaa_stac_conf["EVENT_BUCKET"]
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
bucket = airflow_vars_json.get("EVENT_BUCKET")

try:
client = boto3.client("s3")
Expand Down
20 changes: 16 additions & 4 deletions dags/veda_data_pipeline/groups/collection_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,40 @@ def ingest_collection_task(ti):
dataset (Dict[str, Any]): dataset dictionary (JSON)
role_arn (str): role arn for Zarr collection generation
"""
import json
collection = ti.xcom_pull(task_ids='Collection.generate_collection')
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
cognito_app_secret = airflow_vars_json.get("COGNITO_APP_SECRET")
stac_ingestor_api_url = airflow_vars_json.get("STAC_INGESTOR_API_URL")

return submission_handler(
event=collection,
endpoint="/collections",
cognito_app_secret=Variable.get("COGNITO_APP_SECRET"),
stac_ingestor_api_url=Variable.get("STAC_INGESTOR_API_URL"),
cognito_app_secret=cognito_app_secret,
stac_ingestor_api_url=stac_ingestor_api_url
)


# NOTE unused, but useful for item ingests, since collections are a dependency for items
def check_collection_exists_task(ti):
import json
config = ti.dag_run.conf
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
stac_url = airflow_vars_json.get("STAC_URL")
return check_collection_exists(
endpoint=Variable.get("STAC_URL", default_var=None),
endpoint=stac_url,
collection_id=config.get("collection"),
)


def generate_collection_task(ti):
import json
config = ti.dag_run.conf
role_arn = Variable.get("ASSUME_ROLE_READ_ARN", default_var=None)
airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
role_arn = airflow_vars_json.get("ASSUME_ROLE_READ_ARN")

# TODO it would be ideal if this also works with complete collections where provided - this would make the collection ingest more re-usable
collection = generator.generate_stac(
Expand Down
43 changes: 21 additions & 22 deletions dags/veda_data_pipeline/groups/discover_group.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
from datetime import timedelta
import time
import json
import uuid

from airflow.models.variable import Variable
from airflow.models.xcom import LazyXComAccess
from airflow.operators.dummy_operator import DummyOperator as EmptyOperator
from airflow.decorators import task_group, task
from airflow.models.baseoperator import chain
from airflow.operators.python import BranchPythonOperator, PythonOperator, ShortCircuitOperator
from airflow.utils.trigger_rule import TriggerRule
from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator
from airflow.decorators import task
from veda_data_pipeline.utils.s3_discovery import (
s3_discovery_handler, EmptyFileListError
)
from veda_data_pipeline.groups.processing_tasks import build_stac_kwargs, submit_to_stac_ingestor_task


group_kwgs = {"group_id": "Discover", "tooltip": "Discover"}


@task(retries=1, retry_delay=timedelta(minutes=1))
def discover_from_s3_task(ti=None, event={}, **kwargs):
"""Discover grouped assets/files from S3 in batches of 2800. Produce a list of such files stored on S3 to process.
Expand All @@ -32,39 +26,44 @@ def discover_from_s3_task(ti=None, event={}, **kwargs):
if event.get("schedule") and last_successful_execution:
config["last_successful_execution"] = last_successful_execution.isoformat()
# (event, chunk_size=2800, role_arn=None, bucket_output=None):
MWAA_STAC_CONF = Variable.get("MWAA_STACK_CONF", deserialize_json=True)
read_assume_arn = Variable.get("ASSUME_ROLE_READ_ARN", default_var=None)

airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
event_bucket = airflow_vars_json.get("EVENT_BUCKET")
read_assume_arn = airflow_vars_json.get("ASSUME_ROLE_READ_ARN")
# Making the chunk size small, this helped us process large data faster than
# passing a large chunk of 500
chunk_size = config.get("chunk_size", 500)
try:
return s3_discovery_handler(
event=config,
role_arn=read_assume_arn,
bucket_output=MWAA_STAC_CONF["EVENT_BUCKET"],
bucket_output=event_bucket,
chunk_size=chunk_size
)
except EmptyFileListError as ex:
print(f"Received an exception {ex}")
# TODO test continued short circuit operator behavior (no files -> skip remaining tasks)
return {}


@task
def get_files_to_process(payload, ti=None):
"""Get files from S3 produced by the discovery task.
Used as part of both the parallel_run_process_rasters and parallel_run_process_vectors tasks.
"""
if isinstance(payload, LazyXComAccess): # if used as part of a dynamic task mapping
if isinstance(payload, LazyXComAccess): # if used as part of a dynamic task mapping
payloads_xcom = payload[0].pop("payload", [])
payload = payload[0]
else:
payloads_xcom = payload.pop("payload", [])
dag_run_id = ti.dag_run.run_id
return [{
"run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}",
**payload,
"payload": payload_xcom,
} for indx, payload_xcom in enumerate(payloads_xcom)]
"run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}",
**payload,
"payload": payload_xcom,
} for indx, payload_xcom in enumerate(payloads_xcom)]


@task
def get_dataset_files_to_process(payload, ti=None):
Expand All @@ -75,16 +74,16 @@ def get_dataset_files_to_process(payload, ti=None):

result = []
for x in payload:
if isinstance(x, LazyXComAccess): # if used as part of a dynamic task mapping
if isinstance(x, LazyXComAccess): # if used as part of a dynamic task mapping
payloads_xcom = x[0].pop("payload", [])
payload_0 = x[0]
else:
payloads_xcom = x.pop("payload", [])
payload_0 = x
for indx, payload_xcom in enumerate(payloads_xcom):
result.append({
"run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}",
**payload_0,
"payload": payload_xcom,
})
"run_id": f"{dag_run_id}_{uuid.uuid4()}_{indx}",
**payload_0,
"payload": payload_xcom,
})
return result
209 changes: 11 additions & 198 deletions dags/veda_data_pipeline/groups/processing_tasks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from datetime import timedelta
import json
import logging

import smart_open
from airflow.models.variable import Variable
from airflow.operators.python import PythonOperator
from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator
from airflow.decorators import task_group, task
from airflow.decorators import task
from veda_data_pipeline.utils.submit_stac import submission_handler

group_kwgs = {"group_id": "Process", "tooltip": "Process"}
Expand All @@ -15,213 +12,29 @@
def log_task(text: str):
logging.info(text)


@task(retries=1, retry_delay=timedelta(minutes=1))
def submit_to_stac_ingestor_task(built_stac:str):
def submit_to_stac_ingestor_task(built_stac: dict):
"""Submit STAC items to the STAC ingestor API."""
event = json.loads(built_stac)
event = built_stac.copy()
success_file = event["payload"]["success_event_key"]

airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
cognito_app_secret = airflow_vars_json.get("COGNITO_APP_SECRET")
stac_ingestor_api_url = airflow_vars_json.get("STAC_INGESTOR_API_URL")
with smart_open.open(success_file, "r") as _file:
stac_items = json.loads(_file.read())

for item in stac_items:
submission_handler(
event=item,
endpoint="/ingestions",
cognito_app_secret=Variable.get("COGNITO_APP_SECRET"),
stac_ingestor_api_url=Variable.get("STAC_INGESTOR_API_URL"),
cognito_app_secret=cognito_app_secret,
stac_ingestor_api_url=stac_ingestor_api_url,
)
return event

@task
def build_stac_kwargs(event={}):
"""Build kwargs for the ECS operator."""
mwaa_stack_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True)
if event:
intermediate = {
**event
} # this is dumb but it resolves the MappedArgument to a dict that can be JSON serialized
payload = json.dumps(intermediate)
else:
payload = "{{ task_instance.dag_run.conf }}"

return {
"overrides": {
"containerOverrides": [
{
"name": f"{mwaa_stack_conf.get('PREFIX')}-veda-stac-build",
"command": [
"/usr/local/bin/python",
"handler.py",
"--payload",
payload,
],
"environment": [
{
"name": "EXTERNAL_ROLE_ARN",
"value": Variable.get(
"ASSUME_ROLE_READ_ARN", default_var=""
),
},
{
"name": "BUCKET",
"value": "veda-data-pipelines-staging-lambda-ndjson-bucket",
},
{
"name": "EVENT_BUCKET",
"value": mwaa_stack_conf.get("EVENT_BUCKET"),
}
],
"memory": 2048,
"cpu": 1024,
},
],
},
"network_configuration": {
"awsvpcConfiguration": {
"securityGroups": mwaa_stack_conf.get("SECURITYGROUPS"),
"subnets": mwaa_stack_conf.get("SUBNETS"),
},
},
"awslogs_group": mwaa_stack_conf.get("LOG_GROUP_NAME"),
"awslogs_stream_prefix": f"ecs/{mwaa_stack_conf.get('PREFIX')}-veda-stac-build",
}

@task
def build_vector_kwargs(event={}):
"""Build kwargs for the ECS operator."""
mwaa_stack_conf = Variable.get(
"MWAA_STACK_CONF", default_var={}, deserialize_json=True
)
vector_ecs_conf = Variable.get(
"VECTOR_ECS_CONF", default_var={}, deserialize_json=True
)

if event:
intermediate = {
**event
}
payload = json.dumps(intermediate)
else:
payload = "{{ task_instance.dag_run.conf }}"

return {
"overrides": {
"containerOverrides": [
{
"name": f"{mwaa_stack_conf.get('PREFIX')}-veda-vector_ingest",
"command": [
"/var/lang/bin/python",
"handler.py",
"--payload",
payload,
],
"environment": [
{
"name": "EXTERNAL_ROLE_ARN",
"value": Variable.get(
"ASSUME_ROLE_READ_ARN", default_var=""
),
},
{
"name": "AWS_REGION",
"value": mwaa_stack_conf.get("AWS_REGION"),
},
{
"name": "VECTOR_SECRET_NAME",
"value": Variable.get("VECTOR_SECRET_NAME"),
},
{
"name": "AWS_STS_REGIONAL_ENDPOINTS",
"value": "regional", # to override this behavior, make sure AWS_REGION is set to `aws-global`
}
],
},
],
},
"network_configuration": {
"awsvpcConfiguration": {
"securityGroups": vector_ecs_conf.get("VECTOR_SECURITY_GROUP"),
"subnets": vector_ecs_conf.get("VECTOR_SUBNETS"),
},
},
"awslogs_group": mwaa_stack_conf.get("LOG_GROUP_NAME"),
"awslogs_stream_prefix": f"ecs/{mwaa_stack_conf.get('PREFIX')}-veda-vector_ingest",
}


@task
def build_generic_vector_kwargs(event={}):
"""Build kwargs for the ECS operator."""
mwaa_stack_conf = Variable.get(
"MWAA_STACK_CONF", default_var={}, deserialize_json=True
)
vector_ecs_conf = Variable.get(
"VECTOR_ECS_CONF", default_var={}, deserialize_json=True
)

if event:
intermediate = {
**event
}
payload = json.dumps(intermediate)
else:
payload = "{{ task_instance.dag_run.conf }}"

return {
"overrides":{
"containerOverrides": [
{
"name": f"{mwaa_stack_conf.get('PREFIX')}-veda-generic_vector_ingest",
"command": [
"/var/lang/bin/python",
"handler.py",
"--payload",
payload,
],
"environment": [
{
"name": "EXTERNAL_ROLE_ARN",
"value": Variable.get(
"ASSUME_ROLE_READ_ARN", default_var=""
),
},
{
"name": "AWS_REGION",
"value": mwaa_stack_conf.get("AWS_REGION"),
},
{
"name": "VECTOR_SECRET_NAME",
"value": Variable.get("VECTOR_SECRET_NAME"),
},
{
"name": "AWS_STS_REGIONAL_ENDPOINTS",
"value": "regional", # to override this behavior, make sure AWS_REGION is set to `aws-global`
}
],
},
],
},
"network_configuration":{
"awsvpcConfiguration": {
"securityGroups": vector_ecs_conf.get("VECTOR_SECURITY_GROUP") + mwaa_stack_conf.get("SECURITYGROUPS"),
"subnets": vector_ecs_conf.get("VECTOR_SUBNETS"),
},
},
"awslogs_group":mwaa_stack_conf.get("LOG_GROUP_NAME"),
"awslogs_stream_prefix":f"ecs/{mwaa_stack_conf.get('PREFIX')}-veda-generic-vector_ingest", # prefix with container name
}


@task_group
def subdag_process(event={}):

build_stac = EcsRunTaskOperator.partial(
task_id="build_stac"
).expand_kwargs(build_stac_kwargs(event=event))

submit_to_stac_ingestor = PythonOperator(
task_id="submit_to_stac_ingestor",
python_callable=submit_to_stac_ingestor_task,
)

build_stac >> submit_to_stac_ingestor
Loading

0 comments on commit c190564

Please sign in to comment.