# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import datetime
# This product contains a modified portion of 'Flask App Builder' developed by Daniel Vaz Gaspar.
# (https://github.com/dpgaspar/Flask-AppBuilder).
# Copyright 2013, Daniel Vaz Gaspar
from typing import TYPE_CHECKING
import packaging.version
from flask import current_app, g
from flask_appbuilder.models.sqla import Model
from sqlalchemy import (
    Boolean,
    Column,
    DateTime,
    ForeignKey,
    Index,
    Integer,
    MetaData,
    String,
    Table,
    UniqueConstraint,
    event,
    func,
    select,
)
from sqlalchemy.orm import backref, declared_attr, registry, relationship
from airflow import __version__ as airflow_version
from airflow.auth.managers.models.base_user import BaseUser
from airflow.models.base import _get_schema, naming_convention
if TYPE_CHECKING:
    try:
        from sqlalchemy import Identity
    except Exception:
"""
Compatibility note: The models in this file are duplicated from Flask AppBuilder.
"""
[docs]mapper_registry = registry(metadata=metadata) 
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) >= packaging.version.parse(
    "3.0.0"
):
    Model.metadata = metadata
else:
    from airflow.models.base import Base
    Model.metadata = Base.metadata
[docs]class Action(Model):
    """Represents permission actions such as `can_read`."""
[docs]    __tablename__ = "ab_permission" 
[docs]    id = Column(Integer, primary_key=True) 
[docs]    name = Column(String(100), unique=True, nullable=False) 
[docs]    def __repr__(self):
        return self.name  
[docs]class Resource(Model):
    """Represents permission object such as `User` or `Dag`."""
[docs]    __tablename__ = "ab_view_menu" 
[docs]    id = Column(Integer, primary_key=True) 
[docs]    name = Column(String(250), unique=True, nullable=False) 
[docs]    def __eq__(self, other):
        return (isinstance(other, self.__class__)) and (self.name == other.name) 
[docs]    def __neq__(self, other):
        return self.name != other.name 
[docs]    def __repr__(self):
        return self.name  
[docs]assoc_permission_role = Table(
    "ab_permission_view_role",
    Model.metadata,
    Column("id", Integer, primary_key=True),
    Column("permission_view_id", Integer, ForeignKey("ab_permission_view.id")),
    Column("role_id", Integer, ForeignKey("ab_role.id")),
    UniqueConstraint("permission_view_id", "role_id"),
) 
[docs]class Role(Model):
    """Represents a user role to which permissions can be assigned."""
[docs]    __tablename__ = "ab_role" 
[docs]    id = Column(Integer, primary_key=True) 
[docs]    name = Column(String(64), unique=True, nullable=False) 
[docs]    permissions = relationship("Permission", secondary=assoc_permission_role, backref="role", lazy="joined") 
[docs]    def __repr__(self):
        return self.name  
[docs]class Permission(Model):
    """Permission pair comprised of an Action + Resource combo."""
[docs]    __tablename__ = "ab_permission_view" 
[docs]    __table_args__ = (UniqueConstraint("permission_id", "view_menu_id"),) 
[docs]    id = Column(Integer, primary_key=True) 
[docs]    action_id = Column("permission_id", Integer, ForeignKey("ab_permission.id")) 
[docs]    action = relationship(
        "Action",
        uselist=False,
        lazy="joined",
    ) 
[docs]    resource_id = Column("view_menu_id", Integer, ForeignKey("ab_view_menu.id")) 
[docs]    resource = relationship(
        "Resource",
        uselist=False,
        lazy="joined",
    ) 
[docs]    def __repr__(self):
        return str(self.action).replace("_", " ") + " on " + str(self.resource)  
[docs]assoc_user_role = Table(
    "ab_user_role",
    Model.metadata,
    Column("id", Integer, primary_key=True),
    Column("user_id", Integer, ForeignKey("ab_user.id")),
    Column("role_id", Integer, ForeignKey("ab_role.id")),
    UniqueConstraint("user_id", "role_id"),
) 
[docs]class User(Model, BaseUser):
    """Represents an Airflow user which has roles assigned to it."""
[docs]    __tablename__ = "ab_user" 
[docs]    id = Column(Integer, primary_key=True) 
[docs]    first_name = Column(String(256), nullable=False) 
[docs]    last_name = Column(String(256), nullable=False) 
[docs]    username = Column(
        String(512).with_variant(String(512, collation="NOCASE"), "sqlite"), unique=True, nullable=False
    ) 
