Source code for airflow.providers.keycloak.auth_manager.keycloak_auth_manager

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import base64
import json
import logging
import time
import warnings
from base64 import urlsafe_b64decode
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin

import requests
from fastapi import FastAPI
from keycloak import KeycloakOpenID
from keycloak.exceptions import KeycloakPostError
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager
from airflow.exceptions import AirflowConfigException, AirflowProviderDeprecationWarning

try:
    from airflow.api_fastapi.auth.managers.base_auth_manager import ExtendedResourceMethod
except ImportError:
    from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod as ExtendedResourceMethod

from airflow.api_fastapi.common.types import MenuItem
from airflow.cli.cli_config import CLICommand
from airflow.providers.common.compat.sdk import AirflowException, conf
from airflow.providers.keycloak.auth_manager.cache import single_flight
from airflow.providers.keycloak.auth_manager.constants import (
    CONF_CLIENT_ID_KEY,
    CONF_CLIENT_SECRET_KEY,
    CONF_REALM_KEY,
    CONF_REQUESTS_POOL_SIZE_KEY,
    CONF_REQUESTS_RETRIES_KEY,
    CONF_SECTION_NAME,
    CONF_SERVER_URL_KEY,
)
from airflow.providers.keycloak.auth_manager.resources import KeycloakResource
from airflow.providers.keycloak.auth_manager.user import KeycloakAuthManagerUser
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
    from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
    from airflow.api_fastapi.auth.managers.models.resource_details import (
        AccessView,
        AssetAliasDetails,
        AssetDetails,
        BackfillDetails,
        ConfigurationDetails,
        ConnectionDetails,
        DagAccessEntity,
        DagDetails,
        PoolDetails,
        TeamDetails,
        VariableDetails,
    )
    from airflow.cli.cli_config import CLICommand

