Source code for ditto.rendering

import os
CI_MODE = bool(os.environ.get('TRAVIS', False))

if not CI_MODE:
    import matplotlib
    from matplotlib import patches
    import matplotlib.pyplot as plt

import random
from typing import List
import networkx as nx
from airflow import DAG
from networkx.drawing.nx_agraph import graphviz_layout

from ditto.api import Transformer
from ditto.utils import TransformerUtils

[docs]def ut_relabeler(dg: nx.DiGraph): labels = {} for node in dg.nodes: labels[node] = node.task_id return labels
[docs]def ut_colorer (dg: nx.DiGraph): color_map = [] for node in dg.nodes: if node.task_id.startswith("tp"): color_map.append('red') elif node.task_id.startswith("t2p"): color_map.append('green') else: color_map.append('blue') return color_map
[docs]def debug_relabeler(dg: nx.DiGraph): labels = {} i = 1 for node in dg.nodes: labels[node] = f"{i}" i+=1 return labels
[docs]def debug_colorer (dg: nx.DiGraph): color_map = [] for node in dg.nodes: if Transformer.TRANSFORMED_BY_HEADER in node.params: color_map.append('red') else: color_map.append('blue') return color_map
[docs]def debug_legender (dg: nx.DiGraph): handles = [] i = 1 for node in dg.nodes: label = f"{i}:{node.task_id}<{node.__class__.__name__}>" if Transformer.TRANSFORMED_BY_HEADER in node.params: handles.append(patches.Patch(color='red', label=label)) else: handles.append(patches.Patch(color='blue', label=label)) i += 1 return handles
[docs]def draw_dag_graphiviz_rendering(dag: DAG, colorer=ut_colorer, relabeler=ut_relabeler, legender=None, figsize=[6.4, 4.8], legend_own_figure=False): dg = TransformerUtils.get_digraph_from_airflow_dag(dag) labels = {} if relabeler: labels = relabeler(dg) color_map = [] if colorer: color_map = colorer(dg) dg.graph.setdefault('graph', {})['rankdir'] = 'LR' dg.graph.setdefault('graph', {})['newrank'] = 'true' plt.figure(figsize=figsize) plt.title(dag.dag_id) pos = graphviz_layout(dg, prog='dot', args='-Gnodesep=0.1') rads = random.uniform(0.05, 0.1) nx.draw_networkx(dg, pos=pos, labels=labels, font_size=8, node_color=color_map, node_size=900, font_color='white', font_weight='bold', connectionstyle=f"arc3, rad={rads}") if legender: if legend_own_figure: plt.figure() plt.title(dag.dag_id) plt.rcParams["legend.fontsize"] = 8 plt.legend(handles=legender(dg), ncol=2) else: plt.rcParams["legend.fontsize"] = 7 plt.legend(handles=legender(dg), borderaxespad=0.9, ncol=2, loc='lower center')
[docs]def show_single_dag_graphviz(dag: DAG, **kwargs): matplotlib.use("TkAgg") draw_dag_graphiviz_rendering(dag, **{k: v for k, v in kwargs.items() if v is not None}) plt.show()
[docs]def show_multi_dag_graphviz(daglist: List[DAG], **kwargs): matplotlib.use("TkAgg") i = 1 for dag in daglist: draw_dag_graphiviz_rendering(dag, **{k: v for k, v in kwargs.items() if v is not None}) i += 1 plt.show()
[docs]def debug_dags(daglist: List[DAG], **kwargs): show_multi_dag_graphviz(daglist, relabeler=debug_relabeler, colorer=debug_colorer, legender=debug_legender, **{k: v for k, v in kwargs.items() if v is not None})