Source code for airflow.providers.amazon.aws.hooks.base_aws
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""This module contains Base AWS Hook... seealso:: For more information on how to use this hook, take a look at the guide: :ref:`howto/connection:aws`"""from__future__importannotationsimportdatetimeimportinspectimportjsonimportloggingimportosimportwarningsfromcopyimportdeepcopyfromfunctoolsimportcached_property,wrapsfrompathlibimportPathfromtypingimportTYPE_CHECKING,Any,Callable,Generic,TypeVar,Unionimportboto3importbotocoreimportbotocore.sessionimportjinja2importrequestsimporttenacityfrombotocore.configimportConfigfrombotocore.waiterimportWaiter,WaiterModelfromdateutil.tzimporttzlocalfromslugifyimportslugifyfromairflow.configurationimportconffromairflow.exceptionsimport(AirflowException,AirflowNotFoundException,AirflowProviderDeprecationWarning,)fromairflow.hooks.baseimportBaseHookfromairflow.providers.amazon.aws.utils.connection_wrapperimportAwsConnectionWrapperfromairflow.providers.amazon.aws.utils.identifiersimportgenerate_uuidfromairflow.providers.amazon.aws.utils.suppressimportreturn_on_errorfromairflow.providers_managerimportProvidersManagerfromairflow.utils.helpersimportexactly_onefromairflow.utils.log.logging_mixinimportLoggingMixinfromairflow.utils.log.secrets_maskerimportmask_secret
[docs]classBaseSessionFactory(LoggingMixin):"""Base AWS Session Factory class. This handles synchronous and async boto session creation. It can handle most of the AWS supported authentication methods. User can also derive from this class to have full control of boto3 session creation or to support custom federation. .. note:: Not all features implemented for synchronous sessions are available for async sessions. .. seealso:: - :ref:`howto/connection:aws:session-factory` """def__init__(self,conn:Connection|AwsConnectionWrapper|None,region_name:str|None=None,config:Config|None=None,)->None:super().__init__()self._conn=connself._region_name=region_nameself._config=config@cached_property
[docs]defbasic_session(self)->boto3.session.Session:"""Cached property with basic boto3.session.Session."""returnself._create_basic_session(session_kwargs=self.conn.session_kwargs)
[docs]defcreate_session(self,deferrable:bool=False)->boto3.session.Session:"""Create boto3 or aiobotocore Session from connection config."""ifnotself.conn:self.log.info("No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). ""See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html",self.region_name,)ifdeferrable:session=self.get_async_session()self._apply_session_kwargs(session)returnsessionelse:returnboto3.session.Session(region_name=self.region_name)elifnotself.role_arn:ifdeferrable:session=self.get_async_session()self._apply_session_kwargs(session)returnsessionelse:returnself.basic_session# Values stored in ``AwsConnectionWrapper.session_kwargs`` are intended to be used only# to create the initial boto3 session.# If the user wants to use the 'assume_role' mechanism then only the 'region_name' needs to be# provided, otherwise other parameters might conflict with the base botocore session.# Unfortunately it is not a part of public boto3 API, see source of boto3.session.Session:# https://boto3.amazonaws.com/v1/documentation/api/latest/_modules/boto3/session.html#Session# If we provide 'aws_access_key_id' or 'aws_secret_access_key' or 'aws_session_token'# as part of session kwargs it will use them instead of assumed credentials.assume_session_kwargs={}ifself.conn.region_name:assume_session_kwargs["region_name"]=self.conn.region_namereturnself._create_session_with_assume_role(session_kwargs=assume_session_kwargs,deferrable=deferrable)
def_create_basic_session(self,session_kwargs:dict[str,Any])->boto3.session.Session:returnboto3.session.Session(**session_kwargs)def_create_session_with_assume_role(self,session_kwargs:dict[str,Any],deferrable:bool=False)->boto3.session.Session:ifself.conn.assume_role_method=="assume_role_with_web_identity":# Deferred credentials have no initial credentialscredential_fetcher=self._get_web_identity_credential_fetcher()params={"method":"assume-role-with-web-identity","refresh_using":credential_fetcher.fetch_credentials,"time_fetcher":lambda:datetime.datetime.now(tz=tzlocal()),}ifdeferrable:fromaiobotocore.credentialsimportAioDeferredRefreshableCredentialscredentials=AioDeferredRefreshableCredentials(**params)else:credentials=botocore.credentials.DeferredRefreshableCredentials(**params)else:# Refreshable credentials do have initial credentialsparams={"metadata":self._refresh_credentials(),"refresh_using":self._refresh_credentials,"method":"sts-assume-role",}ifdeferrable:fromaiobotocore.credentialsimportAioRefreshableCredentialscredentials=AioRefreshableCredentials.create_from_metadata(**params)else:credentials=botocore.credentials.RefreshableCredentials.create_from_metadata(**params)ifdeferrable:fromaiobotocore.sessionimportget_sessionasasync_get_sessionsession=async_get_session()else:session=botocore.session.get_session()session._credentials=credentialssession.set_config_variable("region",self.basic_session.region_name)returnboto3.session.Session(botocore_session=session,**session_kwargs)def_refresh_credentials(self)->dict[str,Any]:self.log.debug("Refreshing credentials")assume_role_method=self.conn.assume_role_methodifassume_role_methodnotin("assume_role","assume_role_with_saml"):raiseNotImplementedError(f"assume_role_method={assume_role_method} not expected")sts_client=self.basic_session.client("sts",config=self.config,endpoint_url=self.conn.get_service_endpoint_url("sts",sts_connection_assume=True),)ifassume_role_method=="assume_role":sts_response=self._assume_role(sts_client=sts_client)else:sts_response=self._assume_role_with_saml(sts_client=sts_client)sts_response_http_status=sts_response["ResponseMetadata"]["HTTPStatusCode"]ifsts_response_http_status!=200:raiseRuntimeError(f"sts_response_http_status={sts_response_http_status}")credentials=sts_response["Credentials"]expiry_time=credentials.get("Expiration").isoformat()self.log.debug("New credentials expiry_time: %s",expiry_time)credentials={"access_key":credentials.get("AccessKeyId"),"secret_key":credentials.get("SecretAccessKey"),"token":credentials.get("SessionToken"),"expiry_time":expiry_time,}returncredentialsdef_assume_role(self,sts_client:boto3.client)->dict:kw={"RoleSessionName":self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}"),**self.conn.assume_role_kwargs,"RoleArn":self.role_arn,}returnsts_client.assume_role(**kw)def_assume_role_with_saml(self,sts_client:boto3.client)->dict[str,Any]:saml_config=self.extra_config["assume_role_with_saml"]principal_arn=saml_config["principal_arn"]idp_auth_method=saml_config["idp_auth_method"]ifidp_auth_method=="http_spegno_auth":saml_assertion=self._fetch_saml_assertion_using_http_spegno_auth(saml_config)else:raiseNotImplementedError(f"idp_auth_method={idp_auth_method} in Connection {self.conn.conn_id} Extra."'Currently only "http_spegno_auth" is supported, and must be specified.')self.log.debug("Doing sts_client.assume_role_with_saml to role_arn=%s",self.role_arn)returnsts_client.assume_role_with_saml(RoleArn=self.role_arn,PrincipalArn=principal_arn,SAMLAssertion=saml_assertion,**self.conn.assume_role_kwargs,)def_get_idp_response(self,saml_config:dict[str,Any],auth:requests.auth.AuthBase)->requests.models.Response:idp_url=saml_config["idp_url"]self.log.debug("idp_url= %s",idp_url)session=requests.Session()# Configurable Retry when querying the IDP endpointif"idp_request_retry_kwargs"insaml_config:idp_request_retry_kwargs=saml_config["idp_request_retry_kwargs"]self.log.info("idp_request_retry_kwargs= %s",idp_request_retry_kwargs)fromrequests.adaptersimportHTTPAdapterfromrequests.packages.urllib3.util.retryimportRetryretry_strategy=Retry(**idp_request_retry_kwargs)adapter=HTTPAdapter(max_retries=retry_strategy)session.mount("https://",adapter)session.mount("http://",adapter)idp_request_kwargs={}if"idp_request_kwargs"insaml_config:idp_request_kwargs=saml_config["idp_request_kwargs"]idp_response=session.get(idp_url,auth=auth,**idp_request_kwargs)idp_response.raise_for_status()returnidp_responsedef_fetch_saml_assertion_using_http_spegno_auth(self,saml_config:dict[str,Any])->str:# requests_gssapi will need paramiko > 2.6 since you'll need# 'gssapi' not 'python-gssapi' from PyPi.# https://github.com/paramiko/paramiko/pull/1311importrequests_gssapifromlxmlimportetreeauth=requests_gssapi.HTTPSPNEGOAuth()if"mutual_authentication"insaml_config:mutual_auth=saml_config["mutual_authentication"]ifmutual_auth=="REQUIRED":auth=requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.REQUIRED)elifmutual_auth=="OPTIONAL":auth=requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.OPTIONAL)elifmutual_auth=="DISABLED":auth=requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.DISABLED)else:raiseNotImplementedError(f"mutual_authentication={mutual_auth} in Connection {self.conn.conn_id} Extra."'Currently "REQUIRED", "OPTIONAL" and "DISABLED" are supported.'"(Exclude this setting will default to HTTPSPNEGOAuth() ).")# Query the IDPidp_response=self._get_idp_response(saml_config,auth=auth)# Assist with debugging. Note: contains sensitive info!xpath=saml_config["saml_response_xpath"]log_idp_response="log_idp_response"insaml_configandsaml_config["log_idp_response"]iflog_idp_response:self.log.warning("The IDP response contains sensitive information, but log_idp_response is ON (%s).",log_idp_response,)self.log.debug("idp_response.content= %s",idp_response.content)self.log.debug("xpath= %s",xpath)# Extract SAML Assertion from the returned HTML / XMLxml=etree.fromstring(idp_response.content)saml_assertion=xml.xpath(xpath)ifisinstance(saml_assertion,list):iflen(saml_assertion)==1:saml_assertion=saml_assertion[0]ifnotsaml_assertion:raiseValueError("Invalid SAML Assertion")returnsaml_assertiondef_get_web_identity_credential_fetcher(self,)->botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher:base_session=self.basic_session._sessionorbotocore.session.get_session()client_creator=base_session.create_clientfederation=str(self.extra_config.get("assume_role_with_web_identity_federation"))web_identity_token_loader={"file":self._get_file_token_loader,"google":self._get_google_identity_token_loader,}.get(federation)ifnotweb_identity_token_loader:raiseAirflowException(f"Unsupported federation: {federation}.")returnbotocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher(client_creator=client_creator,web_identity_token_loader=web_identity_token_loader(),role_arn=self.role_arn,extra_args=self.conn.assume_role_kwargs,)def_get_file_token_loader(self):frombotocore.credentialsimportFileWebIdentityTokenLoadertoken_file=self.extra_config.get("assume_role_with_web_identity_token_file")oros.getenv("AWS_WEB_IDENTITY_TOKEN_FILE")returnFileWebIdentityTokenLoader(token_file)def_get_google_identity_token_loader(self):fromgoogle.auth.transportimportrequestsasrequests_transportfromairflow.providers.google.common.utils.id_token_credentialsimport(get_default_id_token_credentials,)audience=self.extra_config.get("assume_role_with_web_identity_federation_audience")google_id_token_credentials=get_default_id_token_credentials(target_audience=audience)defweb_identity_token_loader():ifnotgoogle_id_token_credentials.valid:request_adapter=requests_transport.Request()google_id_token_credentials.refresh(request=request_adapter)returngoogle_id_token_credentials.tokenreturnweb_identity_token_loaderdef_strip_invalid_session_name_characters(self,role_session_name:str)->str:returnslugify(role_session_name,regex_pattern=r"[^\w+=,.@-]+")
[docs]classAwsGenericHook(BaseHook,Generic[BaseAwsConnection]):"""Generic class for interact with AWS. This class provide a thin wrapper around the boto3 Python library. :param aws_conn_id: The Airflow connection used for AWS credentials. If this is None or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). :param verify: Whether or not to verify SSL certificates. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. :param client_type: Reference to :external:py:meth:`boto3.client service_name \ <boto3.session.Session.client>`, e.g. 'emr', 'batch', 's3', etc. Mutually exclusive with ``resource_type``. :param resource_type: Reference to :external:py:meth:`boto3.resource service_name \ <boto3.session.Session.resource>`, e.g. 's3', 'ec2', 'dynamodb', etc. Mutually exclusive with ``client_type``. :param config: Configuration for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """
def__init__(self,aws_conn_id:str|None=default_conn_name,verify:bool|str|None=None,region_name:str|None=None,client_type:str|None=None,resource_type:str|None=None,config:Config|dict[str,Any]|None=None,)->None:super().__init__()self.aws_conn_id=aws_conn_idself.client_type=client_typeself.resource_type=resource_typeself._region_name=region_nameifisinstance(config,dict):config=Config(**config)self._config=configself._verify=verify@classmethod@return_on_error("Unknown")def_get_provider_version(cls)->str:"""Check the Providers Manager for the package version."""manager=ProvidersManager()hook=manager.hooks[cls.conn_type]ifnothook:# This gets caught immediately, but without it MyPy complains# Item "None" of "Optional[HookInfo]" has no attribute "package_name"# on the following line and static checks fail.raiseValueError(f"Hook info for {cls.conn_type} not found in the Provider Manager.")returnmanager.providers[hook.package_name].version@staticmethoddef_find_class_name(target_function_name:str)->str:"""Given a frame off the stack, return the name of the class that made the call. This method may raise a ValueError or an IndexError. The caller is responsible with catching and handling those. """stack=inspect.stack()# Find the index of the most recent frame which called the provided function name# and pull that frame off the stack.target_frame=next(frameforframeinstackifframe.function==target_function_name)[0]# Get the local variables for that frame.frame_variables=target_frame.f_locals["self"]# Get the class object for that frame.frame_class_object=frame_variables.__class__# Return the name of the class object.returnframe_class_object.__name__@return_on_error("Unknown")def_get_caller(self,target_function_name:str="execute")->str:"""Given a function name, walk the stack and return the name of the class which called it last."""caller=self._find_class_name(target_function_name)ifcaller=="BaseSensorOperator":# If the result is a BaseSensorOperator, then look for whatever last called "poke".returnself._get_caller("poke")returncaller@staticmethod@return_on_error("00000000-0000-0000-0000-000000000000")def_generate_dag_key()->str:"""Generate a DAG key. The Object Identifier (OID) namespace is used to salt the dag_id value. That salted value is used to generate a SHA-1 hash which, by definition, can not (reasonably) be reversed. No personal data can be inferred or extracted from the resulting UUID. """returngenerate_uuid(os.environ.get("AIRFLOW_CTX_DAG_ID"))@staticmethod@return_on_error("Unknown")def_get_airflow_version()->str:"""Fetch and return the current Airflow version."""# This can be a circular import under specific configurations.# Importing locally to either avoid or catch it if it does happen.fromairflowimport__version__asairflow_versionreturnairflow_versiondef_generate_user_agent_extra_field(self,existing_user_agent_extra:str)->str:user_agent_extra_values=[f"Airflow/{self._get_airflow_version()}",f"AmPP/{self._get_provider_version()}",f"Caller/{self._get_caller()}",f"DagRunKey/{self._generate_dag_key()}",existing_user_agent_extraor"",]return" ".join(user_agent_extra_values).strip()@cached_property
[docs]defconn_config(self)->AwsConnectionWrapper:"""Get the Airflow Connection object and wrap it in helper (cached)."""connection=Noneifself.aws_conn_id:try:connection=self.get_connection(self.aws_conn_id)exceptAirflowNotFoundException:self.log.warning("Unable to find AWS Connection ID '%s', switching to empty.",self.aws_conn_id)returnAwsConnectionWrapper(conn=connection,region_name=self._region_name,botocore_config=self._config,verify=self._verify)
def_resolve_service_name(self,is_resource_type:bool=False)->str:"""Resolve service name based on type or raise an error."""ifexactly_one(self.client_type,self.resource_type):# It is possible to write simple conditions, however it make mypy unhappy.ifself.client_type:ifis_resource_type:raiseLookupError("Requested `resource_type`, but `client_type` was set instead.")returnself.client_typeelifself.resource_type:ifnotis_resource_type:raiseLookupError("Requested `client_type`, but `resource_type` was set instead.")returnself.resource_typeraiseValueError(f"Either client_type={self.client_type!r} or "f"resource_type={self.resource_type!r} must be provided, not both.")@property
[docs]defservice_name(self)->str:"""Extracted botocore/boto3 service name from hook parameters."""returnself._resolve_service_name(is_resource_type=bool(self.resource_type))
@property
[docs]defservice_config(self)->dict:"""Config for hook-specific service from AWS Connection."""returnself.conn_config.get_service_config(service_name=self.service_name)
@property
[docs]defregion_name(self)->str|None:"""AWS Region Name read-only property."""returnself.conn_config.region_name
@property
[docs]defconfig(self)->Config:"""Configuration for botocore client read-only property."""returnself.conn_config.botocore_configorbotocore.config.Config()
@property
[docs]defverify(self)->bool|str|None:"""Verify or not SSL certificates boto3 client/resource read-only property."""returnself.conn_config.verify
[docs]defget_session(self,region_name:str|None=None,deferrable:bool=False)->boto3.session.Session:"""Get the underlying boto3.session.Session(region_name=region_name)."""returnSessionFactory(conn=self.conn_config,region_name=region_name,config=self.config).create_session(deferrable=deferrable)
def_get_config(self,config:Config|None=None)->Config:""" No AWS Operators use the config argument to this method. Keep backward compatibility with other users who might use it. """ifconfigisNone:config=deepcopy(self.config)# ignore[union-attr] is required for this block to appease MyPy# because the user_agent_extra field is generated at runtime.user_agent_config=Config(user_agent_extra=self._generate_user_agent_extra_field(existing_user_agent_extra=config.user_agent_extra# type: ignore[union-attr]))returnconfig.merge(user_agent_config)# type: ignore[union-attr]
[docs]defget_client_type(self,region_name:str|None=None,config:Config|None=None,deferrable:bool=False,)->boto3.client:"""Get the underlying boto3 client using boto3 session."""service_name=self._resolve_service_name(is_resource_type=False)session=self.get_session(region_name=region_name,deferrable=deferrable)endpoint_url=self.conn_config.get_service_endpoint_url(service_name=service_name)ifnotisinstance(session,boto3.session.Session):returnsession.create_client(service_name=service_name,endpoint_url=endpoint_url,config=self._get_config(config),verify=self.verify,)returnsession.client(service_name=service_name,endpoint_url=endpoint_url,config=self._get_config(config),verify=self.verify,)
[docs]defget_resource_type(self,region_name:str|None=None,config:Config|None=None,)->boto3.resource:"""Get the underlying boto3 resource using boto3 session."""service_name=self._resolve_service_name(is_resource_type=True)session=self.get_session(region_name=region_name)returnsession.resource(service_name=service_name,endpoint_url=self.conn_config.get_service_endpoint_url(service_name=service_name),config=self._get_config(config),verify=self.verify,)
@cached_property
[docs]defconn(self)->BaseAwsConnection:""" Get the underlying boto3 client/resource (cached). :return: boto3.client or boto3.resource """ifself.client_type:returnself.get_client_type(region_name=self.region_name)returnself.get_resource_type(region_name=self.region_name)
@property
[docs]defasync_conn(self):"""Get an aiobotocore client to use for async operations."""ifnotself.client_type:raiseValueError("client_type must be specified.")returnself.get_client_type(region_name=self.region_name,deferrable=True)
[docs]defconn_client_meta(self)->ClientMeta:"""Get botocore client metadata from Hook connection (cached)."""returnself._client.meta
@property
[docs]defconn_region_name(self)->str:"""Get actual AWS Region Name from Hook connection (cached)."""returnself.conn_client_meta.region_name
@property
[docs]defconn_partition(self)->str:"""Get associated AWS Region Partition from Hook connection (cached)."""returnself.conn_client_meta.partition
[docs]defget_conn(self)->BaseAwsConnection:""" Get the underlying boto3 client/resource (cached). Implemented so that caching works as intended. It exists for compatibility with subclasses that rely on a super().get_conn() method. :return: boto3.client or boto3.resource """# Compat shimreturnself.conn
[docs]defget_credentials(self,region_name:str|None=None)->ReadOnlyCredentials:""" Get the underlying `botocore.Credentials` object. This contains the following authentication attributes: access_key, secret_key and token. By use this method also secret_key and token will mask in tasks logs. """# Credentials are refreshable, so accessing your access key and# secret key separately can lead to a race condition.# See https://stackoverflow.com/a/36291428/8283373creds=self.get_session(region_name=region_name).get_credentials().get_frozen_credentials()mask_secret(creds.secret_key)ifcreds.token:mask_secret(creds.token)returncreds
[docs]defexpand_role(self,role:str,region_name:str|None=None)->str:"""Get the Amazon Resource Name (ARN) for the role. If IAM role is already an IAM role ARN, the value is returned unchanged. :param role: IAM role name or ARN :param region_name: Optional region name to get credentials for :return: IAM role ARN """if"/"inrole:returnroleelse:session=self.get_session(region_name=region_name)_client=session.client(service_name="iam",endpoint_url=self.conn_config.get_service_endpoint_url("iam"),config=self.config,verify=self.verify,)return_client.get_role(RoleName=role)["Role"]["Arn"]
@staticmethod
[docs]defretry(should_retry:Callable[[Exception],bool]):"""Repeat requests in response to exceeding a temporary quote limit."""defretry_decorator(fun:Callable):@wraps(fun)defdecorator_f(self,*args,**kwargs):retry_args=getattr(self,"retry_args",None)ifretry_argsisNone:returnfun(self,*args,**kwargs)multiplier=retry_args.get("multiplier",1)min_limit=retry_args.get("min",1)max_limit=retry_args.get("max",1)stop_after_delay=retry_args.get("stop_after_delay",10)tenacity_before_logger=tenacity.before_log(self.log,logging.INFO)ifself.logelseNonetenacity_after_logger=tenacity.after_log(self.log,logging.INFO)ifself.logelseNonedefault_kwargs={"wait":tenacity.wait_exponential(multiplier=multiplier,max=max_limit,min=min_limit),"retry":tenacity.retry_if_exception(should_retry),"stop":tenacity.stop_after_delay(stop_after_delay),"before":tenacity_before_logger,"after":tenacity_after_logger,}returntenacity.retry(**default_kwargs)(fun)(self,*args,**kwargs)returndecorator_freturnretry_decorator
@staticmethod
[docs]defget_ui_field_behaviour()->dict[str,Any]:"""Return custom UI field behaviour for AWS Connection."""return{"hidden_fields":["host","schema","port"],"relabeling":{"login":"AWS Access Key ID","password":"AWS Secret Access Key",},"placeholders":{"login":"AKIAIOSFODNN7EXAMPLE","password":"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY","extra":json.dumps({"region_name":"us-east-1","session_kwargs":{"profile_name":"default"},"config_kwargs":{"retries":{"mode":"standard","max_attempts":10}},"role_arn":"arn:aws:iam::123456789098:role/role-name","assume_role_method":"assume_role","assume_role_kwargs":{"RoleSessionName":"airflow"},"aws_session_token":"AQoDYXdzEJr...EXAMPLETOKEN","endpoint_url":"http://localhost:4566",},indent=2,),},}
[docs]deftest_connection(self):"""Test the AWS connection by call AWS STS (Security Token Service) GetCallerIdentity API. .. seealso:: https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html """try:session=self.get_session()conn_info=session.client(service_name="sts",endpoint_url=self.conn_config.get_service_endpoint_url("sts",sts_test_connection=True),).get_caller_identity()metadata=conn_info.pop("ResponseMetadata",{})ifmetadata.get("HTTPStatusCode")!=200:try:returnFalse,json.dumps(metadata)exceptTypeError:returnFalse,str(metadata)conn_info["credentials_method"]=session.get_credentials().methodconn_info["region_name"]=session.region_namereturnTrue,", ".join(f"{k}={v!r}"fork,vinconn_info.items())exceptExceptionase:returnFalse,f"{type(e).__name__!r} error occurred while testing connection: {e}"
[docs]defget_waiter(self,waiter_name:str,parameters:dict[str,str]|None=None,deferrable:bool=False,client=None,)->Waiter:"""Get a waiter by name. First checks if there is a custom waiter with the provided waiter_name and uses that if it exists, otherwise it will check the service client for a waiter that matches the name and pass that through. If `deferrable` is True, the waiter will be an AIOWaiter, generated from the client that is passed as a parameter. If `deferrable` is True, `client` must be provided. :param waiter_name: The name of the waiter. The name should exactly match the name of the key in the waiter model file (typically this is CamelCase). :param parameters: will scan the waiter config for the keys of that dict, and replace them with the corresponding value. If a custom waiter has such keys to be expanded, they need to be provided here. :param deferrable: If True, the waiter is going to be an async custom waiter. An async client must be provided in that case. :param client: The client to use for the waiter's operations """fromairflow.providers.amazon.aws.waiters.base_waiterimportBaseBotoWaiterifdeferrableandnotclient:raiseValueError("client must be provided for a deferrable waiter.")# Currently, the custom waiter doesn't work with resource_type, only client_type is supported.client=clientorself._clientifself.waiter_pathand(waiter_nameinself._list_custom_waiters()):# Technically if waiter_name is in custom_waiters then self.waiter_path must# exist but MyPy doesn't like the fact that self.waiter_path could be None.withopen(self.waiter_path)asconfig_file:config=json.loads(config_file.read())config=self._apply_parameters_value(config,waiter_name,parameters)returnBaseBotoWaiter(client=client,model_config=config,deferrable=deferrable).waiter(waiter_name)# If there is no custom waiter found for the provided name,# then try checking the service's official waiters.returnclient.get_waiter(waiter_name)
@staticmethoddef_apply_parameters_value(config:dict,waiter_name:str,parameters:dict[str,str]|None)->dict:"""Replace potential jinja templates in acceptors definition."""# only process the waiter we're going to use to not raise errors for missing params for other waiters.acceptors=config["waiters"][waiter_name]["acceptors"]forainacceptors:arg=a["argument"]template=jinja2.Template(arg,autoescape=False,undefined=jinja2.StrictUndefined)try:a["argument"]=template.render(parametersor{})exceptjinja2.UndefinedErrorase:raiseAirflowException(f"Parameter was not supplied for templated waiter's acceptor '{arg}'",e)returnconfig
[docs]deflist_waiters(self)->list[str]:"""Return a list containing the names of all waiters for the service, official and custom."""return[*self._list_official_waiters(),*self._list_custom_waiters()]
[docs]classAwsBaseHook(AwsGenericHook[Union[boto3.client,boto3.resource]]):"""Base class for interact with AWS. This class provide a thin wrapper around the boto3 Python library. :param aws_conn_id: The Airflow connection used for AWS credentials. If this is None or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). :param verify: Whether or not to verify SSL certificates. See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. :param client_type: Reference to :external:py:meth:`boto3.client service_name \ <boto3.session.Session.client>`, e.g. 'emr', 'batch', 's3', etc. Mutually exclusive with ``resource_type``. :param resource_type: Reference to :external:py:meth:`boto3.resource service_name \ <boto3.session.Session.resource>`, e.g. 's3', 'ec2', 'dynamodb', etc. Mutually exclusive with ``client_type``. :param config: Configuration for botocore client. See: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """
[docs]defresolve_session_factory()->type[BaseSessionFactory]:"""Resolve custom SessionFactory class."""clazz=conf.getimport("aws","session_factory",fallback=None)ifnotclazz:returnBaseSessionFactoryifnotissubclass(clazz,BaseSessionFactory):raiseTypeError(f"Your custom AWS SessionFactory class `{clazz.__name__}` is not a subclass "f"of `{BaseSessionFactory.__name__}`.")returnclazz
def_parse_s3_config(config_file_name:str,config_format:str|None="boto",profile:str|None=None):"""For compatibility with airflow.contrib.hooks.aws_hook."""fromairflow.providers.amazon.aws.utils.connection_wrapperimport_parse_s3_configreturn_parse_s3_config(config_file_name=config_file_name,config_format=config_format,profile=profile,)try:importaiobotocore.credentialsfromaiobotocore.sessionimportAioSession,get_sessionexceptImportError:pass
[docs]classBaseAsyncSessionFactory(BaseSessionFactory):""" Base AWS Session Factory class to handle aiobotocore session creation. It currently, handles ENV, AWS secret key and STS client method ``assume_role`` provided in Airflow connection """def__init__(self,*args,**kwargs):warnings.warn("airflow.providers.amazon.aws.hook.base_aws.BaseAsyncSessionFactory has been deprecated and ""will be removed in future",AirflowProviderDeprecationWarning,stacklevel=2,)super().__init__(*args,**kwargs)
[docs]asyncdefget_role_credentials(self)->dict:"""Get the role_arn, method credentials from connection and get the role credentials."""asyncwithself._basic_session.create_client("sts",region_name=self.region_name)asclient:response=awaitclient.assume_role(RoleArn=self.role_arn,RoleSessionName=self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}"),**self.conn.assume_role_kwargs,)returnresponse["Credentials"]
asyncdef_get_refresh_credentials(self)->dict[str,Any]:self.log.debug("Refreshing credentials")assume_role_method=self.conn.assume_role_methodifassume_role_method!="assume_role":raiseNotImplementedError(f"assume_role_method={assume_role_method} not expected")credentials=awaitself.get_role_credentials()expiry_time=credentials["Expiration"].isoformat()self.log.debug("New credentials expiry_time: %s",expiry_time)credentials={"access_key":credentials.get("AccessKeyId"),"secret_key":credentials.get("SecretAccessKey"),"token":credentials.get("SessionToken"),"expiry_time":expiry_time,}returncredentialsdef_get_session_with_assume_role(self)->AioSession:assume_role_method=self.conn.assume_role_methodifassume_role_method!="assume_role":raiseNotImplementedError(f"assume_role_method={assume_role_method} not expected")credentials=aiobotocore.credentials.AioRefreshableCredentials.create_from_metadata(metadata=self._get_refresh_credentials(),refresh_using=self._get_refresh_credentials,method="sts-assume-role",)session=aiobotocore.session.get_session()session._credentials=credentialsreturnsession@cached_propertydef_basic_session(self)->AioSession:"""Cached property with basic aiobotocore.session.AioSession."""session_kwargs=self.conn.session_kwargsaws_access_key_id=session_kwargs.get("aws_access_key_id")aws_secret_access_key=session_kwargs.get("aws_secret_access_key")aws_session_token=session_kwargs.get("aws_session_token")region_name=session_kwargs.get("region_name")profile_name=session_kwargs.get("profile_name")aio_session=get_session()ifprofile_nameisnotNone:aio_session.set_config_variable("profile",profile_name)ifaws_access_key_idoraws_secret_access_keyoraws_session_token:aio_session.set_credentials(access_key=aws_access_key_id,secret_key=aws_secret_access_key,token=aws_session_token,)ifregion_nameisnotNone:aio_session.set_config_variable("region",region_name)returnaio_session
[docs]defcreate_session(self,deferrable:bool=False)->AioSession:"""Create aiobotocore Session from connection and config."""ifnotself._conn:self.log.info("No connection ID provided. Fallback on boto3 credential strategy")returnget_session()elifnotself.role_arn:returnself._basic_sessionreturnself._get_session_with_assume_role()
[docs]classAwsBaseAsyncHook(AwsBaseHook):"""Interacts with AWS using aiobotocore asynchronously. :param aws_conn_id: The Airflow connection used for AWS credentials. If this is None or empty then the default botocore behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default botocore configuration would be used (and must be maintained on each worker node). :param verify: Whether to verify SSL certificates. :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. :param client_type: boto3.client client_type. Eg 's3', 'emr' etc :param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc :param config: Configuration for botocore client. """def__init__(self,*args,**kwargs):warnings.warn("airflow.providers.amazon.aws.hook.base_aws.AwsBaseAsyncHook has been deprecated and ""will be removed in future",AirflowProviderDeprecationWarning,stacklevel=2,)super().__init__(*args,**kwargs)
[docs]defget_async_session(self)->AioSession:"""Get the underlying aiobotocore.session.AioSession(...)."""returnBaseAsyncSessionFactory(conn=self.conn_config,region_name=self.region_name,config=self.config).create_session()
[docs]asyncdefget_client_async(self):"""Get the underlying aiobotocore client using aiobotocore session."""returnself.get_async_session().create_client(self.client_type,region_name=self.region_name,verify=self.verify,endpoint_url=self.conn_config.endpoint_url,config=self.config,)