[docs] log = logging.getLogger(__name__)
[docs] RESOURCE_ID_ATTRIBUTE_NAME = "resource_id"
[docs] TEAM_SCOPED_RESOURCES = frozenset( { KeycloakResource.CONNECTION, KeycloakResource.DAG, KeycloakResource.POOL, KeycloakResource.TEAM, KeycloakResource.VARIABLE, } )
[docs] class KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]): """ Keycloak auth manager. Leverages Keycloak to perform authentication and authorization in Airflow. """ def __init__(self): super().__init__() self._http_session = None @property
[docs] def http_session(self) -> requests.Session: """Lazy-initialize and return the requests session with connection pooling.""" if self._http_session is not None: return self._http_session self._http_session = requests.Session() pool_size = conf.getint(CONF_SECTION_NAME, CONF_REQUESTS_POOL_SIZE_KEY, fallback=10) retry_total = conf.getint(CONF_SECTION_NAME, CONF_REQUESTS_RETRIES_KEY, fallback=3) retry_strategy = Retry( total=retry_total, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504], allowed_methods=["HEAD", "GET", "OPTIONS", "POST"], ) adapter = HTTPAdapter(pool_connections=pool_size, pool_maxsize=pool_size, max_retries=retry_strategy) self._http_session.mount("https://", adapter) self._http_session.mount("http://", adapter) return self._http_session
[docs] def deserialize_user(self, token: dict[str, Any]) -> KeycloakAuthManagerUser: return KeycloakAuthManagerUser( user_id=token.pop("user_id"), name=token.pop("name"), access_token=token.pop("access_token"), refresh_token=token.pop("refresh_token"), )
[docs] def serialize_user(self, user: KeycloakAuthManagerUser) -> dict[str, Any]: return { "user_id": user.get_id(), "name": user.get_name(), "access_token": user.access_token, "refresh_token": user.refresh_token, }
[docs] def get_cli_user(self) -> KeycloakAuthManagerUser: """ Return a service-account user for the local CLI to mint a token for. Keycloak tokens are issued by the external Keycloak server, so they cannot be forged locally. The Keycloak client is already configured for Airflow to talk to Keycloak, so we reuse it to obtain a service-account token through the ``client_credentials`` flow. The service account's effective permissions are governed by the Keycloak deployment. If the client credentials are not usable, the operator must provide a token via the ``AIRFLOW_CLI_TOKEN`` environment variable. """ try: tokens = self.get_keycloak_client().token(grant_type="client_credentials") except Exception as e: raise AirflowConfigException( "Could not obtain a Keycloak service-account token for the CLI via the " "client_credentials flow. Set the AIRFLOW_CLI_TOKEN environment variable " f"with a valid API token instead. Original error: {e}" ) from e return KeycloakAuthManagerUser( user_id="airflow-cli", name="airflow-cli", access_token=tokens["access_token"], # No refresh token is issued for the client_credentials flow (RFC 6749 §4.4.3), # which marks this as a service account in refresh_user/refresh_tokens. refresh_token=tokens.get("refresh_token"), )
[docs] def get_url_login(self, **kwargs) -> str: base_url = conf.get("api", "base_url", fallback="/") return urljoin(base_url, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
[docs] def get_url_logout(self) -> str | None: base_url = conf.get("api", "base_url", fallback="/") return urljoin(base_url, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/logout")
[docs] def refresh_user(self, *, user: KeycloakAuthManagerUser) -> KeycloakAuthManagerUser | None: # According to RFC6749 section 4.4.3, a refresh token should not be included when using # the Service accounts/client_credentials flow. # We check whether the user has a refresh token; if not, we assume it's a service account # and return None. if not user.refresh_token: return None if self._token_expired(user.access_token): tokens = self.refresh_tokens(user=user) if tokens: user.refresh_token = tokens["refresh_token"] user.access_token = tokens["access_token"] return user return None
[docs] def refresh_tokens(self, *, user: KeycloakAuthManagerUser) -> dict[str, str]: if not user.refresh_token: # It is a service account. It used the client credentials flow and no refresh token is issued. return {} try: log.debug("Refreshing the token") client = self.get_keycloak_client() return client.refresh_token(user.refresh_token) except KeycloakPostError as exc: try: from airflow.api_fastapi.auth.managers.exceptions import ( AuthManagerRefreshTokenExpiredException, ) except ImportError: return {} else: raise AuthManagerRefreshTokenExpiredException(exc)
[docs] def is_authorized_configuration( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: ConfigurationDetails | None = None, ) -> bool: config_section = details.section if details else None return self._is_authorized( method=method, resource_type=KeycloakResource.CONFIGURATION, user=user, resource_id=config_section )
[docs] def is_authorized_connection( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: ConnectionDetails | None = None, ) -> bool: connection_id = details.conn_id if details else None team_name = self._get_team_name(details) return self._is_authorized( method=method, resource_type=KeycloakResource.CONNECTION, user=user, resource_id=connection_id, team_name=team_name, )
[docs] def is_authorized_dag( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, ) -> bool: dag_id = details.id if details else None team_name = self._get_team_name(details) access_entity_str = access_entity.value if access_entity else None return self._is_authorized( method=method, resource_type=KeycloakResource.DAG, user=user, resource_id=dag_id, team_name=team_name, attributes={"dag_entity": access_entity_str}, )
[docs] def is_authorized_backfill( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: BackfillDetails | None = None ) -> bool: # Method can be removed once the min Airflow version is >= 3.2.0. warnings.warn( "Use ``is_authorized_dag`` on ``DagAccessEntity.RUN`` instead for a dag level access control.", AirflowProviderDeprecationWarning, stacklevel=2, ) backfill_id = str(details.id) if details else None return self._is_authorized( method=method, resource_type=KeycloakResource.BACKFILL, user=user, resource_id=backfill_id )
[docs] def is_authorized_asset( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: AssetDetails | None = None ) -> bool: asset_id = details.id if details else None return self._is_authorized( method=method, resource_type=KeycloakResource.ASSET, user=user, resource_id=asset_id )
[docs] def is_authorized_asset_alias( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: AssetAliasDetails | None = None, ) -> bool: asset_alias_id = details.id if details else None return self._is_authorized( method=method, resource_type=KeycloakResource.ASSET_ALIAS, user=user, resource_id=asset_alias_id, )
[docs] def is_authorized_variable( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: VariableDetails | None = None ) -> bool: variable_key = details.key if details else None team_name = self._get_team_name(details) return self._is_authorized( method=method, resource_type=KeycloakResource.VARIABLE, user=user, resource_id=variable_key, team_name=team_name, )
[docs] def is_authorized_pool( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: PoolDetails | None = None ) -> bool: pool_name = details.name if details else None team_name = self._get_team_name(details) return self._is_authorized( method=method, resource_type=KeycloakResource.POOL, user=user, resource_id=pool_name, team_name=team_name, )
[docs] def is_authorized_team( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: TeamDetails | None = None ) -> bool: team_name = details.name if details else None return self._is_authorized( method=method, resource_type=KeycloakResource.TEAM, user=user, team_name=team_name, )
[docs] def is_authorized_view(self, *, access_view: AccessView, user: KeycloakAuthManagerUser) -> bool: return self._is_authorized( method="GET", resource_type=KeycloakResource.VIEW, user=user, resource_id=access_view.value, )
[docs] def is_authorized_custom_view( self, *, method: ResourceMethod | str, resource_name: str, user: KeycloakAuthManagerUser ) -> bool: return self._is_authorized( method=method, resource_type=KeycloakResource.CUSTOM, user=user, resource_id=resource_name )
[docs] def filter_authorized_menu_items( self, menu_items: list[MenuItem], *, user: KeycloakAuthManagerUser ) -> list[MenuItem]: authorized_menus = self._is_batch_authorized( permissions=[("MENU", menu_item.value) for menu_item in menu_items], user=user, ) return [MenuItem(menu[1]) for menu in authorized_menus]
[docs] def get_fastapi_app(self) -> FastAPI | None: from airflow.providers.keycloak.auth_manager.routes.login import login_router from airflow.providers.keycloak.auth_manager.routes.token import token_router app = FastAPI( title="Keycloak auth manager sub application", description=( "This is the Keycloak auth manager fastapi sub application. This API is only available if the " "auth manager used in the Airflow environment is Keycloak auth manager. " "This sub application provides login routes." ), ) app.include_router(login_router) app.include_router(token_router) return app
@staticmethod
[docs] def get_cli_commands() -> list[CLICommand]: """Vends CLI commands to be included in Airflow CLI.""" from airflow.providers.keycloak.cli.definition import get_keycloak_cli_commands return get_keycloak_cli_commands()
@staticmethod
[docs] def get_keycloak_client(client_id: str | None = None, client_secret: str | None = None) -> KeycloakOpenID: """ Get a KeycloakOpenID client instance. :param client_id: Optional client ID to override config. If provided, client_secret must also be provided. :param client_secret: Optional client secret to override config. If provided, client_id must also be provided. """ if (client_id is None) != (client_secret is None): raise ValueError( "Both `client_id` and `client_secret` must be provided together, or both must be None" ) if client_id is None: client_id = conf.get(CONF_SECTION_NAME, CONF_CLIENT_ID_KEY) client_secret = conf.get(CONF_SECTION_NAME, CONF_CLIENT_SECRET_KEY) realm = conf.get(CONF_SECTION_NAME, CONF_REALM_KEY) server_url = conf.get(CONF_SECTION_NAME, CONF_SERVER_URL_KEY) return KeycloakOpenID( server_url=server_url, client_id=client_id, client_secret_key=client_secret, realm_name=realm, )
def _is_authorized( self, *, method: ResourceMethod | str, resource_type: KeycloakResource, user: KeycloakAuthManagerUser, resource_id: str | None = None, team_name: str | None = None, attributes: dict[str, str | None] | None = None, ) -> bool: client_id = conf.get(CONF_SECTION_NAME, CONF_CLIENT_ID_KEY) realm = conf.get(CONF_SECTION_NAME, CONF_REALM_KEY) server_url = conf.get(CONF_SECTION_NAME, CONF_SERVER_URL_KEY) context_attributes = prune_dict(attributes or {}) if resource_id: context_attributes[RESOURCE_ID_ATTRIBUTE_NAME] = resource_id elif method == "GET": method = "LIST" if ( team_name and conf.getboolean("core", "multi_team", fallback=False) and resource_type in TEAM_SCOPED_RESOURCES ): resource_name = f"{resource_type.value}:{team_name}" else: resource_name = resource_type.value permission = f"{resource_name}#{method}" resp = self.http_session.post( self._get_token_url(server_url, realm), data=self._get_payload(client_id, permission, context_attributes), headers=self._get_headers(user.access_token), timeout=5, ) if resp.status_code == 200: return True if resp.status_code == 401: log.debug("Received 401 from Keycloak: %s", resp.text) return False if resp.status_code == 403: return False if resp.status_code == 400: error = json.loads(resp.text) raise AirflowException( f"Request not recognized by Keycloak. {error.get('error')}. {error.get('error_description')}" ) raise AirflowException(f"Unexpected error: {resp.status_code} - {resp.text}")
[docs] def filter_authorized_dag_ids( self, *, dag_ids: set[str], user: KeycloakAuthManagerUser, method: ResourceMethod = "GET", team_name: str | None = None, ) -> set[str]: cache_key = (user.get_id(), method, team_name, frozenset(dag_ids)) def query_keycloak() -> set[str]: kwargs: dict = dict(dag_ids=dag_ids, user=user, method=method) if team_name is not None: kwargs["team_name"] = team_name return super(KeycloakAuthManager, self).filter_authorized_dag_ids(**kwargs) return single_flight(cache_key, query_keycloak)
def _is_batch_authorized( self, *, permissions: list[tuple[ExtendedResourceMethod, str]], user: KeycloakAuthManagerUser, ) -> set[tuple[ExtendedResourceMethod, str]]: client_id = conf.get(CONF_SECTION_NAME, CONF_CLIENT_ID_KEY) realm = conf.get(CONF_SECTION_NAME, CONF_REALM_KEY) server_url = conf.get(CONF_SECTION_NAME, CONF_SERVER_URL_KEY) resp = self.http_session.post( self._get_token_url(server_url, realm), data=self._get_batch_payload(client_id, permissions), headers=self._get_headers(user.access_token), timeout=5, ) if resp.status_code == 200: return {(perm["scopes"][0], perm["rsname"]) for perm in resp.json()} if resp.status_code == 401: log.debug("Received 401 from Keycloak: %s", resp.text) return set() if resp.status_code == 403: return set() if resp.status_code == 400: error = json.loads(resp.text) raise AirflowException( f"Request not recognized by Keycloak. {error.get('error')}. {error.get('error_description')}" ) raise AirflowException(f"Unexpected error: {resp.status_code} - {resp.text}") def _get_teams(self) -> set[str]: realm = conf.get(CONF_SECTION_NAME, CONF_REALM_KEY) server_url = conf.get(CONF_SECTION_NAME, CONF_SERVER_URL_KEY) pat = self.get_keycloak_client().token(grant_type="client_credentials")["access_token"] prefix = f"{KeycloakResource.TEAM.value}:" resource_url = f"{server_url.rstrip('/')}/realms/{realm}/authz/protection/resource_set" resources_resp = self.http_session.get( resource_url, params={"name": prefix, "matchingUri": "false", "max": "-1", "deep": "true"}, headers={"Authorization": f"Bearer {pat}"}, timeout=5, ) resources_resp.raise_for_status() return {r["name"][len(prefix) :] for r in resources_resp.json() if r["name"].startswith(prefix)} @staticmethod def _get_token_url(server_url, realm): # Normalize server_url to avoid double slashes (required for Keycloak 26.4+ strict path validation). return f"{server_url.rstrip('/')}/realms/{realm}/protocol/openid-connect/token" @staticmethod def _get_team_name( details: ConnectionDetails | DagDetails | PoolDetails | VariableDetails | None, ) -> str | None: return getattr(details, "team_name", None) if details else None @staticmethod def _get_payload(client_id: str, permission: str, attributes: dict[str, str] | None = None): payload: dict[str, Any] = { "grant_type": "urn:ietf:params:oauth:grant-type:uma-ticket", "audience": client_id, "permission": permission, } if attributes: # Per UMA spec, push claims using claim_token parameter with base64-encoded JSON # Values must be arrays of strings per Keycloak documentation # See: https://www.keycloak.org/docs/latest/authorization_services/index.html#_service_pushing_claims claims = {key: [value] for key, value in attributes.items()} claim_json = json.dumps(claims, sort_keys=True) claim_token = base64.b64encode(claim_json.encode()).decode() payload["claim_token"] = claim_token payload["claim_token_format"] = "urn:ietf:params:oauth:token-type:jwt" return payload @staticmethod def _get_batch_payload(client_id: str, permissions: list[tuple[ExtendedResourceMethod, str]]): payload: dict[str, Any] = { "grant_type": "urn:ietf:params:oauth:grant-type:uma-ticket", "audience": client_id, "permission": [f"{permission[1]}#{permission[0]}" for permission in permissions], "response_mode": "permissions", } return payload @staticmethod def _get_headers(access_token): return { "Authorization": f"Bearer {access_token}", "Content-Type": "application/x-www-form-urlencoded", } @staticmethod def _token_expired(token: str) -> bool: """ Check whether a JWT token is expired. :meta private: :param token: the token """ payload_b64 = token.split(".")[1] + "==" payload_bytes = urlsafe_b64decode(payload_b64) payload = json.loads(payload_bytes) return payload["exp"] < int(time.time())

Was this entry helpful?