airflow.providers.amazon.aws.operators.sagemaker

Module Contents

Classes

SageMakerBaseOperator

This is the base operator for all SageMaker operators.

SageMakerProcessingOperator

Use Amazon SageMaker Processing to analyze data and evaluate machine learning models on Amazon SageMaker.

SageMakerEndpointConfigOperator

Creates an endpoint configuration that Amazon SageMaker hosting services uses to deploy models.

SageMakerEndpointOperator

When you create a serverless endpoint, SageMaker provisions and manages the compute resources for you.

SageMakerTransformOperator

Starts a transform job.

SageMakerTuningOperator

Starts a hyperparameter tuning job.

SageMakerModelOperator

Creates a model in Amazon SageMaker.

SageMakerTrainingOperator

Starts a model training job.

SageMakerDeleteModelOperator

Deletes a SageMaker model.

SageMakerStartPipelineOperator

Starts a SageMaker pipeline execution.

SageMakerStopPipelineOperator

Stops a SageMaker pipeline execution.

SageMakerRegisterModelVersionOperator

Register a SageMaker model by creating a model version that specifies the model group to which it belongs.

SageMakerAutoMLOperator

Creates an auto ML job, learning to predict the given column from the data provided through S3.

SageMakerCreateExperimentOperator

Creates a SageMaker experiment, to be then associated to jobs etc.

Functions

serialize(result)

Attributes

DEFAULT_CONN_ID

CHECK_INTERVAL_SECOND

airflow.providers.amazon.aws.operators.sagemaker.DEFAULT_CONN_ID: str = 'aws_default'[source]
airflow.providers.amazon.aws.operators.sagemaker.CHECK_INTERVAL_SECOND: int = 30[source]
airflow.providers.amazon.aws.operators.sagemaker.serialize(result)[source]
class airflow.providers.amazon.aws.operators.sagemaker.SageMakerBaseOperator(*, config, aws_conn_id=DEFAULT_CONN_ID, **kwargs)[source]

Bases: airflow.models.BaseOperator

This is the base operator for all SageMaker operators.

Parameters

config (dict) – The configuration necessary to start a training job (templated)

template_fields: Sequence[str] = ('config',)[source]
template_ext: Sequence[str] = ()[source]
template_fields_renderers: dict[source]
ui_color: str = '#ededed'[source]
integer_fields: list[list[Any]] = [][source]
parse_integer(config, field)[source]

Recursive method for parsing string fields holding integer values to integers.

parse_config_integers()[source]

Parse the integer fields to ints in case the config is rendered by Jinja and all fields are str.

expand_role()[source]

Placeholder for calling boto3’s expand_role, which expands an IAM role name into an ARN.

preprocess_config()[source]

Process the config into a usable form.

abstract execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

hook()[source]

Return SageMakerHook.

class airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator(*, config, aws_conn_id=DEFAULT_CONN_ID, wait_for_completion=True, print_log=True, check_interval=CHECK_INTERVAL_SECOND, max_attempts=None, max_ingestion_time=None, action_if_job_exists='timestamp', deferrable=conf.getboolean('operators', 'default_deferrable', fallback=False), **kwargs)[source]

Bases: SageMakerBaseOperator

Use Amazon SageMaker Processing to analyze data and evaluate machine learning models on Amazon SageMaker.

With Processing, you can use a simplified, managed experience on SageMaker to run your data processing workloads, such as feature engineering, data validation, model evaluation, and model interpretation.

See also

For more information on how to use this operator, take a look at the guide: Create an Amazon SageMaker processing job

Parameters
  • config (dict) – The configuration necessary to start a processing job (templated). For details of the configuration parameter see SageMaker.Client.create_processing_job()

  • aws_conn_id (str) – The AWS connection ID to use.

  • wait_for_completion (bool) – If wait is set to True, the time interval, in seconds, that the operation waits to check the status of the processing job.

  • print_log (bool) – if the operator should print the cloudwatch log during processing

  • check_interval (int) – if wait is set to be true, this is the time interval in seconds which the operator will check the status of the processing job

  • max_attempts (int | None) – Number of times to poll for query state before returning the current state, defaults to None.

  • max_ingestion_time (int | None) – If wait is set to True, the operation fails if the processing job doesn’t finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout.

  • action_if_job_exists (str) – Behaviour if the job name already exists. Possible options are “timestamp” (default), “increment” (deprecated) and “fail”.

  • deferrable (bool) – Run operator in the deferrable mode. This is only effective if wait_for_completion is set to True.

