Source code for airflow.providers.amazon.aws.sensors.sagemaker_base
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.fromtypingimportOptional,Setfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.amazon.aws.hooks.sagemakerimportSageMakerHookfromairflow.sensors.baseimportBaseSensorOperator
[docs]classSageMakerBaseSensor(BaseSensorOperator):""" Contains general sensor behavior for SageMaker. Subclasses should implement get_sagemaker_response() and state_from_response() methods. Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods. """
[docs]defpoke(self,context):response=self.get_sagemaker_response()ifnotresponse['ResponseMetadata']['HTTPStatusCode']==200:self.log.info('Bad HTTP response: %s',response)returnFalsestate=self.state_from_response(response)self.log.info('Job currently %s',state)ifstateinself.non_terminal_states():returnFalseifstateinself.failed_states():failed_reason=self.get_failed_reason_from_response(response)raiseAirflowException(f'Sagemaker job failed for the following reason: {failed_reason}')returnTrue
[docs]defnon_terminal_states(self)->Set[str]:"""Placeholder for returning states with should not terminate."""raiseNotImplementedError('Please implement non_terminal_states() in subclass')
[docs]deffailed_states(self)->Set[str]:"""Placeholder for returning states with are considered failed."""raiseNotImplementedError('Please implement failed_states() in subclass')
[docs]defget_sagemaker_response(self)->Optional[dict]:"""Placeholder for checking status of a SageMaker task."""raiseNotImplementedError('Please implement get_sagemaker_response() in subclass')
[docs]defget_failed_reason_from_response(self,response:dict)->str:"""Placeholder for extracting the reason for failure from an AWS response."""return'Unknown'
[docs]defstate_from_response(self,response:dict)->str:"""Placeholder for extracting the state from an AWS response."""raiseNotImplementedError('Please implement state_from_response() in subclass')