Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CMIP6 Ingestion DAG #258

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions dags/veda_data_pipeline/groups/discover_group.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,17 @@
from datetime import timedelta
import time
import uuid

from airflow.models.variable import Variable
from airflow.models.xcom import LazyXComAccess
from airflow.operators.dummy_operator import DummyOperator as EmptyOperator
from airflow.decorators import task_group, task
from airflow.models.baseoperator import chain
from airflow.operators.python import BranchPythonOperator, PythonOperator, ShortCircuitOperator
from airflow.utils.trigger_rule import TriggerRule
from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator
from airflow.decorators import task
from veda_data_pipeline.utils.s3_discovery import (
s3_discovery_handler, EmptyFileListError
s3_discovery_handler, EmptyFileListError, cmip_discovery_handler
)
from veda_data_pipeline.groups.processing_tasks import build_stac_kwargs, submit_to_stac_ingestor_task


group_kwgs = {"group_id": "Discover", "tooltip": "Discover"}

@task(retries=1, retry_delay=timedelta(minutes=1))
def discover_from_s3_task(ti=None, event={}, **kwargs):
def discover_from_s3_task(ti=None, event={}, asset_prediction=False, **kwargs):
"""Discover grouped assets/files from S3 in batches of 2800. Produce a list of such files stored on S3 to process.
This task is used as part of the discover_group subdag and outputs data to EVENT_BUCKET.
"""
Expand All @@ -38,12 +30,19 @@ def discover_from_s3_task(ti=None, event={}, **kwargs):
# passing a large chunk of 500
chunk_size = config.get("chunk_size", 500)
try:
return s3_discovery_handler(
event=config,
role_arn=read_assume_arn,
bucket_output=MWAA_STAC_CONF["EVENT_BUCKET"],
chunk_size=chunk_size
)
if not asset_prediction:
return s3_discovery_handler(
event=config,
role_arn=read_assume_arn,
bucket_output=MWAA_STAC_CONF["EVENT_BUCKET"],
chunk_size=chunk_size
)
else:
return cmip_discovery_handler(
event=config,
role_arn=read_assume_arn,
bucket_output=MWAA_STAC_CONF["EVENT_BUCKET"],
)
except EmptyFileListError as ex:
print(f"Received an exception {ex}")
# TODO test continued short circuit operator behavior (no files -> skip remaining tasks)
Expand Down
152 changes: 151 additions & 1 deletion dags/veda_data_pipeline/utils/s3_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
import re
from typing import List
from typing import List, Dict
from uuid import uuid4
from pathlib import Path

