Source code for airflow.providers.amazon.aws.triggers.sagemaker
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportasynciofromcollectionsimportCounterfromcollections.abcimportAsyncIteratorfromenumimportIntEnumfromfunctoolsimportcached_propertyfromtypingimportAnyfrombotocore.exceptionsimportWaiterErrorfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.amazon.aws.hooks.sagemakerimportSageMakerHookfromairflow.providers.amazon.aws.utils.waiter_with_loggingimportasync_waitfromairflow.triggers.baseimportBaseTrigger,TriggerEvent
[docs]classSageMakerTrigger(BaseTrigger):""" SageMakerTrigger is fired as deferred class with params to run the task in triggerer. :param job_name: name of the job to check status :param job_type: Type of the sagemaker job whether it is Transform or Training :param poke_interval: polling period in seconds to check for the status :param max_attempts: Number of times to poll for query state before returning the current state, defaults to None. :param aws_conn_id: AWS connection ID for sagemaker """def__init__(self,job_name:str,job_type:str,poke_interval:int=30,max_attempts:int=480,aws_conn_id:str|None="aws_default",):super().__init__()
[docs]defserialize(self)->tuple[str,dict[str,Any]]:"""Serialize SagemakerTrigger arguments and classpath."""return("airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger",{"job_name":self.job_name,"job_type":self.job_type,"poke_interval":self.poke_interval,"max_attempts":self.max_attempts,"aws_conn_id":self.aws_conn_id,},)
@staticmethoddef_get_job_type_waiter(job_type:str)->str:return{"training":"TrainingJobComplete","transform":"TransformJobComplete","processing":"ProcessingJobComplete","tuning":"TuningJobComplete","endpoint":"endpoint_in_service",# this one is provided by boto}[job_type.lower()]@staticmethoddef_get_waiter_arg_name(job_type:str)->str:return{"training":"TrainingJobName","transform":"TransformJobName","processing":"ProcessingJobName","tuning":"HyperParameterTuningJobName","endpoint":"EndpointName",}[job_type.lower()]@staticmethoddef_get_response_status_key(job_type:str)->str:return{"training":"TrainingJobStatus","transform":"TransformJobStatus","processing":"ProcessingJobStatus","tuning":"HyperParameterTuningJobStatus","endpoint":"EndpointStatus",}[job_type.lower()]
[docs]asyncdefrun(self):self.log.info("job name is %s and job type is %s",self.job_name,self.job_type)asyncwithself.hook.async_connasclient:waiter=self.hook.get_waiter(self._get_job_type_waiter(self.job_type),deferrable=True,client=client)awaitasync_wait(waiter=waiter,waiter_delay=self.poke_interval,waiter_max_attempts=self.max_attempts,args={self._get_waiter_arg_name(self.job_type):self.job_name},failure_message=f"Error while waiting for {self.job_type} job",status_message=f"{self.job_type} job not done yet",status_args=[self._get_response_status_key(self.job_type)],)yieldTriggerEvent({"status":"success","message":"Job completed.","job_name":self.job_name})
[docs]classSageMakerPipelineTrigger(BaseTrigger):"""Trigger to wait for a sagemaker pipeline execution to finish."""
[docs]classType(IntEnum):"""Type of waiter to use."""
[docs]defserialize(self)->tuple[str,dict[str,Any]]:return(self.__class__.__module__+"."+self.__class__.__qualname__,{"waiter_type":self.waiter_type.value,# saving the int value here"pipeline_execution_arn":self.pipeline_execution_arn,"waiter_delay":self.waiter_delay,"waiter_max_attempts":self.waiter_max_attempts,"aws_conn_id":self.aws_conn_id,},)
[docs]asyncdefrun(self)->AsyncIterator[TriggerEvent]:hook=SageMakerHook(aws_conn_id=self.aws_conn_id)asyncwithhook.async_connasconn:waiter=hook.get_waiter(self._waiter_name[self.waiter_type],deferrable=True,client=conn)for_inrange(self.waiter_max_attempts):try:awaitwaiter.wait(PipelineExecutionArn=self.pipeline_execution_arn,WaiterConfig={"MaxAttempts":1})# we reach this point only if the waiter met a success criteriayieldTriggerEvent({"status":"success","value":self.pipeline_execution_arn})returnexceptWaiterErroraserror:if"terminal failure"instr(error):raiseself.log.info("Status of the pipeline execution: %s",error.last_response["PipelineExecutionStatus"])res=awaitconn.list_pipeline_execution_steps(PipelineExecutionArn=self.pipeline_execution_arn)count_by_state=Counter(s["StepStatus"]forsinres["PipelineExecutionSteps"])running_steps=[s["StepName"]forsinres["PipelineExecutionSteps"]ifs["StepStatus"]=="Executing"]self.log.info("State of the pipeline steps: %s",count_by_state)self.log.info("Steps currently in progress: %s",running_steps)awaitasyncio.sleep(int(self.waiter_delay))raiseAirflowException("Waiter error: max attempts reached")