Source code for airflow.models.connection

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import json
import warnings
from json import JSONDecodeError
from typing import Dict, Optional, Union
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlparse

from sqlalchemy import Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import reconstructor, synonym

from airflow.configuration import ensure_secrets_loaded
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
from airflow.providers_manager import ProvidersManager
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.secrets_masker import mask_secret
from airflow.utils.module_loading import import_string


[docs]def parse_netloc_to_hostname(*args, **kwargs): """This method is deprecated.""" warnings.warn("This method is deprecated.", DeprecationWarning) return _parse_netloc_to_hostname(*args, **kwargs)
# Python automatically converts all letters to lowercase in hostname # See: https://issues.apache.org/jira/browse/AIRFLOW-3615
[docs]def _parse_netloc_to_hostname(uri_parts): """Parse a URI string to get correct Hostname.""" hostname = unquote(uri_parts.hostname or '') if '/' in hostname: hostname = uri_parts.netloc if "@" in hostname: hostname = hostname.rsplit("@", 1)[1] if ":" in hostname: hostname = hostname.split(":", 1)[0] hostname = unquote(hostname) return hostname
[docs]class Connection(Base, LoggingMixin): """ Placeholder to store information about different database instances connection information. The idea here is that scripts use references to database instances (conn_id) instead of hard coding hostname, logins and passwords when using operators or hooks. .. seealso:: For more information on how to use this class, see: :doc:`/howto/connection` :param conn_id: The connection ID. :type conn_id: str :param conn_type: The connection type. :type conn_type: str :param description: The connection description. :type description: str :param host: The host. :type host: str :param login: The login. :type login: str :param password: The password. :type password: str :param schema: The schema. :type schema: str :param port: The port number. :type port: int :param extra: Extra metadata. Non-standard data such as private/SSH keys can be saved here. JSON encoded object. :type extra: str :param uri: URI address describing connection parameters. :type uri: str """
[docs] EXTRA_KEY = '__extra__'
[docs] __tablename__ = "connection"
[docs] id = Column(Integer(), primary_key=True)
[docs] conn_id = Column(String(ID_LEN), unique=True, nullable=False)
[docs] conn_type = Column(String(500), nullable=False)
[docs] description = Column(Text(5000))
[docs] host = Column(String(500))
[docs] schema = Column(String(500))
[docs] login = Column(String(500))
[docs] _password = Column('password', String(5000))
[docs] port = Column(Integer())
[docs] is_encrypted = Column(Boolean, unique=False, default=False)
[docs] is_extra_encrypted = Column(Boolean, unique=False, default=False)
[docs] _extra = Column('extra', Text())
def __init__( self, conn_id: Optional[str] = None, conn_type: Optional[str] = None, description: Optional[str] = None, host: Optional[str] = None, login: Optional[str] = None, password: Optional[str] = None, schema: Optional[str] = None, port: Optional[int] = None, extra: Optional[Union[str, dict]] = None, uri: Optional[str] = None, ): super().__init__() self.conn_id = conn_id self.description = description if extra and not isinstance(extra, str): extra = json.dumps(extra) if uri and (conn_type or host or login or password or schema or port or extra): raise AirflowException( "You must create an object using the URI or individual values " "(conn_type, host, login, password, schema, port or extra)." "You can't mix these two ways to create this object." ) if uri: self._parse_from_uri(uri) else: self.conn_type = conn_type self.host = host self.login = login self.password = password self.schema = schema self.port = port self.extra = extra if self.password: mask_secret(self.password) @reconstructor
[docs] def on_db_load(self): if self.password: mask_secret(self.password)
[docs] def parse_from_uri(self, **uri): """This method is deprecated. Please use uri parameter in constructor.""" warnings.warn( "This method is deprecated. Please use uri parameter in constructor.", DeprecationWarning ) self._parse_from_uri(**uri)
[docs] def _parse_from_uri(self, uri: str): uri_parts = urlparse(uri) conn_type = uri_parts.scheme if conn_type == 'postgresql': conn_type = 'postgres' elif '-' in conn_type: conn_type = conn_type.replace('-', '_') self.conn_type = conn_type self.host = _parse_netloc_to_hostname(uri_parts) quoted_schema = uri_parts.path[1:] self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema self.login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username self.password = unquote(uri_parts.password) if uri_parts.password else uri_parts.password self.port = uri_parts.port if uri_parts.query: query = dict(parse_qsl(uri_parts.query, keep_blank_values=True)) if self.EXTRA_KEY in query: self.extra = query[self.EXTRA_KEY] else: self.extra = json.dumps(query)
[docs] def get_uri(self) -> str: """Return connection in URI format""" uri = f"{str(self.conn_type).lower().replace('_', '-')}://" authority_block = '' if self.login is not None: authority_block += quote(self.login, safe='') if self.password is not None: authority_block += ':' + quote(self.password, safe='') if authority_block > '': authority_block += '@' uri += authority_block host_block = '' if self.host: host_block += quote(self.host, safe='') if self.port: if host_block > '': host_block += f':{self.port}' else: host_block += f'@:{self.port}' if self.schema: host_block += f"/{quote(self.schema, safe='')}" uri += host_block if self.extra: try: query = urlencode(self.extra_dejson) except TypeError: query = None if query and self.extra_dejson == dict(parse_qsl(query, keep_blank_values=True)): uri += '?' + query else: uri += '?' + urlencode({self.EXTRA_KEY: self.extra}) return uri
[docs] def get_password(self) -> Optional[str]: """Return encrypted password.""" if self._password and self.is_encrypted: fernet = get_fernet() if not fernet.is_encrypted: raise AirflowException( "Can't decrypt encrypted password for login={}, \ FERNET_KEY configuration is missing".format( self.login ) ) return fernet.decrypt(bytes(self._password, 'utf-8')).decode() else: return self._password
[docs] def set_password(self, value: Optional[str]): """Encrypt password and set in object attribute.""" if value: fernet = get_fernet() self._password = fernet.encrypt(bytes(value, 'utf-8')).decode() self.is_encrypted = fernet.is_encrypted
@declared_attr
[docs] def password(cls): """Password. The value is decrypted/encrypted when reading/setting the value.""" return synonym('_password', descriptor=property(cls.get_password, cls.set_password))
[docs] def get_extra(self) -> Dict: """Return encrypted extra-data.""" if self._extra and self.is_extra_encrypted: fernet = get_fernet() if not fernet.is_encrypted: raise AirflowException( "Can't decrypt `extra` params for login={},\ FERNET_KEY configuration is missing".format( self.login ) ) return fernet.decrypt(bytes(self._extra, 'utf-8')).decode() else: return self._extra
[docs] def set_extra(self, value: str): """Encrypt extra-data and save in object attribute to object.""" if value: fernet = get_fernet() self._extra = fernet.encrypt(bytes(value, 'utf-8')).decode() self.is_extra_encrypted = fernet.is_encrypted else: self._extra = value self.is_extra_encrypted = False
@declared_attr
[docs] def extra(cls): """Extra data. The value is decrypted/encrypted when reading/setting the value.""" return synonym('_extra', descriptor=property(cls.get_extra, cls.set_extra))
[docs] def rotate_fernet_key(self): """Encrypts data with a new key. See: :ref:`security/fernet`""" fernet = get_fernet() if self._password and self.is_encrypted: self._password = fernet.rotate(self._password.encode('utf-8')).decode() if self._extra and self.is_extra_encrypted: self._extra = fernet.rotate(self._extra.encode('utf-8')).decode()
[docs] def get_hook(self): """Return hook based on conn_type.""" hook_class_name, conn_id_param, package_name, hook_name = ProvidersManager().hooks.get( self.conn_type, (None, None, None, None) ) if not hook_class_name: raise AirflowException(f'Unknown hook type "{self.conn_type}"') try: hook_class = import_string(hook_class_name) except ImportError: warnings.warn( "Could not import %s when discovering %s %s", hook_class_name, hook_name, package_name ) raise return hook_class(**{conn_id_param: self.conn_id})
[docs] def __repr__(self): return self.conn_id
[docs] def log_info(self): """ This method is deprecated. You can read each field individually or use the default representation (`__repr__`). """ warnings.warn( "This method is deprecated. You can read each field individually or " "use the default representation (__repr__).", DeprecationWarning, stacklevel=2, ) return "id: {}. Host: {}, Port: {}, Schema: {}, Login: {}, Password: {}, extra: {}".format( self.conn_id, self.host, self.port, self.schema, self.login, "XXXXXXXX" if self.password else None, "XXXXXXXX" if self.extra_dejson else None,
)
[docs] def debug_info(self): """ This method is deprecated. You can read each field individually or use the default representation (`__repr__`). """ warnings.warn( "This method is deprecated. You can read each field individually or " "use the default representation (__repr__).", DeprecationWarning, stacklevel=2, ) return "id: {}. Host: {}, Port: {}, Schema: {}, Login: {}, Password: {}, extra: {}".format( self.conn_id, self.host, self.port, self.schema, self.login, "XXXXXXXX" if self.password else None, self.extra_dejson,
) @property
[docs] def extra_dejson(self) -> Dict: """Returns the extra property by deserializing json.""" obj = {} if self.extra: try: obj = json.loads(self.extra) except JSONDecodeError: self.log.exception("Failed parsing the json for conn_id %s", self.conn_id) # Mask sensitive keys from this list mask_secret(obj) return obj
@classmethod
[docs] def get_connection_from_secrets(cls, conn_id: str) -> 'Connection': """ Get connection by conn_id. :param conn_id: connection id :return: connection """ for secrets_backend in ensure_secrets_loaded(): conn = secrets_backend.get_connection(conn_id=conn_id) if conn: return conn raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined")

Was this entry helpful?