#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This module contains Base AWS Hook.
.. seealso::
For more information on how to use this hook, take a look at the guide:
:ref:`howto/connection:AWSHook`
"""
import configparser
import datetime
import logging
from typing import Any, Dict, Optional, Tuple, Union
import boto3
import botocore
import botocore.session
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
from cached_property import cached_property
from dateutil.tz import tzlocal
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection
from airflow.utils.log.logging_mixin import LoggingMixin
[docs]class _SessionFactory(LoggingMixin):
def __init__(self, conn: Connection, region_name: Optional[str], config: Config) -> None:
super().__init__()
self.conn = conn
self.region_name = region_name
self.config = config
self.extra_config = self.conn.extra_dejson
[docs] def create_session(self) -> boto3.session.Session:
"""Create AWS session."""
session_kwargs = {}
if "session_kwargs" in self.extra_config:
self.log.info(
"Retrieving session_kwargs from Connection.extra_config['session_kwargs']: %s",
self.extra_config["session_kwargs"],
)
session_kwargs = self.extra_config["session_kwargs"]
session = self._create_basic_session(session_kwargs=session_kwargs)
role_arn = self._read_role_arn_from_extra_config()
# If role_arn was specified then STS + assume_role
if role_arn is None:
return session
return self._impersonate_to_role(role_arn=role_arn, session=session, session_kwargs=session_kwargs)
[docs] def _create_basic_session(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
aws_access_key_id, aws_secret_access_key = self._read_credentials_from_connection()
aws_session_token = self.extra_config.get("aws_session_token")
region_name = self.region_name
if self.region_name is None and 'region_name' in self.extra_config:
self.log.info("Retrieving region_name from Connection.extra_config['region_name']")
region_name = self.extra_config["region_name"]
self.log.info(
"Creating session with aws_access_key_id=%s region_name=%s",
aws_access_key_id,
region_name,
)
return boto3.session.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
aws_session_token=aws_session_token,
**session_kwargs,
)
[docs] def _impersonate_to_role(
self, role_arn: str, session: boto3.session.Session, session_kwargs: Dict[str, Any]
) -> boto3.session.Session:
assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
assume_role_method = self.extra_config.get('assume_role_method')
self.log.info("assume_role_method=%s", assume_role_method)
if not assume_role_method or assume_role_method == 'assume_role':
sts_client = session.client("sts", config=self.config)
sts_response = self._assume_role(
sts_client=sts_client, role_arn=role_arn, assume_role_kwargs=assume_role_kwargs
)
elif assume_role_method == 'assume_role_with_saml':
sts_client = session.client("sts", config=self.config)
sts_response = self._assume_role_with_saml(
sts_client=sts_client, role_arn=role_arn, assume_role_kwargs=assume_role_kwargs
)
elif assume_role_method == 'assume_role_with_web_identity':
botocore_session = self._assume_role_with_web_identity(
role_arn=role_arn,
assume_role_kwargs=assume_role_kwargs,
base_session=session._session, # pylint: disable=protected-access
)
return boto3.session.Session(
region_name=session.region_name,
botocore_session=botocore_session,
**session_kwargs,
)
else:
raise NotImplementedError(
f'assume_role_method={assume_role_method} in Connection {self.conn.conn_id} Extra.'
'Currently "assume_role" or "assume_role_with_saml" are supported.'
'(Exclude this setting will default to "assume_role").'
)
# Use credentials retrieved from STS
credentials = sts_response["Credentials"]
aws_access_key_id = credentials["AccessKeyId"]
aws_secret_access_key = credentials["SecretAccessKey"]
aws_session_token = credentials["SessionToken"]
self.log.info(
"Creating session with aws_access_key_id=%s region_name=%s",
aws_access_key_id,
session.region_name,
)
return boto3.session.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=session.region_name,
aws_session_token=aws_session_token,
**session_kwargs,
)
[docs] def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]:
aws_access_key_id = None
aws_secret_access_key = None
if self.conn.login:
aws_access_key_id = self.conn.login
aws_secret_access_key = self.conn.password
self.log.info("Credentials retrieved from login")
elif "aws_access_key_id" in self.extra_config and "aws_secret_access_key" in self.extra_config:
aws_access_key_id = self.extra_config["aws_access_key_id"]
aws_secret_access_key = self.extra_config["aws_secret_access_key"]
self.log.info("Credentials retrieved from extra_config")
elif "s3_config_file" in self.extra_config:
aws_access_key_id, aws_secret_access_key = _parse_s3_config(
self.extra_config["s3_config_file"],
self.extra_config.get("s3_config_format"),
self.extra_config.get("profile"),
)
self.log.info("Credentials retrieved from extra_config['s3_config_file']")
else:
self.log.info("No credentials retrieved from Connection")
return aws_access_key_id, aws_secret_access_key
[docs] def _assume_role(
self, sts_client: boto3.client, role_arn: str, assume_role_kwargs: Dict[str, Any]
) -> Dict:
if "external_id" in self.extra_config: # Backwards compatibility
assume_role_kwargs["ExternalId"] = self.extra_config.get("external_id")
role_session_name = f"Airflow_{self.conn.conn_id}"
self.log.info(
"Doing sts_client.assume_role to role_arn=%s (role_session_name=%s)",
role_arn,
role_session_name,
)
return sts_client.assume_role(
RoleArn=role_arn, RoleSessionName=role_session_name, **assume_role_kwargs
)
[docs] def _assume_role_with_saml(
self, sts_client: boto3.client, role_arn: str, assume_role_kwargs: Dict[str, Any]
) -> Dict[str, Any]:
saml_config = self.extra_config['assume_role_with_saml']
principal_arn = saml_config['principal_arn']
idp_auth_method = saml_config['idp_auth_method']
if idp_auth_method == 'http_spegno_auth':
saml_assertion = self._fetch_saml_assertion_using_http_spegno_auth(saml_config)
else:
raise NotImplementedError(
f'idp_auth_method={idp_auth_method} in Connection {self.conn.conn_id} Extra.'
'Currently only "http_spegno_auth" is supported, and must be specified.'
)
self.log.info("Doing sts_client.assume_role_with_saml to role_arn=%s", role_arn)
return sts_client.assume_role_with_saml(
RoleArn=role_arn, PrincipalArn=principal_arn, SAMLAssertion=saml_assertion, **assume_role_kwargs
)
[docs] def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]) -> str:
import requests
# requests_gssapi will need paramiko > 2.6 since you'll need
# 'gssapi' not 'python-gssapi' from PyPi.
# https://github.com/paramiko/paramiko/pull/1311
import requests_gssapi
from lxml import etree
idp_url = saml_config["idp_url"]
self.log.info("idp_url= %s", idp_url)
idp_request_kwargs = saml_config["idp_request_kwargs"]
auth = requests_gssapi.HTTPSPNEGOAuth()
if 'mutual_authentication' in saml_config:
mutual_auth = saml_config['mutual_authentication']
if mutual_auth == 'REQUIRED':
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.REQUIRED)
elif mutual_auth == 'OPTIONAL':
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.OPTIONAL)
elif mutual_auth == 'DISABLED':
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.DISABLED)
else:
raise NotImplementedError(
f'mutual_authentication={mutual_auth} in Connection {self.conn.conn_id} Extra.'
'Currently "REQUIRED", "OPTIONAL" and "DISABLED" are supported.'
'(Exclude this setting will default to HTTPSPNEGOAuth() ).'
)
# Query the IDP
idp_response = requests.get(idp_url, auth=auth, **idp_request_kwargs)
idp_response.raise_for_status()
# Assist with debugging. Note: contains sensitive info!
xpath = saml_config['saml_response_xpath']
log_idp_response = 'log_idp_response' in saml_config and saml_config['log_idp_response']
if log_idp_response:
self.log.warning(
'The IDP response contains sensitive information, but log_idp_response is ON (%s).',
log_idp_response,
)
self.log.info('idp_response.content= %s', idp_response.content)
self.log.info('xpath= %s', xpath)
# Extract SAML Assertion from the returned HTML / XML
xml = etree.fromstring(idp_response.content)
saml_assertion = xml.xpath(xpath)
if isinstance(saml_assertion, list):
if len(saml_assertion) == 1:
saml_assertion = saml_assertion[0]
if not saml_assertion:
raise ValueError('Invalid SAML Assertion')
return saml_assertion
[docs] def _assume_role_with_web_identity(self, role_arn, assume_role_kwargs, base_session):
base_session = base_session or botocore.session.get_session()
client_creator = base_session.create_client
federation = self.extra_config.get('assume_role_with_web_identity_federation')
if federation == 'google':
web_identity_token_loader = self._get_google_identity_token_loader()
else:
raise AirflowException(
f'Unsupported federation: {federation}. Currently "google" only are supported.'
)
fetcher = botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher(
client_creator=client_creator,
web_identity_token_loader=web_identity_token_loader,
role_arn=role_arn,
extra_args=assume_role_kwargs or {},
)
aws_creds = botocore.credentials.DeferredRefreshableCredentials(
method='assume-role-with-web-identity',
refresh_using=fetcher.fetch_credentials,
time_fetcher=lambda: datetime.datetime.now(tz=tzlocal()),
)
botocore_session = botocore.session.Session()
botocore_session._credentials = aws_creds # pylint: disable=protected-access
return botocore_session
[docs] def _get_google_identity_token_loader(self):
from google.auth.transport import requests as requests_transport
from airflow.providers.google.common.utils.id_token_credentials import (
get_default_id_token_credentials,
)
audience = self.extra_config.get('assume_role_with_web_identity_federation_audience')
google_id_token_credentials = get_default_id_token_credentials(target_audience=audience)
def web_identity_token_loader():
if not google_id_token_credentials.valid:
request_adapter = requests_transport.Request()
google_id_token_credentials.refresh(request=request_adapter)
return google_id_token_credentials.token
return web_identity_token_loader
[docs]class AwsBaseHook(BaseHook):
"""
Interact with AWS.
This class is a thin wrapper around the boto3 python library.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:type aws_conn_id: str
:param verify: Whether or not to verify SSL certificates.
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:type verify: Union[bool, str, None]
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:type region_name: Optional[str]
:param client_type: boto3.client client_type. Eg 's3', 'emr' etc
:type client_type: Optional[str]
:param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc
:type resource_type: Optional[str]
:param config: Configuration for botocore client.
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html)
:type config: Optional[botocore.client.Config]
"""
[docs] conn_name_attr = 'aws_conn_id'
[docs] default_conn_name = 'aws_default'
[docs] hook_name = 'Amazon Web Services'
def __init__(
self,
aws_conn_id: Optional[str] = default_conn_name,
verify: Union[bool, str, None] = None,
region_name: Optional[str] = None,
client_type: Optional[str] = None,
resource_type: Optional[str] = None,
config: Optional[Config] = None,
) -> None:
super().__init__()
self.aws_conn_id = aws_conn_id
self.verify = verify
self.client_type = client_type
self.resource_type = resource_type
self.region_name = region_name
self.config = config
if not (self.client_type or self.resource_type):
raise AirflowException('Either client_type or resource_type must be provided.')
[docs] def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
if not self.aws_conn_id:
session = boto3.session.Session(region_name=region_name)
return session, None
self.log.info("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
try:
# Fetch the Airflow connection object
connection_object = self.get_connection(self.aws_conn_id)
extra_config = connection_object.extra_dejson
endpoint_url = extra_config.get("host")
# https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html#botocore.config.Config
if "config_kwargs" in extra_config:
self.log.info(
"Retrieving config_kwargs from Connection.extra_config['config_kwargs']: %s",
extra_config["config_kwargs"],
)
self.config = Config(**extra_config["config_kwargs"])
session = _SessionFactory(
conn=connection_object, region_name=region_name, config=self.config
).create_session()
return session, endpoint_url
except AirflowException:
self.log.warning("Unable to use Airflow Connection for credentials.")
self.log.info("Fallback on boto3 credential strategy")
# http://boto3.readthedocs.io/en/latest/guide/configuration.html
self.log.info(
"Creating session using boto3 credential strategy region_name=%s",
region_name,
)
session = boto3.session.Session(region_name=region_name)
return session, None
[docs] def get_client_type(
self,
client_type: str,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)
# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
config = self.config
return session.client(client_type, endpoint_url=endpoint_url, config=config, verify=self.verify)
[docs] def get_resource_type(
self,
resource_type: str,
region_name: Optional[str] = None,
config: Optional[Config] = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)
# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
config = self.config
return session.resource(resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify)
@cached_property
[docs] def conn(self) -> Union[boto3.client, boto3.resource]:
"""
Get the underlying boto3 client/resource (cached)
:return: boto3.client or boto3.resource
:rtype: Union[boto3.client, boto3.resource]
"""
if self.client_type:
return self.get_client_type(self.client_type, region_name=self.region_name)
elif self.resource_type:
return self.get_resource_type(self.resource_type, region_name=self.region_name)
else:
# Rare possibility - subclasses have not specified a client_type or resource_type
raise NotImplementedError('Could not get boto3 connection!')
[docs] def get_conn(self) -> Union[boto3.client, boto3.resource]:
"""
Get the underlying boto3 client/resource (cached)
Implemented so that caching works as intended. It exists for compatibility
with subclasses that rely on a super().get_conn() method.
:return: boto3.client or boto3.resource
:rtype: Union[boto3.client, boto3.resource]
"""
# Compat shim
return self.conn
[docs] def get_session(self, region_name: Optional[str] = None) -> boto3.session.Session:
"""Get the underlying boto3.session."""
session, _ = self._get_credentials(region_name)
return session
[docs] def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials:
"""
Get the underlying `botocore.Credentials` object.
This contains the following authentication attributes: access_key, secret_key and token.
"""
session, _ = self._get_credentials(region_name)
# Credentials are refreshable, so accessing your access key and
# secret key separately can lead to a race condition.
# See https://stackoverflow.com/a/36291428/8283373
return session.get_credentials().get_frozen_credentials()
[docs] def expand_role(self, role: str) -> str:
"""
If the IAM role is a role name, get the Amazon Resource Name (ARN) for the role.
If IAM role is already an IAM role ARN, no change is made.
:param role: IAM role name or ARN
:return: IAM role ARN
"""
if "/" in role:
return role
else:
return self.get_client_type("iam").get_role(RoleName=role)["Role"]["Arn"]
[docs]def _parse_s3_config(
config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
) -> Tuple[Optional[str], Optional[str]]:
"""
Parses a config file for s3 credentials. Can currently
parse boto, s3cmd.conf and AWS SDK config formats
:param config_file_name: path to the config file
:type config_file_name: str
:param config_format: config type. One of "boto", "s3cmd" or "aws".
Defaults to "boto"
:type config_format: str
:param profile: profile name in AWS type config file
:type profile: str
"""
config = configparser.ConfigParser()
if config.read(config_file_name): # pragma: no cover
sections = config.sections()
else:
raise AirflowException(f"Couldn't read {config_file_name}")
# Setting option names depending on file format
if config_format is None:
config_format = "boto"
conf_format = config_format.lower()
if conf_format == "boto": # pragma: no cover
if profile is not None and "profile " + profile in sections:
cred_section = "profile " + profile
else:
cred_section = "Credentials"
elif conf_format == "aws" and profile is not None:
cred_section = profile
else:
cred_section = "default"
# Option names
if conf_format in ("boto", "aws"): # pragma: no cover
key_id_option = "aws_access_key_id"
secret_key_option = "aws_secret_access_key"
# security_token_option = 'aws_security_token'
else:
key_id_option = "access_key"
secret_key_option = "secret_key"
# Actual Parsing
if cred_section not in sections:
raise AirflowException("This config file format is not recognized")
else:
try:
access_key = config.get(cred_section, key_id_option)
secret_key = config.get(cred_section, secret_key_option)
except Exception:
logging.warning("Option Error in parsing s3 config file")
raise
return access_key, secret_key