Return Dict

Returns The ARN of the processing job created in Amazon SageMaker.

expand_role()[source]

Expands an IAM role name into an ARN.

execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

execute_complete(context, event=None)[source]
class airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointConfigOperator(*, config, aws_conn_id=DEFAULT_CONN_ID, **kwargs)[source]

Bases: SageMakerBaseOperator

Creates an endpoint configuration that Amazon SageMaker hosting services uses to deploy models.

In the configuration, you identify one or more models, created using the CreateModel API, to deploy and the resources that you want Amazon SageMaker to provision.

See also

For more information on how to use this operator, take a look at the guide: Create an Amazon SageMaker endpoint config job

Parameters
Return Dict

Returns The ARN of the endpoint config created in Amazon SageMaker.

execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.amazon.aws.operators.sagemaker.SageMakerEndpointOperator(*, config, aws_conn_id=DEFAULT_CONN_ID, wait_for_completion=True, check_interval=CHECK_INTERVAL_SECOND, max_ingestion_time=None, operation='create', deferrable=conf.getboolean('operators', 'default_deferrable', fallback=False), **kwargs)[source]

Bases: SageMakerBaseOperator

When you create a serverless endpoint, SageMaker provisions and manages the compute resources for you.

Then, you can make inference requests to the endpoint and receive model predictions in response. SageMaker scales the compute resources up and down as needed to handle your request traffic.

Requires an Endpoint Config.

See also

For more information on how to use this operator, take a look at the guide: Create an Amazon SageMaker endpoint job

Parameters
  • config (dict) –

    The configuration necessary to create an endpoint.

    If you need to create a SageMaker endpoint based on an existed SageMaker model and an existed SageMaker endpoint config:

    config = endpoint_configuration;
    

    If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint:

    config = {
        'Model': model_configuration,
        'EndpointConfig': endpoint_config_configuration,
        'Endpoint': endpoint_configuration
    }
    

    For details of the configuration parameter of model_configuration see SageMaker.Client.create_model()

    For details of the configuration parameter of endpoint_config_configuration see SageMaker.Client.create_endpoint_config()

    For details of the configuration parameter of endpoint_configuration see SageMaker.Client.create_endpoint()

  • wait_for_completion (bool) – Whether the operator should wait until the endpoint creation finishes.

  • check_interval (int) – If wait is set to True, this is the time interval, in seconds, that this operation waits before polling the status of the endpoint creation.

  • max_ingestion_time (int | None) – If wait is set to True, this operation fails if the endpoint creation doesn’t finish within max_ingestion_time seconds. If you set this parameter to None it never times out.

  • operation (str) – Whether to create an endpoint or update an endpoint. Must be either ‘create or ‘update’.

  • aws_conn_id (str) – The AWS connection ID to use.

  • deferrable (bool) – Will wait asynchronously for completion.

Return Dict

Returns The ARN of the endpoint created in Amazon SageMaker.

expand_role()[source]

Expands an IAM role name into an ARN.

execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

execute_complete(context, event=None)[source]
class airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator(*, config, aws_conn_id=DEFAULT_CONN_ID, wait_for_completion=True, check_interval=CHECK_INTERVAL_SECOND, max_attempts=None, max_ingestion_time=None, check_if_job_exists=True, action_if_job_exists='timestamp', deferrable=conf.getboolean('operators', 'default_deferrable', fallback=False), **kwargs)[source]

Bases: SageMakerBaseOperator

Starts a transform job.

A transform job uses a trained model to get inferences on a dataset and saves these results to an Amazon S3 location that you specify.

See also

For more information on how to use this operator, take a look at the guide: Create an Amazon SageMaker transform job

