from __future__ import annotations

from copy import deepcopy
from typing import (

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import BaseOperator
from import KiotaRequestAdapterHook
from import (
from airflow.utils.xcom import XCOM_RETURN_KEY

    from io import BytesIO

    from kiota_abstractions.request_adapter import ResponseType
    from kiota_abstractions.request_information import QueryParams
    from msgraph_core import APIVersion

    from airflow.utils.context import Context

[docs]class MSGraphAsyncOperator(BaseOperator): """ A Microsoft Graph API operator which allows you to execute REST call to the Microsoft Graph API. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:MSGraphAsyncOperator` :param url: The url being executed on the Microsoft Graph API (templated). :param response_type: The expected return type of the response as a string. Possible value are: `bytes`, `str`, `int`, `float`, `bool` and `datetime` (default is None). :param method: The HTTP method being used to do the REST call (default is GET). :param conn_id: The HTTP Connection ID to run the operator against (templated). :param key: The key that will be used to store `XCom's` ("return_value" is default). :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None). When no timeout is specified or set to None then there is no HTTP timeout on each request. :param proxies: A dict defining the HTTP proxies to be used (default is None). :param api_version: The API version of the Microsoft Graph API to be used (default is v1). You can pass an enum named APIVersion which has 2 possible members v1 and beta, or you can pass a string as `v1.0` or `beta`. :param result_processor: Function to further process the response from MS Graph API (default is lambda: context, response: response). When the response returned by the `KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string. :param serializer: Class which handles response serialization (default is ResponseSerializer). Bytes will be base64 encoded into a string, so it can be stored as an XCom. """
[docs] template_fields: Sequence[str] = ( "url", "response_type", "path_parameters", "url_template", "query_parameters", "headers", "data", "conn_id", )
def __init__( self, *, url: str, response_type: ResponseType | None = None, path_parameters: dict[str, Any] | None = None, url_template: str | None = None, method: str = "GET", query_parameters: dict[str, QueryParams] | None = None, headers: dict[str, str] | None = None, data: dict[str, Any] | str | BytesIO | None = None, conn_id: str = KiotaRequestAdapterHook.default_conn_name, key: str = XCOM_RETURN_KEY, timeout: float | None = None, proxies: dict | None = None, api_version: APIVersion | None = None, pagination_function: Callable[[MSGraphAsyncOperator, dict], tuple[str, dict]] | None = None, result_processor: Callable[[Context, Any], Any] = lambda context, result: result, serializer: type[ResponseSerializer] = ResponseSerializer, **kwargs: Any, ): super().__init__(**kwargs) self.url = url self.response_type = response_type self.path_parameters = path_parameters self.url_template = url_template self.method = method self.query_parameters = query_parameters self.headers = headers = data self.conn_id = conn_id self.key = key self.timeout = timeout self.proxies = proxies self.api_version = api_version self.pagination_function = pagination_function or self.paginate self.result_processor = result_processor self.serializer: ResponseSerializer = serializer() self.results: list[Any] | None = None
[docs] def execute(self, context: Context) -> None: self.defer( trigger=MSGraphTrigger( url=self.url, response_type=self.response_type, path_parameters=self.path_parameters, url_template=self.url_template, method=self.method, query_parameters=self.query_parameters, headers=self.headers,, conn_id=self.conn_id, timeout=self.timeout, proxies=self.proxies, api_version=self.api_version, serializer=type(self.serializer), ), method_name=self.execute_complete.__name__, )
[docs] def execute_complete( self, context: Context, event: dict[Any, Any] | None = None, ) -> Any: """ Execute callback when MSGraphTrigger finishes execution. This method gets executed automatically when MSGraphTrigger completes its execution. """ self.log.debug("context: %s", context) if event: self.log.debug("%s completed with %s: %s", self.task_id, event.get("status"), event) if event.get("status") == "failure": raise AirflowException(event.get("message")) response = event.get("response") self.log.debug("response: %s", response) if response: response = self.serializer.deserialize(response) self.log.debug("deserialize response: %s", response) result = self.result_processor(context, response) self.log.debug("processed response: %s", result) event["response"] = result try: self.trigger_next_link(response, method_name=self.pull_execute_complete.__name__) except TaskDeferred as exception: self.append_result( result=result, append_result_as_list_if_absent=True, ) self.push_xcom(context=context, value=self.results) raise exception self.append_result(result=result) self.log.debug("results: %s", self.results) return self.results return None
[docs] def append_result( self, result: Any, append_result_as_list_if_absent: bool = False, ): self.log.debug("value: %s", result) if isinstance(self.results, list): if isinstance(result, list): self.results.extend(result) else: self.results.append(result) else: if append_result_as_list_if_absent: if isinstance(result, list): self.results = result else: self.results = [result] else: self.results = result
[docs] def push_xcom(self, context: Context, value) -> None: self.log.debug("do_xcom_push: %s", self.do_xcom_push) if self.do_xcom_push:"Pushing XCom with key '%s': %s", self.key, value) self.xcom_push(context=context, key=self.key, value=value)
[docs] def pull_execute_complete(self, context: Context, event: dict[Any, Any] | None = None) -> Any: self.results = list( self.xcom_pull( context=context, task_ids=self.task_id, dag_id=self.dag_id, key=self.key, ) or [] ) "Pulled XCom with task_id '%s' and dag_id '%s' and key '%s': %s", self.task_id, self.dag_id, self.key, self.results, ) return self.execute_complete(context, event)
[docs] def paginate(operator: MSGraphAsyncOperator, response: dict) -> tuple[Any, dict[str, Any] | None]: odata_count = response.get("@odata.count") if odata_count and operator.query_parameters: query_parameters = deepcopy(operator.query_parameters) top = query_parameters.get("$top") odata_count = response.get("@odata.count") if top and odata_count: if len(response.get("value", [])) == top: skip = ( sum(map(lambda result: len(result["value"]), operator.results)) + top if operator.results else top ) query_parameters["$skip"] = skip return operator.url, query_parameters return response.get("@odata.nextLink"), operator.query_parameters

