Cluster Policies

If you want to check or mutate DAGs or Tasks on a cluster-wide level, then a Cluster Policy will let you do that. They have three main purposes:

  • Checking that DAGs/Tasks meet a certain standard

  • Setting default arguments on DAGs/Tasks

  • Performing custom routing logic

There are three types of cluster policy:

  • dag_policy: Takes a DAG parameter called dag. Runs at load time.

  • task_policy: Takes a BaseOperator parameter called task. Runs at load time.

  • task_instance_mutation_hook: Takes a TaskInstance parameter called task_instance. Called right before task execution.

The DAG and Task cluster policies can raise the AirflowClusterPolicyViolation exception to indicate that the dag/task they were passed is not compliant and should not be loaded.

Any extra attributes set by a cluster policy take priority over those defined in your DAG file; for example, if you set an sla on your Task in the DAG file, and then your cluster policy also sets an sla, the cluster policy’s value will take precedence.

To configure cluster policies, you should create an airflow_local_settings.py file in either the config folder under your $AIRFLOW_HOME, or place it on the $PYTHONPATH, and then add callables to the file matching one or more of the cluster policy names above (e.g. dag_policy)

Examples

DAG policies

This policy checks if each DAG has at least one tag defined:

def dag_policy(dag: DAG):
    """Ensure that DAG has at least one tag"""
    if not dag.tags:
        raise AirflowClusterPolicyViolation(
            f"DAG {dag.dag_id} has no tags. At least one tag required. File path: {dag.fileloc}"
        )


Note

To avoid import cycles, if you use DAG in type annotations in your cluster policy, be sure to import from airflow.models and not from airflow.

Note

DAG policies are applied after the DAG has been completely loaded, so overriding the default_args parameter has no effect. If you want to override the default operator settings, use task policies instead.

Task policies

Here’s an example of enforcing a maximum timeout policy on every task:

def task_policy(task: TimedOperator):
    if task.task_type == 'HivePartitionSensor':
        task.queue = "sensor_queue"
    if task.timeout > timedelta(hours=48):
        task.timeout = timedelta(hours=48)


You could also implement to protect against common errors, rather than as technical security controls. For example, don’t run tasks without airflow owners:

def task_must_have_owners(task: BaseOperator):
    if not task.owner or task.owner.lower() == conf.get('operators', 'default_owner'):
        raise AirflowClusterPolicyViolation(
            f'''Task must have non-None non-default owner. Current value: {task.owner}'''
        )


If you have multiple checks to apply, it is best practice to curate these rules in a separate python module and have a single policy / task mutation hook that performs multiple of these custom checks and aggregates the various error messages so that a single AirflowClusterPolicyViolation can be reported in the UI (and import errors table in the database).

For example, your airflow_local_settings.py might follow this pattern:

TASK_RULES: List[Callable[[BaseOperator], None]] = [
    task_must_have_owners,
]


def _check_task_rules(current_task: BaseOperator):
    """Check task rules for given task."""
    notices = []
    for rule in TASK_RULES:
        try:
            rule(current_task)
        except AirflowClusterPolicyViolation as ex:
            notices.append(str(ex))
    if notices:
        notices_list = " * " + "\n * ".join(notices)
        raise AirflowClusterPolicyViolation(
            f"DAG policy violation (DAG ID: {current_task.dag_id}, Path: {current_task.dag.fileloc}):\n"
            f"Notices:\n"
            f"{notices_list}"
        )


def cluster_policy(task: BaseOperator):
    """Ensure Tasks have non-default owners."""
    _check_task_rules(task)


Task instance mutation

Here’s an example of re-routing tasks that are on their second (or greater) retry to a different queue:

def task_instance_mutation_hook(task_instance: TaskInstance):
    if task_instance.try_number >= 1:
        task_instance.queue = 'retry_queue'


Was this entry helpful?