# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import cx_Oracle
from airflow.hooks.dbapi_hook import DbApiHook
from builtins import str
from past.builtins import basestring
from datetime import datetime
import numpy
[docs]class OracleHook(DbApiHook):
    """
    Interact with Oracle SQL.
    """
[docs]    conn_name_attr = 'oracle_conn_id' 
[docs]    default_conn_name = 'oracle_default' 
[docs]    supports_autocommit = False 
[docs]    def get_conn(self):
        """
        Returns a oracle connection object
        Optional parameters for using a custom DSN connection
        (instead of using a server alias from tnsnames.ora)
        The dsn (data source name) is the TNS entry
        (from the Oracle names server or tnsnames.ora file)
        or is a string like the one returned from makedsn().
        :param dsn: the host address for the Oracle server
        :param service_name: the db_unique_name of the database
              that you are connecting to (CONNECT_DATA part of TNS)
        You can set these parameters in the extra fields of your connection
        as in ``{ "dsn":"some.host.address" , "service_name":"some.service.name" }``
        see more param detail in
        `cx_Oracle.connect <https://cx-oracle.readthedocs.io/en/latest/module.html#cx_Oracle.connect>`_
        """
        conn = self.get_connection(self.oracle_conn_id)
        conn_config = {
            'user': conn.login,
            'password': conn.password
        }
        dsn = conn.extra_dejson.get('dsn', None)
        sid = conn.extra_dejson.get('sid', None)
        mod = conn.extra_dejson.get('module', None)
        service_name = conn.extra_dejson.get('service_name', None)
        port = conn.port if conn.port else 1521
        if dsn and sid and not service_name:
            conn_config['dsn'] = cx_Oracle.makedsn(dsn, port, sid)
        elif dsn and service_name and not sid:
            conn_config['dsn'] = cx_Oracle.makedsn(dsn, port, service_name=service_name)
        else:
            conn_config['dsn'] = conn.host
        if 'encoding' in conn.extra_dejson:
            conn_config['encoding'] = conn.extra_dejson.get('encoding')
            # if `encoding` is specific but `nencoding` is not
            # `nencoding` should use same values as `encoding` to set encoding, inspired by
            # https://github.com/oracle/python-cx_Oracle/issues/157#issuecomment-371877993
            if 'nencoding' not in conn.extra_dejson:
                conn_config['nencoding'] = conn.extra_dejson.get('encoding')
        if 'nencoding' in conn.extra_dejson:
            conn_config['nencoding'] = conn.extra_dejson.get('nencoding')
        if 'threaded' in conn.extra_dejson:
            conn_config['threaded'] = conn.extra_dejson.get('threaded')
        if 'events' in conn.extra_dejson:
            conn_config['events'] = conn.extra_dejson.get('events')
        mode = conn.extra_dejson.get('mode', '').lower()
        if mode == 'sysdba':
            conn_config['mode'] = cx_Oracle.SYSDBA
        elif mode == 'sysasm':
            conn_config['mode'] = cx_Oracle.SYSASM
        elif mode == 'sysoper':
            conn_config['mode'] = cx_Oracle.SYSOPER
        elif mode == 'sysbkp':
            conn_config['mode'] = cx_Oracle.SYSBKP
        elif mode == 'sysdgd':
            conn_config['mode'] = cx_Oracle.SYSDGD
        elif mode == 'syskmt':
            conn_config['mode'] = cx_Oracle.SYSKMT
        elif mode == 'sysrac':
            conn_config['mode'] = cx_Oracle.SYSRAC
        purity = conn.extra_dejson.get('purity', '').lower()
        if purity == 'new':
            conn_config['purity'] = cx_Oracle.ATTR_PURITY_NEW
        elif purity == 'self':
            conn_config['purity'] = cx_Oracle.ATTR_PURITY_SELF
        elif purity == 'default':
            conn_config['purity'] = cx_Oracle.ATTR_PURITY_DEFAULT
        conn = cx_Oracle.connect(**conn_config)
        if mod is not None:
            conn.module = mod
        return conn 
