Source code for airflow.providers.amazon.aws.operators.sagemaker
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportjsonfromtypingimportTYPE_CHECKING,Any,Sequencefrombotocore.exceptionsimportClientErrorfromairflow.compat.functoolsimportcached_propertyfromairflow.exceptionsimportAirflowExceptionfromairflow.modelsimportBaseOperatorfromairflow.providers.amazon.aws.hooks.base_awsimportAwsBaseHookfromairflow.providers.amazon.aws.hooks.sagemakerimportSageMakerHookfromairflow.providers.amazon.aws.utils.sagemakerimportApprovalStatusfromairflow.utils.jsonimportAirflowJsonEncoderifTYPE_CHECKING:fromairflow.utils.contextimportContext
[docs]classSageMakerBaseOperator(BaseOperator):"""This is the base operator for all SageMaker operators. :param config: The configuration necessary to start a training job (templated) """
[docs]defparse_integer(self,config:dict,field:list[str]|str)->None:"""Recursive method for parsing string fields holding integer values to integers."""iflen(field)==1:ifisinstance(config,list):forsub_configinconfig:self.parse_integer(sub_config,field)returnhead=field[0]ifheadinconfig:config[head]=int(config[head])returnifisinstance(config,list):forsub_configinconfig:self.parse_integer(sub_config,field)return(head,tail)=(field[0],field[1:])ifheadinconfig:self.parse_integer(config[head],tail)return
[docs]defparse_config_integers(self)->None:"""Parse the integer fields to ints in case the config is rendered by Jinja and all fields are str."""forfieldinself.integer_fields:self.parse_integer(self.config,field)
[docs]defexpand_role(self)->None:"""Placeholder for calling boto3's `expand_role`, which expands an IAM role name into an ARN."""
[docs]defpreprocess_config(self)->None:"""Process the config into a usable form."""self._create_integer_fields()self.log.info("Preprocessing the config and doing required s3_operations")self.hook.configure_s3_resources(self.config)self.parse_config_integers()self.expand_role()self.log.info("After preprocessing the config is:\n%s",json.dumps(self.config,sort_keys=True,indent=4,separators=(",",": ")),
)def_create_integer_fields(self)->None:""" Set fields which should be cast to integers. Child classes should override this method if they need integer fields parsed. """self.integer_fields=[]
[docs]defexecute(self,context:Context):raiseNotImplementedError("Please implement execute() in sub class!")
[docs]classSageMakerProcessingOperator(SageMakerBaseOperator):""" Use Amazon SageMaker Processing to analyze data and evaluate machine learning models on Amazon SageMake. With Processing, you can use a simplified, managed experience on SageMaker to run your data processing workloads, such as feature engineering, data validation, model evaluation, and model interpretation. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerProcessingOperator` :param config: The configuration necessary to start a processing job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job` :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: If wait is set to True, the time interval, in seconds, that the operation waits to check the status of the processing job. :param print_log: if the operator should print the cloudwatch log during processing :param check_interval: if wait is set to be true, this is the time interval in seconds which the operator will check the status of the processing job :param max_ingestion_time: If wait is set to True, the operation fails if the processing job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". :return Dict: Returns The ARN of the processing job created in Amazon SageMaker. """def__init__(self,*,config:dict,aws_conn_id:str=DEFAULT_CONN_ID,wait_for_completion:bool=True,print_log:bool=True,check_interval:int=CHECK_INTERVAL_SECOND,max_ingestion_time:int|None=None,action_if_job_exists:str="increment",**kwargs,):super().__init__(config=config,aws_conn_id=aws_conn_id,**kwargs)ifaction_if_job_existsnotin("increment","fail"):raiseAirflowException(f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ Provided value: '{action_if_job_exists}'.")self.action_if_job_exists=action_if_job_existsself.wait_for_completion=wait_for_completionself.print_log=print_logself.check_interval=check_intervalself.max_ingestion_time=max_ingestion_timedef_create_integer_fields(self)->None:"""Set fields which should be cast to integers."""self.integer_fields:list[list[str]|list[list[str]]]=[["ProcessingResources","ClusterConfig","InstanceCount"],["ProcessingResources","ClusterConfig","VolumeSizeInGB"],]if"StoppingCondition"inself.config:self.integer_fields.append(["StoppingCondition","MaxRuntimeInSeconds"])
[docs]defexpand_role(self)->None:"""Expands an IAM role name into an ARN."""if"RoleArn"inself.config:hook=AwsBaseHook(self.aws_conn_id,client_type="iam")self.config["RoleArn"]=hook.expand_role(self.config["RoleArn"])
[docs]defexecute(self,context:Context)->dict:self.preprocess_config()processing_job_name=self.config["ProcessingJobName"]processing_job_dedupe_pattern="-[0-9]+$"existing_jobs_found=self.hook.count_processing_jobs_by_name(processing_job_name,processing_job_dedupe_pattern)ifexisting_jobs_found:ifself.action_if_job_exists=="fail":raiseAirflowException(f"A SageMaker processing job with name {processing_job_name} already exists.")elifself.action_if_job_exists=="increment":self.log.info("Found existing processing job with name '%s'.",processing_job_name)new_processing_job_name=f"{processing_job_name}-{existing_jobs_found+1}"self.config["ProcessingJobName"]=new_processing_job_nameself.log.info("Incremented processing job name to '%s'.",new_processing_job_name)response=self.hook.create_processing_job(self.config,wait_for_completion=self.wait_for_completion,check_interval=self.check_interval,max_ingestion_time=self.max_ingestion_time,)ifresponse["ResponseMetadata"]["HTTPStatusCode"]!=200:raiseAirflowException(f"Sagemaker Processing Job creation failed: {response}")return{"Processing":serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
[docs]classSageMakerEndpointConfigOperator(SageMakerBaseOperator):""" Creates an endpoint configuration that Amazon SageMaker hosting services uses to deploy models. In the configuration, you identify one or more models, created using the CreateModel API, to deploy and the resources that you want Amazon SageMaker to provision. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerEndpointConfigOperator` :param config: The configuration necessary to create an endpoint config. For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config` :param aws_conn_id: The AWS connection ID to use. :return Dict: Returns The ARN of the endpoint config created in Amazon SageMaker. """def__init__(self,*,config:dict,aws_conn_id:str=DEFAULT_CONN_ID,**kwargs,):super().__init__(config=config,aws_conn_id=aws_conn_id,**kwargs)def_create_integer_fields(self)->None:"""Set fields which should be cast to integers."""self.integer_fields:list[list[str]]=[["ProductionVariants","InitialInstanceCount"]]
[docs]classSageMakerEndpointOperator(SageMakerBaseOperator):""" When you create a serverless endpoint, SageMaker provisions and manages the compute resources for you. Then, you can make inference requests to the endpoint and receive model predictions in response. SageMaker scales the compute resources up and down as needed to handle your request traffic. Requires an Endpoint Config. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerEndpointOperator` :param config: The configuration necessary to create an endpoint. If you need to create a SageMaker endpoint based on an existed SageMaker model and an existed SageMaker endpoint config:: config = endpoint_configuration; If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint:: config = { 'Model': model_configuration, 'EndpointConfig': endpoint_config_configuration, 'Endpoint': endpoint_configuration } For details of the configuration parameter of model_configuration see :py:meth:`SageMaker.Client.create_model` For details of the configuration parameter of endpoint_config_configuration see :py:meth:`SageMaker.Client.create_endpoint_config` For details of the configuration parameter of endpoint_configuration see :py:meth:`SageMaker.Client.create_endpoint` :param wait_for_completion: Whether the operator should wait until the endpoint creation finishes. :param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation waits before polling the status of the endpoint creation. :param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't finish within max_ingestion_time seconds. If you set this parameter to None it never times out. :param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'. :param aws_conn_id: The AWS connection ID to use. :return Dict: Returns The ARN of the endpoint created in Amazon SageMaker. """def__init__(self,*,config:dict,aws_conn_id:str=DEFAULT_CONN_ID,wait_for_completion:bool=True,check_interval:int=CHECK_INTERVAL_SECOND,max_ingestion_time:int|None=None,operation:str="create",**kwargs,):super().__init__(config=config,aws_conn_id=aws_conn_id,**kwargs)self.wait_for_completion=wait_for_completionself.check_interval=check_intervalself.max_ingestion_time=max_ingestion_timeself.operation=operation.lower()ifself.operationnotin["create","update"]:raiseValueError('Invalid value! Argument operation has to be one of "create" and "update"')def_create_integer_fields(self)->None:"""Set fields which should be cast to integers."""if"EndpointConfig"inself.config:self.integer_fields:list[list[str]]=[["EndpointConfig","ProductionVariants","InitialInstanceCount"]]
[docs]defexpand_role(self)->None:"""Expands an IAM role name into an ARN."""if"Model"notinself.config:returnhook=AwsBaseHook(self.aws_conn_id,client_type="iam")config=self.config["Model"]if"ExecutionRoleArn"inconfig:config["ExecutionRoleArn"]=hook.expand_role(config["ExecutionRoleArn"])
[docs]defexecute(self,context:Context)->dict:self.preprocess_config()model_info=self.config.get("Model")endpoint_config_info=self.config.get("EndpointConfig")endpoint_info=self.config.get("Endpoint",self.config)ifmodel_info:self.log.info("Creating SageMaker model %s.",model_info["ModelName"])self.hook.create_model(model_info)ifendpoint_config_info:self.log.info("Creating endpoint config %s.",endpoint_config_info["EndpointConfigName"])self.hook.create_endpoint_config(endpoint_config_info)ifself.operation=="create":sagemaker_operation=self.hook.create_endpointlog_str="Creating"elifself.operation=="update":sagemaker_operation=self.hook.update_endpointlog_str="Updating"else:raiseValueError('Invalid value! Argument operation has to be one of "create" and "update"')self.log.info("%s SageMaker endpoint %s.",log_str,endpoint_info["EndpointName"])try:response=sagemaker_operation(endpoint_info,wait_for_completion=self.wait_for_completion,check_interval=self.check_interval,max_ingestion_time=self.max_ingestion_time,)exceptClientError:self.operation="update"sagemaker_operation=self.hook.update_endpointlog_str="Updating"response=sagemaker_operation(endpoint_info,wait_for_completion=self.wait_for_completion,check_interval=self.check_interval,max_ingestion_time=self.max_ingestion_time,)ifresponse["ResponseMetadata"]["HTTPStatusCode"]!=200:raiseAirflowException(f"Sagemaker endpoint creation failed: {response}")else:return{"EndpointConfig":serialize(self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])),"Endpoint":serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
}
[docs]classSageMakerTransformOperator(SageMakerBaseOperator):""" Starts a transform job. A transform job uses a trained model to get inferences on a dataset and saves these results to an Amazon S3 location that you specify. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerTransformOperator` :param config: The configuration necessary to start a transform job (templated). If you need to create a SageMaker transform job based on an existed SageMaker model:: config = transform_config If you need to create both SageMaker model and SageMaker Transform job:: config = { 'Model': model_config, 'Transform': transform_config } For details of the configuration parameter of transform_config see :py:meth:`SageMaker.Client.create_transform_job` For details of the configuration parameter of model_config, See: :py:meth:`SageMaker.Client.create_model` :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Set to True to wait until the transform job finishes. :param check_interval: If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the transform job. :param max_ingestion_time: If wait is set to True, the operation fails if the transform job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. :param check_if_job_exists: If set to true, then the operator will check whether a transform job already exists for the name in the config. :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". This is only relevant if check_if_job_exists is True. :return Dict: Returns The ARN of the model created in Amazon SageMaker. """def__init__(self,*,config:dict,aws_conn_id:str=DEFAULT_CONN_ID,wait_for_completion:bool=True,check_interval:int=CHECK_INTERVAL_SECOND,max_ingestion_time:int|None=None,check_if_job_exists:bool=True,action_if_job_exists:str="increment",**kwargs,):super().__init__(config=config,aws_conn_id=aws_conn_id,**kwargs)self.wait_for_completion=wait_for_completionself.check_interval=check_intervalself.max_ingestion_time=max_ingestion_timeself.check_if_job_exists=check_if_job_existsifaction_if_job_existsin("increment","fail"):self.action_if_job_exists=action_if_job_existselse:raiseAirflowException(f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ Provided value: '{action_if_job_exists}'.")def_create_integer_fields(self)->None:"""Set fields which should be cast to integers."""self.integer_fields:list[list[str]]=[["Transform","TransformResources","InstanceCount"],["Transform","MaxConcurrentTransforms"],["Transform","MaxPayloadInMB"],]if"Transform"notinself.config:forfieldinself.integer_fields:field.pop(0)
[docs]defexpand_role(self)->None:"""Expands an IAM role name into an ARN."""if"Model"notinself.config:returnconfig=self.config["Model"]if"ExecutionRoleArn"inconfig:hook=AwsBaseHook(self.aws_conn_id,client_type="iam")config["ExecutionRoleArn"]=hook.expand_role(config["ExecutionRoleArn"])
[docs]defexecute(self,context:Context)->dict:self.preprocess_config()model_config=self.config.get("Model")transform_config=self.config.get("Transform",self.config)ifself.check_if_job_exists:self._check_if_transform_job_exists()ifmodel_config:self.log.info("Creating SageMaker Model %s for transform job",model_config["ModelName"])self.hook.create_model(model_config)self.log.info("Creating SageMaker transform Job %s.",transform_config["TransformJobName"])response=self.hook.create_transform_job(transform_config,wait_for_completion=self.wait_for_completion,check_interval=self.check_interval,max_ingestion_time=self.max_ingestion_time,)ifresponse["ResponseMetadata"]["HTTPStatusCode"]!=200:raiseAirflowException(f"Sagemaker transform Job creation failed: {response}")else:return{"Model":serialize(self.hook.describe_model(transform_config["ModelName"])),"Transform":serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])
),}def_check_if_transform_job_exists(self)->None:transform_config=self.config.get("Transform",self.config)transform_job_name=transform_config["TransformJobName"]transform_jobs=self.hook.list_transform_jobs(name_contains=transform_job_name)iftransform_job_namein[tj["TransformJobName"]fortjintransform_jobs]:ifself.action_if_job_exists=="increment":self.log.info("Found existing transform job with name '%s'.",transform_job_name)new_transform_job_name=f"{transform_job_name}-{(len(transform_jobs)+1)}"transform_config["TransformJobName"]=new_transform_job_nameself.log.info("Incremented transform job name to '%s'.",new_transform_job_name)elifself.action_if_job_exists=="fail":raiseAirflowException(f"A SageMaker transform job with name {transform_job_name} already exists."
)
[docs]classSageMakerTuningOperator(SageMakerBaseOperator):""" Starts a hyperparameter tuning job. A hyperparameter tuning job finds the best version of a model by running many training jobs on your dataset using the algorithm you choose and values for hyperparameters within ranges that you specify. It then chooses the hyperparameter values that result in a model that performs the best, as measured by an objective metric that you choose. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerTuningOperator` :param config: The configuration necessary to start a tuning job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job` :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: Set to True to wait until the tuning job finishes. :param check_interval: If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the tuning job. :param max_ingestion_time: If wait is set to True, the operation fails if the tuning job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. :return Dict: Returns The ARN of the tuning job created in Amazon SageMaker. """def__init__(self,*,config:dict,aws_conn_id:str=DEFAULT_CONN_ID,wait_for_completion:bool=True,check_interval:int=CHECK_INTERVAL_SECOND,max_ingestion_time:int|None=None,**kwargs,):super().__init__(config=config,aws_conn_id=aws_conn_id,**kwargs)self.wait_for_completion=wait_for_completionself.check_interval=check_intervalself.max_ingestion_time=max_ingestion_time
[docs]defexpand_role(self)->None:"""Expands an IAM role name into an ARN."""if"TrainingJobDefinition"inself.config:config=self.config["TrainingJobDefinition"]if"RoleArn"inconfig:hook=AwsBaseHook(self.aws_conn_id,client_type="iam")config["RoleArn"]=hook.expand_role(config["RoleArn"])
def_create_integer_fields(self)->None:"""Set fields which should be cast to integers."""self.integer_fields:list[list[str]]=[["HyperParameterTuningJobConfig","ResourceLimits","MaxNumberOfTrainingJobs"],["HyperParameterTuningJobConfig","ResourceLimits","MaxParallelTrainingJobs"],["TrainingJobDefinition","ResourceConfig","InstanceCount"],["TrainingJobDefinition","ResourceConfig","VolumeSizeInGB"],["TrainingJobDefinition","StoppingCondition","MaxRuntimeInSeconds"],]
[docs]classSageMakerModelOperator(SageMakerBaseOperator):""" Creates a model in Amazon SageMaker. In the request, you name the model and describe a primary container. For the primary container, you specify the Docker image that contains inference code, artifacts (from prior training), and a custom environment map that the inference code uses when you deploy the model for predictions. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerModelOperator` :param config: The configuration necessary to create a model. For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model` :param aws_conn_id: The AWS connection ID to use. :return Dict: Returns The ARN of the model created in Amazon SageMaker. """def__init__(self,*,config:dict,aws_conn_id:str=DEFAULT_CONN_ID,**kwargs):super().__init__(config=config,aws_conn_id=aws_conn_id,**kwargs)
[docs]defexpand_role(self)->None:"""Expands an IAM role name into an ARN."""if"ExecutionRoleArn"inself.config:hook=AwsBaseHook(self.aws_conn_id,client_type="iam")self.config["ExecutionRoleArn"]=hook.expand_role(self.config["ExecutionRoleArn"])
[docs]defexecute(self,context:Context)->dict:self.preprocess_config()self.log.info("Creating SageMaker Model %s.",self.config["ModelName"])response=self.hook.create_model(self.config)ifresponse["ResponseMetadata"]["HTTPStatusCode"]!=200:raiseAirflowException(f"Sagemaker model creation failed: {response}")else:return{"Model":serialize(self.hook.describe_model(self.config["ModelName"]))}
[docs]classSageMakerTrainingOperator(SageMakerBaseOperator):""" Starts a model training job. After training completes, Amazon SageMaker saves the resulting model artifacts to an Amazon S3 location that you specify. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerTrainingOperator` :param config: The configuration necessary to start a training job (templated). For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job` :param aws_conn_id: The AWS connection ID to use. :param wait_for_completion: If wait is set to True, the time interval, in seconds, that the operation waits to check the status of the training job. :param print_log: if the operator should print the cloudwatch log during training :param check_interval: if wait is set to be true, this is the time interval in seconds which the operator will check the status of the training job :param max_ingestion_time: If wait is set to True, the operation fails if the training job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. :param check_if_job_exists: If set to true, then the operator will check whether a training job already exists for the name in the config. :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". This is only relevant if check_if_job_exists is True. :return Dict: Returns The ARN of the training job created in Amazon SageMaker. """def__init__(self,*,config:dict,aws_conn_id:str=DEFAULT_CONN_ID,wait_for_completion:bool=True,print_log:bool=True,check_interval:int=CHECK_INTERVAL_SECOND,max_ingestion_time:int|None=None,check_if_job_exists:bool=True,action_if_job_exists:str="increment",**kwargs,):super().__init__(config=config,aws_conn_id=aws_conn_id,**kwargs)self.wait_for_completion=wait_for_completionself.print_log=print_logself.check_interval=check_intervalself.max_ingestion_time=max_ingestion_timeself.check_if_job_exists=check_if_job_existsifaction_if_job_existsin("increment","fail"):self.action_if_job_exists=action_if_job_existselse:raiseAirflowException(f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ Provided value: '{action_if_job_exists}'.")
[docs]defexpand_role(self)->None:"""Expands an IAM role name into an ARN."""if"RoleArn"inself.config:hook=AwsBaseHook(self.aws_conn_id,client_type="iam")self.config["RoleArn"]=hook.expand_role(self.config["RoleArn"])
def_create_integer_fields(self)->None:"""Set fields which should be cast to integers."""self.integer_fields:list[list[str]]=[["ResourceConfig","InstanceCount"],["ResourceConfig","VolumeSizeInGB"],["StoppingCondition","MaxRuntimeInSeconds"],]
[docs]defexecute(self,context:Context)->dict:self.preprocess_config()ifself.check_if_job_exists:self._check_if_job_exists()self.log.info("Creating SageMaker training job %s.",self.config["TrainingJobName"])response=self.hook.create_training_job(self.config,wait_for_completion=self.wait_for_completion,print_log=self.print_log,check_interval=self.check_interval,max_ingestion_time=self.max_ingestion_time,)ifresponse["ResponseMetadata"]["HTTPStatusCode"]!=200:raiseAirflowException(f"Sagemaker Training Job creation failed: {response}")else:return{"Training":serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
def_check_if_job_exists(self)->None:training_job_name=self.config["TrainingJobName"]training_jobs=self.hook.list_training_jobs(name_contains=training_job_name)iftraining_job_namein[tj["TrainingJobName"]fortjintraining_jobs]:ifself.action_if_job_exists=="increment":self.log.info("Found existing training job with name '%s'.",training_job_name)new_training_job_name=f"{training_job_name}-{(len(training_jobs)+1)}"self.config["TrainingJobName"]=new_training_job_nameself.log.info("Incremented training job name to '%s'.",new_training_job_name)elifself.action_if_job_exists=="fail":raiseAirflowException(f"A SageMaker training job with name {training_job_name} already exists."
)
[docs]classSageMakerDeleteModelOperator(SageMakerBaseOperator):""" Deletes a SageMaker model. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerDeleteModelOperator` :param config: The configuration necessary to delete the model. For details of the configuration parameter see :py:meth:`SageMaker.Client.delete_model` :param aws_conn_id: The AWS connection ID to use. """def__init__(self,*,config:dict,aws_conn_id:str=DEFAULT_CONN_ID,**kwargs):super().__init__(config=config,aws_conn_id=aws_conn_id,**kwargs)
[docs]classSageMakerStartPipelineOperator(SageMakerBaseOperator):""" Starts a SageMaker pipeline execution. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerStartPipelineOperator` :param config: The configuration to start the pipeline execution. :param aws_conn_id: The AWS connection ID to use. :param pipeline_name: Name of the pipeline to start. :param display_name: The name this pipeline execution will have in the UI. Doesn't need to be unique. :param pipeline_params: Optional parameters for the pipeline. All parameters supplied need to already be present in the pipeline definition. :param wait_for_completion: If true, this operator will only complete once the pipeline is complete. :param check_interval: How long to wait between checks for pipeline status when waiting for completion. :param verbose: Whether to print steps details when waiting for completion. Defaults to true, consider turning off for pipelines that have thousands of steps. :return str: Returns The ARN of the pipeline execution created in Amazon SageMaker. """
[docs]defexecute(self,context:Context)->str:arn=self.hook.start_pipeline(pipeline_name=self.pipeline_name,display_name=self.display_name,pipeline_params=self.pipeline_params,wait_for_completion=self.wait_for_completion,check_interval=self.check_interval,verbose=self.verbose,)self.log.info("Starting a new execution for pipeline %s, running with ARN %s",self.pipeline_name,arn)returnarn
[docs]classSageMakerStopPipelineOperator(SageMakerBaseOperator):""" Stops a SageMaker pipeline execution. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerStopPipelineOperator` :param config: The configuration to start the pipeline execution. :param aws_conn_id: The AWS connection ID to use. :param pipeline_exec_arn: Amazon Resource Name of the pipeline execution to stop. :param wait_for_completion: If true, this operator will only complete once the pipeline is fully stopped. :param check_interval: How long to wait between checks for pipeline status when waiting for completion. :param verbose: Whether to print steps details when waiting for completion. Defaults to true, consider turning off for pipelines that have thousands of steps. :param fail_if_not_running: raises an exception if the pipeline stopped or succeeded before this was run :return str: Returns the status of the pipeline execution after the operation has been done. """
[docs]defexecute(self,context:Context)->str:status=self.hook.stop_pipeline(pipeline_exec_arn=self.pipeline_exec_arn,wait_for_completion=self.wait_for_completion,check_interval=self.check_interval,verbose=self.verbose,fail_if_not_running=self.fail_if_not_running,)self.log.info("Stop requested for pipeline execution with ARN %s. Status is now %s",self.pipeline_exec_arn,status,)returnstatus
[docs]classSageMakerRegisterModelVersionOperator(SageMakerBaseOperator):""" Registers an Amazon SageMaker model by creating a model version that specifies the model group to which it belongs. Will create the model group if it does not exist already. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SageMakerRegisterModelVersionOperator` :param image_uri: The Amazon EC2 Container Registry (Amazon ECR) path where inference code is stored. :param model_url: The Amazon S3 path where the model artifacts (the trained weights of the model), which result from model training, are stored. This path must point to a single gzip compressed tar archive (.tar.gz suffix). :param package_group_name: The name of the model package group that the model is going to be registered to. Will be created if it doesn't already exist. :param package_group_desc: Description of the model package group, if it was to be created (optional). :param package_desc: Description of the model package (optional). :param model_approval: Approval status of the model package. Defaults to PendingManualApproval :param extras: Can contain extra parameters for the boto call to create_model_package, and/or overrides for any parameter defined above. For a complete list of available parameters, see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model_package :return str: Returns the ARN of the model package created. """
[docs]defexecute(self,context:Context):# create a model package group if it does not existgroup_created=self.hook.create_model_package_group(self.package_group_name,self.package_desc)# then create a model package in that groupinput_dict={"InferenceSpecification":{"Containers":[{"Image":self.image_uri,"ModelDataUrl":self.model_url,}],"SupportedContentTypes":["text/csv"],"SupportedResponseMIMETypes":["text/csv"],},"ModelPackageGroupName":self.package_group_name,"ModelPackageDescription":self.package_desc,"ModelApprovalStatus":self.model_approval.value,}ifself.extras:input_dict.update(self.extras)# overrides config above if keys are redefined in extrastry:res=self.hook.conn.create_model_package(**input_dict)returnres["ModelPackageArn"]exceptClientError:# rollback group creation if adding the model to it was not successfulifgroup_created:self.hook.conn.delete_model_package_group(ModelPackageGroupName=self.package_group_name)raise