Source code for airflow.providers.amazon.aws.triggers.sagemaker
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsfromfunctoolsimportcached_propertyfromtypingimportAnyfromairflow.providers.amazon.aws.hooks.sagemakerimportSageMakerHookfromairflow.providers.amazon.aws.utils.waiter_with_loggingimportasync_waitfromairflow.triggers.baseimportBaseTrigger,TriggerEvent
[docs]classSageMakerTrigger(BaseTrigger):""" SageMakerTrigger is fired as deferred class with params to run the task in triggerer. :param job_name: name of the job to check status :param job_type: Type of the sagemaker job whether it is Transform or Training :param poke_interval: polling period in seconds to check for the status :param max_attempts: Number of times to poll for query state before returning the current state, defaults to None. :param aws_conn_id: AWS connection ID for sagemaker """def__init__(self,job_name:str,job_type:str,poke_interval:int=30,max_attempts:int=480,aws_conn_id:str="aws_default",):super().__init__()self.job_name=job_nameself.job_type=job_typeself.poke_interval=poke_intervalself.max_attempts=max_attemptsself.aws_conn_id=aws_conn_id
[docs]defserialize(self)->tuple[str,dict[str,Any]]:"""Serializes SagemakerTrigger arguments and classpath."""return("airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger",{"job_name":self.job_name,"job_type":self.job_type,"poke_interval":self.poke_interval,"max_attempts":self.max_attempts,"aws_conn_id":self.aws_conn_id,},)
@staticmethoddef_get_job_type_waiter(job_type:str)->str:return{"training":"TrainingJobComplete","transform":"TransformJobComplete","processing":"ProcessingJobComplete","tuning":"TuningJobComplete","endpoint":"endpoint_in_service",# this one is provided by boto}[job_type.lower()]@staticmethoddef_get_waiter_arg_name(job_type:str)->str:return{"training":"TrainingJobName","transform":"TransformJobName","processing":"ProcessingJobName","tuning":"HyperParameterTuningJobName","endpoint":"EndpointName",}[job_type.lower()]@staticmethoddef_get_response_status_key(job_type:str)->str:return{"training":"TrainingJobStatus","transform":"TransformJobStatus","processing":"ProcessingJobStatus","tuning":"HyperParameterTuningJobStatus","endpoint":"EndpointStatus",}[job_type.lower()]
[docs]asyncdefrun(self):self.log.info("job name is %s and job type is %s",self.job_name,self.job_type)asyncwithself.hook.async_connasclient:waiter=self.hook.get_waiter(self._get_job_type_waiter(self.job_type),deferrable=True,client=client)awaitasync_wait(waiter=waiter,waiter_delay=self.poke_interval,waiter_max_attempts=self.max_attempts,args={self._get_waiter_arg_name(self.job_type):self.job_name},failure_message=f"Error while waiting for {self.job_type} job",status_message=f"{self.job_type} job not done yet",status_args=[self._get_response_status_key(self.job_type)],)yieldTriggerEvent({"status":"success","message":"Job completed."})