# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import argparse
from functools import cached_property
from typing import TYPE_CHECKING, Sequence, cast
from flask import session, url_for
from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
from airflow.providers.amazon.aws.auth_manager.avp.facade import (
AwsAuthManagerAmazonVerifiedPermissionsFacade,
IsAuthorizedRequest,
)
from airflow.providers.amazon.aws.auth_manager.cli.definition import (
AWS_AUTH_MANAGER_COMMANDS,
)
from airflow.providers.amazon.aws.auth_manager.constants import (
CONF_ENABLE_KEY,
CONF_SECTION_NAME,
)
from airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override import (
AwsSecurityManagerOverride,
)
from airflow.security.permissions import (
RESOURCE_AUDIT_LOG,
RESOURCE_CLUSTER_ACTIVITY,
RESOURCE_CONFIG,
RESOURCE_CONNECTION,
RESOURCE_DAG,
RESOURCE_DAG_CODE,
RESOURCE_DAG_DEPENDENCIES,
RESOURCE_DAG_RUN,
RESOURCE_DATASET,
RESOURCE_DOCS,
RESOURCE_JOB,
RESOURCE_PLUGIN,
RESOURCE_POOL,
RESOURCE_PROVIDER,
RESOURCE_SLA_MISS,
RESOURCE_TASK_INSTANCE,
RESOURCE_TASK_RESCHEDULE,
RESOURCE_TRIGGER,
RESOURCE_VARIABLE,
RESOURCE_XCOM,
)
try:
from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod
from airflow.auth.managers.models.resource_details import (
AccessView,
ConnectionDetails,
DagAccessEntity,
DagDetails,
PoolDetails,
VariableDetails,
)
except ImportError:
raise AirflowOptionalProviderFeatureException(
"Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0"
)
if TYPE_CHECKING:
from flask_appbuilder.menu import MenuItem
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
IsAuthorizedDagRequest,
IsAuthorizedPoolRequest,
IsAuthorizedVariableRequest,
)
from airflow.auth.managers.models.resource_details import (
ConfigurationDetails,
DatasetDetails,
)
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
_MENU_ITEM_REQUESTS: dict[str, IsAuthorizedRequest] = {
RESOURCE_AUDIT_LOG: {
"method": "GET",
"entity_type": AvpEntities.DAG,
"context": {
"dag_entity": {
"string": DagAccessEntity.AUDIT_LOG.value,
},
},
},
RESOURCE_CLUSTER_ACTIVITY: {
"method": "GET",
"entity_type": AvpEntities.VIEW,
"entity_id": AccessView.CLUSTER_ACTIVITY.value,
},
RESOURCE_CONFIG: {
"method": "GET",
"entity_type": AvpEntities.CONFIGURATION,
},
RESOURCE_CONNECTION: {
"method": "GET",
"entity_type": AvpEntities.CONNECTION,
},
RESOURCE_DAG: {
"method": "GET",
"entity_type": AvpEntities.DAG,
},
RESOURCE_DAG_CODE: {
"method": "GET",
"entity_type": AvpEntities.DAG,
"context": {
"dag_entity": {
"string": DagAccessEntity.CODE.value,
},
},
},
RESOURCE_DAG_DEPENDENCIES: {
"method": "GET",
"entity_type": AvpEntities.DAG,
"context": {
"dag_entity": {
"string": DagAccessEntity.DEPENDENCIES.value,
},
},
},
RESOURCE_DAG_RUN: {
"method": "GET",
"entity_type": AvpEntities.DAG,
"context": {
"dag_entity": {
"string": DagAccessEntity.RUN.value,
},
},
},
RESOURCE_DATASET: {
"method": "GET",
"entity_type": AvpEntities.DATASET,
},
RESOURCE_DOCS: {
"method": "GET",
"entity_type": AvpEntities.VIEW,
"entity_id": AccessView.DOCS.value,
},
RESOURCE_PLUGIN: {
"method": "GET",
"entity_type": AvpEntities.VIEW,
"entity_id": AccessView.PLUGINS.value,
},
RESOURCE_JOB: {
"method": "GET",
"entity_type": AvpEntities.VIEW,
"entity_id": AccessView.JOBS.value,
},
RESOURCE_POOL: {
"method": "GET",
"entity_type": AvpEntities.POOL,
},
RESOURCE_PROVIDER: {
"method": "GET",
"entity_type": AvpEntities.VIEW,
"entity_id": AccessView.PROVIDERS.value,
},
RESOURCE_SLA_MISS: {
"method": "GET",
"entity_type": AvpEntities.DAG,
"context": {
"dag_entity": {
"string": DagAccessEntity.SLA_MISS.value,
},
},
},
RESOURCE_TASK_INSTANCE: {
"method": "GET",
"entity_type": AvpEntities.DAG,
"context": {
"dag_entity": {
"string": DagAccessEntity.TASK_INSTANCE.value,
},
},
},
RESOURCE_TASK_RESCHEDULE: {
"method": "GET",
"entity_type": AvpEntities.DAG,
"context": {
"dag_entity": {
"string": DagAccessEntity.TASK_RESCHEDULE.value,
},
},
},
RESOURCE_TRIGGER: {
"method": "GET",
"entity_type": AvpEntities.VIEW,
"entity_id": AccessView.TRIGGERS.value,
},
RESOURCE_VARIABLE: {
"method": "GET",
"entity_type": AvpEntities.VARIABLE,
},
RESOURCE_XCOM: {
"method": "GET",
"entity_type": AvpEntities.DAG,
"context": {
"dag_entity": {
"string": DagAccessEntity.XCOM.value,
},
},
},
}
[docs]class AwsAuthManager(BaseAuthManager):
"""
AWS auth manager.
Leverages AWS services such as Amazon Identity Center and Amazon Verified Permissions to perform
authentication and authorization in Airflow.
:param appbuilder: the flask app builder
"""
def __init__(self, appbuilder: AirflowAppBuilder) -> None:
super().__init__(appbuilder)
enable = conf.getboolean(CONF_SECTION_NAME, CONF_ENABLE_KEY)
if not enable:
raise NotImplementedError(
"The AWS auth manager is currently being built. It is not finalized. It is not intended to be used yet."
)
@cached_property
[docs] def avp_facade(self):
return AwsAuthManagerAmazonVerifiedPermissionsFacade()
[docs] def get_user(self) -> AwsAuthManagerUser | None:
return session["aws_user"] if self.is_logged_in() else None
[docs] def is_logged_in(self) -> bool:
return "aws_user" in session
[docs] def is_authorized_configuration(
self,
*,
method: ResourceMethod,
details: ConfigurationDetails | None = None,
user: BaseUser | None = None,
) -> bool:
config_section = details.section if details else None
return self.avp_facade.is_authorized(
method=method,
entity_type=AvpEntities.CONFIGURATION,
user=user or self.get_user(),
entity_id=config_section,
)
[docs] def is_authorized_connection(
self,
*,
method: ResourceMethod,
details: ConnectionDetails | None = None,
user: BaseUser | None = None,
) -> bool:
connection_id = details.conn_id if details else None
return self.avp_facade.is_authorized(
method=method,
entity_type=AvpEntities.CONNECTION,
user=user or self.get_user(),
entity_id=connection_id,
)
[docs] def is_authorized_dag(
self,
*,
method: ResourceMethod,
access_entity: DagAccessEntity | None = None,
details: DagDetails | None = None,
user: BaseUser | None = None,
) -> bool:
dag_id = details.id if details else None
context = (
None
if access_entity is None
else {
"dag_entity": {
"string": access_entity.value,
},
}
)
return self.avp_facade.is_authorized(
method=method,
entity_type=AvpEntities.DAG,
user=user or self.get_user(),
entity_id=dag_id,
context=context,
)
[docs] def is_authorized_dataset(
self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None
) -> bool:
dataset_uri = details.uri if details else None
return self.avp_facade.is_authorized(
method=method,
entity_type=AvpEntities.DATASET,
user=user or self.get_user(),
entity_id=dataset_uri,
)
[docs] def is_authorized_pool(
self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None
) -> bool:
pool_name = details.name if details else None
return self.avp_facade.is_authorized(
method=method,
entity_type=AvpEntities.POOL,
user=user or self.get_user(),
entity_id=pool_name,
)
[docs] def is_authorized_variable(
self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None
) -> bool:
variable_key = details.key if details else None
return self.avp_facade.is_authorized(
method=method,
entity_type=AvpEntities.VARIABLE,
user=user or self.get_user(),
entity_id=variable_key,
)
[docs] def is_authorized_view(
self,
*,
access_view: AccessView,
user: BaseUser | None = None,
) -> bool:
return self.avp_facade.is_authorized(
method="GET",
entity_type=AvpEntities.VIEW,
user=user or self.get_user(),
entity_id=access_view.value,
)
[docs] def batch_is_authorized_connection(
self,
requests: Sequence[IsAuthorizedConnectionRequest],
) -> bool:
"""
Batch version of ``is_authorized_connection``.
:param requests: a list of requests containing the parameters for ``is_authorized_connection``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.CONNECTION,
"entity_id": cast(ConnectionDetails, request["details"]).conn_id
if request.get("details")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())
[docs] def batch_is_authorized_dag(
self,
requests: Sequence[IsAuthorizedDagRequest],
) -> bool:
"""
Batch version of ``is_authorized_dag``.
:param requests: a list of requests containing the parameters for ``is_authorized_dag``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.DAG,
"entity_id": cast(DagDetails, request["details"]).id if request.get("details") else None,
"context": {
"dag_entity": {
"string": cast(DagAccessEntity, request["access_entity"]).value,
},
}
if request.get("access_entity")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())
[docs] def batch_is_authorized_pool(
self,
requests: Sequence[IsAuthorizedPoolRequest],
) -> bool:
"""
Batch version of ``is_authorized_pool``.
:param requests: a list of requests containing the parameters for ``is_authorized_pool``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.POOL,
"entity_id": cast(PoolDetails, request["details"]).name if request.get("details") else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())
[docs] def batch_is_authorized_variable(
self,
requests: Sequence[IsAuthorizedVariableRequest],
) -> bool:
"""
Batch version of ``is_authorized_variable``.
:param requests: a list of requests containing the parameters for ``is_authorized_variable``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.VARIABLE,
"entity_id": cast(VariableDetails, request["details"]).key
if request.get("details")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())
[docs] def get_url_login(self, **kwargs) -> str:
return url_for("AwsAuthManagerAuthenticationViews.login")
[docs] def get_url_logout(self) -> str:
return url_for("AwsAuthManagerAuthenticationViews.logout")
@cached_property
[docs] def security_manager(self) -> AwsSecurityManagerOverride:
return AwsSecurityManagerOverride(self.appbuilder)
@staticmethod
[docs] def get_cli_commands() -> list[CLICommand]:
"""Vends CLI commands to be included in Airflow CLI."""
return [
GroupCommand(
name="aws-auth-manager",
help="Manage resources used by AWS auth manager",
subcommands=AWS_AUTH_MANAGER_COMMANDS,
),
]
@staticmethod
def _get_menu_item_request(fab_resource_name: str) -> IsAuthorizedRequest:
menu_item_request = _MENU_ITEM_REQUESTS.get(fab_resource_name)
if menu_item_request:
return menu_item_request
else:
raise AirflowException(f"Unknown resource name {fab_resource_name}")
def _has_access_to_menu_item(
self, batch_is_authorized_results: list[dict], request: IsAuthorizedRequest, user: AwsAuthManagerUser
):
result = self.avp_facade.get_batch_is_authorized_single_result(
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
)
return result["decision"] == "ALLOW"
[docs]def get_parser() -> argparse.ArgumentParser:
"""Generate documentation; used by Sphinx argparse."""
from airflow.cli.cli_parser import AirflowHelpFormatter, _add_command
parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter)
subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND")
for group_command in AwsAuthManager.get_cli_commands():
_add_command(subparsers, group_command)
return parser