Parameters
  • config (dict) –

    The configuration necessary to start a transform job (templated).

    If you need to create a SageMaker transform job based on an existed SageMaker model:

    config = transform_config
    

    If you need to create both SageMaker model and SageMaker Transform job:

    config = {
        'Model': model_config,
        'Transform': transform_config
    }
    

    For details of the configuration parameter of transform_config see SageMaker.Client.create_transform_job()

    For details of the configuration parameter of model_config, See: SageMaker.Client.create_model()

  • aws_conn_id (str) – The AWS connection ID to use.

  • wait_for_completion (bool) – Set to True to wait until the transform job finishes.

  • check_interval (int) – If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the transform job.

  • max_attempts (int | None) – Number of times to poll for query state before returning the current state, defaults to None.

  • max_ingestion_time (int | None) – If wait is set to True, the operation fails if the transform job doesn’t finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout.

  • check_if_job_exists (bool) – If set to true, then the operator will check whether a transform job already exists for the name in the config.

  • action_if_job_exists (str) – Behaviour if the job name already exists. Possible options are “timestamp” (default), “increment” (deprecated) and “fail”. This is only relevant if check_if_job_exists is True.

Return Dict

Returns The ARN of the model created in Amazon SageMaker.

expand_role()[source]

Expands an IAM role name into an ARN.

execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

execute_complete(context, event=None)[source]
class airflow.providers.amazon.aws.operators.sagemaker.SageMakerTuningOperator(*, config, aws_conn_id=DEFAULT_CONN_ID, wait_for_completion=True, check_interval=CHECK_INTERVAL_SECOND, max_ingestion_time=None, deferrable=conf.getboolean('operators', 'default_deferrable', fallback=False), **kwargs)[source]

Bases: SageMakerBaseOperator

Starts a hyperparameter tuning job.

A hyperparameter tuning job finds the best version of a model by running many training jobs on your dataset using the algorithm you choose and values for hyperparameters within ranges that you specify. It then chooses the hyperparameter values that result in a model that performs the best, as measured by an objective metric that you choose.

See also

For more information on how to use this operator, take a look at the guide: Start a hyperparameter tuning job

Parameters
  • config (dict) –

    The configuration necessary to start a tuning job (templated).

    For details of the configuration parameter see SageMaker.Client.create_hyper_parameter_tuning_job()

  • aws_conn_id (str) – The AWS connection ID to use.

  • wait_for_completion (bool) – Set to True to wait until the tuning job finishes.

  • check_interval (int) – If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the tuning job.

  • max_ingestion_time (int | None) – If wait is set to True, the operation fails if the tuning job doesn’t finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout.

  • deferrable (bool) – Will wait asynchronously for completion.

Return Dict

Returns The ARN of the tuning job created in Amazon SageMaker.

expand_role()[source]

Expands an IAM role name into an ARN.

execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

execute_complete(context, event=None)[source]
class airflow.providers.amazon.aws.operators.sagemaker.SageMakerModelOperator(*, config, aws_conn_id=DEFAULT_CONN_ID, **kwargs)[source]

Bases: SageMakerBaseOperator

Creates a model in Amazon SageMaker.

In the request, you name the model and describe a primary container. For the primary container, you specify the Docker image that contains inference code, artifacts (from prior training), and a custom environment map that the inference code uses when you deploy the model for predictions.

See also

For more information on how to use this operator, take a look at the guide: Create an Amazon SageMaker model

Parameters
  • config (dict) –

    The configuration necessary to create a model.

    For details of the configuration parameter see SageMaker.Client.create_model()

  • aws_conn_id (str) – The AWS connection ID to use.

Return Dict

Returns The ARN of the model created in Amazon SageMaker.

expand_role()[source]

Expands an IAM role name into an ARN.

execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.amazon.aws.operators.sagemaker.SageMakerTrainingOperator(*, config, aws_conn_id=DEFAULT_CONN_ID, wait_for_completion=True, print_log=True, check_interval=CHECK_INTERVAL_SECOND, max_attempts=None, max_ingestion_time=None, check_if_job_exists=True, action_if_job_exists='timestamp', deferrable=conf.getboolean('operators', 'default_deferrable', fallback=False), **kwargs)[source]

Bases: SageMakerBaseOperator

Starts a model training job.

After training completes, Amazon SageMaker saves the resulting model artifacts to an Amazon S3 location that you specify.

See also

For more information on how to use this operator, take a look at the guide: Create an Amazon SageMaker training job

