Source code for airflow.providers.amazon.aws.sensors.sagemaker
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.importtimefromtypingimportTYPE_CHECKING,Optional,Sequence,Setfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.amazon.aws.hooks.sagemakerimportLogState,SageMakerHookfromairflow.sensors.baseimportBaseSensorOperatorifTYPE_CHECKING:fromairflow.utils.contextimportContext
[docs]classSageMakerBaseSensor(BaseSensorOperator):""" Contains general sensor behavior for SageMaker. Subclasses should implement get_sagemaker_response() and state_from_response() methods. Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. """
[docs]defpoke(self,context:'Context'):response=self.get_sagemaker_response()ifresponse['ResponseMetadata']['HTTPStatusCode']!=200:self.log.info('Bad HTTP response: %s',response)returnFalsestate=self.state_from_response(response)self.log.info('Job currently %s',state)ifstateinself.non_terminal_states():returnFalseifstateinself.failed_states():failed_reason=self.get_failed_reason_from_response(response)raiseAirflowException(f'Sagemaker job failed for the following reason: {failed_reason}')returnTrue
[docs]defnon_terminal_states(self)->Set[str]:"""Placeholder for returning states with should not terminate."""raiseNotImplementedError('Please implement non_terminal_states() in subclass')
[docs]deffailed_states(self)->Set[str]:"""Placeholder for returning states with are considered failed."""raiseNotImplementedError('Please implement failed_states() in subclass')
[docs]defget_sagemaker_response(self)->dict:"""Placeholder for checking status of a SageMaker task."""raiseNotImplementedError('Please implement get_sagemaker_response() in subclass')
[docs]defget_failed_reason_from_response(self,response:dict)->str:"""Placeholder for extracting the reason for failure from an AWS response."""return'Unknown'
[docs]defstate_from_response(self,response:dict)->str:"""Placeholder for extracting the state from an AWS response."""raiseNotImplementedError('Please implement state_from_response() in subclass')
[docs]classSageMakerEndpointSensor(SageMakerBaseSensor):""" Polls the endpoint state until it reaches a terminal state. Raises an AirflowException with the failure reason if a failed state is reached. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerEndpointSensor` :param endpoint_name: Name of the endpoint instance to watch. """
[docs]classSageMakerTransformSensor(SageMakerBaseSensor):""" Polls the transform job until it reaches a terminal state. Raises an AirflowException with the failure reason if a failed state is reached. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTransformSensor` :param job_name: Name of the transform job to watch. """
[docs]classSageMakerTuningSensor(SageMakerBaseSensor):""" Asks for the state of the tuning state until it reaches a terminal state. Raises an AirflowException with the failure reason if a failed state is reached. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTuningSensor` :param job_name: Name of the tuning instance to watch. """
[docs]classSageMakerTrainingSensor(SageMakerBaseSensor):""" Polls the training job until it reaches a terminal state. Raises an AirflowException with the failure reason if a failed state is reached. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTrainingSensor` :param job_name: Name of the training job to watch. :param print_log: Prints the cloudwatch log if True; Defaults to True. """
[docs]definit_log_resource(self,hook:SageMakerHook)->None:"""Set tailing LogState for associated training job."""description=hook.describe_training_job(self.job_name)self.instance_count=description['ResourceConfig']['InstanceCount']status=description['TrainingJobStatus']job_already_completed=statusnotinself.non_terminal_states()self.state=LogState.COMPLETEifjob_already_completedelseLogState.TAILINGself.last_description=descriptionself.last_describe_job_call=time.monotonic()self.log_resource_inited=True