[docs]    password = Column(String(256)) 
[docs]    active = Column(Boolean, default=True) 
[docs]    email = Column(String(512), unique=True, nullable=False) 
[docs]    last_login = Column(DateTime) 
[docs]    login_count = Column(Integer) 
[docs]    fail_login_count = Column(Integer) 
[docs]    roles = relationship("Role", secondary=assoc_user_role, backref="user", lazy="selectin") 
[docs]    created_on = Column(DateTime, default=datetime.datetime.now, nullable=True) 
[docs]    changed_on = Column(DateTime, default=datetime.datetime.now, nullable=True) 
    @declared_attr
[docs]    def created_by_fk(self):
        return Column(Integer, ForeignKey("ab_user.id"), default=self.get_user_id, nullable=True) 
    @declared_attr
[docs]    def changed_by_fk(self):
        return Column(Integer, ForeignKey("ab_user.id"), default=self.get_user_id, nullable=True) 
[docs]    created_by = relationship(
        "User",
        backref=backref("created", uselist=True),
        remote_side=[id],
        primaryjoin="User.created_by_fk == User.id",
        uselist=False,
    ) 
[docs]    changed_by = relationship(
        "User",
        backref=backref("changed", uselist=True),
        remote_side=[id],
        primaryjoin="User.changed_by_fk == User.id",
        uselist=False,
    ) 
    @classmethod
[docs]    def get_user_id(cls):
        try:
            return g.user.get_id()
        except Exception:
            return None 
    @property
[docs]    def is_authenticated(self):
        return True 
    @property
[docs]    def is_active(self):
        return self.active 
    @property
[docs]    def is_anonymous(self):
        return False 
    @property
[docs]    def perms(self):
        if not self._perms:
            # Using the ORM here is _slow_ (Creating lots of objects to then throw them away) since this is in
            # the path for every request. Avoid it if we can!
            if current_app:
                sm = current_app.appbuilder.sm
                self._perms: set[tuple[str, str]] = set(
                    sm.get_session.execute(
                        select(sm.action_model.name, sm.resource_model.name)
                        .join(sm.permission_model.action)
                        .join(sm.permission_model.resource)
                        .join(sm.permission_model.role)
                        .where(sm.role_model.user.contains(self))
                    )
                )
            else:
                self._perms = {
                    (perm.action.name, perm.resource.name) for role in self.roles for perm in role.permissions
                }
        return self._perms 
[docs]    def get_id(self):
        return self.id 
[docs]    def get_name(self) -> str:
        return self.username or self.email or self.user_id 
[docs]    def get_full_name(self):
        return f"{self.first_name} {self.last_name}" 
[docs]    def __repr__(self):
        return self.get_full_name() 
    _perms = None 
[docs]class RegisterUser(Model):
    """Represents a user registration."""
[docs]    __tablename__ = "ab_register_user" 
[docs]    id = Column(Integer, primary_key=True) 
[docs]    first_name = Column(String(256), nullable=False) 
[docs]    last_name = Column(String(256), nullable=False) 
[docs]    username = Column(
        String(512).with_variant(String(512, collation="NOCASE"), "sqlite"), unique=True, nullable=False
    ) 
[docs]    password = Column(String(256)) 
[docs]    email = Column(String(512), nullable=False) 
[docs]    registration_date = Column(DateTime, default=datetime.datetime.now, nullable=True) 
[docs]    registration_hash = Column(String(256))  
@event.listens_for(User.__table__, "before_create")
[docs]def add_index_on_ab_user_username_postgres(table, conn, **kw):
    if conn.dialect.name != "postgresql":
        return
    index_name = "idx_ab_user_username"
    if not any(table_index.name == index_name for table_index in table.indexes):
        table.indexes.add(Index(index_name, func.lower(table.c.username), unique=True)) 
@event.listens_for(RegisterUser.__table__, "before_create")
[docs]def add_index_on_ab_register_user_username_postgres(table, conn, **kw):
    if conn.dialect.name != "postgresql":
        return
    index_name = "idx_ab_register_user_username"
    if not any(table_index.name == index_name for table_index in table.indexes):
        table.indexes.add(Index(index_name, func.lower(table.c.username), unique=True))