Parameters
  • config (dict) –

    The configuration necessary to start a training job (templated).

    For details of the configuration parameter see SageMaker.Client.create_training_job()

  • aws_conn_id (str) – The AWS connection ID to use.

  • wait_for_completion (bool) – If wait is set to True, the time interval, in seconds, that the operation waits to check the status of the training job.

  • print_log (bool) – if the operator should print the cloudwatch log during training

  • check_interval (int) – if wait is set to be true, this is the time interval in seconds which the operator will check the status of the training job

  • max_attempts (int | None) – Number of times to poll for query state before returning the current state, defaults to None.

  • max_ingestion_time (int | None) – If wait is set to True, the operation fails if the training job doesn’t finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout.

  • check_if_job_exists (bool) – If set to true, then the operator will check whether a training job already exists for the name in the config.

  • action_if_job_exists (str) – Behaviour if the job name already exists. Possible options are “timestamp” (default), “increment” (deprecated) and “fail”. This is only relevant if check_if_job_exists is True.

  • deferrable (bool) – Run operator in the deferrable mode. This is only effective if wait_for_completion is set to True.

Return Dict

Returns The ARN of the training job created in Amazon SageMaker.

expand_role()[source]

Expands an IAM role name into an ARN.

execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

execute_complete(context, event=None)[source]
class airflow.providers.amazon.aws.operators.sagemaker.SageMakerDeleteModelOperator(*, config, aws_conn_id=DEFAULT_CONN_ID, **kwargs)[source]

Bases: SageMakerBaseOperator

Deletes a SageMaker model.

See also

For more information on how to use this operator, take a look at the guide: Delete an Amazon SageMaker model

Parameters
  • config (dict) – The configuration necessary to delete the model. For details of the configuration parameter see SageMaker.Client.delete_model()

  • aws_conn_id (str) – The AWS connection ID to use.

execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.amazon.aws.operators.sagemaker.SageMakerStartPipelineOperator(*, aws_conn_id=DEFAULT_CONN_ID, pipeline_name, display_name='airflow-triggered-execution', pipeline_params=None, wait_for_completion=False, check_interval=CHECK_INTERVAL_SECOND, verbose=True, **kwargs)[source]

Bases: SageMakerBaseOperator

Starts a SageMaker pipeline execution.

See also

For more information on how to use this operator, take a look at the guide: Start an Amazon SageMaker pipeline execution

Parameters
  • config – The configuration to start the pipeline execution.

  • aws_conn_id (str) – The AWS connection ID to use.

  • pipeline_name (str) – Name of the pipeline to start.

  • display_name (str) – The name this pipeline execution will have in the UI. Doesn’t need to be unique.

  • pipeline_params (dict | None) – Optional parameters for the pipeline. All parameters supplied need to already be present in the pipeline definition.

  • wait_for_completion (bool) – If true, this operator will only complete once the pipeline is complete.

  • check_interval (int) – How long to wait between checks for pipeline status when waiting for completion.

  • verbose (bool) – Whether to print steps details when waiting for completion. Defaults to true, consider turning off for pipelines that have thousands of steps.

Return str

Returns The ARN of the pipeline execution created in Amazon SageMaker.

template_fields: Sequence[str] = ('aws_conn_id', 'pipeline_name', 'display_name', 'pipeline_params')[source]
execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.amazon.aws.operators.sagemaker.SageMakerStopPipelineOperator(*, aws_conn_id=DEFAULT_CONN_ID, pipeline_exec_arn, wait_for_completion=False, check_interval=CHECK_INTERVAL_SECOND, verbose=True, fail_if_not_running=False, **kwargs)[source]

Bases: SageMakerBaseOperator

Stops a SageMaker pipeline execution.

See also

For more information on how to use this operator, take a look at the guide: Stop an Amazon SageMaker pipeline execution

Parameters
  • config – The configuration to start the pipeline execution.

  • aws_conn_id (str) – The AWS connection ID to use.

  • pipeline_exec_arn (str) – Amazon Resource Name of the pipeline execution to stop.

  • wait_for_completion (bool) – If true, this operator will only complete once the pipeline is fully stopped.

  • check_interval (int) – How long to wait between checks for pipeline status when waiting for completion.

  • verbose (bool) – Whether to print steps details when waiting for completion. Defaults to true, consider turning off for pipelines that have thousands of steps.

  • fail_if_not_running (bool) – raises an exception if the pipeline stopped or succeeded before this was run

Return str

Returns the status of the pipeline execution after the operation has been done.

template_fields: Sequence[str] = ('aws_conn_id', 'pipeline_exec_arn')[source]
execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.amazon.aws.operators.sagemaker.SageMakerRegisterModelVersionOperator(*, image_uri, model_url, package_group_name, package_group_desc='', package_desc='', model_approval=ApprovalStatus.PENDING_MANUAL_APPROVAL, extras=None, aws_conn_id=DEFAULT_CONN_ID, config=None, **kwargs)[source]

