Source code for ditto.transformers.subdag.check_cluster_subdag_transformer
from typing import List, Callable
import networkx as nx
from airflow import DAG
from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.models.taskinstance import TaskInstance
from ditto.api import SubDagTransformer, TaskMatcher, DAGFragment, TransformerDefaults
from ditto.matchers import PythonCallTaskMatcher, ClassTaskMatcher
from ditto.transformers.emr import EmrCreateJobFlowOperatorTransformer
from ditto.utils import TransformerUtils
[docs]class CheckClusterSubDagTransformer(SubDagTransformer):
"""
Transforms a create-if-not-exists subdag pattern
commonly used in EMR airflow DAGs to the corresponding
pattern for an Azure HDI airflow DAG
"""
def __init__(self, dag: DAG, defaults: TransformerDefaults):
super().__init__(dag, defaults)
#: the python callable which checks if the cluster exists.
self.check_for_existing_emr_cluster = defaults.other_defaults['pycall_check_cluster']
[docs] def get_sub_dag_matcher(self) -> List[TaskMatcher]:
check_for_emr_cluster_op = PythonCallTaskMatcher(self.check_for_existing_emr_cluster)
create_cluster_op = ClassTaskMatcher(EmrCreateJobFlowOperator)
cluster_exists_op = ClassTaskMatcher(DummyOperator)
get_cluster_id_op = PythonCallTaskMatcher(TaskInstance.xcom_pull)
check_for_emr_cluster_op >> [create_cluster_op, cluster_exists_op]
create_cluster_op >> get_cluster_id_op
cluster_exists_op >> get_cluster_id_op
return [check_for_emr_cluster_op]
[docs] def transform(self, subdag: nx.DiGraph, parent_fragment: DAGFragment) -> DAGFragment:
transformer = EmrCreateJobFlowOperatorTransformer(self.dag, self.defaults)
return transformer.transform(
TransformerUtils.find_matching_tasks(
subdag, ClassTaskMatcher(EmrCreateJobFlowOperator))[0], parent_fragment)