Source code for airflow.providers.amazon.aws.sensors.sagemaker
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimporttimefromfunctoolsimportcached_propertyfromtypingimportTYPE_CHECKING,Sequencefromdeprecatedimportdeprecatedfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.amazon.aws.hooks.sagemakerimportLogState,SageMakerHookfromairflow.sensors.baseimportBaseSensorOperatorifTYPE_CHECKING:fromairflow.utils.contextimportContext
[docs]classSageMakerBaseSensor(BaseSensorOperator):""" Contains general sensor behavior for SageMaker. Subclasses should implement get_sagemaker_response() and state_from_response() methods. Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. """
def__init__(self,*,aws_conn_id:str="aws_default",resource_type:str="job",**kwargs):super().__init__(**kwargs)self.aws_conn_id=aws_conn_idself.resource_type=resource_type# only used for logs, to say what kind of resource we are sensing@deprecated(reason="use `hook` property instead.")
[docs]defpoke(self,context:Context):response=self.get_sagemaker_response()ifresponse["ResponseMetadata"]["HTTPStatusCode"]!=200:self.log.info("Bad HTTP response: %s",response)returnFalsestate=self.state_from_response(response)self.log.info("%s currently %s",self.resource_type,state)ifstateinself.non_terminal_states():returnFalseifstateinself.failed_states():failed_reason=self.get_failed_reason_from_response(response)raiseAirflowException(f"Sagemaker {self.resource_type} failed for the following reason: {failed_reason}")returnTrue
[docs]defnon_terminal_states(self)->set[str]:"""Return states with should not terminate."""raiseNotImplementedError("Please implement non_terminal_states() in subclass")
[docs]deffailed_states(self)->set[str]:"""Return states with are considered failed."""raiseNotImplementedError("Please implement failed_states() in subclass")
[docs]defget_sagemaker_response(self)->dict:"""Check status of a SageMaker task."""raiseNotImplementedError("Please implement get_sagemaker_response() in subclass")
[docs]defget_failed_reason_from_response(self,response:dict)->str:"""Extract the reason for failure from an AWS response."""return"Unknown"
[docs]defstate_from_response(self,response:dict)->str:"""Extract the state from an AWS response."""raiseNotImplementedError("Please implement state_from_response() in subclass")
[docs]classSageMakerEndpointSensor(SageMakerBaseSensor):""" Poll the endpoint state until it reaches a terminal state; raise AirflowException with the failure reason. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerEndpointSensor` :param endpoint_name: Name of the endpoint instance to watch. """
[docs]classSageMakerTransformSensor(SageMakerBaseSensor):""" Poll the transform job until it reaches a terminal state; raise AirflowException with the failure reason. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTransformSensor` :param job_name: Name of the transform job to watch. """
[docs]classSageMakerTuningSensor(SageMakerBaseSensor):""" Poll the tuning state until it reaches a terminal state; raise AirflowException with the failure reason. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTuningSensor` :param job_name: Name of the tuning instance to watch. """
[docs]classSageMakerTrainingSensor(SageMakerBaseSensor):""" Poll the training job until it reaches a terminal state; raise AirflowException with the failure reason. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTrainingSensor` :param job_name: Name of the training job to watch. :param print_log: Prints the cloudwatch log if True; Defaults to True. """
[docs]definit_log_resource(self,hook:SageMakerHook)->None:"""Set tailing LogState for associated training job."""description=hook.describe_training_job(self.job_name)self.instance_count=description["ResourceConfig"]["InstanceCount"]status=description["TrainingJobStatus"]job_already_completed=statusnotinself.non_terminal_states()self.state=LogState.COMPLETEifjob_already_completedelseLogState.TAILINGself.last_description=descriptionself.last_describe_job_call=time.monotonic()self.log_resource_inited=True
[docs]classSageMakerPipelineSensor(SageMakerBaseSensor):""" Poll the pipeline until it reaches a terminal state; raise AirflowException with the failure reason. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerPipelineSensor` :param pipeline_exec_arn: ARN of the pipeline to watch. :param verbose: Whether to print steps details while waiting for completion. Defaults to true, consider turning off for pipelines that have thousands of steps. """
[docs]classSageMakerAutoMLSensor(SageMakerBaseSensor):""" Poll the auto ML job until it reaches a terminal state; raise AirflowException with the failure reason. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerAutoMLSensor` :param job_name: unique name of the AutoML job to watch. """