#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Iterable
from sqlalchemy import Column, Integer, String, Text, func
from sqlalchemy.orm.session import Session
from airflow.exceptions import AirflowException, PoolNotFound
from airflow.models.base import Base
from airflow.ti_deps.dependencies_states import EXECUTION_STATES
from airflow.typing_compat import TypedDict
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import nowait, with_row_locks
from airflow.utils.state import State
[docs]class PoolStats(TypedDict):
    """Dictionary containing Pool Stats"""
 
[docs]class Pool(Base):
    """the class to get Pool info."""
[docs]    __tablename__ = "slot_pool" 
[docs]    id = Column(Integer, primary_key=True) 
[docs]    pool = Column(String(256), unique=True) 
    # -1 for infinite
[docs]    slots = Column(Integer, default=0) 
[docs]    description = Column(Text) 
[docs]    DEFAULT_POOL_NAME = "default_pool" 
[docs]    def __repr__(self):
        return str(self.pool) 
    @staticmethod
    @provide_session
[docs]    def get_pools(session: Session = NEW_SESSION):
        """Get all pools."""
        return session.query(Pool).all() 
    @staticmethod
    @provide_session
[docs]    def get_pool(pool_name: str, session: Session = NEW_SESSION):
        """
        Get the Pool with specific pool name from the Pools.
        :param pool_name: The pool name of the Pool to get.
        :param session: SQLAlchemy ORM Session
        :return: the pool object
        """
        return session.query(Pool).filter(Pool.pool == pool_name).first() 
    @staticmethod
    @provide_session
[docs]    def get_default_pool(session: Session = NEW_SESSION):
        """
        Get the Pool of the default_pool from the Pools.
        :param session: SQLAlchemy ORM Session
        :return: the pool object
        """
        return Pool.get_pool(Pool.DEFAULT_POOL_NAME, session=session) 
    @staticmethod
    @provide_session
[docs]    def is_default_pool(id: int, session: Session = NEW_SESSION) -> bool:
        """
        Check id if is the default_pool.
        :param id: pool id
        :param session: SQLAlchemy ORM Session
        :return: True if id is default_pool, otherwise False
        """
        return (
            session.query(func.count(Pool.id))
            .filter(Pool.id == id, Pool.pool == Pool.DEFAULT_POOL_NAME)
            .scalar()
            > 0 
        )
    @staticmethod
    @provide_session
[docs]    def create_or_update_pool(name: str, slots: int, description: str, session: Session = NEW_SESSION):
        """Create a pool with given parameters or update it if it already exists."""
        if not name:
            return
        pool = session.query(Pool).filter_by(pool=name).first()
        if pool is None:
            pool = Pool(pool=name, slots=slots, description=description)
            session.add(pool)
        else:
            pool.slots = slots
            pool.description = description
        session.commit()
        return pool 
    @staticmethod
    @provide_session
[docs]    def delete_pool(name: str, session: Session = NEW_SESSION):
        """Delete pool by a given name."""
        if name == Pool.DEFAULT_POOL_NAME:
            raise AirflowException(f"{Pool.DEFAULT_POOL_NAME} cannot be deleted")
        pool = session.query(Pool).filter_by(pool=name).first()
        if pool is None:
            raise PoolNotFound(f"Pool '{name}' doesn't exist")
        session.delete(pool)
        session.commit()
        return pool 
    @staticmethod
    @provide_session
[docs]    def slots_stats(
        *,
        lock_rows: bool = False,
        session: Session = NEW_SESSION,
    ) -> dict[str, PoolStats]:
        """
        Get Pool stats (Number of Running, Queued, Open & Total tasks)
        If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a
        non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an
        OperationalError.
        :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns
        :param session: SQLAlchemy ORM Session
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import
        pools: dict[str, PoolStats] = {}
        query = session.query(Pool.pool, Pool.slots)
        if lock_rows:
            query = with_row_locks(query, session=session, **nowait(session))
        pool_rows: Iterable[tuple[str, int]] = query.all()
        for (pool_name, total_slots) in pool_rows:
            if total_slots == -1:
                total_slots = float("inf")  # type: ignore
            pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0)
        state_count_by_pool = (
            session.query(TaskInstance.pool, TaskInstance.state, func.sum(TaskInstance.pool_slots))
            .filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
            .group_by(TaskInstance.pool, TaskInstance.state)
        ).all()
        # calculate queued and running metrics
        for (pool_name, state, count) in state_count_by_pool:
            # Some databases return decimal.Decimal here.
            count = int(count)
            stats_dict: PoolStats | None = pools.get(pool_name)
            if not stats_dict:
                continue
            # TypedDict key must be a string literal, so we use if-statements to set value
            if state == "running":
                stats_dict["running"] = count
            elif state == "queued":
                stats_dict["queued"] = count
            else:
                raise AirflowException(f"Unexpected state. Expected values: {EXECUTION_STATES}.")
        # calculate open metric
        for pool_name, stats_dict in pools.items():
            stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"]
        return pools 
[docs]    def to_json(self):
        """
        Get the Pool in a json structure
        :return: the pool object in json format
        """
        return {
            "id": self.id,
            "pool": self.pool,
            "slots": self.slots,
            "description": self.description, 
        }
    @provide_session
[docs]    def occupied_slots(self, session: Session = NEW_SESSION):
        """
        Get the number of slots used by running/queued tasks at the moment.
        :param session: SQLAlchemy ORM Session
        :return: the used number of slots
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import
        return int(
            session.query(func.sum(TaskInstance.pool_slots))
            .filter(TaskInstance.pool == self.pool)
            .filter(TaskInstance.state.in_(list(EXECUTION_STATES)))
            .scalar()
            or 0 
        )
    @provide_session
[docs]    def running_slots(self, session: Session = NEW_SESSION):
        """
        Get the number of slots used by running tasks at the moment.
        :param session: SQLAlchemy ORM Session
        :return: the used number of slots
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import
        return int(
            session.query(func.sum(TaskInstance.pool_slots))
            .filter(TaskInstance.pool == self.pool)
            .filter(TaskInstance.state == State.RUNNING)
            .scalar()
            or 0 
        )
    @provide_session
[docs]    def queued_slots(self, session: Session = NEW_SESSION):
        """
        Get the number of slots used by queued tasks at the moment.
        :param session: SQLAlchemy ORM Session
        :return: the used number of slots
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import
        return int(
            session.query(func.sum(TaskInstance.pool_slots))
            .filter(TaskInstance.pool == self.pool)
            .filter(TaskInstance.state == State.QUEUED)
            .scalar()
            or 0 
        )
    @provide_session
[docs]    def scheduled_slots(self, session: Session = NEW_SESSION):
        """
        Get the number of slots scheduled at the moment.
        :param session: SQLAlchemy ORM Session
        :return: the number of scheduled slots
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import
        return int(
            session.query(func.sum(TaskInstance.pool_slots))
            .filter(TaskInstance.pool == self.pool)
            .filter(TaskInstance.state == State.SCHEDULED)
            .scalar()
            or 0 
        )
    @provide_session
[docs]    def open_slots(self, session: Session = NEW_SESSION) -> float:
        """
        Get the number of slots open at the moment.
        :param session: SQLAlchemy ORM Session
        :return: the number of slots
        """
        if self.slots == -1:
            return float("inf")
        else:
            return self.slots - self.occupied_slots(session)