airflow.providers.amazon.aws.hooks.sagemaker

Module Contents

Classes

LogState

Enum-style class holding all possible states of CloudWatch log streams.

SageMakerHook

Interact with Amazon SageMaker.

Functions

argmin(arr, f: Callable) → Optional[int]

Return the index, i, in arr that minimizes f(arr[i])

secondary_training_status_changed(current_job_description: dict, prev_job_description: dict) → bool

Returns true if training job's secondary status message has changed.

secondary_training_status_message(job_description: Dict[str, List[dict]], prev_description: Optional[dict]) → str

Returns a string contains start time and the secondary training job status message.

Attributes

Position

class airflow.providers.amazon.aws.hooks.sagemaker.LogState[source]

Enum-style class holding all possible states of CloudWatch log streams. https://sagemaker.readthedocs.io/en/stable/session.html#sagemaker.session.LogState

STARTING = 1[source]
WAIT_IN_PROGRESS = 2[source]
TAILING = 3[source]
JOB_COMPLETE = 4[source]
COMPLETE = 5[source]
airflow.providers.amazon.aws.hooks.sagemaker.Position[source]
airflow.providers.amazon.aws.hooks.sagemaker.argmin(arr, f: Callable) Optional[int][source]

Return the index, i, in arr that minimizes f(arr[i])

airflow.providers.amazon.aws.hooks.sagemaker.secondary_training_status_changed(current_job_description: dict, prev_job_description: dict) bool[source]

Returns true if training job's secondary status message has changed.

Parameters
  • current_job_description (dict) -- Current job description, returned from DescribeTrainingJob call.

  • prev_job_description (dict) -- Previous job description, returned from DescribeTrainingJob call.

Returns

Whether the secondary status message of a training job changed or not.

airflow.providers.amazon.aws.hooks.sagemaker.secondary_training_status_message(job_description: Dict[str, List[dict]], prev_description: Optional[dict]) str[source]

Returns a string contains start time and the secondary training job status message.

Parameters
  • job_description (dict) -- Returned response from DescribeTrainingJob call

  • prev_description (dict) -- Previous job description from DescribeTrainingJob call

Returns

Job status string to be printed.

class airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook(*args, **kwargs)[source]

Bases: airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook

Interact with Amazon SageMaker.

Additional arguments (such as aws_conn_id) may be specified and are passed down to the underlying AwsBaseHook.

See also

AwsBaseHook

non_terminal_states[source]
endpoint_non_terminal_states[source]
failed_states[source]
tar_and_s3_upload(self, path: str, key: str, bucket: str) None[source]

Tar the local file or directory and upload to s3

Parameters
  • path (str) -- local file or directory

  • key (str) -- s3 key

  • bucket (str) -- s3 bucket

Returns

None

configure_s3_resources(self, config: dict) None[source]

Extract the S3 operations from the configuration and execute them.

Parameters

config (dict) -- config of SageMaker operation

Return type

dict

check_s3_url(self, s3url: str) bool[source]

Check if an S3 URL exists

Parameters

s3url (str) -- S3 url

Return type

bool

check_training_config(self, training_config: dict) None[source]

Check if a training configuration is valid

Parameters

training_config (dict) -- training_config

Returns

None

check_tuning_config(self, tuning_config: dict) None[source]

Check if a tuning configuration is valid

Parameters

tuning_config (dict) -- tuning_config

Returns

None

get_log_conn(self)[source]

This method is deprecated. Please use airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn() instead.

log_stream(self, log_group, stream_name, start_time=0, skip=0)[source]

This method is deprecated. Please use airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events() instead.

multi_stream_iter(self, log_group: str, streams: list, positions=None) Generator[source]

Iterate over the available events coming from a set of log streams in a single log group interleaving the events from each stream so they're yielded in timestamp order.

Parameters
  • log_group (str) -- The name of the log group.

  • streams (list) -- A list of the log stream names. The position of the stream in this list is the stream number.

  • positions (list) -- A list of pairs of (timestamp, skip) which represents the last record read from each stream.

