Source code for airflow.hooks.postgres_hook

# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from contextlib import closing

from airflow.hooks.dbapi_hook import DbApiHook


[docs]class PostgresHook(DbApiHook): """ Interact with Postgres. You can specify ssl parameters in the extra field of your connection as ``{"sslmode": "require", "sslcert": "/path/to/cert.pem", etc}``. Also you can choose cursor as ``{"cursor": "dictcursor"}``. Refer to the psycopg2.extras for more details. Note: For Redshift, use keepalives_idle in the extra connection parameters and set it to less than 300 seconds. Note: For AWS IAM authentication, use iam in the extra connection parameters and set it to true. Leave the password field empty. This will use the "aws_default" connection to get the temporary token unless you override in extras. extras example: ``{"iam":true, "aws_conn_id":"my_aws_conn"}`` For Redshift, also use redshift in the extra connection parameters and set it to true. The cluster-identifier is extracted from the beginning of the host field, so is optional. It can however be overridden in the extra field. extras example: ``{"iam":true, "redshift":true, "cluster-identifier": "my_cluster_id"}`` """
[docs] conn_name_attr = 'postgres_conn_id'
[docs] default_conn_name = 'postgres_default'
[docs] supports_autocommit = True
def __init__(self, *args, **kwargs): super(PostgresHook, self).__init__(*args, **kwargs) self.schema = kwargs.pop("schema", None)
[docs] def _get_cursor(self, raw_cursor): _cursor = raw_cursor.lower() if _cursor == 'dictcursor': return psycopg2.extras.DictCursor if _cursor == 'realdictcursor': return psycopg2.extras.RealDictCursor if _cursor == 'namedtuplecursor': return psycopg2.extras.NamedTupleCursor raise ValueError('Invalid cursor passed {}'.format(_cursor))
[docs] def get_conn(self): conn_id = getattr(self, self.conn_name_attr) conn = self.get_connection(conn_id) # check for authentication via AWS IAM if conn.extra_dejson.get('iam', False): conn.login, conn.password, conn.port = self.get_iam_token(conn) conn_args = dict( host=conn.host, user=conn.login, password=conn.password, dbname=self.schema or conn.schema, port=conn.port) raw_cursor = conn.extra_dejson.get('cursor', False) if raw_cursor: conn_args['cursor_factory'] = self._get_cursor(raw_cursor) # check for ssl parameters in conn.extra for arg_name, arg_val in conn.extra_dejson.items(): if arg_name in ['sslmode', 'sslcert', 'sslkey', 'sslrootcert', 'sslcrl', 'application_name', 'keepalives_idle']: conn_args[arg_name] = arg_val self.conn = psycopg2.connect(**conn_args) return self.conn
[docs] def copy_expert(self, sql, filename, open=open): """ Executes SQL using psycopg2 copy_expert method. Necessary to execute COPY command without access to a superuser. Note: if this method is called with a "COPY FROM" statement and the specified input file does not exist, it creates an empty file and no data is loaded, but the operation succeeds. So if users want to be aware when the input file does not exist, they have to check its existence by themselves. """ if not os.path.isfile(filename): with open(filename, 'w'): pass with open(filename, 'r+') as f: with closing(self.get_conn()) as conn: with closing(conn.cursor()) as cur: cur.copy_expert(sql, f) f.truncate(f.tell()) conn.commit()
[docs] def bulk_load(self, table, tmp_file): """ Loads a tab-delimited file into a database table """ self.copy_expert("COPY {table} FROM STDIN".format(table=table), tmp_file)
[docs] def bulk_dump(self, table, tmp_file): """ Dumps a database table into a tab-delimited file """ self.copy_expert("COPY {table} TO STDOUT".format(table=table), tmp_file)
@staticmethod
[docs] def _serialize_cell(cell, conn): """ Postgresql will adapt all arguments to the execute() method internally, hence we return cell without any conversion. See http://initd.org/psycopg/docs/advanced.html#adapting-new-types for more information. :param cell: The cell to insert into the table :type cell: object :param conn: The database connection :type conn: connection object :return: The cell :rtype: object """ return cell
[docs] def get_iam_token(self, conn): """ Uses AWSHook to retrieve a temporary password to connect to Postgres or Redshift. Port is required. If none is provided, default is used for each service """ from airflow.contrib.hooks.aws_hook import AwsHook redshift = conn.extra_dejson.get('redshift', False) aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default') aws_hook = AwsHook(aws_conn_id) login = conn.login if conn.port is None: port = 5439 if redshift else 5432 else: port = conn.port if redshift: # Pull the custer-identifier from the beginning of the Redshift URL # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster cluster_identifier = conn.extra_dejson.get('cluster-identifier', conn.host.split('.')[0]) client = aws_hook.get_client_type('redshift') cluster_creds = client.get_cluster_credentials( DbUser=conn.login, DbName=self.schema or conn.schema, ClusterIdentifier=cluster_identifier, AutoCreate=False) token = cluster_creds['DbPassword'] login = cluster_creds['DbUser'] else: client = aws_hook.get_client_type('rds') token = client.generate_db_auth_token(conn.host, port, conn.login) return login, token, port

Was this entry helpful?