[docs]    def insert_rows(self, table, rows, target_fields=None, commit_every=1000):
        """
        A generic way to insert a set of tuples into a table,
        the whole set of inserts is treated as one transaction
        Changes from standard DbApiHook implementation:
        - Oracle SQL queries in cx_Oracle can not be terminated with a semicolon (`;`)
        - Replace NaN values with NULL using `numpy.nan_to_num` (not using
          `is_nan()` because of input types error for strings)
        - Coerce datetime cells to Oracle DATETIME format during insert
        :param table: target Oracle table, use dot notation to target a
            specific database
        :type table: str
        :param rows: the rows to insert into the table
        :type rows: iterable of tuples
        :param target_fields: the names of the columns to fill in the table
        :type target_fields: iterable of str
        :param commit_every: the maximum number of rows to insert in one transaction
            Default 1000, Set greater than 0.
            Set 1 to insert each row in each single transaction
        :type commit_every: int
        """
        if target_fields:
            target_fields = ', '.join(target_fields)
            target_fields = '({})'.format(target_fields)
        else:
            target_fields = ''
        conn = self.get_conn()
        cur = conn.cursor()
        if self.supports_autocommit:
            cur.execute('SET autocommit = 0')
        conn.commit()
        i = 0
        for row in rows:
            i += 1
            lst = []
            for cell in row:
                if isinstance(cell, basestring):
                    lst.append("'" + str(cell).replace("'", "''") + "'")
                elif cell is None:
                    lst.append('NULL')
                elif type(cell) == float and \
                        
numpy.isnan(cell):  # coerce numpy NaN to NULL
                    lst.append('NULL')
                elif isinstance(cell, numpy.datetime64):
                    lst.append("'" + str(cell) + "'")
                elif isinstance(cell, datetime):
                    lst.append("to_date('" +
                               cell.strftime('%Y-%m-%d %H:%M:%S') +
                               "','YYYY-MM-DD HH24:MI:SS')")
                else:
                    lst.append(str(cell))
            values = tuple(lst)
            sql = 'INSERT /*+ APPEND */ ' \
                  
'INTO {0} {1} VALUES ({2})'.format(table,
                                                     target_fields,
                                                     ','.join(values))
            cur.execute(sql)
            if i % commit_every == 0:
                conn.commit()
                self.log.info('Loaded %s into %s rows so far', i, table)
        conn.commit()
        cur.close()
        conn.close()
        self.log.info('Done loading. Loaded a total of %s rows', i) 
[docs]    def bulk_insert_rows(self, table, rows, target_fields=None, commit_every=5000):
        """
        A performant bulk insert for cx_Oracle
        that uses prepared statements via `executemany()`.
        For best performance, pass in `rows` as an iterator.
        :param table: target Oracle table, use dot notation to target a
            specific database
        :type table: str
        :param rows: the rows to insert into the table
        :type rows: iterable of tuples
        :param target_fields: the names of the columns to fill in the table, default None.
            If None, each rows should have some order as table columns name
        :type target_fields: iterable of str Or None
        :param commit_every: the maximum number of rows to insert in one transaction
            Default 5000. Set greater than 0. Set 1 to insert each row in each transaction
        :type commit_every: int
        """
        if not rows:
            raise ValueError("parameter rows could not be None or empty iterable")
        conn = self.get_conn()
        cursor = conn.cursor()
        values_base = target_fields if target_fields else rows[0]
        prepared_stm = 'insert into {tablename} {columns} values ({values})'.format(
            tablename=table,
            columns='({})'.format(', '.join(target_fields)) if target_fields else '',
            values=', '.join(':%s' % i for i in range(1, len(values_base) + 1)),
        )
        row_count = 0
        # Chunk the rows
        row_chunk = []
        for row in rows:
            row_chunk.append(row)
            row_count += 1
            if row_count % commit_every == 0:
                cursor.prepare(prepared_stm)
                cursor.executemany(None, row_chunk)
                conn.commit()
                self.log.info('[%s] inserted %s rows', table, row_count)
                # Empty chunk
                row_chunk = []
        # Commit the leftover chunk
        cursor.prepare(prepared_stm)
        cursor.executemany(None, row_chunk)
        conn.commit()
        self.log.info('[%s] inserted %s rows', table, row_count)
        cursor.close()
        conn.close()