Source code for airflow.providers.google.cloud.utils.external_token_supplier

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import abc
import time
from functools import wraps
from typing import TYPE_CHECKING, Any

import requests
from google.auth.exceptions import RefreshError
from google.auth.identity_pool import SubjectTokenSupplier

if TYPE_CHECKING:
    from google.auth.external_account import SupplierContext
    from google.auth.transport import Request

from airflow.utils.log.logging_mixin import LoggingMixin


[docs]def cache_token_decorator(get_subject_token_method): """ Cache calls to ``SubjectTokenSupplier`` instances' ``get_token_supplier`` methods. Different instances of a same SubjectTokenSupplier class with the same attributes share the OIDC token cache. :param get_subject_token_method: A method that returns both a token and an integer specifying the time in seconds until the token expires See also: https://googleapis.dev/python/google-auth/latest/reference/google.auth.identity_pool.html#google.auth.identity_pool.SubjectTokenSupplier.get_subject_token """ cache = {} @wraps(get_subject_token_method) def wrapper(supplier_instance: CacheTokenSupplier, *args, **kwargs) -> str: """ Obeys the interface set by ``SubjectTokenSupplier`` for ``get_subject_token`` methods. :param supplier_instance: the SubjectTokenSupplier instance whose get_subject_token method is being decorated :return: The token string """ nonlocal cache cache_key = supplier_instance.get_subject_key() token: dict[str, str | float] = {} if cache_key not in cache or cache[cache_key]["expiration_time"] < time.monotonic(): supplier_instance.log.info("OIDC token missing or expired") try: access_token, expires_in = get_subject_token_method(supplier_instance, *args, **kwargs) if not isinstance(expires_in, int) or not isinstance(access_token, str): raise RefreshError # assume error if strange values are provided except RefreshError: supplier_instance.log.error("Failed retrieving new OIDC Token from IdP") raise expiration_time = time.monotonic() + float(expires_in) token["access_token"] = access_token token["expiration_time"] = expiration_time cache[cache_key] = token supplier_instance.log.info("New OIDC token retrieved, expires in %s seconds.", expires_in) return cache[cache_key]["access_token"] return wrapper
[docs]class CacheTokenSupplier(LoggingMixin, SubjectTokenSupplier): """ A superclass for all Subject Token Supplier classes that wish to implement a caching mechanism. Child classes must implement the ``get_subject_key`` method to generate a string that serves as the cache key, ensuring that tokens are shared appropriately among instances. Methods: get_subject_key: Abstract method to be implemented by child classes. It should return a string that serves as the cache key. """ def __init__(self): super().__init__() @abc.abstractmethod
[docs] def get_subject_key(self) -> str: raise NotImplementedError("")
[docs]class ClientCredentialsGrantFlowTokenSupplier(CacheTokenSupplier): """ Class that retrieves an OIDC token from an external IdP using OAuth2.0 Client Credentials Grant flow. This class implements the ``SubjectTokenSupplier`` interface class used by ``google.auth.identity_pool.Credentials`` :params oidc_issuer_url: URL of the IdP that performs OAuth2.0 Client Credentials Grant flow and returns an OIDC token. :params client_id: Client ID of the application requesting the token :params client_secret: Client secret of the application requesting the token :params extra_params_kwargs: Extra parameters to be passed in the payload of the POST request to the `oidc_issuer_url` See also: https://googleapis.dev/python/google-auth/latest/reference/google.auth.identity_pool.html#google.auth.identity_pool.SubjectTokenSupplier """ def __init__( self, oidc_issuer_url: str, client_id: str, client_secret: str, **extra_params_kwargs: Any, ) -> None: super().__init__() self.oidc_issuer_url = oidc_issuer_url self.client_id = client_id self.client_secret = client_secret self.extra_params_kwargs = extra_params_kwargs @cache_token_decorator
[docs] def get_subject_token(self, context: SupplierContext, request: Request) -> tuple[str, int]: """Perform Client Credentials Grant flow with IdP and retrieves an OIDC token and expiration time.""" self.log.info("Requesting new OIDC token from external IdP.") try: response = requests.post( self.oidc_issuer_url, data={ "grant_type": "client_credentials", "client_id": self.client_id, "client_secret": self.client_secret, **self.extra_params_kwargs, }, ) response.raise_for_status() except requests.HTTPError as e: raise RefreshError(str(e)) except requests.ConnectionError as e: raise RefreshError(str(e)) try: response_dict = response.json() except requests.JSONDecodeError: raise RefreshError(f"Didn't get a json response from {self.oidc_issuer_url}") # These fields are required if {"access_token", "expires_in"} - set(response_dict.keys()): # TODO more information about the error can be provided in the exception by inspecting the response raise RefreshError(f"No access token returned from {self.oidc_issuer_url}") return response_dict["access_token"], response_dict["expires_in"]
[docs] def get_subject_key(self) -> str: """ Create a cache key using the OIDC issuer URL, client ID, client secret and additional parameters. Instances with the same credentials will share tokens. """ cache_key = ( self.oidc_issuer_url + self.client_id + self.client_secret + ",".join(sorted(self.extra_params_kwargs)) ) return cache_key

Was this entry helpful?