Returns

A tuple of (stream number, cloudwatch log event).

create_training_job(self, config: dict, wait_for_completion: bool = True, print_log: bool = True, check_interval: int = 30, max_ingestion_time: Optional[int] = None)[source]

Create a training job

Parameters
  • config (dict) -- the config for training

  • wait_for_completion (bool) -- if the program should keep running until job finishes

  • check_interval (int) -- the time interval in seconds which the operator will check the status of any SageMaker job

  • max_ingestion_time (int) -- the maximum ingestion time in seconds. Any SageMaker jobs that run longer than this will fail. Setting this to None implies no timeout for any SageMaker job.

Returns

A response to training job creation

create_tuning_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30, max_ingestion_time: Optional[int] = None)[source]

Create a tuning job

Parameters
  • config (dict) -- the config for tuning

  • wait_for_completion (bool) -- if the program should keep running until job finishes

  • check_interval (int) -- the time interval in seconds which the operator will check the status of any SageMaker job

  • max_ingestion_time (int) -- the maximum ingestion time in seconds. Any SageMaker jobs that run longer than this will fail. Setting this to None implies no timeout for any SageMaker job.

Returns

A response to tuning job creation

create_transform_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30, max_ingestion_time: Optional[int] = None)[source]

Create a transform job

Parameters
  • config (dict) -- the config for transform job

  • wait_for_completion (bool) -- if the program should keep running until job finishes

  • check_interval (int) -- the time interval in seconds which the operator will check the status of any SageMaker job

  • max_ingestion_time (int) -- the maximum ingestion time in seconds. Any SageMaker jobs that run longer than this will fail. Setting this to None implies no timeout for any SageMaker job.

Returns

A response to transform job creation

create_processing_job(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30, max_ingestion_time: Optional[int] = None)[source]

Create a processing job

Parameters
  • config (dict) -- the config for processing job

  • wait_for_completion (bool) -- if the program should keep running until job finishes

  • check_interval (int) -- the time interval in seconds which the operator will check the status of any SageMaker job

  • max_ingestion_time (int) -- the maximum ingestion time in seconds. Any SageMaker jobs that run longer than this will fail. Setting this to None implies no timeout for any SageMaker job.

Returns

A response to transform job creation

create_model(self, config: dict)[source]

Create a model job

Parameters

config (dict) -- the config for model

Returns

A response to model creation

create_endpoint_config(self, config: dict)[source]

Create an endpoint config

Parameters

config (dict) -- the config for endpoint-config

Returns

A response to endpoint config creation

create_endpoint(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30, max_ingestion_time: Optional[int] = None)[source]

Create an endpoint

Parameters
  • config (dict) -- the config for endpoint

  • wait_for_completion (bool) -- if the program should keep running until job finishes

  • check_interval (int) -- the time interval in seconds which the operator will check the status of any SageMaker job

  • max_ingestion_time (int) -- the maximum ingestion time in seconds. Any SageMaker jobs that run longer than this will fail. Setting this to None implies no timeout for any SageMaker job.

Returns

A response to endpoint creation

update_endpoint(self, config: dict, wait_for_completion: bool = True, check_interval: int = 30, max_ingestion_time: Optional[int] = None)[source]

Update an endpoint

Parameters
  • config (dict) -- the config for endpoint

  • wait_for_completion (bool) -- if the program should keep running until job finishes

  • check_interval (int) -- the time interval in seconds which the operator will check the status of any SageMaker job

  • max_ingestion_time (int) -- the maximum ingestion time in seconds. Any SageMaker jobs that run longer than this will fail. Setting this to None implies no timeout for any SageMaker job.

Returns

A response to endpoint update

describe_training_job(self, name: str)[source]

Return the training job info associated with the name

Parameters

name (str) -- the name of the training job

Returns

A dict contains all the training job info

