# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow.models import TaskInstance
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.db import provide_session
from airflow.utils.decorators import apply_defaults
from airflow.utils.state import State
[docs]class ExternalTaskSensor(BaseSensorOperator):
"""
Waits for a task to complete in a different DAG
:param external_dag_id: The dag_id that contains the task you want to
wait for
:type external_dag_id: string
:param external_task_id: The task_id that contains the task you want to
wait for
:type external_task_id: string
:param allowed_states: list of allowed states, default is ``['success']``
:type allowed_states: list
:param execution_delta: time difference with the previous execution to
look at, the default is the same execution_date as the current task.
For yesterday, use [positive!] datetime.timedelta(days=1). Either
execution_delta or execution_date_fn can be passed to
ExternalTaskSensor, but not both.
:type execution_delta: datetime.timedelta
:param execution_date_fn: function that receives the current execution date
and returns the desired execution dates to query. Either execution_delta
or execution_date_fn can be passed to ExternalTaskSensor, but not both.
:type execution_date_fn: callable
"""
template_fields = ['external_dag_id', 'external_task_id']
ui_color = '#19647e'
@apply_defaults
def __init__(self,
external_dag_id,
external_task_id,
allowed_states=None,
execution_delta=None,
execution_date_fn=None,
*args,
**kwargs):
super(ExternalTaskSensor, self).__init__(*args, **kwargs)
self.allowed_states = allowed_states or [State.SUCCESS]
if execution_delta is not None and execution_date_fn is not None:
raise ValueError(
'Only one of `execution_date` or `execution_date_fn` may'
'be provided to ExternalTaskSensor; not both.')
self.execution_delta = execution_delta
self.execution_date_fn = execution_date_fn
self.external_dag_id = external_dag_id
self.external_task_id = external_task_id
[docs] @provide_session
def poke(self, context, session=None):
if self.execution_delta:
dttm = context['execution_date'] - self.execution_delta
elif self.execution_date_fn:
dttm = self.execution_date_fn(context['execution_date'])
else:
dttm = context['execution_date']
dttm_filter = dttm if isinstance(dttm, list) else [dttm]
serialized_dttm_filter = ','.join(
[datetime.isoformat() for datetime in dttm_filter])
self.log.info(
'Poking for '
'{self.external_dag_id}.'
'{self.external_task_id} on '
'{} ... '.format(serialized_dttm_filter, **locals()))
TI = TaskInstance
count = session.query(TI).filter(
TI.dag_id == self.external_dag_id,
TI.task_id == self.external_task_id,
TI.state.in_(self.allowed_states),
TI.execution_date.in_(dttm_filter),
).count()
session.commit()
return count == len(dttm_filter)