Skip to content

Commit

Permalink
♻️ refactor: CLIN-3267 Code review suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
laurabegin committed Oct 10, 2024
1 parent dda54a2 commit 98ef776
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 183 deletions.
4 changes: 1 addition & 3 deletions dags/lib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from airflow.exceptions import AirflowConfigException
from airflow.models import Variable

from lib import config_nextflow_pipelines
from lib.operators.base_kubernetes import KubeConfig
from lib.operators.base_kubernetes import KubeConfig, ConfigMap
from lib.operators.nextflow import NextflowOperatorConfig


Expand Down Expand Up @@ -251,5 +250,4 @@ def k8s_load_config(context: str) -> None:
persistent_volume_sub_path='workspace',
persistent_volume_mount_path="/mnt/workspace",
nextflow_working_dir=f's3://{clin_scratch_bucket}/nextflow/scratch',
config_maps=[config_nextflow_pipelines.default_config_map]
)
43 changes: 0 additions & 43 deletions dags/lib/config_nextflow_pipelines.py

This file was deleted.

6 changes: 3 additions & 3 deletions dags/lib/operators/base_kubernetes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field, fields, asdict
from typing import Optional, List, Type, TypeVar
from typing_extensions import Self

Expand Down Expand Up @@ -82,14 +82,14 @@ def __init__(
)
self.image_pull_secrets_name = image_pull_secrets_name

def execute(self, context: Context, **kwargs):
def execute(self, context: Context):
if self.image_pull_secrets_name:
self.image_pull_secrets = [
k8s.V1LocalObjectReference(
name=self.image_pull_secrets_name,
),
]
super().execute(context, **kwargs)
super().execute(context)


@dataclass
Expand Down
128 changes: 69 additions & 59 deletions dags/lib/operators/nextflow.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import copy
from dataclasses import dataclass, field
import logging
import re
from dataclasses import dataclass, field
from typing import Optional, List, Type
from typing_extensions import Self

from airflow.exceptions import AirflowSkipException
from airflow.utils.context import Context
from kubernetes.client import models as k8s
from typing_extensions import Self

