Source code for ditto.utils
from collections import deque
from typing import Type, List, Tuple
from airflow import DAG
from airflow.models import BaseOperator
import networkx as nx
from ditto.api import TaskMatcher, DAGFragment
import re
from queue import Queue
[docs]class TransformerUtils:
[docs] @staticmethod
def get_list_index_from_xcom_pull(xcom_template: str) -> str:
"""
Parses an airflow template variable and finds the list index accessed
from the return value of an :meth:`~airflow.models.TaskInstance.xcom_pull`
:Example:
>>> task_instance.xcom_pull("add_steps_to_cluster", key="return_value")[3]
will return "3"
>>> {{ ti.xcom_pull("add_steps_to_cluster", key="return_value")[2] }}
will return "2"
:param xcom_template: airflow template string to parse
:return: the list index accessed
"""
return re.search("\[(\d+)\]", xcom_template).group(1)
[docs] @staticmethod
def get_task_id_from_xcom_pull(xcom_template: str) -> str:
"""
Parses an airflow template variable and finds the task ID
from which an :meth:`~airflow.models.TaskInstance.xcom_pull` is being done
:Example:
>>> {{ ti.xcom_pull("add_steps_to_cluster", key="return_value")[0] }}
will return "add_steps_to_cluster"
>>> {{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}
will return "add_steps"
:param xcom_template: airflow template string to parse
:return: the task_id
"""
return re.search("{{.*xcom_pull\s*\(.*[\"|\'](.*)[\"|\'],.*}}", xcom_template).group(1)
[docs] @staticmethod
def add_downstream_dag_fragment(fragment_up: DAGFragment, fragment_down: DAGFragment):
"""
Attaches the roots of the `fragment_down` to the leaves of `fragment_up`,
.. note::
The leaves of `fragment_up` are found by traversing its operator DAG, not
its :class:`~ditto.api.DAGFragment` dag. This does not join fragment DAGS but
only operator DAGS in two :class:`~ditto.api.DAGFragment`\'s
See the documentation of :class:`~ditto.api.DAGFragment` for understanding
what that means.
:param fragment_up: the upstream :class:`~ditto.api.DAGFragment` to which `fragment_down`
has to be added
:param fragment_down: the downstream :class:`~ditto.api.DAGFragment`
"""
downstream_task_q: "Queue[BaseOperator]" = Queue()
seen_tasks = set()
if fragment_up is None:
return fragment_down
if fragment_down is None:
return fragment_up
for task in fragment_up.tasks:
downstream_task_q.put(task)
# add fragment_down.root_steps to the leaves of fragment_up
while not downstream_task_q.empty():
task = downstream_task_q.get()
if len(task.downstream_list) > 0:
for downstream_task in task.downstream_list:
if downstream_task not in seen_tasks:
downstream_task_q.put(downstream_task)
seen_tasks.add(downstream_task)
else:
task.set_downstream(fragment_down.tasks)
return fragment_up
[docs] @classmethod
def find_op_in_parent_fragment_chain(cls, parent_fragment: DAGFragment,
operator_type: Type[BaseOperator] = None,
task_id: str = None) -> BaseOperator:
"""
Finds the operator matched by the `operator_type` class and having task ID `task_id`
in the passed linked-list referenced by the `parent_fragment` :class:`~ditto.api.DAGFragment`
Uses the :meth:`~ditto.utils.TransformerUtils.find_op_in_dag_fragment` for understanding
how the search is done.
:param parent_fragment: See :paramref:`~ditto.api.OperatorTransformer.transform.parent_fragment`
:param operator_type: the type of operator to find
:param task_id: the task_id of the operator to find
:return: the operator found
"""
op_found = None
fragment_q: "Queue[DAGFragment]" = Queue()
fragment_q.put(parent_fragment)
while not fragment_q.empty():
dag_fragment = fragment_q.get()
op_found = cls.find_op_in_dag_fragment(dag_fragment,
operator_type=operator_type, task_id=task_id)
if op_found:
return op_found
for parent in dag_fragment.parents:
fragment_q.put(parent)
[docs] @classmethod
def find_op_in_fragment_list(cls, fragment_list: List[DAGFragment],
operator_type: Type[BaseOperator] = None,
task_id: str = None) -> BaseOperator:
"""
Lenient version of :meth:`~ditto.utils.TransformerUtils.find_op_in_fragment_list_strict`
:param fragment_list: the list of :class:`~ditto.api.DAGFragment`\'s to search in
:param operator_type: the type of operator to find
:param task_id: the task_id of the operator to find
:return: the operator found
"""
found_op = cls.find_op_in_fragment_list_strict(fragment_list,
operator_type=operator_type,
task_id=task_id)
if not found_op:
found_op = cls.find_op_in_fragment_list_strict(fragment_list,
operator_type=operator_type)
return found_op
[docs] @classmethod
def find_op_in_fragment_list_strict(cls, fragment_list: List[DAGFragment],
operator_type: Type[BaseOperator] = None,
task_id: str = None) -> BaseOperator:
"""
Uses :meth:`~ditto.utils.TransformerUtils.find_op_in_dag_fragment` to find
an operator in a list of :class:`~ditto.api.DAGFragment`\'s
:param fragment_list: the list of :class:`~ditto.api.DAGFragment`\'s to search in
:param operator_type: the type of operator to find
:param task_id: the task_id of the operator to find
:return: the operator found
"""
for fragment in fragment_list:
op_found = cls.find_op_in_dag_fragment(fragment,
operator_type=operator_type, task_id=task_id)
if op_found:
return op_found
[docs] @staticmethod
def find_op_in_dag_fragment(dag_fragment: DAGFragment,
operator_type: Type[BaseOperator] = None,
task_id: str = None,
upstream=False) -> BaseOperator:
"""
Traverses the operator dag of the given :class:`~ditto.api.DAGFragment`
and finds a :class:`~airflow.models.BaseOperator` matching the given `operator_type`
and `task_id`. First matches using the `operator_type` and subsequently using the
`task_id`. Can search upstream or downstream of the tasks in the given
:class:`~ditto.api.DAGFragment`
:param dag_fragment: fragment whose operator dag has to be searched
:param operator_type: the type of operator to find
:param task_id: the task_id of the operator to find
:param upstream: search upstream if `True` otherwise search `downstream`
:return: the operator found
"""
task_q: "Queue[BaseOperator]" = Queue()
seen_tasks = set()
for task in dag_fragment.tasks:
task_q.put(task)
while not task_q.empty():
task = task_q.get()
found_task = False
if operator_type:
if isinstance(task, operator_type):
found_task = True
if task_id:
if task.task_id == task_id:
found_task = True
else:
found_task = False
if found_task:
return task
relative_task_list = task.downstream_list
if upstream and task.upstream_list:
relative_task_list = task.upstream_list
if relative_task_list:
for relative_task in relative_task_list:
if relative_task not in seen_tasks:
task_q.put(relative_task)
seen_tasks.add(relative_task)
[docs] @staticmethod
def get_digraph_from_airflow_dag(dag: DAG) -> nx.DiGraph:
"""
Construct a :class:`~networkx.DiGraph` from the given airflow :class:`~airflow.models.DAG`
:param dag: the airflow DAG
:return: the networkx DiGraph
"""
dg = nx.OrderedDiGraph()
task_q: "deque[BaseOperator]" = deque()
task_q.extend(dag.roots)
while len(task_q) > 0:
task = task_q.popleft()
dg.add_node(task, op=task)
if task.downstream_list:
task_q.extend(task.downstream_list)
for child in task.downstream_list:
dg.add_node(child, op=child)
dg.add_edge(task, child)
return dg
[docs] @staticmethod
def get_digraph_from_matcher_dag(matcher_roots: List[TaskMatcher]) -> nx.DiGraph:
"""
Construct a :class:`~networkx.DiGraph` from the given :class:`~ditto.api.TaskMatcher` dag
:param dag: the matcher DAG
:return: the networkx DiGraph
"""
dg = nx.OrderedDiGraph()
matcher_q: "deque[TaskMatcher]" = deque()
matcher_q.extend(matcher_roots)
while len(matcher_q) > 0:
matcher = matcher_q.popleft()
dg.add_node(matcher, m=matcher)
if matcher.children:
matcher_q.extend(matcher.children)
for child in matcher.children:
dg.add_node(child, m=child)
dg.add_edge(matcher, child)
return dg
[docs] @classmethod
def find_sub_dag(cls, dag: DAG, matcher_roots: List[TaskMatcher]) -> Tuple[nx.DiGraph, List[nx.DiGraph]]:
"""
The problem is to find a sub-DAG in a DAG where the sub-DAG's nodes are
matcher functions which test nodes
It can be generalized to: find if a DAG or DiGraph G1 is isomorphic with
a DAG G2, with the node comparison function being running of the matchers in G1
on nodes in G2
.. note::
This uses python's NetworkX graph library which uses the
`VF2 <https://networkx.github.io/documentation/stable/reference/algorithms/isomorphism.vf2.html>`_
algorithm for `graph isomorphism <https://ieeexplore.ieee.org/document/1323804>`_.
.. note::
We are trying to find an exact sub-DAG match. In graph theory, this is called a
`node-induced <https://math.stackexchange.com/questions/1013143/difference-between-a-sub-graph-and-induced-sub-graph>`_
subgraph. A subgraph 𝐻 of 𝐺 is called INDUCED, if for any two vertices
𝑢,𝑣 in 𝐻, 𝑢 and 𝑣 are adjacent in 𝐻 if and only if they are adjacent in 𝐺.
In other words, 𝐻 has the same edges as 𝐺 between the vertices in 𝐻.
.. seealso::
This is an NP-complete problem: https://en.wikipedia.org/wiki/Subgraph_isomorphism_problem
:param task: the DAG where the sub-dag has to be found
:param matcher: the root task matcher of the [TaskMatcher] dag
:return: a tuple containing the :class:`~networkx.DiGraph` of the souce DAG
and the list of matching subdag :class:`~networkx.DiGraph`\'s
"""
dag_dg = cls.get_digraph_from_airflow_dag(dag)
matcher_dg = cls.get_digraph_from_matcher_dag(matcher_roots)
def node_matcher(n1, n2):
task: BaseOperator = n1['op']
matcher: TaskMatcher = n2['m']
return matcher.does_match(task)
digm = nx.isomorphism.DiGraphMatcher(dag_dg, matcher_dg, node_match=node_matcher)
subdags: List[nx.DiGraph] = []
if digm.subgraph_is_isomorphic():
for subgraph in digm.subgraph_isomorphisms_iter():
subdags.append(dag_dg.subgraph(subgraph.keys()))
return (dag_dg, subdags)
[docs] @staticmethod
def remove_task_from_dag(dag: DAG, dag_nodes: List[BaseOperator], task: BaseOperator):
"""
Removes the given list of :class:`~airflow.models.BaseOperator`\'s from the given
:class:`~airflow.models.DAG`
:param dag: the source airflow DAG
:param dag_nodes: the list of nodes in the source DAG
:param task: the task to remove
"""
all_other_tasks = [t for t in dag_nodes if t is not task]
for this_task in all_other_tasks:
if task.task_id in this_task._upstream_task_ids:
this_task._upstream_task_ids.remove(task.task_id)
if task.task_id in this_task._downstream_task_ids:
this_task._downstream_task_ids.remove(task.task_id)
task._upstream_task_ids.clear()
task._downstream_task_ids.clear()
task._dag = None
del dag.task_dict[task.task_id]
[docs] @classmethod
def find_matching_tasks(cls, subdag: nx.DiGraph, matcher: TaskMatcher):
"""
Find matching tasks in a :class:`~networkx.DiGraph` of operators
:param subdag: the dag to search for matches
:param matcher: the task matcher to use
:return: matching nodes
"""
matching_nodes = []
for node in subdag.nodes:
if matcher.does_match(node):
matching_nodes.append(node)
return matching_nodes
[docs] @staticmethod
def assign_task_to_dag(op: BaseOperator, dag: DAG):
"""
Assigns the given :class:`~airflow.models.BaseOperator` and all its downstream
tasks to the given :class:`~airflow.models.DAG`
:param op: the task to assign
:param dag: the dag to assign the task and its downstream to
"""
task_q: "deque[BaseOperator]" = deque()
task_q.append(op)
seen_tasks = set()
while len(task_q) > 0:
task = task_q.popleft()
task.dag = dag
if task.downstream_list:
for child in task.downstream_list:
if child not in seen_tasks:
task_q.append(child)
seen_tasks.add(child)
[docs] @classmethod
def add_dag_fragment_to_dag(cls, dag: DAG, frag: DAGFragment):
"""
Traverses and assigns all the tasks in this fragment
to the given DAG using :meth:`.assign_task_to_dag`
:param dag: the dag to assign the fragment's tasks to
:param frag: the dag fragment to assign
"""
fragment_q: "deque[DAGFragment]" = deque()
fragment_q.append(frag)
seen_frag = set()
while len(fragment_q) > 0:
frag = fragment_q.popleft()
for task in frag.tasks:
cls.assign_task_to_dag(task, dag)
if frag.children:
for child in frag.children:
if not child in seen_frag:
fragment_q.append(child)
seen_frag.add(child)