Source code for airflow.contrib.operators.ecs_operator

# -*- coding: utf-8 -*-
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.utils import apply_defaults

from airflow.contrib.hooks.aws_hook import AwsHook

[docs]class ECSOperator(BaseOperator): """ Execute a task on AWS EC2 Container Service :param task_definition: the task definition name on EC2 Container Service :type task_definition: str :param cluster: the cluster name on EC2 Container Service :type cluster: str :param: overrides: the same parameter that boto3 will receive: :type: overrides: dict :param aws_conn_id: connection id of AWS credentials / region name. If None, credential boto3 strategy will be used ( :type aws_conn_id: str :param region_name: region name to use in AWS Hook. Override the region_name in connection (if provided) """ ui_color = '#f0ede4' client = None arn = None template_fields = ('overrides',) @apply_defaults def __init__(self, task_definition, cluster, overrides, aws_conn_id=None, region_name=None, **kwargs): super(ECSOperator, self).__init__(**kwargs) self.aws_conn_id = aws_conn_id self.region_name = region_name self.task_definition = task_definition self.cluster = cluster self.overrides = overrides self.hook = self.get_hook() def execute(self, context): 'Running ECS Task - Task definition: %s - on cluster %s', self.task_definition,self.cluster )'ECSOperator overrides: %s', self.overrides) self.client = self.hook.get_client_type( 'ecs', region_name=self.region_name ) response = self.client.run_task( cluster=self.cluster, taskDefinition=self.task_definition, overrides=self.overrides, startedBy=self.owner ) failures = response['failures'] if len(failures) > 0: raise AirflowException(response)'ECS Task started: %s', response) self.arn = response['tasks'][0]['taskArn'] self._wait_for_task_ended() self._check_success_task()'ECS Task has been successfully executed: %s', response) def _wait_for_task_ended(self): waiter = self.client.get_waiter('tasks_stopped') waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow waiter.wait( cluster=self.cluster, tasks=[self.arn] ) def _check_success_task(self): response = self.client.describe_tasks( cluster=self.cluster, tasks=[self.arn] )'ECS Task stopped, check status: %s', response) if len(response.get('failures', [])) > 0: raise AirflowException(response) for task in response['tasks']: containers = task['containers'] for container in containers: if container.get('lastStatus') == 'STOPPED' and container['exitCode'] != 0: raise AirflowException('This task is not in success state {}'.format(task)) elif container.get('lastStatus') == 'PENDING': raise AirflowException('This task is still pending {}'.format(task)) elif 'error' in container.get('reason', '').lower(): raise AirflowException('This containers encounter an error during launching : {}'. format(container.get('reason', '').lower())) def get_hook(self): return AwsHook( aws_conn_id=self.aws_conn_id ) def on_kill(self): response = self.client.stop_task( cluster=self.cluster, task=self.arn, reason='Task killed by the user')