Source code for airflow.hooks.presto_hook
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from builtins import str
import prestodb
from prestodb.exceptions import DatabaseError
from prestodb.transaction import IsolationLevel
from airflow.hooks.dbapi_hook import DbApiHook
[docs]class PrestoException(Exception):
    pass 
[docs]class PrestoHook(DbApiHook):
    """
    Interact with Presto through PyHive!
    >>> ph = PrestoHook()
    >>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
    >>> ph.get_records(sql)
    [[340698]]
    """
[docs]    conn_name_attr = 'presto_conn_id' 
[docs]    default_conn_name = 'presto_default' 
[docs]    def get_conn(self):
        """Returns a connection object"""
        db = self.get_connection(self.presto_conn_id)
        auth = prestodb.auth.BasicAuthentication(db.login, db.password) if db.password else None
        return prestodb.dbapi.connect(
            host=db.host,
            port=db.port,
            user=db.login,
            source=db.extra_dejson.get('source', 'airflow'),
            http_scheme=db.extra_dejson.get('protocol', 'http'),
            catalog=db.extra_dejson.get('catalog', 'hive'),
            schema=db.schema,
            auth=auth,
            isolation_level=self.get_isolation_level() 
        )
[docs]    def get_isolation_level(self):
        """Returns an isolation level"""
        db = self.get_connection(self.presto_conn_id)
        isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper()
        return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT) 
    @staticmethod
[docs]    def _strip_sql(sql):
        return sql.strip().rstrip(';') 
    @staticmethod
[docs]    def _get_pretty_exception_message(e):
        """
        Parses some DatabaseError to provide a better error message
        """
        if (hasattr(e, 'message') and
            'errorName' in e.message and
                'message' in e.message):
            return ('{name}: {message}'.format(
                    name=e.message['errorName'],
                    message=e.message['message']))
        else:
            return str(e) 
[docs]    def get_records(self, hql, parameters=None):
        """
        Get a set of records from Presto
        """
        try:
            return super(PrestoHook, self).get_records(
                self._strip_sql(hql), parameters)
        except DatabaseError as e:
            raise PrestoException(self._get_pretty_exception_message(e)) 
[docs]    def get_first(self, hql, parameters=None):
        """
        Returns only the first row, regardless of how many rows the query
        returns.
        """
        try:
            return super(PrestoHook, self).get_first(
                self._strip_sql(hql), parameters)
        except DatabaseError as e:
            raise PrestoException(self._get_pretty_exception_message(e)) 
[docs]    def get_pandas_df(self, hql, parameters=None):
        """
        Get a pandas dataframe from a sql query.
        """
        import pandas
        cursor = self.get_cursor()
        try:
            cursor.execute(self._strip_sql(hql), parameters)
            data = cursor.fetchall()
        except DatabaseError as e:
            raise PrestoException(self._get_pretty_exception_message(e))
        column_descriptions = cursor.description
        if data:
            df = pandas.DataFrame(data)
            df.columns = [c[0] for c in column_descriptions]
        else:
            df = pandas.DataFrame()
        return df 
[docs]    def run(self, hql, parameters=None):
        """
        Execute the statement against Presto. Can be used to create views.
        """
        return super(PrestoHook, self).run(self._strip_sql(hql), parameters) 
[docs]    def insert_rows(self, table, rows, target_fields=None, commit_every=0):
        """
        A generic way to insert a set of tuples into a table.
        :param table: Name of the target table
        :type table: str
        :param rows: The rows to insert into the table
        :type rows: iterable of tuples
        :param target_fields: The names of the columns to fill in the table
        :type target_fields: iterable of strings
        :param commit_every: The maximum number of rows to insert in one
            transaction. Set to 0 to insert all rows in one transaction.
        :type commit_every: int
        """
        if self.get_isolation_level() == IsolationLevel.AUTOCOMMIT:
            self.log.info(
                'Transactions are not enable in presto connection. '
                'Please use the isolation_level property to enable it. '
                'Falling back to insert all rows in one transaction.'
            )
            commit_every = 0
        super(PrestoHook, self).insert_rows(table, rows, target_fields, commit_every)