Source code for airflow.providers.amazon.aws.operators.sagemaker
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.importjsonimportsysfromtypingimportTYPE_CHECKING,Any,List,Optional,Sequencefrombotocore.exceptionsimportClientErrorfromairflow.exceptionsimportAirflowExceptionfromairflow.modelsimportBaseOperatorfromairflow.providers.amazon.aws.hooks.base_awsimportAwsBaseHookfromairflow.providers.amazon.aws.hooks.sagemakerimportSageMakerHookifsys.version_info>=(3,8):fromfunctoolsimportcached_propertyelse:fromcached_propertyimportcached_propertyifTYPE_CHECKING:fromairflow.utils.contextimportContext
[docs]classSageMakerBaseOperator(BaseOperator):"""This is the base operator for all SageMaker operators. :param config: The configuration necessary to start a training job (templated) :param aws_conn_id: The AWS connection ID to use. """
[docs]defparse_integer(self,config,field):"""Recursive method for parsing string fields holding integer values to integers."""iflen(field)==1:ifisinstance(config,list):forsub_configinconfig:self.parse_integer(sub_config,field)returnhead=field[0]ifheadinconfig:config[head]=int(config[head])returnifisinstance(config,list):forsub_configinconfig:self.parse_integer(sub_config,field)return(head,tail)=(field[0],field[1:])ifheadinconfig:self.parse_integer(config[head],tail)return
[docs]defparse_config_integers(self):""" Parse the integer fields of training config to integers in case the config is rendered by Jinja and all fields are str """forfieldinself.integer_fields:self.parse_integer(self.config,field)
[docs]defexpand_role(self):"""Placeholder for calling boto3's `expand_role`, which expands an IAM role name into an ARN."""
[docs]defpreprocess_config(self):"""Process the config into a usable form."""self.log.info('Preprocessing the config and doing required s3_operations')self.hook.configure_s3_resources(self.config)self.parse_config_integers()self.expand_role()self.log.info('After preprocessing the config is:\n%s',json.dumps(self.config,sort_keys=True,indent=4,separators=(',',': ')),
)
[docs]defexecute(self,context:'Context'):raiseNotImplementedError('Please implement execute() in sub class!')
[docs]classSageMakerProcessingOperator(SageMakerBaseOperator):"""Initiate a SageMaker processing job. This operator returns The ARN of the processing job created in Amazon SageMaker. :param config: The configuration necessary to start a processing job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job` :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: If wait is set to True, the time interval, in seconds, that the operation waits to check the status of the processing job. :param print_log: if the operator should print the cloudwatch log during processing :param check_interval: if wait is set to be true, this is the time interval in seconds which the operator will check the status of the processing job :param max_ingestion_time: If wait is set to True, the operation fails if the processing job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". """def__init__(self,*,config:dict,aws_conn_id:str,wait_for_completion:bool=True,print_log:bool=True,check_interval:int=30,max_ingestion_time:Optional[int]=None,action_if_job_exists:str='increment',**kwargs,):super().__init__(config=config,aws_conn_id=aws_conn_id,**kwargs)ifaction_if_job_existsnotin('increment','fail'):raiseAirflowException(f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ Provided value: '{action_if_job_exists}'.")self.action_if_job_exists=action_if_job_existsself.wait_for_completion=wait_for_completionself.print_log=print_logself.check_interval=check_intervalself.max_ingestion_time=max_ingestion_timeself._create_integer_fields()def_create_integer_fields(self)->None:"""Set fields which should be casted to integers."""self.integer_fields=[['ProcessingResources','ClusterConfig','InstanceCount'],['ProcessingResources','ClusterConfig','VolumeSizeInGB'],]if'StoppingCondition'inself.config:self.integer_fields+=[['StoppingCondition','MaxRuntimeInSeconds']]
[docs]defexecute(self,context:'Context')->dict:self.preprocess_config()processing_job_name=self.config['ProcessingJobName']ifself.hook.find_processing_job_by_name(processing_job_name):raiseAirflowException(f'A SageMaker processing job with name {processing_job_name} already exists.')self.log.info('Creating SageMaker processing job %s.',self.config['ProcessingJobName'])response=self.hook.create_processing_job(self.config,wait_for_completion=self.wait_for_completion,check_interval=self.check_interval,max_ingestion_time=self.max_ingestion_time,)ifresponse['ResponseMetadata']['HTTPStatusCode']!=200:raiseAirflowException(f'Sagemaker Processing Job creation failed: {response}')return{'Processing':self.hook.describe_processing_job(self.config['ProcessingJobName'])}
[docs]classSageMakerEndpointConfigOperator(SageMakerBaseOperator):""" Create a SageMaker endpoint config. This operator returns The ARN of the endpoint config created in Amazon SageMaker :param config: The configuration necessary to create an endpoint config. For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config` :param aws_conn_id: The AWS connection ID to use. """
[docs]classSageMakerEndpointOperator(SageMakerBaseOperator):""" Create a SageMaker endpoint. This operator returns The ARN of the endpoint created in Amazon SageMaker :param config: The configuration necessary to create an endpoint. If you need to create a SageMaker endpoint based on an existed SageMaker model and an existed SageMaker endpoint config:: config = endpoint_configuration; If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint:: config = { 'Model': model_configuration, 'EndpointConfig': endpoint_config_configuration, 'Endpoint': endpoint_configuration } For details of the configuration parameter of model_configuration see :py:meth:`SageMaker.Client.create_model` For details of the configuration parameter of endpoint_config_configuration see :py:meth:`SageMaker.Client.create_endpoint_config` For details of the configuration parameter of endpoint_configuration see :py:meth:`SageMaker.Client.create_endpoint` :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Whether the operator should wait until the endpoint creation finishes. :param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation waits before polling the status of the endpoint creation. :param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't finish within max_ingestion_time seconds. If you set this parameter to None it never times out. :param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'. """def__init__(self,*,config:dict,wait_for_completion:bool=True,check_interval:int=30,max_ingestion_time:Optional[int]=None,operation:str='create',**kwargs,):super().__init__(config=config,**kwargs)self.config=configself.wait_for_completion=wait_for_completionself.check_interval=check_intervalself.max_ingestion_time=max_ingestion_timeself.operation=operation.lower()ifself.operationnotin['create','update']:raiseValueError('Invalid value! Argument operation has to be one of "create" and "update"')self.create_integer_fields()
[docs]defcreate_integer_fields(self)->None:"""Set fields which should be casted to integers."""if'EndpointConfig'inself.config:self.integer_fields=[['EndpointConfig','ProductionVariants','InitialInstanceCount']]
[docs]defexecute(self,context:'Context')->dict:self.preprocess_config()model_info=self.config.get('Model')endpoint_config_info=self.config.get('EndpointConfig')endpoint_info=self.config.get('Endpoint',self.config)ifmodel_info:self.log.info('Creating SageMaker model %s.',model_info['ModelName'])self.hook.create_model(model_info)ifendpoint_config_info:self.log.info('Creating endpoint config %s.',endpoint_config_info['EndpointConfigName'])self.hook.create_endpoint_config(endpoint_config_info)ifself.operation=='create':sagemaker_operation=self.hook.create_endpointlog_str='Creating'elifself.operation=='update':sagemaker_operation=self.hook.update_endpointlog_str='Updating'else:raiseValueError('Invalid value! Argument operation has to be one of "create" and "update"')self.log.info('%s SageMaker endpoint %s.',log_str,endpoint_info['EndpointName'])try:response=sagemaker_operation(endpoint_info,wait_for_completion=self.wait_for_completion,check_interval=self.check_interval,max_ingestion_time=self.max_ingestion_time,)exceptClientError:self.operation='update'sagemaker_operation=self.hook.update_endpointlog_str='Updating'response=sagemaker_operation(endpoint_info,wait_for_completion=self.wait_for_completion,check_interval=self.check_interval,max_ingestion_time=self.max_ingestion_time,)ifresponse['ResponseMetadata']['HTTPStatusCode']!=200:raiseAirflowException(f'Sagemaker endpoint creation failed: {response}')else:return{'EndpointConfig':self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']),'Endpoint':self.hook.describe_endpoint(endpoint_info['EndpointName']),
}
[docs]classSageMakerTransformOperator(SageMakerBaseOperator):"""Initiate a SageMaker transform job. This operator returns The ARN of the model created in Amazon SageMaker. :param config: The configuration necessary to start a transform job (templated). If you need to create a SageMaker transform job based on an existed SageMaker model:: config = transform_config If you need to create both SageMaker model and SageMaker Transform job:: config = { 'Model': model_config, 'Transform': transform_config } For details of the configuration parameter of transform_config see :py:meth:`SageMaker.Client.create_transform_job` For details of the configuration parameter of model_config, See: :py:meth:`SageMaker.Client.create_model` :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Set to True to wait until the transform job finishes. :param check_interval: If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the transform job. :param max_ingestion_time: If wait is set to True, the operation fails if the transform job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. """def__init__(self,*,config:dict,wait_for_completion:bool=True,check_interval:int=30,max_ingestion_time:Optional[int]=None,**kwargs,):super().__init__(config=config,**kwargs)self.config=configself.wait_for_completion=wait_for_completionself.check_interval=check_intervalself.max_ingestion_time=max_ingestion_timeself.create_integer_fields()
[docs]defcreate_integer_fields(self)->None:"""Set fields which should be casted to integers."""self.integer_fields:List[List[str]]=[['Transform','TransformResources','InstanceCount'],['Transform','MaxConcurrentTransforms'],['Transform','MaxPayloadInMB'],]if'Transform'notinself.config:forfieldinself.integer_fields:field.pop(0)
[docs]defexecute(self,context:'Context')->dict:self.preprocess_config()model_config=self.config.get('Model')transform_config=self.config.get('Transform',self.config)ifmodel_config:self.log.info('Creating SageMaker Model %s for transform job',model_config['ModelName'])self.hook.create_model(model_config)self.log.info('Creating SageMaker transform Job %s.',transform_config['TransformJobName'])response=self.hook.create_transform_job(transform_config,wait_for_completion=self.wait_for_completion,check_interval=self.check_interval,max_ingestion_time=self.max_ingestion_time,)ifresponse['ResponseMetadata']['HTTPStatusCode']!=200:raiseAirflowException(f'Sagemaker transform Job creation failed: {response}')else:return{'Model':self.hook.describe_model(transform_config['ModelName']),'Transform':self.hook.describe_transform_job(transform_config['TransformJobName']),
}
[docs]classSageMakerTuningOperator(SageMakerBaseOperator):"""Initiate a SageMaker hyperparameter tuning job. This operator returns The ARN of the tuning job created in Amazon SageMaker. :param config: The configuration necessary to start a tuning job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job` :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Set to True to wait until the tuning job finishes. :param check_interval: If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the tuning job. :param max_ingestion_time: If wait is set to True, the operation fails if the tuning job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. """
[docs]classSageMakerModelOperator(SageMakerBaseOperator):"""Create a SageMaker model. This operator returns The ARN of the model created in Amazon SageMaker :param config: The configuration necessary to create a model. For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model` :param aws_conn_id: The AWS connection ID to use. """def__init__(self,*,config,**kwargs):super().__init__(config=config,**kwargs)self.config=config
[docs]defexecute(self,context:'Context')->dict:self.preprocess_config()self.log.info('Creating SageMaker Model %s.',self.config['ModelName'])response=self.hook.create_model(self.config)ifresponse['ResponseMetadata']['HTTPStatusCode']!=200:raiseAirflowException(f'Sagemaker model creation failed: {response}')else:return{'Model':self.hook.describe_model(self.config['ModelName'])}
[docs]classSageMakerTrainingOperator(SageMakerBaseOperator):""" Initiate a SageMaker training job. This operator returns The ARN of the training job created in Amazon SageMaker. :param config: The configuration necessary to start a training job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job` :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: If wait is set to True, the time interval, in seconds, that the operation waits to check the status of the training job. :param print_log: if the operator should print the cloudwatch log during training :param check_interval: if wait is set to be true, this is the time interval in seconds which the operator will check the status of the training job :param max_ingestion_time: If wait is set to True, the operation fails if the training job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. :param check_if_job_exists: If set to true, then the operator will check whether a training job already exists for the name in the config. :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". This is only relevant if check_if """
[docs]defexecute(self,context:'Context')->dict:self.preprocess_config()ifself.check_if_job_exists:self._check_if_job_exists()self.log.info('Creating SageMaker training job %s.',self.config['TrainingJobName'])response=self.hook.create_training_job(self.config,wait_for_completion=self.wait_for_completion,print_log=self.print_log,check_interval=self.check_interval,max_ingestion_time=self.max_ingestion_time,)ifresponse['ResponseMetadata']['HTTPStatusCode']!=200:raiseAirflowException(f'Sagemaker Training Job creation failed: {response}')else:return{'Training':self.hook.describe_training_job(self.config['TrainingJobName'])}
def_check_if_job_exists(self)->None:training_job_name=self.config['TrainingJobName']training_jobs=self.hook.list_training_jobs(name_contains=training_job_name)iftraining_job_namein[tj['TrainingJobName']fortjintraining_jobs]:ifself.action_if_job_exists=='increment':self.log.info("Found existing training job with name '%s'.",training_job_name)new_training_job_name=f'{training_job_name}-{(len(training_jobs)+1)}'self.config['TrainingJobName']=new_training_job_nameself.log.info("Incremented training job name to '%s'.",new_training_job_name)elifself.action_if_job_exists=='fail':raiseAirflowException(f'A SageMaker training job with name {training_job_name} already exists.'
)
[docs]classSageMakerDeleteModelOperator(SageMakerBaseOperator):"""Deletes a SageMaker model. This operator deletes the Model entry created in SageMaker. :param config: The configuration necessary to delete the model. For details of the configuration parameter see :py:meth:`SageMaker.Client.delete_model` :param aws_conn_id: The AWS connection ID to use. """def__init__(self,*,config,aws_conn_id:str,**kwargs):super().__init__(config=config,**kwargs)self.aws_conn_id=aws_conn_idself.config=config