from typing import AbstractSet, List, Mapping, Optional, Set, Tuple
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from dagster import (
AssetKey,
AssetsDefinition,
GraphDefinition,
OutputMapping,
TimeWindowPartitionsDefinition,
)
from dagster._core.definitions.graph_definition import create_adjacency_lists
from dagster._utils.schedules import is_valid_cron_schedule
from dagster_airflow.dagster_job_factory import make_dagster_job_from_airflow_dag
from dagster_airflow.utils import (
DagsterAirflowError,
normalized_name,
)
# pylint: enable=no-name-in-module,import-error
def _build_asset_dependencies(
dag: DAG,
graph: GraphDefinition,
task_ids_by_asset_key: Mapping[AssetKey, AbstractSet[str]],
upstream_dependencies_by_asset_key: Mapping[AssetKey, AbstractSet[AssetKey]],
) -> Tuple[AbstractSet[OutputMapping], Mapping[str, AssetKey], Mapping[str, Set[AssetKey]]]:
"""Builds the asset dependency graph for a given set of airflow task mappings and a dagster graph
"""
output_mappings = set()
keys_by_output_name = {}
internal_asset_deps: dict[str, Set[AssetKey]] = {}
visited_nodes: dict[str, bool] = {}
upstream_deps = set()
def find_upstream_dependency(node_name: str) -> None:
"""find_upstream_dependency uses Depth-Firs-Search to find all upstream asset dependencies
as described in task_ids_by_asset_key
"""
# node has been visited
if visited_nodes[node_name]:
return
# mark node as visted
visited_nodes[node_name] = True
# traverse upstream nodes
for output_handle in graph.dependency_structure.all_upstream_outputs_from_node(node_name):
forward_node = output_handle.node_name
match = False
# find any assets produced by upstream nodes and add them to the internal asset deps
for asset_key in task_ids_by_asset_key:
if (
forward_node.replace(f"{normalized_name(dag.dag_id)}__", "")
in task_ids_by_asset_key[asset_key]
):
upstream_deps.add(asset_key)
match = True
# don't traverse past nodes that have assets
if not match:
find_upstream_dependency(forward_node)
# iterate through each asset to find all upstream asset dependencies
for asset_key in task_ids_by_asset_key:
asset_upstream_deps = set()
for task_id in task_ids_by_asset_key[asset_key]:
visited_nodes = {s.name: False for s in graph.nodes}
upstream_deps = set()
find_upstream_dependency(normalized_name(dag.dag_id, task_id))
for dep in upstream_deps:
asset_upstream_deps.add(dep)
keys_by_output_name[f"result_{normalized_name(dag.dag_id, task_id)}"] = asset_key
output_mappings.add(
OutputMapping(
graph_output_name=f"result_{normalized_name(dag.dag_id, task_id)}",
mapped_node_name=normalized_name(dag.dag_id, task_id),
mapped_node_output_name="airflow_task_complete", # Default output name
)
)
# the tasks for a given asset should have the same internal deps
for task_id in task_ids_by_asset_key[asset_key]:
if f"result_{normalized_name(dag.dag_id, task_id)}" in internal_asset_deps:
internal_asset_deps[f"result_{normalized_name(dag.dag_id, task_id)}"].update(
asset_upstream_deps
)
else:
internal_asset_deps[
f"result_{normalized_name(dag.dag_id, task_id)}"
] = asset_upstream_deps
# add new upstream asset dependencies to the internal deps
for asset_key in upstream_dependencies_by_asset_key:
for key in keys_by_output_name:
if keys_by_output_name[key] == asset_key:
internal_asset_deps[key].update(upstream_dependencies_by_asset_key[asset_key])
return (output_mappings, keys_by_output_name, internal_asset_deps)
[docs]def load_assets_from_airflow_dag(
dag: DAG,
task_ids_by_asset_key: Mapping[AssetKey, AbstractSet[str]] = {},
upstream_dependencies_by_asset_key: Mapping[AssetKey, AbstractSet[AssetKey]] = {},
connections: Optional[List[Connection]] = None,
) -> List[AssetsDefinition]:
"""[Experimental] Construct Dagster Assets for a given Airflow DAG.
Args:
dag (DAG): The Airflow DAG to compile into a Dagster job
task_ids_by_asset_key (Optional[Mapping[AssetKey, AbstractSet[str]]]): A mapping from asset
keys to task ids. Used break up the Airflow Dag into multiple SDAs
upstream_dependencies_by_asset_key (Optional[Mapping[AssetKey, AbstractSet[AssetKey]]]): A
mapping from upstream asset keys to assets provided in task_ids_by_asset_key. Used to
declare new upstream SDA depenencies.
connections (List[Connection]): List of Airflow Connections to be created in the Airflow DB
Returns:
List[AssetsDefinition]
"""
cron_schedule = dag.normalized_schedule_interval
if cron_schedule is not None and not is_valid_cron_schedule(str(cron_schedule)):
raise DagsterAirflowError(
"Invalid cron schedule: {} in DAG {}".format(cron_schedule, dag.dag_id)
)
job = make_dagster_job_from_airflow_dag(dag, connections=connections)
graph = job._graph_def
start_date = dag.start_date if dag.start_date else dag.default_args.get("start_date")
if start_date is None:
raise DagsterAirflowError("Invalid start_date: {} in DAG {}".format(start_date, dag.dag_id))
# leaf nodes have no downstream nodes
forward_edges, _ = create_adjacency_lists(graph.nodes, graph.dependency_structure)
leaf_nodes = {
node_name.replace(f"{normalized_name(dag.dag_id)}__", "")
for node_name, downstream_nodes in forward_edges.items()
if not downstream_nodes
}
mutated_task_ids_by_asset_key: dict[AssetKey, set[str]] = {}
if task_ids_by_asset_key is None or task_ids_by_asset_key == {}:
# if no mappings are provided the dag becomes a single SDA
task_ids_by_asset_key = {AssetKey(dag.dag_id): leaf_nodes}
else:
# if mappings were provide any unmapped leaf nodes are added to a default asset
used_nodes: set[str] = set()
for key in task_ids_by_asset_key:
used_nodes.update(task_ids_by_asset_key[key])
mutated_task_ids_by_asset_key[AssetKey(dag.dag_id)] = leaf_nodes - used_nodes
for key in task_ids_by_asset_key:
if key not in mutated_task_ids_by_asset_key:
mutated_task_ids_by_asset_key[key] = set(task_ids_by_asset_key[key])
else:
mutated_task_ids_by_asset_key[key].update(task_ids_by_asset_key[key])
output_mappings, keys_by_output_name, internal_asset_deps = _build_asset_dependencies(
dag, graph, mutated_task_ids_by_asset_key, upstream_dependencies_by_asset_key
)
new_graph = GraphDefinition(
name=graph.name,
node_defs=graph.node_defs,
dependencies=graph.dependencies,
output_mappings=list(output_mappings),
)
asset_def = AssetsDefinition.from_graph(
graph_def=new_graph,
partitions_def=TimeWindowPartitionsDefinition(
cron_schedule=str(cron_schedule),
timezone=dag.timezone.name,
start=start_date.strftime("%Y-%m-%dT%H:%M:%S"),
fmt="%Y-%m-%dT%H:%M:%S",
)
if cron_schedule is not None
else None,
group_name=dag.dag_id,
keys_by_output_name=keys_by_output_name,
internal_asset_deps=internal_asset_deps,
can_subset=True,
)
return [asset_def]