Source code for airflow.providers.fab.auth_manager.models

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import datetime

# This product contains a modified portion of 'Flask App Builder' developed by Daniel Vaz Gaspar.
# (https://github.com/dpgaspar/Flask-AppBuilder).
# Copyright 2013, Daniel Vaz Gaspar
from typing import TYPE_CHECKING

from flask import current_app, g
from flask_appbuilder.models.sqla import Model
from sqlalchemy import (
    Boolean,
    Column,
    DateTime,
    ForeignKey,
    Index,
    Integer,
    String,
    Table,
    UniqueConstraint,
    event,
    func,
    select,
)
from sqlalchemy.orm import backref, declared_attr, relationship

from airflow.auth.managers.models.base_user import BaseUser
from airflow.models.base import Base

"""
Compatibility note: The models in this file are duplicated from Flask AppBuilder.
"""
# Use airflow metadata to create the tables
Model.metadata = Base.metadata

if TYPE_CHECKING:
    try:
        from sqlalchemy import Identity
    except Exception:
[docs] Identity = None
[docs]class Action(Model): """Represents permission actions such as `can_read`."""
[docs] __tablename__ = "ab_permission"
[docs] id = Column(Integer, primary_key=True)
[docs] name = Column(String(100), unique=True, nullable=False)
[docs] def __repr__(self): return self.name
[docs]class Resource(Model): """Represents permission object such as `User` or `Dag`."""
[docs] __tablename__ = "ab_view_menu"
[docs] id = Column(Integer, primary_key=True)
[docs] name = Column(String(250), unique=True, nullable=False)
[docs] def __eq__(self, other): return (isinstance(other, self.__class__)) and (self.name == other.name)
[docs] def __neq__(self, other): return self.name != other.name
[docs] def __repr__(self): return self.name
[docs]assoc_permission_role = Table( "ab_permission_view_role", Model.metadata, Column("id", Integer, primary_key=True), Column("permission_view_id", Integer, ForeignKey("ab_permission_view.id")), Column("role_id", Integer, ForeignKey("ab_role.id")), UniqueConstraint("permission_view_id", "role_id"), )
[docs]class Role(Model): """Represents a user role to which permissions can be assigned."""
[docs] __tablename__ = "ab_role"
[docs] id = Column(Integer, primary_key=True)
[docs] name = Column(String(64), unique=True, nullable=False)
[docs] permissions = relationship("Permission", secondary=assoc_permission_role, backref="role", lazy="joined")
[docs] def __repr__(self): return self.name
[docs]class Permission(Model): """Permission pair comprised of an Action + Resource combo."""
[docs] __tablename__ = "ab_permission_view"
[docs] __table_args__ = (UniqueConstraint("permission_id", "view_menu_id"),)
[docs] id = Column(Integer, primary_key=True)
[docs] action_id = Column("permission_id", Integer, ForeignKey("ab_permission.id"))
[docs] action = relationship( "Action", uselist=False, lazy="joined", )
[docs] resource_id = Column("view_menu_id", Integer, ForeignKey("ab_view_menu.id"))
[docs] resource = relationship( "Resource", uselist=False, lazy="joined", )
[docs] def __repr__(self): return str(self.action).replace("_", " ") + " on " + str(self.resource)
[docs]assoc_user_role = Table( "ab_user_role", Model.metadata, Column("id", Integer, primary_key=True), Column("user_id", Integer, ForeignKey("ab_user.id")), Column("role_id", Integer, ForeignKey("ab_role.id")), UniqueConstraint("user_id", "role_id"), )
[docs]class User(Model, BaseUser): """Represents an Airflow user which has roles assigned to it."""
[docs] __tablename__ = "ab_user"
[docs] id = Column(Integer, primary_key=True)
[docs] first_name = Column(String(256), nullable=False)
[docs] last_name = Column(String(256), nullable=False)
[docs] username = Column( String(512).with_variant(String(512, collation="NOCASE"), "sqlite"), unique=True, nullable=False )
[docs] password = Column(String(256))
[docs] active = Column(Boolean, default=True)
[docs] email = Column(String(512), unique=True, nullable=False)
[docs] last_login = Column(DateTime)
[docs] login_count = Column(Integer)
[docs] fail_login_count = Column(Integer)
[docs] roles = relationship("Role", secondary=assoc_user_role, backref="user", lazy="selectin")
[docs] created_on = Column(DateTime, default=datetime.datetime.now, nullable=True)
[docs] changed_on = Column(DateTime, default=datetime.datetime.now, nullable=True)
@declared_attr
[docs] def created_by_fk(self): return Column(Integer, ForeignKey("ab_user.id"), default=self.get_user_id, nullable=True)
@declared_attr
[docs] def changed_by_fk(self): return Column(Integer, ForeignKey("ab_user.id"), default=self.get_user_id, nullable=True)
[docs] created_by = relationship( "User", backref=backref("created", uselist=True), remote_side=[id], primaryjoin="User.created_by_fk == User.id", uselist=False, )
[docs] changed_by = relationship( "User", backref=backref("changed", uselist=True), remote_side=[id], primaryjoin="User.changed_by_fk == User.id", uselist=False, )
@classmethod
[docs] def get_user_id(cls): try: return g.user.get_id() except Exception: return None
@property
[docs] def is_authenticated(self): return True
@property
[docs] def is_active(self): return self.active
@property
[docs] def is_anonymous(self): return False
@property
[docs] def perms(self): if not self._perms: # Using the ORM here is _slow_ (Creating lots of objects to then throw them away) since this is in # the path for every request. Avoid it if we can! if current_app: sm = current_app.appbuilder.sm self._perms: set[tuple[str, str]] = set( sm.get_session.execute( select(sm.action_model.name, sm.resource_model.name) .join(sm.permission_model.action) .join(sm.permission_model.resource) .join(sm.permission_model.role) .where(sm.role_model.user.contains(self)) ) ) else: self._perms = { (perm.action.name, perm.resource.name) for role in self.roles for perm in role.permissions } return self._perms
[docs] def get_id(self): return self.id
[docs] def get_name(self) -> str: return self.username or self.email or self.user_id
[docs] def get_full_name(self): return f"{self.first_name} {self.last_name}"
[docs] def __repr__(self): return self.get_full_name()
_perms = None
[docs]class RegisterUser(Model): """Represents a user registration."""
[docs] __tablename__ = "ab_register_user"
[docs] id = Column(Integer, primary_key=True)
[docs] first_name = Column(String(256), nullable=False)
[docs] last_name = Column(String(256), nullable=False)
[docs] username = Column( String(512).with_variant(String(512, collation="NOCASE"), "sqlite"), unique=True, nullable=False )
[docs] password = Column(String(256))
[docs] email = Column(String(512), nullable=False)
[docs] registration_date = Column(DateTime, default=datetime.datetime.now, nullable=True)
[docs] registration_hash = Column(String(256))
@event.listens_for(User.__table__, "before_create")
[docs]def add_index_on_ab_user_username_postgres(table, conn, **kw): if conn.dialect.name != "postgresql": return index_name = "idx_ab_user_username" if not any(table_index.name == index_name for table_index in table.indexes): table.indexes.add(Index(index_name, func.lower(table.c.username), unique=True))
@event.listens_for(RegisterUser.__table__, "before_create")
[docs]def add_index_on_ab_register_user_username_postgres(table, conn, **kw): if conn.dialect.name != "postgresql": return index_name = "idx_ab_register_user_username" if not any(table_index.name == index_name for table_index in table.indexes): table.indexes.add(Index(index_name, func.lower(table.c.username), unique=True))

Was this entry helpful?