Expand Down Expand Up @@ -290,3 +290,153 @@ def s3_discovery_handler(event, chunk_size=2800, role_arn=None, bucket_output=No
except KeyError:
pass
return {**event, "payload": out_keys, "discovered": discovered}

def generate_dates_from_datetime_args(date_fields: Dict[str, str], frequency: str) -> List[str]:
"""
Generate a list of date strings based on date_fields and a frequency.
:param date_fields: Dictionary with datetime-related fields, e.g., 'start_datetime', 'end_datetime'.
:param frequency: The frequency of dates ('daily', 'monthly', 'yearly').
:return: List of dates in 'YYYYMM' or 'YYYYMMDD' format based on frequency.
"""
if "start_datetime" in date_fields and "end_datetime" in date_fields:
start_date = datetime.strptime(date_fields["start_datetime"], "%Y%m%d")
end_date = datetime.strptime(date_fields["end_datetime"], "%Y%m%d")
else:
raise ValueError("Datetime range with 'start_datetime' and 'end_datetime' must be provided.")

current_date = start_date
dates = []

while current_date <= end_date:
if frequency == "day":
dates.append(current_date.strftime("%Y%m%d"))
current_date += timedelta(days=1)
elif frequency == "month":
dates.append(current_date.strftime("%Y%m"))
# Move to the first day of the next month
month = current_date.month + 1 if current_date.month < 12 else 1
year = current_date.year if month > 1 else current_date.year + 1
current_date = current_date.replace(year=year, month=month, day=1)
elif frequency == "year":
dates.append(current_date.strftime("%Y"))
current_date = current_date.replace(year=current_date.year + 1, month=1, day=1)
else:
raise ValueError(f"Unsupported frequency: {frequency}")

print(f"Generated {len(dates)} dates with frequency '{frequency}' from datetime args {date_fields}.")
return dates

# Function to generate item IDs using the id template
def generate_item_ids(dates: List[str], id_template: str) -> List[str]:
item_ids = [id_template.format(date) for date in dates]
print(f"Generated {len(item_ids)} item IDs using template '{id_template}'.")
return item_ids

# Function to generate asset metadata for each item ID
def generate_asset_metadata(variable: str, date: str, model:str, asset_template: str, asset_definitions: Dict[str, dict]) -> Dict[str, dict]:
expected_href = asset_template.format(variable, variable, model, date)
assets = {}

asset_info = asset_definitions.get(variable, {})
asset_metadata = asset_info.copy()
asset_metadata["href"] = expected_href
assets[variable] = asset_metadata
# title and description must be provided, but can be autofilled if needed
asset_metadata["title"] = asset_info.get("title", variable)
asset_metadata["description"] = asset_info.get("description", f"{variable} asset (default description)")

return assets

def asset_exists(s3_client, bucket_name: str, asset_href: str) -> bool:
"""Check if an asset exists in the specified S3 bucket."""
try:
s3_client.head_object(Bucket=bucket_name, Key=asset_href.split(f"s3://{bucket_name}/", 1)[-1])
print(f"Asset '{asset_href}' exists in S3 bucket '{bucket_name}'.")
return True
except s3_client.exceptions.ClientError as e:
print(f"Asset '{asset_href}' does not exist in S3 bucket '{bucket_name}'.")
print(f"Error: {e}")
return False

# Reduced chunk size to reflect more assets per item
def cmip_discovery_handler(event, chunk_size=200, role_arn=None, bucket_output=None):
variables = event.get("variables", ["rlds", "huss", "hurs"])
date_fields = propagate_forward_datetime_args(event)
frequency = date_fields.get("datetime_range", "month") # Frequency (daily, monthly, yearly)
assets = event.get("assets")
id_template = event.get("id_template", "{}-{}")
asset_template = event.get("asset_template", "s3://nex-gddp-cmip6-cog/monthly/CMIP6_ensemble/median/{}/{}_month_ensemble-median_{}_{}.tif")
bucket = event.get("bucket", "nex-gddp-cmip6-cog")
collection = event.get("collection", "cmip6-test")
model = event.get("model", "ssp585")
dry_run = event.get("dry_run", False)

s3_client = boto3.client("s3")

dates = generate_dates_from_datetime_args(date_fields, frequency)
item_ids = generate_item_ids(dates, id_template)
items_with_assets = []
for item_id in item_ids:
_, date = item_id.split("-")
valid_assets = {}
for variable in variables:
assets_metadata = generate_asset_metadata(variable, date, model, asset_template, assets)
for asset_type, asset_metadata in assets_metadata.items():
print(f"Asset type: {asset_type}\nAsset metadata: {asset_metadata}\n")
asset_key = asset_metadata["href"].split(f"s3://{bucket}/", 1)[-1]
if asset_exists(s3_client, bucket, asset_key):
valid_assets[asset_type] = asset_metadata

# Only include items with at least one valid asset
if valid_assets:
items_with_assets.append({"item_id": item_id, "assets": valid_assets})

print(f"Initial discovery completed. {len(items_with_assets)} items with assets found.")

payload = {**event, "objects": []}
slice = event.get("slice")

bucket_output = os.environ.get("EVENT_BUCKET", bucket_output)
key = f"s3://{bucket_output}/events/{collection}"
records = 0
out_keys = []
discovered = []

item_count = 0
for item in items_with_assets:
item_count += 1
if slice:
if item_count < slice[0]:
continue
if (
item_count >= slice[1]
):
break
file_obj = {
"collection": collection,
"item_id": item["item_id"],
"assets": item["assets"],
#"properties": properties,
"datetime_range": frequency,
}

if dry_run and item_count < 10:
print("-DRYRUN- Example item")
print(json.dumps(file_obj))

payload["objects"].append(file_obj)
if records == chunk_size:
out_keys.append(generate_payload(s3_prefix_key=key, payload=payload))
records = 0
discovered.append(len(payload["objects"]))
payload["objects"] = []
records += 1

if payload["objects"]:
out_keys.append(generate_payload(s3_prefix_key=key, payload=payload))
discovered.append(len(payload["objects"]))
try:
del event["assets"]
except KeyError:
pass
return {**event, "payload": out_keys, "discovered": discovered}
125 changes: 125 additions & 0 deletions dags/veda_data_pipeline/veda_cmip6_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import pendulum

from datetime import timedelta
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.trigger_rule import TriggerRule
from airflow.models.variable import Variable
from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator

from veda_data_pipeline.groups.discover_group import discover_from_s3_task, get_files_to_process
from veda_data_pipeline.groups.processing_tasks import build_stac_kwargs, submit_to_stac_ingestor_task


dag_doc_md = """
### Discover files from S3
#### Purpose
This DAG discovers files from either S3 and/or CMR then runs a DAG id `veda_ingest`.
The DAG `veda_ingest` will run in parallel processing (2800 files per each DAG)
#### Notes
- This DAG can run with the following configuration <br>
```json
{
"collection": "collection-id",
"bucket": "veda-data-store-staging",
"prefix": "s3-prefix/",
"filename_regex": "^(.*).tif$",
"id_regex": ".*_(.*).tif$",
"process_from_yyyy_mm_dd": "YYYY-MM-DD",
"id_template": "example-id-prefix-{}",
"datetime_range": "month",
"last_successful_execution": datetime(2015,01,01),
"assets": {
"asset1": {
"title": "Asset type 1",
"description": "First of a multi-asset item.",
"regex": ".*asset1.*",
},
"asset2": {
"title": "Asset type 2",
"description": "Second of a multi-asset item.",
"regex": ".*asset2.*",
},
}
}
```
- [Supports linking to external content](https://github.com/NASA-IMPACT/veda-data-pipelines)
"""

dag_args = {
"start_date": pendulum.today("UTC").add(days=-1),
"catchup": False,
"doc_md": dag_doc_md,
"is_paused_upon_creation": False,
}

template_dag_run_conf = {
{
"asset_template": "s3://nex-gddp-cmip6-cog/monthly/CMIP6_ensemble_median/{}/{}_month_ensemble-median_{}_{}.tif",
"assets": {
"rlds": {
"title": "Example title for rlds variable - can be anything!",
"description": "Asset name must match one of the provided variables"
}
},
"bucket": "nex-gddp-cmip6-cog",
"collection": "CMIP6-ensemble-median-ssp245",
"datetime_range": "month|year|day",
"end_datetime": "21010101",
"id_template": "ssp245-{}",
"start_datetime": "20150101",
"model": "ssp245|used to populate asset hrefs",
"variables": [
"hurs",
"huss",
"pr",
"rlds",
"rsds",
"sfcWind",
"tas",
"tasmax",
"tasmin"
]
}
}


def get_discover_dag(id, event={}):
params_dag_run_conf = event or template_dag_run_conf
with DAG(
id,
schedule_interval=event.get("schedule"),
params=params_dag_run_conf,
**dag_args
) as dag:
# ECS dependency variable
mwaa_stack_conf = Variable.get("MWAA_STACK_CONF", deserialize_json=True)

start = DummyOperator(task_id="Start", dag=dag)
end = DummyOperator(
task_id="End", trigger_rule=TriggerRule.ONE_SUCCESS, dag=dag
)
# define DAG using taskflow notation

discover = discover_from_s3_task(event=event, asset_prediction=True)
get_files = get_files_to_process(payload=discover)
build_stac_kwargs_task = build_stac_kwargs.expand(event=get_files)
# partial() is needed for the operator to be used with taskflow inputs
build_stac = EcsRunTaskOperator.partial(
task_id="build_stac",
execution_timeout=timedelta(minutes=60),
trigger_rule=TriggerRule.NONE_FAILED,
cluster=f"{mwaa_stack_conf.get('PREFIX')}-cluster",
task_definition=f"{mwaa_stack_conf.get('PREFIX')}-tasks",
launch_type="FARGATE",
do_xcom_push=True
).expand_kwargs(build_stac_kwargs_task)
# .output is needed coming from a non-taskflow operator
submit_stac = submit_to_stac_ingestor_task.expand(built_stac=build_stac.output)

discover.set_upstream(start)
submit_stac.set_downstream(end)

return dag

get_discover_dag("veda_cmip6_discover")