Source code for airflow.providers.amazon.aws.operators.sagemaker_base
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.importjsonfromtypingimportIterabletry:fromfunctoolsimportcached_propertyexceptImportError:fromcached_propertyimportcached_propertyfromairflow.modelsimportBaseOperatorfromairflow.providers.amazon.aws.hooks.sagemakerimportSageMakerHook
[docs]classSageMakerBaseOperator(BaseOperator):""" This is the base operator for all SageMaker operators. :param config: The configuration necessary to start a training job (templated) :type config: dict :param aws_conn_id: The AWS connection ID to use. :type aws_conn_id: str """
[docs]defparse_integer(self,config,field):"""Recursive method for parsing string fields holding integer values to integers."""iflen(field)==1:ifisinstance(config,list):forsub_configinconfig:self.parse_integer(sub_config,field)returnhead=field[0]ifheadinconfig:config[head]=int(config[head])returnifisinstance(config,list):forsub_configinconfig:self.parse_integer(sub_config,field)returnhead,tail=field[0],field[1:]ifheadinconfig:self.parse_integer(config[head],tail)return
[docs]defparse_config_integers(self):""" Parse the integer fields of training config to integers in case the config is rendered by Jinja and all fields are str. """forfieldinself.integer_fields:self.parse_integer(self.config,field)
"""Placeholder for calling boto3's `expand_role`, which expands an IAM role name into an ARN."""
[docs]defpreprocess_config(self):"""Process the config into a usable form."""self.log.info('Preprocessing the config and doing required s3_operations')self.hook.configure_s3_resources(self.config)self.parse_config_integers()self.expand_role()self.log.info("After preprocessing the config is:\n%s",json.dumps(self.config,sort_keys=True,indent=4,separators=(",",": ")),
)
[docs]defexecute(self,context):raiseNotImplementedError('Please implement execute() in sub class!')