Source code for airflow.providers.amazon.aws.triggers.sagemaker
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportasyncioimporttimefromcollectionsimportCounterfromenumimportIntEnumfromfunctoolsimportcached_propertyfromtypingimportAny,AsyncIteratorfrombotocore.exceptionsimportWaiterErrorfromdeprecatedimportdeprecatedfromairflow.exceptionsimportAirflowException,AirflowProviderDeprecationWarningfromairflow.providers.amazon.aws.hooks.sagemakerimportLogState,SageMakerHookfromairflow.providers.amazon.aws.utils.waiter_with_loggingimportasync_waitfromairflow.triggers.baseimportBaseTrigger,TriggerEvent
[docs]classSageMakerTrigger(BaseTrigger):""" SageMakerTrigger is fired as deferred class with params to run the task in triggerer. :param job_name: name of the job to check status :param job_type: Type of the sagemaker job whether it is Transform or Training :param poke_interval: polling period in seconds to check for the status :param max_attempts: Number of times to poll for query state before returning the current state, defaults to None. :param aws_conn_id: AWS connection ID for sagemaker """def__init__(self,job_name:str,job_type:str,poke_interval:int=30,max_attempts:int=480,aws_conn_id:str|None="aws_default",):super().__init__()self.job_name=job_nameself.job_type=job_typeself.poke_interval=poke_intervalself.max_attempts=max_attemptsself.aws_conn_id=aws_conn_id
[docs]defserialize(self)->tuple[str,dict[str,Any]]:"""Serialize SagemakerTrigger arguments and classpath."""return("airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger",{"job_name":self.job_name,"job_type":self.job_type,"poke_interval":self.poke_interval,"max_attempts":self.max_attempts,"aws_conn_id":self.aws_conn_id,},)
@staticmethoddef_get_job_type_waiter(job_type:str)->str:return{"training":"TrainingJobComplete","transform":"TransformJobComplete","processing":"ProcessingJobComplete","tuning":"TuningJobComplete","endpoint":"endpoint_in_service",# this one is provided by boto}[job_type.lower()]@staticmethoddef_get_waiter_arg_name(job_type:str)->str:return{"training":"TrainingJobName","transform":"TransformJobName","processing":"ProcessingJobName","tuning":"HyperParameterTuningJobName","endpoint":"EndpointName",}[job_type.lower()]@staticmethoddef_get_response_status_key(job_type:str)->str:return{"training":"TrainingJobStatus","transform":"TransformJobStatus","processing":"ProcessingJobStatus","tuning":"HyperParameterTuningJobStatus","endpoint":"EndpointStatus",}[job_type.lower()]
[docs]asyncdefrun(self):self.log.info("job name is %s and job type is %s",self.job_name,self.job_type)asyncwithself.hook.async_connasclient:waiter=self.hook.get_waiter(self._get_job_type_waiter(self.job_type),deferrable=True,client=client)awaitasync_wait(waiter=waiter,waiter_delay=self.poke_interval,waiter_max_attempts=self.max_attempts,args={self._get_waiter_arg_name(self.job_type):self.job_name},failure_message=f"Error while waiting for {self.job_type} job",status_message=f"{self.job_type} job not done yet",status_args=[self._get_response_status_key(self.job_type)],)yieldTriggerEvent({"status":"success","message":"Job completed.","job_name":self.job_name})
[docs]classSageMakerPipelineTrigger(BaseTrigger):"""Trigger to wait for a sagemaker pipeline execution to finish."""
[docs]classType(IntEnum):"""Type of waiter to use."""
[docs]defserialize(self)->tuple[str,dict[str,Any]]:return(self.__class__.__module__+"."+self.__class__.__qualname__,{"waiter_type":self.waiter_type.value,# saving the int value here"pipeline_execution_arn":self.pipeline_execution_arn,"waiter_delay":self.waiter_delay,"waiter_max_attempts":self.waiter_max_attempts,"aws_conn_id":self.aws_conn_id,},)
[docs]asyncdefrun(self)->AsyncIterator[TriggerEvent]:hook=SageMakerHook(aws_conn_id=self.aws_conn_id)asyncwithhook.async_connasconn:waiter=hook.get_waiter(self._waiter_name[self.waiter_type],deferrable=True,client=conn)for_inrange(self.waiter_max_attempts):try:awaitwaiter.wait(PipelineExecutionArn=self.pipeline_execution_arn,WaiterConfig={"MaxAttempts":1})# we reach this point only if the waiter met a success criteriayieldTriggerEvent({"status":"success","value":self.pipeline_execution_arn})returnexceptWaiterErroraserror:if"terminal failure"instr(error):raiseself.log.info("Status of the pipeline execution: %s",error.last_response["PipelineExecutionStatus"])res=awaitconn.list_pipeline_execution_steps(PipelineExecutionArn=self.pipeline_execution_arn)count_by_state=Counter(s["StepStatus"]forsinres["PipelineExecutionSteps"])running_steps=[s["StepName"]forsinres["PipelineExecutionSteps"]ifs["StepStatus"]=="Executing"]self.log.info("State of the pipeline steps: %s",count_by_state)self.log.info("Steps currently in progress: %s",running_steps)awaitasyncio.sleep(int(self.waiter_delay))raiseAirflowException("Waiter error: max attempts reached")
@deprecated(reason=("`airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger` ""has been deprecated and will be removed in future. Please use ``SageMakerTrigger`` instead."),category=AirflowProviderDeprecationWarning,)
[docs]classSageMakerTrainingPrintLogTrigger(BaseTrigger):""" SageMakerTrainingPrintLogTrigger is fired as deferred class with params to run the task in triggerer. :param job_name: name of the job to check status :param poke_interval: polling period in seconds to check for the status :param aws_conn_id: AWS connection ID for sagemaker """def__init__(self,job_name:str,poke_interval:float,aws_conn_id:str|None="aws_default",):super().__init__()self.job_name=job_nameself.poke_interval=poke_intervalself.aws_conn_id=aws_conn_id
[docs]defserialize(self)->tuple[str,dict[str,Any]]:"""Serialize SageMakerTrainingPrintLogTrigger arguments and classpath."""return("airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger",{"poke_interval":self.poke_interval,"aws_conn_id":self.aws_conn_id,"job_name":self.job_name,},)
[docs]asyncdefrun(self)->AsyncIterator[TriggerEvent]:"""Make async connection to sagemaker async hook and gets job status for a job submitted by the operator."""stream_names:list[str]=[]# The list of log streamspositions:dict[str,Any]={}# The current position in each stream, map of stream name -> positionlast_description=awaitself.hook.describe_training_job_async(self.job_name)instance_count=last_description["ResourceConfig"]["InstanceCount"]status=last_description["TrainingJobStatus"]job_already_completed=statusnotinself.hook.non_terminal_statesstate=LogState.COMPLETEifjob_already_completedelseLogState.TAILINGlast_describe_job_call=time.time()try:whileTrue:(state,last_description,last_describe_job_call,)=awaitself.hook.describe_training_job_with_log_async(self.job_name,positions,stream_names,instance_count,state,last_description,last_describe_job_call,)status=last_description["TrainingJobStatus"]ifstatusinself.hook.non_terminal_states:awaitasyncio.sleep(self.poke_interval)elifstatusinself.hook.failed_states:reason=last_description.get("FailureReason","(No reason provided)")error_message=f"SageMaker job failed because {reason}"yieldTriggerEvent({"status":"error","message":error_message})returnelse:billable_seconds=SageMakerHook.count_billable_seconds(training_start_time=last_description["TrainingStartTime"],training_end_time=last_description["TrainingEndTime"],instance_count=instance_count,)self.log.info("Billable seconds: %d",billable_seconds)yieldTriggerEvent({"status":"success","message":last_description})returnexceptExceptionase:yieldTriggerEvent({"status":"error","message":str(e)})