# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import warnings
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Any, Iterable, Sequence
import pendulum
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.serialization.enums import DagAttributeTypes
if TYPE_CHECKING:
    from logging import Logger
    from airflow.models.dag import DAG
    from airflow.models.operator import Operator
    from airflow.utils.edgemodifier import EdgeModifier
    from airflow.utils.task_group import TaskGroup
[docs]class DependencyMixin:
    """Mixing implementing common dependency setting methods methods like >> and <<."""
    @property
[docs]    def roots(self) -> Sequence[DependencyMixin]:
        """
        List of root nodes -- ones with no upstream dependencies.
        a.k.a. the "start" of this sub-graph
        """
        raise NotImplementedError() 
    @property
[docs]    def leaves(self) -> Sequence[DependencyMixin]:
        """
        List of leaf nodes -- ones with only upstream dependencies.
        a.k.a. the "end" of this sub-graph
        """
        raise NotImplementedError() 
    @abstractmethod
[docs]    def set_upstream(self, other: DependencyMixin | Sequence[DependencyMixin]):
        """Set a task or a task list to be directly upstream from the current task."""
        raise NotImplementedError() 
    @abstractmethod
[docs]    def set_downstream(self, other: DependencyMixin | Sequence[DependencyMixin]):
        """Set a task or a task list to be directly downstream from the current task."""
        raise NotImplementedError() 
[docs]    def update_relative(self, other: DependencyMixin, upstream=True) -> None:
        """
        Update relationship information about another TaskMixin. Default is no-op.
        Override if necessary.
        """ 
[docs]    def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
        """Implements Task << Task"""
        self.set_upstream(other)
        return other 
[docs]    def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
        """Implements Task >> Task"""
        self.set_downstream(other)
        return other 
[docs]    def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
        """Called for Task >> [Task] because list don't have __rshift__ operators."""
        self.__lshift__(other)
        return self 
[docs]    def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
        """Called for Task << [Task] because list don't have __lshift__ operators."""
        self.__rshift__(other)
        return self  
[docs]class TaskMixin(DependencyMixin):
    """:meta private:"""
[docs]    def __init_subclass__(cls) -> None:
        warnings.warn(
            f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}",
            category=RemovedInAirflow3Warning,
            stacklevel=2,
        )
        return super().__init_subclass__()  
[docs]class DAGNode(DependencyMixin, metaclass=ABCMeta):
    """
    A base class for a node in the graph of a workflow -- an Operator or a Task Group, either mapped or
    unmapped.
    """
[docs]    task_group: TaskGroup | None = None 
    """The task_group that contains this node"""
    @property
    @abstractmethod
[docs]    def node_id(self) -> str:
        raise NotImplementedError() 
    @property
[docs]    def label(self) -> str | None:
        tg = self.task_group
        if tg and tg.node_id and tg.prefix_group_id:
            # "task_group_id.task_id" -> "task_id"
            return self.node_id[len(tg.node_id) + 1 :]
        return self.node_id 
[docs]    start_date: pendulum.DateTime | None 
[docs]    end_date: pendulum.DateTime | None 
[docs]    upstream_task_ids: set[str] 
[docs]    downstream_task_ids: set[str] 
[docs]    def has_dag(self) -> bool:
        return self.dag is not None 
    @property
[docs]    def dag_id(self) -> str:
        """Returns dag id if it has one or an adhoc/meaningless ID"""
        if self.dag:
            return self.dag.dag_id
        return "_in_memory_dag_" 
    @property
[docs]    def log(self) -> Logger:
        raise NotImplementedError() 
    @property
    @abstractmethod
[docs]    def roots(self) -> Sequence[DAGNode]:
        raise NotImplementedError() 
    @property
    @abstractmethod
[docs]    def leaves(self) -> Sequence[DAGNode]:
        raise NotImplementedError() 
    def _set_relatives(
        self,
        task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
        upstream: bool = False,
        edge_modifier: EdgeModifier | None = None,
    ) -> None:
        """Sets relatives for the task or task list."""
        from airflow.models.baseoperator import BaseOperator
        from airflow.models.mappedoperator import MappedOperator
        from airflow.models.operator import Operator
        if not isinstance(task_or_task_list, Sequence):
            task_or_task_list = [task_or_task_list]
        task_list: list[Operator] = []
        for task_object in task_or_task_list:
            task_object.update_relative(self, not upstream)
            relatives = task_object.leaves if upstream else task_object.roots
            for task in relatives:
                if not isinstance(task, (BaseOperator, MappedOperator)):
                    raise AirflowException(
                        f"Relationships can only be set between Operators; received {task.__class__.__name__}"
                    )
                task_list.append(task)
        # relationships can only be set if the tasks share a single DAG. Tasks
        # without a DAG are assigned to that DAG.
        dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag}
        if len(dags) > 1:
            raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}")
        elif len(dags) == 1:
            dag = dags.pop()
        else:
            raise AirflowException(
                f"Tried to create relationships between tasks that don't have DAGs yet. "
                f"Set the DAG for at least one task and try again: {[self, *task_list]}"
            )
        if not self.has_dag():
            # If this task does not yet have a dag, add it to the same dag as the other task.
            self.dag = dag
        def add_only_new(obj, item_set: set[str], item: str) -> None:
            """Adds only new items to item set"""
            if item in item_set:
                self.log.warning("Dependency %s, %s already registered for DAG: %s", obj, item, dag.dag_id)
            else:
                item_set.add(item)
        for task in task_list:
            if dag and not task.has_dag():
                # If the other task does not yet have a dag, add it to the same dag as this task and
                dag.add_task(task)
            if upstream:
                add_only_new(task, task.downstream_task_ids, self.node_id)
                add_only_new(self, self.upstream_task_ids, task.node_id)
                if edge_modifier:
                    edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id)
            else:
                add_only_new(self, self.downstream_task_ids, task.node_id)
                add_only_new(task, task.upstream_task_ids, self.node_id)
                if edge_modifier:
                    edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id)
[docs]    def set_downstream(
        self,
        task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
        edge_modifier: EdgeModifier | None = None,
    ) -> None:
        """Set a node (or nodes) to be directly downstream from the current node."""
        self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) 
[docs]    def set_upstream(
        self,
        task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
        edge_modifier: EdgeModifier | None = None,
    ) -> None:
        """Set a node (or nodes) to be directly upstream from the current node."""
        self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) 
    @property
[docs]    def downstream_list(self) -> Iterable[Operator]:
        """List of nodes directly downstream"""
        if not self.dag:
            raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
        return [self.dag.get_task(tid) for tid in self.downstream_task_ids] 
    @property
[docs]    def upstream_list(self) -> Iterable[Operator]:
        """List of nodes directly upstream"""
        if not self.dag:
            raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
        return [self.dag.get_task(tid) for tid in self.upstream_task_ids] 
[docs]    def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
        """
        Get set of the direct relative ids to the current task, upstream or
        downstream.
        """
        if upstream:
            return self.upstream_task_ids
        else:
            return self.downstream_task_ids 
[docs]    def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]:
        """
        Get list of the direct relatives to the current task, upstream or
        downstream.
        """
        if upstream:
            return self.upstream_list
        else:
            return self.downstream_list 
[docs]    def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
        """This is used by TaskGroupSerialization to serialize a task group's content."""
        raise NotImplementedError()