Source code for airflow.providers.ydb.hooks._vendor.dbapi.connection

import collections.abc
import posixpath
from typing import Any, List, NamedTuple, Optional

import sqlalchemy.util as util
import ydb

from .cursor import AsyncCursor, Cursor
from .errors import InterfaceError, InternalError, NotSupportedError


[docs]class IsolationLevel:
[docs] SERIALIZABLE = "SERIALIZABLE"
[docs] ONLINE_READONLY = "ONLINE READONLY"
[docs] ONLINE_READONLY_INCONSISTENT = "ONLINE READONLY INCONSISTENT"
[docs] STALE_READONLY = "STALE READONLY"
[docs] SNAPSHOT_READONLY = "SNAPSHOT READONLY"
[docs] AUTOCOMMIT = "AUTOCOMMIT"
[docs]class Connection: _await = staticmethod(util.await_only) _is_async = False _ydb_driver_class = ydb.Driver _ydb_session_pool_class = ydb.SessionPool _ydb_table_client_class = ydb.TableClient _cursor_class = Cursor def __init__( self, host: str = "", port: str = "", database: str = "", **conn_kwargs: Any, ): self.endpoint = f"grpc://{host}:{port}" self.database = database self.conn_kwargs = conn_kwargs self.credentials = self.conn_kwargs.pop("credentials", None) self.table_path_prefix = self.conn_kwargs.pop("ydb_table_path_prefix", "") if "ydb_session_pool" in self.conn_kwargs: # Use session pool managed manually self._shared_session_pool = True self.session_pool: ydb.SessionPool = self.conn_kwargs.pop("ydb_session_pool") self.driver = ( self.session_pool._driver if hasattr(self.session_pool, "_driver") else self.session_pool._pool_impl._driver ) self.driver.table_client = self._ydb_table_client_class(self.driver, self._get_table_client_settings()) else: self._shared_session_pool = False self.driver = self._create_driver() self.session_pool = self._ydb_session_pool_class(self.driver, size=5) self.interactive_transaction: bool = False # AUTOCOMMIT self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite() self.tx_context: Optional[ydb.TxContext] = None self.use_scan_query: bool = False
[docs] def cursor(self): return self._cursor_class( self.driver, self.session_pool, self.tx_mode, self.tx_context, self.use_scan_query, self.table_path_prefix )
[docs] def describe(self, table_path: str) -> ydb.TableDescription: abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path) cursor = self.cursor() return cursor.describe_table(abs_table_path)
[docs] def check_exists(self, table_path: str) -> ydb.SchemeEntry: abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path) cursor = self.cursor() return cursor.check_exists(abs_table_path)
[docs] def get_table_names(self) -> List[str]: abs_dir_path = posixpath.join(self.database, self.table_path_prefix) cursor = self.cursor() return [posixpath.relpath(path, abs_dir_path) for path in cursor.get_table_names(abs_dir_path)]
[docs] def set_isolation_level(self, isolation_level: str): class IsolationSettings(NamedTuple): ydb_mode: ydb.AbstractTransactionModeBuilder interactive: bool ydb_isolation_settings_map = { IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.SerializableReadWrite(), interactive=False), IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.SerializableReadWrite(), interactive=True), IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.OnlineReadOnly(), interactive=False), IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings( ydb.OnlineReadOnly().with_allow_inconsistent_reads(), interactive=False ), IsolationLevel.STALE_READONLY: IsolationSettings(ydb.StaleReadOnly(), interactive=False), IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly(), interactive=True), } ydb_isolation_settings = ydb_isolation_settings_map[isolation_level] if self.tx_context and self.tx_context.tx_id: raise InternalError("Failed to set transaction mode: transaction is already began") self.tx_mode = ydb_isolation_settings.ydb_mode self.interactive_transaction = ydb_isolation_settings.interactive
[docs] def get_isolation_level(self) -> str: if self.tx_mode.name == ydb.SerializableReadWrite().name: if self.interactive_transaction: return IsolationLevel.SERIALIZABLE else: return IsolationLevel.AUTOCOMMIT elif self.tx_mode.name == ydb.OnlineReadOnly().name: if self.tx_mode.settings.allow_inconsistent_reads: return IsolationLevel.ONLINE_READONLY_INCONSISTENT else: return IsolationLevel.ONLINE_READONLY elif self.tx_mode.name == ydb.StaleReadOnly().name: return IsolationLevel.STALE_READONLY elif self.tx_mode.name == ydb.SnapshotReadOnly().name: return IsolationLevel.SNAPSHOT_READONLY else: raise NotSupportedError(f"{self.tx_mode.name} is not supported")
[docs] def set_ydb_scan_query(self, value: bool) -> None: self.use_scan_query = value
[docs] def get_ydb_scan_query(self) -> bool: return self.use_scan_query
[docs] def begin(self): self.tx_context = None if self.interactive_transaction and not self.use_scan_query: session = self._maybe_await(self.session_pool.acquire) self.tx_context = session.transaction(self.tx_mode) self._maybe_await(self.tx_context.begin)
[docs] def commit(self): if self.tx_context and self.tx_context.tx_id: self._maybe_await(self.tx_context.commit) self._maybe_await(self.session_pool.release, self.tx_context.session) self.tx_context = None
[docs] def rollback(self): if self.tx_context and self.tx_context.tx_id: self._maybe_await(self.tx_context.rollback) self._maybe_await(self.session_pool.release, self.tx_context.session) self.tx_context = None
[docs] def close(self): self.rollback() if not self._shared_session_pool: self._maybe_await(self.session_pool.stop) self._stop_driver()
@classmethod def _maybe_await(cls, callee: collections.abc.Callable, *args, **kwargs) -> Any: if cls._is_async: return cls._await(callee(*args, **kwargs)) return callee(*args, **kwargs) def _get_table_client_settings(self) -> ydb.TableClientSettings: return ( ydb.TableClientSettings() .with_native_date_in_result_sets(True) .with_native_datetime_in_result_sets(True) .with_native_timestamp_in_result_sets(True) .with_native_interval_in_result_sets(True) .with_native_json_in_result_sets(False) ) def _create_driver(self): driver_config = ydb.DriverConfig( endpoint=self.endpoint, database=self.database, table_client_settings=self._get_table_client_settings(), credentials=self.credentials, ) driver = self._ydb_driver_class(driver_config) try: self._maybe_await(driver.wait, timeout=5, fail_fast=True) except ydb.Error as e: raise InterfaceError(e.message, original_error=e) from e except Exception as e: self._maybe_await(driver.stop) raise InterfaceError(f"Failed to connect to YDB, details {driver.discovery_debug_details()}") from e return driver def _stop_driver(self): self._maybe_await(self.driver.stop)
[docs]class AsyncConnection(Connection): _is_async = True _ydb_driver_class = ydb.aio.Driver _ydb_session_pool_class = ydb.aio.SessionPool _ydb_table_client_class = ydb.aio.table.TableClient _cursor_class = AsyncCursor

Was this entry helpful?