Bases: SageMakerBaseOperator

Register a SageMaker model by creating a model version that specifies the model group to which it belongs.

Will create the model group if it does not exist already.

See also

For more information on how to use this operator, take a look at the guide: Register a Sagemaker Model Version

Parameters
  • image_uri (str) – The Amazon EC2 Container Registry (Amazon ECR) path where inference code is stored.

  • model_url (str) – The Amazon S3 path where the model artifacts (the trained weights of the model), which result from model training, are stored. This path must point to a single gzip compressed tar archive (.tar.gz suffix).

  • package_group_name (str) – The name of the model package group that the model is going to be registered to. Will be created if it doesn’t already exist.

  • package_group_desc (str) – Description of the model package group, if it was to be created (optional).

  • package_desc (str) – Description of the model package (optional).

  • model_approval (airflow.providers.amazon.aws.utils.sagemaker.ApprovalStatus) – Approval status of the model package. Defaults to PendingManualApproval

  • extras (dict | None) – Can contain extra parameters for the boto call to create_model_package, and/or overrides for any parameter defined above. For a complete list of available parameters, see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_model_package

Return str

Returns the ARN of the model package created.

template_fields: Sequence[str] = ('image_uri', 'model_url', 'package_group_name', 'package_group_desc', 'package_desc', 'model_approval')[source]
execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.amazon.aws.operators.sagemaker.SageMakerAutoMLOperator(*, job_name, s3_input, target_attribute, s3_output, role_arn, compressed_input=False, time_limit=None, autodeploy_endpoint_name=None, extras=None, wait_for_completion=True, check_interval=30, aws_conn_id=DEFAULT_CONN_ID, config=None, **kwargs)[source]

Bases: SageMakerBaseOperator

Creates an auto ML job, learning to predict the given column from the data provided through S3.

The learning output is written to the specified S3 location.

See also

For more information on how to use this operator, take a look at the guide: Launch an AutoML experiment

Parameters
  • job_name (str) – Name of the job to create, needs to be unique within the account.

  • s3_input (str) – The S3 location (folder or file) where to fetch the data. By default, it expects csv with headers.

  • target_attribute (str) – The name of the column containing the values to predict.

  • s3_output (str) – The S3 folder where to write the model artifacts. Must be 128 characters or fewer.

  • role_arn (str) – The ARN of the IAM role to use when interacting with S3. Must have read access to the input, and write access to the output folder.

  • compressed_input (bool) – Set to True if the input is gzipped.

  • time_limit (int | None) – The maximum amount of time in seconds to spend training the model(s).

  • autodeploy_endpoint_name (str | None) – If specified, the best model will be deployed to an endpoint with that name. No deployment made otherwise.

  • extras (dict | None) – Use this dictionary to set any variable input variable for job creation that is not offered through the parameters of this function. The format is described in: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.create_auto_ml_job

  • wait_for_completion (bool) – Whether to wait for the job to finish before returning. Defaults to True.

  • check_interval (int) – Interval in seconds between 2 status checks when waiting for completion.

Returns

Only if waiting for completion, a dictionary detailing the best model. The structure is that of the “BestCandidate” key in: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job

template_fields: Sequence[str] = ('job_name', 's3_input', 'target_attribute', 's3_output', 'role_arn', 'compressed_input',...[source]
execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

class airflow.providers.amazon.aws.operators.sagemaker.SageMakerCreateExperimentOperator(*, name, description=None, tags=None, aws_conn_id=DEFAULT_CONN_ID, **kwargs)[source]

Bases: SageMakerBaseOperator

Creates a SageMaker experiment, to be then associated to jobs etc.

See also

For more information on how to use this operator, take a look at the guide: Create an Experiment for later use

Parameters
  • name (str) – name of the experiment, must be unique within the AWS account

  • description (str | None) – description of the experiment, optional

  • tags (dict | None) – tags to attach to the experiment, optional

  • aws_conn_id (str) – The AWS connection ID to use.

Returns

the ARN of the experiment created, though experiments are referred to by name

template_fields: Sequence[str] = ('name', 'description', 'tags')[source]
execute(context)[source]

This is the main method to derive when creating an operator. Context is the same dictionary used as when rendering jinja templates.

Refer to get_template_context for more context.

Was this entry helpful?