# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Sequence
import ydb
from sqlalchemy.engine import URL
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.ydb.hooks._vendor.dbapi.connection import Connection as DbApiConnection
from airflow.providers.ydb.hooks._vendor.dbapi.cursor import YdbQuery
from airflow.providers.ydb.utils.credentials import get_credentials_from_connection
from airflow.providers.ydb.utils.defaults import CONN_NAME_ATTR, CONN_TYPE, DEFAULT_CONN_NAME
[docs]DEFAULT_YDB_GRPCS_PORT: int = 2135 
if TYPE_CHECKING:
    from airflow.models.connection import Connection
    from airflow.providers.ydb.hooks._vendor.dbapi.cursor import Cursor as DbApiCursor
[docs]class YDBCursor:
    """YDB cursor wrapper."""
    def __init__(self, delegatee: DbApiCursor, is_ddl: bool):
        self.delegatee: DbApiCursor = delegatee
        self.is_ddl: bool = is_ddl
[docs]    def execute(self, sql: str, parameters: Mapping[str, Any] | None = None):
        if parameters is not None:
            raise AirflowException("parameters is not supported yet")
        q = YdbQuery(yql_text=sql, is_ddl=self.is_ddl)
        return self.delegatee.execute(q, parameters) 
[docs]    def executemany(self, sql: str, seq_of_parameters: Sequence[Mapping[str, Any]]):
        for parameters in seq_of_parameters:
            self.execute(sql, parameters) 
[docs]    def executescript(self, script):
        return self.execute(script) 
[docs]    def fetchone(self):
        return self.delegatee.fetchone() 
[docs]    def fetchmany(self, size=None):
        return self.delegatee.fetchmany(size=size) 
[docs]    def fetchall(self):
        return self.delegatee.fetchall() 
[docs]    def nextset(self):
        return self.delegatee.nextset() 
[docs]    def setoutputsize(self, column=None):
        return self.delegatee.setoutputsize(column) 
[docs]    def __enter__(self):
        return self 
[docs]    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close() 
[docs]    def close(self):
        return self.delegatee.close() 
    @property
[docs]    def rowcount(self):
        return self.delegatee.rowcount 
    @property
[docs]    def description(self):
        return self.delegatee.description  
[docs]class YDBConnection:
    """YDB connection wrapper."""
    def __init__(self, endpoint: str, database: str, credentials: Any, is_ddl: bool = False):
        self.is_ddl = is_ddl
        driver_config = ydb.DriverConfig(
            endpoint=endpoint,
            database=database,
            table_client_settings=YDBConnection._get_table_client_settings(),
            credentials=credentials,
        )
        driver = ydb.Driver(driver_config)
        # wait until driver become initialized
        driver.wait(fail_fast=True, timeout=10)
        ydb_session_pool = ydb.SessionPool(driver, size=5)
        self.delegatee: DbApiConnection = DbApiConnection(ydb_session_pool=ydb_session_pool)
[docs]    def cursor(self) -> YDBCursor:
        return YDBCursor(self.delegatee.cursor(), is_ddl=self.is_ddl) 
[docs]    def begin(self) -> None:
        self.delegatee.begin() 
[docs]    def commit(self) -> None:
        self.delegatee.commit() 
[docs]    def rollback(self) -> None:
        self.delegatee.rollback() 
[docs]    def __enter__(self):
        return self 
[docs]    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.close() 
[docs]    def close(self) -> None:
        self.delegatee.close() 
    @staticmethod
    def _get_table_client_settings() -> ydb.TableClientSettings:
        return (
            ydb.TableClientSettings()
            .with_native_date_in_result_sets(True)
            .with_native_datetime_in_result_sets(True)
            .with_native_timestamp_in_result_sets(True)
            .with_native_interval_in_result_sets(True)
            .with_native_json_in_result_sets(False)
        ) 
[docs]class YDBHook(DbApiHook):
    """Interact with YDB."""
[docs]    conn_name_attr: str = CONN_NAME_ATTR 
[docs]    default_conn_name: str = DEFAULT_CONN_NAME 
[docs]    conn_type: str = CONN_TYPE 
[docs]    supports_autocommit: bool = True 
[docs]    supports_executemany: bool = True 
    def __init__(self, *args, is_ddl: bool = False, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.is_ddl = is_ddl
    @classmethod
    @classmethod
[docs]    def get_ui_field_behaviour(cls) -> dict[str, Any]:
        """Return custom UI field behaviour for YDB connection."""
        return {
            "hidden_fields": ["schema", "extra"],
            "relabeling": {},
            "placeholders": {
                "host": "eg. grpcs://my_host or ydb.serverless.yandexcloud.net or lb.etn9txxxx.ydb.mdb.yandexcloud.net",
                "login": "root",
                "password": "my_password",
                "database": "e.g. /local or /ru-central1/b1gtl2kg13him37quoo6/etndqstq7ne4v68n6c9b",
                "service_account_json": 'e.g. {"id": "...", "service_account_id": "...", "private_key": "..."}',
                "token": "t1.9....AAQ",
            },
        } 
    @property
[docs]    def sqlalchemy_url(self) -> URL:
        conn: Connection = self.get_connection(getattr(self, self.conn_name_attr))
        connection_extra: dict[str, Any] = conn.extra_dejson
        database: str | None = connection_extra.get("database")
        return URL.create(
            drivername="ydb",
            username=conn.login,
            password=conn.password,
            host=conn.host,
            port=conn.port,
            query={"database": database},
        ) 
[docs]    def get_conn(self) -> YDBConnection:
        """Establish a connection to a YDB database."""
        conn: Connection = self.get_connection(getattr(self, self.conn_name_attr))
        host: str | None = conn.host
        if not host:
            raise ValueError("YDB host must be specified")
        port: int = conn.port or DEFAULT_YDB_GRPCS_PORT
        connection_extra: dict[str, Any] = conn.extra_dejson
        database: str | None = connection_extra.get("database")
        if not database:
            raise ValueError("YDB database must be specified")
        endpoint = f"{host}:{port}"
        credentials = get_credentials_from_connection(
            endpoint=endpoint, database=database, connection=conn, connection_extra=connection_extra
        )
        return YDBConnection(
            endpoint=endpoint, database=database, credentials=credentials, is_ddl=self.is_ddl
        ) 
    @staticmethod
    def _serialize_cell(cell: object, conn: YDBConnection | None = None) -> Any:
        return cell