from lib.operators.base_kubernetes import (
BaseConfig,
Expand All @@ -18,6 +18,12 @@

logger = logging.getLogger(__name__)

DEFAULT_NEXTFLOW_CONFIG_MAP = ConfigMap(
name='nextflow',
mount_path='/root/nextflow/config'
)
DEFAULT_NEXTFLOW_CONFIG_FILE = f"{DEFAULT_NEXTFLOW_CONFIG_MAP.mount_path}/nextflow.config"


class NextflowOperator(BaseKubernetesOperator):
"""
Expand All @@ -36,23 +42,27 @@ class NextflowOperator(BaseKubernetesOperator):
and nextflow configuration file(s) are injected through
Kubernetes configmaps.
"""

template_fields = [*BaseKubernetesOperator.template_fields, 'skip']
template_fields = [*BaseKubernetesOperator.template_fields, 'config_maps', 'nextflow_pipeline',
'nextflow_working_dir', 'nextflow_config_files', 'nextflow_params_files',
'nextflow_pipeline_revision', 'skip']

def __init__(
self,
minio_credentials_secret_name: str,
minio_credentials_secret_access_key: str,
minio_credentials_secret_secret_key: str,
persistent_volume_claim_name: str,
persistent_volume_sub_path: str,
persistent_volume_mount_path: str,
nextflow_working_dir: str,
skip: bool = False,
config_maps: Optional[List[ConfigMap]] = None,
**kwargs
self,
config_maps: List[ConfigMap],
minio_credentials_secret_name: str,
minio_credentials_secret_access_key: str,
minio_credentials_secret_secret_key: str,
persistent_volume_claim_name: str,
persistent_volume_sub_path: str,
persistent_volume_mount_path: str,
nextflow_pipeline: str,
nextflow_working_dir: str,
nextflow_config_files: List[str],
nextflow_params_files: Optional[List[str]] = None,
nextflow_pipeline_revision: Optional[str] = None,
skip: bool = False,
**kwargs
) -> None:

super().__init__(
**kwargs
)
Expand All @@ -66,15 +76,30 @@ def __init__(

# Where nextflow will write intermediate outputs. This is different
# from the pod working directory.
self.nextflow_pipeline = nextflow_pipeline
self.nextflow_pipeline_revision = nextflow_pipeline_revision
self.nextflow_working_dir = nextflow_working_dir
self.nextflow_config_files = nextflow_config_files
self.nextflow_params_files = nextflow_params_files if nextflow_params_files else []
self.skip = skip
self.config_maps = config_maps if config_maps else []

def execute(self, context: Context, **kwargs):
self.config_maps = config_maps

def execute(self, context: Context):
if self.skip:
raise AirflowSkipException()

# Prepare nextflow arguments
nextflow_revision_option = ['-r', self.nextflow_pipeline_revision] if self.nextflow_pipeline_revision else []
nextflow_config_file_options = [arg for file in self.nextflow_config_files for arg in ['-c', file] if file]
nextflow_params_file_options = [arg for file in self.nextflow_params_files for arg in ['-params-file', file] if
file]
arguments = [arg for arg in self.arguments if arg] if self.arguments else [] # Remove empty strings

self.arguments = ['nextflow', 'run', self.nextflow_pipeline, *nextflow_revision_option,
*nextflow_config_file_options, *nextflow_params_file_options, *arguments]

logger.info(f"Running arguments : {self.arguments}")

self.env_vars = [
k8s.V1EnvVar(
name="NXF_WORK",
Expand Down Expand Up @@ -149,8 +174,8 @@ def execute(self, context: Context, **kwargs):
# configureit within the container specification in attribute
# full_pod_spec.
pod_working_dir = _get_pod_working_dir(
self.persistent_volume_mount_path,
context
self.persistent_volume_mount_path,
context
)

logger.info("Setting pod working directory to %s", pod_working_dir)
Expand All @@ -160,33 +185,38 @@ def execute(self, context: Context, **kwargs):
container_name=self.base_container_name
)

super().execute(context, **kwargs)
super().execute(context)


@dataclass
class NextflowPipeline:
"""
Represents a nextflow pipeline to be executed in a nextflow pod.
"""
url: Optional[str] = None
revision: Optional[str] = None
config_maps: List[ConfigMap] = field(default_factory=list)
config_files: List[str] = field(default_factory=list)
params_file: Optional[str] = None
class NextflowOperatorConfig(BaseConfig):
nextflow_pipeline: Optional[str] = None,
nextflow_pipeline_revision: Optional[str] = None
nextflow_config_files: List[str] = field(default_factory=lambda: [DEFAULT_NEXTFLOW_CONFIG_FILE])
nextflow_params_files: Optional[str] = None
config_maps: List[ConfigMap] = field(default_factory=lambda: [DEFAULT_NEXTFLOW_CONFIG_MAP])
minio_credentials_secret_name: str = required()
minio_credentials_secret_access_key: str = required()
minio_credentials_secret_secret_key: str = required()
persistent_volume_claim_name: str = required()
persistent_volume_sub_path: str = required()
persistent_volume_mount_path: str = required()
nextflow_working_dir: str = required()
skip: bool = False

def with_url(self, new_url: str) -> Self:
def with_pipeline(self, pipeline: str) -> Self:
c = copy.copy(self)
c.url = new_url
c.nextflow_pipeline = pipeline
return c

def with_revision(self, new_revision: str) -> Self:
def with_revision(self, revision: str) -> Self:
c = copy.copy(self)
c.revision = new_revision
c.pipeline_revision = revision
return c

def with_params_file(self, new_params_file: str) -> Self:
def with_params_file(self, params_file: str) -> Self:
c = copy.copy(self)
c.params_file = new_params_file
c.nextflow_params_files = params_file
return c

def append_config_maps(self, *new_config_maps) -> Self:
Expand All @@ -196,27 +226,7 @@ def append_config_maps(self, *new_config_maps) -> Self:

def append_config_files(self, *new_config_files) -> Self:
c = copy.copy(self)
c.config_files = [*self.config_files, *new_config_files]
return c


@dataclass
class NextflowOperatorConfig(BaseConfig):

minio_credentials_secret_name: str = required()
minio_credentials_secret_access_key: str = required()
minio_credentials_secret_secret_key: str = required()
persistent_volume_claim_name: str = required()
persistent_volume_sub_path: str = required()
persistent_volume_mount_path: str = required()
nextflow_working_dir: str = required()
skip: bool = False

config_maps: List[str] = field(default_factory=list)

def with_config_maps(self, new_config_maps: List[ConfigMap]) -> Self:
c = copy.copy(self)
c.config_maps = new_config_maps
c.nextflow_config_files = [*self.nextflow_config_files, *new_config_files]
return c

def extend_config_maps(self, *new_config_maps) -> Self:
Expand All @@ -226,7 +236,7 @@ def extend_config_maps(self, *new_config_maps) -> Self:

def operator(self,
class_to_instantiate: Type[NextflowOperator] = NextflowOperator,
**kwargs) -> NextflowOperator:
**kwargs) -> BaseKubernetesOperator:
return super().build_operator(
class_to_instantiate=class_to_instantiate,
**kwargs
Expand Down
42 changes: 10 additions & 32 deletions dags/lib/tasks/nextflow.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
from typing import List

from airflow.exceptions import AirflowSkipException
from airflow.models import MappedOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.context import Context

from lib.config import (
clin_datalake_bucket,
s3_conn_id,
nextflow_base_config
)
from lib.config_nextflow_pipelines import (
NextflowPipeline,
svclustering_pipeline
)
from lib.config import (clin_datalake_bucket, s3_conn_id, nextflow_base_config)
from lib.operators.nextflow import NextflowOperator
from lib.operators.spark_etl import SparkETLOperator
from lib.utils_etl import ClinAnalysis
Expand All @@ -21,7 +14,7 @@
def prepare_svclustering_parental_origin(
batch_ids: List[str],
spark_jar: str,
skip: str = '') -> SparkETLOperator:
skip: str = '') -> MappedOperator:
return SparkETLOperator.partial(
task_id='prepare_svclustering_parental_origin',
name='prepare-svclustering-parental-origin',
Expand Down Expand Up @@ -60,7 +53,7 @@ def __init__(self,
'--outdir', self.output_key
]

def execute(self, context: Context, **kwargs):
def execute(self, context: Context):
batch_type = context['ti'].xcom_pull(
task_ids='detect_batch_type',
key=self.batch_id
Expand All @@ -73,33 +66,18 @@ def execute(self, context: Context, **kwargs):
if not s3.check_for_key(self.input_key):
raise AirflowSkipException(f'No CSV input file for batch id \'{self.batch_id}\'')

super().execute(context, **kwargs)
super().execute(context)

return nextflow_base_config\
return nextflow_base_config \
.with_pipeline('Ferlab-Ste-Justine/ferlab-svclustering-parental-origin') \
.with_revision('v1.1') \
.append_args(
*get_run_pipeline_arguments(svclustering_pipeline),
'--fasta', f's3://{clin_datalake_bucket}/public/refgenomes/hg38/Homo_sapiens_assembly38.fasta',
'--fasta', f's3://{clin_datalake_bucket}/public/refgenomes/hg38/Homo_sapiens_assembly38.fasta',
'--fasta_fai', f's3://{clin_datalake_bucket}/public/refgenomes/hg38/Homo_sapiens_assembly38.fasta.fai',
'--fasta_dict', f's3://{clin_datalake_bucket}/public/refgenomes/hg38/Homo_sapiens_assembly38.dict'
) \
.with_config_maps(svclustering_pipeline.config_maps) \
'--fasta_dict', f's3://{clin_datalake_bucket}/public/refgenomes/hg38/Homo_sapiens_assembly38.dict') \
.partial(
SVClusteringParentalOrigin,
task_id='svclustering_parental_origin',
name='svclustering_parental_origin',
skip=skip
).expand(batch_id=batch_ids)


def get_run_pipeline_arguments(pipeline: NextflowPipeline) -> List[str]:
new_args = ['nextflow']

for config_file in pipeline.config_files:
new_args.extend(["-c", config_file])

new_args.extend(["run", pipeline.url, "-r", pipeline.revision])

if (pipeline.params_file):
new_args.extend(["-params-file", pipeline.params_file])

return new_args
Loading

0 comments on commit 98ef776

Please sign in to comment.