# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from functools import wraps
import os
import contextlib
from airflow import settings
from airflow.utils.log.logging_mixin import LoggingMixin
log = LoggingMixin().log
@contextlib.contextmanager
def create_session():
"""
Contextmanager that will create and teardown a session.
"""
session = settings.Session()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
[docs]def provide_session(func):
"""
Function decorator that provides a session if it isn't provided.
If you want to reuse a session or run the function as part of a
database transaction, you pass it to the function, if not this wrapper
will create one and close it for you.
"""
@wraps(func)
def wrapper(*args, **kwargs):
arg_session = 'session'
func_params = func.__code__.co_varnames
session_in_args = arg_session in func_params and \
func_params.index(arg_session) < len(args)
session_in_kwargs = arg_session in kwargs
if session_in_kwargs or session_in_args:
return func(*args, **kwargs)
else:
with create_session() as session:
kwargs[arg_session] = session
return func(*args, **kwargs)
return wrapper
@provide_session
def merge_conn(conn, session=None):
from airflow.models.connection import Connection
if not session.query(Connection).filter(Connection.conn_id == conn.conn_id).first():
session.add(conn)
session.commit()
def initdb(rbac=False):
session = settings.Session()
from airflow import models
from airflow.models.connection import Connection
upgradedb()
merge_conn(
Connection(
conn_id='airflow_db', conn_type='mysql',
host='mysql', login='root', password='',
schema='airflow'))
merge_conn(
Connection(
conn_id='beeline_default', conn_type='beeline', port=10000,
host='localhost', extra="{\"use_beeline\": true, \"auth\": \"\"}",
schema='default'))
merge_conn(
Connection(
conn_id='bigquery_default', conn_type='google_cloud_platform',
schema='default'))
merge_conn(
Connection(
conn_id='local_mysql', conn_type='mysql',
host='localhost', login='airflow', password='airflow',
schema='airflow'))
merge_conn(
Connection(
conn_id='presto_default', conn_type='presto',
host='localhost',
schema='hive', port=3400))
merge_conn(
Connection(
conn_id='google_cloud_default', conn_type='google_cloud_platform',
schema='default',))
merge_conn(
Connection(
conn_id='hive_cli_default', conn_type='hive_cli',
schema='default',))
merge_conn(
Connection(
conn_id='hiveserver2_default', conn_type='hiveserver2',
host='localhost',
schema='default', port=10000))
merge_conn(
Connection(
conn_id='metastore_default', conn_type='hive_metastore',
host='localhost', extra="{\"authMechanism\": \"PLAIN\"}",
port=9083))
merge_conn(
Connection(
conn_id='mongo_default', conn_type='mongo',
host='mongo', port=27017))
merge_conn(
Connection(
conn_id='mysql_default', conn_type='mysql',
login='root',
schema='airflow',
host='mysql'))
merge_conn(
Connection(
conn_id='postgres_default', conn_type='postgres',
login='postgres',
password='airflow',
schema='airflow',
host='postgres'))
merge_conn(
Connection(
conn_id='sqlite_default', conn_type='sqlite',
host='/tmp/sqlite_default.db'))
merge_conn(
Connection(
conn_id='http_default', conn_type='http',
host='https://www.google.com/'))
merge_conn(
Connection(
conn_id='mssql_default', conn_type='mssql',
host='localhost', port=1433))
merge_conn(
Connection(
conn_id='vertica_default', conn_type='vertica',
host='localhost', port=5433))
merge_conn(
Connection(
conn_id='wasb_default', conn_type='wasb',
extra='{"sas_token": null}'))
merge_conn(
Connection(
conn_id='webhdfs_default', conn_type='hdfs',
host='localhost', port=50070))
merge_conn(
Connection(
conn_id='ssh_default', conn_type='ssh',
host='localhost'))
merge_conn(
Connection(
conn_id='sftp_default', conn_type='sftp',
host='localhost', port=22, login='airflow',
extra='''
{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}
'''))
merge_conn(
Connection(
conn_id='fs_default', conn_type='fs',
extra='{"path": "/"}'))
merge_conn(
Connection(
conn_id='aws_default', conn_type='aws',
extra='{"region_name": "us-east-1"}'))
merge_conn(
Connection(
conn_id='spark_default', conn_type='spark',
host='yarn', extra='{"queue": "root.default"}'))
merge_conn(
Connection(
conn_id='druid_broker_default', conn_type='druid',
host='druid-broker', port=8082, extra='{"endpoint": "druid/v2/sql"}'))
merge_conn(
Connection(
conn_id='druid_ingest_default', conn_type='druid',
host='druid-overlord', port=8081, extra='{"endpoint": "druid/indexer/v1/task"}'))
merge_conn(
Connection(
conn_id='redis_default', conn_type='redis',
host='redis', port=6379,
extra='{"db": 0}'))
merge_conn(
Connection(
conn_id='sqoop_default', conn_type='sqoop',
host='rmdbs', extra=''))
merge_conn(
Connection(
conn_id='emr_default', conn_type='emr',
extra='''
{ "Name": "default_job_flow_name",
"LogUri": "s3://my-emr-log-bucket/default_job_flow_location",
"ReleaseLabel": "emr-4.6.0",
"Instances": {
"Ec2KeyName": "mykey",
"Ec2SubnetId": "somesubnet",
"InstanceGroups": [
{
"Name": "Master nodes",
"Market": "ON_DEMAND",
"InstanceRole": "MASTER",
"InstanceType": "r3.2xlarge",
"InstanceCount": 1
},
{
"Name": "Slave nodes",
"Market": "ON_DEMAND",
"InstanceRole": "CORE",
"InstanceType": "r3.2xlarge",
"InstanceCount": 1
}
],
"TerminationProtected": false,
"KeepJobFlowAliveWhenNoSteps": false
},
"Applications":[
{ "Name": "Spark" }
],
"VisibleToAllUsers": true,
"JobFlowRole": "EMR_EC2_DefaultRole",
"ServiceRole": "EMR_DefaultRole",
"Tags": [
{
"Key": "app",
"Value": "analytics"
},
{
"Key": "environment",
"Value": "development"
}
]
}
'''))
merge_conn(
Connection(
conn_id='databricks_default', conn_type='databricks',
host='localhost'))
merge_conn(
Connection(
conn_id='qubole_default', conn_type='qubole',
host='localhost'))
merge_conn(
Connection(
conn_id='segment_default', conn_type='segment',
extra='{"write_key": "my-segment-write-key"}')),
merge_conn(
Connection(
conn_id='azure_data_lake_default', conn_type='azure_data_lake',
extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }'))
merge_conn(
Connection(
conn_id='azure_cosmos_default', conn_type='azure_cosmos',
extra='{"database_name": "<DATABASE_NAME>", "collection_name": "<COLLECTION_NAME>" }'))
merge_conn(
Connection(
conn_id='azure_container_instances_default', conn_type='azure_container_instances',
extra='{"tenantId": "<TENANT>", "subscriptionId": "<SUBSCRIPTION ID>" }'))
merge_conn(
Connection(
conn_id='cassandra_default', conn_type='cassandra',
host='cassandra', port=9042))
merge_conn(
Connection(
conn_id='dingding_default', conn_type='http',
host='', password=''))
merge_conn(
Connection(
conn_id='opsgenie_default', conn_type='http',
host='', password=''))
# Known event types
KET = models.KnownEventType
if not session.query(KET).filter(KET.know_event_type == 'Holiday').first():
session.add(KET(know_event_type='Holiday'))
if not session.query(KET).filter(KET.know_event_type == 'Outage').first():
session.add(KET(know_event_type='Outage'))
if not session.query(KET).filter(
KET.know_event_type == 'Natural Disaster').first():
session.add(KET(know_event_type='Natural Disaster'))
if not session.query(KET).filter(
KET.know_event_type == 'Marketing Campaign').first():
session.add(KET(know_event_type='Marketing Campaign'))
session.commit()
dagbag = models.DagBag()
# Save individual DAGs in the ORM
for dag in dagbag.dags.values():
dag.sync_to_db()
# Deactivate the unknown ones
models.DAG.deactivate_unknown_dags(dagbag.dags.keys())
Chart = models.Chart
chart_label = "Airflow task instance by type"
chart = session.query(Chart).filter(Chart.label == chart_label).first()
if not chart:
chart = Chart(
label=chart_label,
conn_id='airflow_db',
chart_type='bar',
x_is_date=False,
sql=(
"SELECT state, COUNT(1) as number "
"FROM task_instance "
"WHERE dag_id LIKE 'example%' "
"GROUP BY state"),
)
session.add(chart)
session.commit()
if rbac:
from flask_appbuilder.security.sqla import models
from flask_appbuilder.models.sqla import Base
Base.metadata.create_all(settings.engine)
def upgradedb():
# alembic adds significant import time, so we import it lazily
from alembic import command
from alembic.config import Config
log.info("Creating tables")
current_dir = os.path.dirname(os.path.abspath(__file__))
package_dir = os.path.normpath(os.path.join(current_dir, '..'))
directory = os.path.join(package_dir, 'migrations')
config = Config(os.path.join(package_dir, 'alembic.ini'))
config.set_main_option('script_location', directory.replace('%', '%%'))
config.set_main_option('sqlalchemy.url', settings.SQL_ALCHEMY_CONN.replace('%', '%%'))
command.upgrade(config, 'heads')
def resetdb(rbac):
"""
Clear out the database
"""
from airflow import models
# alembic adds significant import time, so we import it lazily
from alembic.migration import MigrationContext
log.info("Dropping tables that exist")
models.base.Base.metadata.drop_all(settings.engine)
mc = MigrationContext.configure(settings.engine)
if mc._version.exists(settings.engine):
mc._version.drop(settings.engine)
if rbac:
# drop rbac security tables
from flask_appbuilder.security.sqla import models
from flask_appbuilder.models.sqla import Base
Base.metadata.drop_all(settings.engine)
initdb(rbac)