# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from builtins import bytes
from urllib.parse import urlparse, unquote, parse_qsl
from sqlalchemy import Column, Integer, String, Boolean
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import synonym
from airflow.exceptions import AirflowException
from airflow import LoggingMixin
from airflow.models import get_fernet
from airflow.models.base import Base, ID_LEN
# Python automatically converts all letters to lowercase in hostname
# See: https://issues.apache.org/jira/browse/AIRFLOW-3615
[docs]def parse_netloc_to_hostname(uri_parts):
hostname = unquote(uri_parts.hostname or '')
if '/' in hostname:
hostname = uri_parts.netloc
if "@" in hostname:
hostname = hostname.rsplit("@", 1)[1]
if ":" in hostname:
hostname = hostname.split(":", 1)[0]
hostname = unquote(hostname)
return hostname
[docs]class Connection(Base, LoggingMixin):
"""
Placeholder to store information about different database instances
connection information. The idea here is that scripts use references to
database instances (conn_id) instead of hard coding hostname, logins and
passwords when using operators or hooks.
"""
[docs] __tablename__ = "connection"
[docs] id = Column(Integer(), primary_key=True)
[docs] conn_id = Column(String(ID_LEN))
[docs] conn_type = Column(String(500))
[docs] host = Column(String(500))
[docs] schema = Column(String(500))
[docs] login = Column(String(500))
[docs] _password = Column('password', String(5000))
[docs] port = Column(Integer())
[docs] is_encrypted = Column(Boolean, unique=False, default=False)
[docs] _types = [
('docker', 'Docker Registry',),
('fs', 'File (path)'),
('ftp', 'FTP',),
('google_cloud_platform', 'Google Cloud Platform'),
('hdfs', 'HDFS',),
('http', 'HTTP',),
('hive_cli', 'Hive Client Wrapper',),
('hive_metastore', 'Hive Metastore Thrift',),
('hiveserver2', 'Hive Server 2 Thrift',),
('jdbc', 'Jdbc Connection',),
('jenkins', 'Jenkins'),
('mysql', 'MySQL',),
('postgres', 'Postgres',),
('oracle', 'Oracle',),
('vertica', 'Vertica',),
('presto', 'Presto',),
('s3', 'S3',),
('samba', 'Samba',),
('sqlite', 'Sqlite',),
('ssh', 'SSH',),
('cloudant', 'IBM Cloudant',),
('mssql', 'Microsoft SQL Server'),
('mesos_framework-id', 'Mesos Framework ID'),
('jira', 'JIRA',),
('redis', 'Redis',),
('wasb', 'Azure Blob Storage'),
('databricks', 'Databricks',),
('aws', 'Amazon Web Services',),
('emr', 'Elastic MapReduce',),
('snowflake', 'Snowflake',),
('segment', 'Segment',),
('azure_data_lake', 'Azure Data Lake'),
('azure_container_instances', 'Azure Container Instances'),
('azure_cosmos', 'Azure CosmosDB'),
('cassandra', 'Cassandra',),
('qubole', 'Qubole'),
('mongo', 'MongoDB'),
('gcpcloudsql', 'Google Cloud SQL'),
]
def __init__(
self, conn_id=None, conn_type=None,
host=None, login=None, password=None,
schema=None, port=None, extra=None,
uri=None):
self.conn_id = conn_id
if uri:
self.parse_from_uri(uri)
else:
self.conn_type = conn_type
self.host = host
self.login = login
self.password = password
self.schema = schema
self.port = port
self.extra = extra
[docs] def parse_from_uri(self, uri):
uri_parts = urlparse(uri)
conn_type = uri_parts.scheme
if conn_type == 'postgresql':
conn_type = 'postgres'
elif '-' in conn_type:
conn_type = conn_type.replace('-', '_')
self.conn_type = conn_type
self.host = parse_netloc_to_hostname(uri_parts)
quoted_schema = uri_parts.path[1:]
self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema
self.login = unquote(uri_parts.username) \
if uri_parts.username else uri_parts.username
self.password = unquote(uri_parts.password) \
if uri_parts.password else uri_parts.password
self.port = uri_parts.port
if uri_parts.query:
self.extra = json.dumps(dict(parse_qsl(uri_parts.query)))
[docs] def get_password(self):
if self._password and self.is_encrypted:
fernet = get_fernet()
if not fernet.is_encrypted:
raise AirflowException(
"Can't decrypt encrypted password for login={}, \
FERNET_KEY configuration is missing".format(self.login))
return fernet.decrypt(bytes(self._password, 'utf-8')).decode()
else:
return self._password
[docs] def set_password(self, value):
if value:
fernet = get_fernet()
self._password = fernet.encrypt(bytes(value, 'utf-8')).decode()
self.is_encrypted = fernet.is_encrypted
@declared_attr
[docs] def password(cls):
return synonym('_password',
descriptor=property(cls.get_password, cls.set_password))
@declared_attr
[docs] def rotate_fernet_key(self):
fernet = get_fernet()
if self._password and self.is_encrypted:
self._password = fernet.rotate(self._password.encode('utf-8')).decode()
if self._extra and self.is_extra_encrypted:
self._extra = fernet.rotate(self._extra.encode('utf-8')).decode()
[docs] def get_hook(self):
try:
if self.conn_type == 'mysql':
from airflow.hooks.mysql_hook import MySqlHook
return MySqlHook(mysql_conn_id=self.conn_id)
elif self.conn_type == 'google_cloud_platform':
from airflow.contrib.hooks.bigquery_hook import BigQueryHook
return BigQueryHook(bigquery_conn_id=self.conn_id)
elif self.conn_type == 'postgres':
from airflow.hooks.postgres_hook import PostgresHook
return PostgresHook(postgres_conn_id=self.conn_id)
elif self.conn_type == 'hive_cli':
from airflow.hooks.hive_hooks import HiveCliHook
return HiveCliHook(hive_cli_conn_id=self.conn_id)
elif self.conn_type == 'presto':
from airflow.hooks.presto_hook import PrestoHook
return PrestoHook(presto_conn_id=self.conn_id)
elif self.conn_type == 'hiveserver2':
from airflow.hooks.hive_hooks import HiveServer2Hook
return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
elif self.conn_type == 'sqlite':
from airflow.hooks.sqlite_hook import SqliteHook
return SqliteHook(sqlite_conn_id=self.conn_id)
elif self.conn_type == 'jdbc':
from airflow.hooks.jdbc_hook import JdbcHook
return JdbcHook(jdbc_conn_id=self.conn_id)
elif self.conn_type == 'mssql':
from airflow.hooks.mssql_hook import MsSqlHook
return MsSqlHook(mssql_conn_id=self.conn_id)
elif self.conn_type == 'oracle':
from airflow.hooks.oracle_hook import OracleHook
return OracleHook(oracle_conn_id=self.conn_id)
elif self.conn_type == 'vertica':
from airflow.contrib.hooks.vertica_hook import VerticaHook
return VerticaHook(vertica_conn_id=self.conn_id)
elif self.conn_type == 'cloudant':
from airflow.contrib.hooks.cloudant_hook import CloudantHook
return CloudantHook(cloudant_conn_id=self.conn_id)
elif self.conn_type == 'jira':
from airflow.contrib.hooks.jira_hook import JiraHook
return JiraHook(jira_conn_id=self.conn_id)
elif self.conn_type == 'redis':
from airflow.contrib.hooks.redis_hook import RedisHook
return RedisHook(redis_conn_id=self.conn_id)
elif self.conn_type == 'wasb':
from airflow.contrib.hooks.wasb_hook import WasbHook
return WasbHook(wasb_conn_id=self.conn_id)
elif self.conn_type == 'docker':
from airflow.hooks.docker_hook import DockerHook
return DockerHook(docker_conn_id=self.conn_id)
elif self.conn_type == 'azure_data_lake':
from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook
return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
elif self.conn_type == 'azure_cosmos':
from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook
return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
elif self.conn_type == 'cassandra':
from airflow.contrib.hooks.cassandra_hook import CassandraHook
return CassandraHook(cassandra_conn_id=self.conn_id)
elif self.conn_type == 'mongo':
from airflow.contrib.hooks.mongo_hook import MongoHook
return MongoHook(conn_id=self.conn_id)
elif self.conn_type == 'gcpcloudsql':
from airflow.contrib.hooks.gcp_sql_hook import CloudSqlDatabaseHook
return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
except Exception:
pass
[docs] def __repr__(self):
return self.conn_id
[docs] def debug_info(self):
return ("id: {}. Host: {}, Port: {}, Schema: {}, "
"Login: {}, Password: {}, extra: {}".
format(self.conn_id,
self.host,
self.port,
self.schema,
self.login,
"XXXXXXXX" if self.password else None,
self.extra_dejson))
@property