Source code for airflow.providers.amazon.aws.sensors.sagemaker
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimporttimefromtypingimportTYPE_CHECKING,Sequencefromairflow.exceptionsimportAirflowExceptionfromairflow.providers.amazon.aws.hooks.sagemakerimportLogState,SageMakerHookfromairflow.sensors.baseimportBaseSensorOperatorifTYPE_CHECKING:fromairflow.utils.contextimportContext
[docs]classSageMakerBaseSensor(BaseSensorOperator):""" Contains general sensor behavior for SageMaker. Subclasses should implement get_sagemaker_response() and state_from_response() methods. Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. """
[docs]defpoke(self,context:Context):response=self.get_sagemaker_response()ifresponse["ResponseMetadata"]["HTTPStatusCode"]!=200:self.log.info("Bad HTTP response: %s",response)returnFalsestate=self.state_from_response(response)self.log.info("Job currently %s",state)ifstateinself.non_terminal_states():returnFalseifstateinself.failed_states():failed_reason=self.get_failed_reason_from_response(response)raiseAirflowException(f"Sagemaker job failed for the following reason: {failed_reason}")returnTrue
[docs]defnon_terminal_states(self)->set[str]:"""Placeholder for returning states with should not terminate."""raiseNotImplementedError("Please implement non_terminal_states() in subclass")
[docs]deffailed_states(self)->set[str]:"""Placeholder for returning states with are considered failed."""raiseNotImplementedError("Please implement failed_states() in subclass")
[docs]defget_sagemaker_response(self)->dict:"""Placeholder for checking status of a SageMaker task."""raiseNotImplementedError("Please implement get_sagemaker_response() in subclass")
[docs]defget_failed_reason_from_response(self,response:dict)->str:"""Placeholder for extracting the reason for failure from an AWS response."""return"Unknown"
[docs]defstate_from_response(self,response:dict)->str:"""Placeholder for extracting the state from an AWS response."""raiseNotImplementedError("Please implement state_from_response() in subclass")
[docs]classSageMakerEndpointSensor(SageMakerBaseSensor):""" Polls the endpoint state until it reaches a terminal state. Raises an AirflowException with the failure reason if a failed state is reached. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerEndpointSensor` :param endpoint_name: Name of the endpoint instance to watch. """
[docs]classSageMakerTransformSensor(SageMakerBaseSensor):""" Polls the transform job until it reaches a terminal state. Raises an AirflowException with the failure reason if a failed state is reached. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTransformSensor` :param job_name: Name of the transform job to watch. """
[docs]classSageMakerTuningSensor(SageMakerBaseSensor):""" Asks for the state of the tuning state until it reaches a terminal state. Raises an AirflowException with the failure reason if a failed state is reached. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTuningSensor` :param job_name: Name of the tuning instance to watch. """
[docs]classSageMakerTrainingSensor(SageMakerBaseSensor):""" Polls the training job until it reaches a terminal state. Raises an AirflowException with the failure reason if a failed state is reached. .. seealso:: For more information on how to use this sensor, take a look at the guide: :ref:`howto/sensor:SageMakerTrainingSensor` :param job_name: Name of the training job to watch. :param print_log: Prints the cloudwatch log if True; Defaults to True. """
[docs]definit_log_resource(self,hook:SageMakerHook)->None:"""Set tailing LogState for associated training job."""description=hook.describe_training_job(self.job_name)self.instance_count=description["ResourceConfig"]["InstanceCount"]status=description["TrainingJobStatus"]job_already_completed=statusnotinself.non_terminal_states()self.state=LogState.COMPLETEifjob_already_completedelseLogState.TAILINGself.last_description=descriptionself.last_describe_job_call=time.monotonic()self.log_resource_inited=True