describe_training_job_with_log(self, job_name: str, positions, stream_names: list, instance_count: int, state: int, last_description: dict, last_describe_job_call: float)[source]

Return the training job info associated with job_name and print CloudWatch logs

describe_tuning_job(self, name: str) dict[source]

Return the tuning job info associated with the name

Parameters

name (str) -- the name of the tuning job

Returns

A dict contains all the tuning job info

describe_model(self, name: str) dict[source]

Return the SageMaker model info associated with the name

Parameters

name (str) -- the name of the SageMaker model

Returns

A dict contains all the model info

describe_transform_job(self, name: str) dict[source]

Return the transform job info associated with the name

Parameters

name (str) -- the name of the transform job

Returns

A dict contains all the transform job info

describe_processing_job(self, name: str) dict[source]

Return the processing job info associated with the name

Parameters

name (str) -- the name of the processing job

Returns

A dict contains all the processing job info

describe_endpoint_config(self, name: str) dict[source]

Return the endpoint config info associated with the name

Parameters

name (str) -- the name of the endpoint config

Returns

A dict contains all the endpoint config info

describe_endpoint(self, name: str) dict[source]
Parameters

name (str) -- the name of the endpoint

Returns

A dict contains all the endpoint info

check_status(self, job_name: str, key: str, describe_function: Callable, check_interval: int, max_ingestion_time: Optional[int] = None, non_terminal_states: Optional[Set] = None)[source]

Check status of a SageMaker job

Parameters
  • job_name (str) -- name of the job to check status

  • key (str) -- the key of the response dict that points to the state

  • describe_function (python callable) -- the function used to retrieve the status

  • args -- the arguments for the function

  • check_interval (int) -- the time interval in seconds which the operator will check the status of any SageMaker job

  • max_ingestion_time (int) -- the maximum ingestion time in seconds. Any SageMaker jobs that run longer than this will fail. Setting this to None implies no timeout for any SageMaker job.

  • non_terminal_states (set) -- the set of nonterminal states

Returns

response of describe call after job is done

check_training_status_with_log(self, job_name: str, non_terminal_states: set, failed_states: set, wait_for_completion: bool, check_interval: int, max_ingestion_time: Optional[int] = None)[source]

Display the logs for a given training job, optionally tailing them until the job is complete.

Parameters
  • job_name (str) -- name of the training job to check status and display logs for

  • non_terminal_states (set) -- the set of non_terminal states

  • failed_states (set) -- the set of failed states

  • wait_for_completion (bool) -- Whether to keep looking for new log entries until the job completes

  • check_interval (int) -- The interval in seconds between polling for new log entries and job completion

  • max_ingestion_time (int) -- the maximum ingestion time in seconds. Any SageMaker jobs that run longer than this will fail. Setting this to None implies no timeout for any SageMaker job.

Returns

None

list_training_jobs(self, name_contains: Optional[str] = None, max_results: Optional[int] = None, **kwargs) List[Dict][source]

This method wraps boto3's list_training_jobs. The training job name and max results are configurable via arguments. Other arguments are not, and should be provided via kwargs. Note boto3 expects these in CamelCase format, for example:

list_training_jobs(name_contains="myjob", StatusEquals="Failed")
Parameters
  • name_contains -- (optional) partial name to match

  • max_results -- (optional) maximum number of results to return. None returns infinite results

  • kwargs -- (optional) kwargs to boto3's list_training_jobs method

Returns

results of the list_training_jobs request

list_processing_jobs(self, **kwargs) List[Dict][source]

This method wraps boto3's list_processing_jobs. All arguments should be provided via kwargs. Note boto3 expects these in CamelCase format, for example:

list_processing_jobs(NameContains="myjob", StatusEquals="Failed")
Parameters

kwargs -- (optional) kwargs to boto3's list_training_jobs method

Returns

results of the list_processing_jobs request

find_processing_job_by_name(self, processing_job_name: str) bool[source]

Query processing job by name

Was this entry helpful?