Source code for airflow.providers.postgres.hooks.postgres
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportosimportwarningsfromcontextlibimportclosingfromcopyimportdeepcopyfromtypingimportTYPE_CHECKING,Any,Iterable,Unionimportpsycopg2importpsycopg2.extensionsimportpsycopg2.extrasfromdeprecatedimportdeprecatedfrompsycopg2.extrasimportDictCursor,NamedTupleCursor,RealDictCursorfromsqlalchemy.engineimportURLfromairflow.exceptionsimportAirflowProviderDeprecationWarningfromairflow.providers.common.sql.hooks.sqlimportDbApiHookifTYPE_CHECKING:frompsycopg2.extensionsimportconnectionfromairflow.models.connectionimportConnectionfromairflow.providers.openlineage.sqlparserimportDatabaseInfo
[docs]classPostgresHook(DbApiHook):""" Interact with Postgres. You can specify ssl parameters in the extra field of your connection as ``{"sslmode": "require", "sslcert": "/path/to/cert.pem", etc}``. Also you can choose cursor as ``{"cursor": "dictcursor"}``. Refer to the psycopg2.extras for more details. Note: For Redshift, use keepalives_idle in the extra connection parameters and set it to less than 300 seconds. Note: For AWS IAM authentication, use iam in the extra connection parameters and set it to true. Leave the password field empty. This will use the "aws_default" connection to get the temporary token unless you override in extras. extras example: ``{"iam":true, "aws_conn_id":"my_aws_conn"}`` For Redshift, also use redshift in the extra connection parameters and set it to true. The cluster-identifier is extracted from the beginning of the host field, so is optional. It can however be overridden in the extra field. extras example: ``{"iam":true, "redshift":true, "cluster-identifier": "my_cluster_id"}`` :param postgres_conn_id: The :ref:`postgres conn id <howto/connection:postgres>` reference to a specific postgres database. :param options: Optional. Specifies command-line options to send to the server at connection start. For example, setting this to ``-c search_path=myschema`` sets the session's value of the ``search_path`` to ``myschema``. """
def__init__(self,*args,options:str|None=None,**kwargs)->None:if"schema"inkwargs:warnings.warn('The "schema" arg has been renamed to "database" as it contained the database name.''Please use "database" to set the database name.',AirflowProviderDeprecationWarning,stacklevel=2,)kwargs["database"]=kwargs["schema"]super().__init__(*args,**kwargs)self.connection:Connection|None=kwargs.pop("connection",None)self.conn:connection=Noneself.database:str|None=kwargs.pop("database",None)self.options=options@property@deprecated(reason=('The "schema" variable has been renamed to "database" as it contained the database name.''Please use "database" to get the database name.'),category=AirflowProviderDeprecationWarning,)
@schema.setter@deprecated(reason=('The "schema" variable has been renamed to "database" as it contained the database name.''Please use "database" to set the database name.'),category=AirflowProviderDeprecationWarning,)defschema(self,value):self.database=value@property
[docs]defget_conn(self)->connection:"""Establish a connection to a postgres database."""conn_id=self.get_conn_id()conn=deepcopy(self.connectionorself.get_connection(conn_id))# check for authentication via AWS IAMifconn.extra_dejson.get("iam",False):conn.login,conn.password,conn.port=self.get_iam_token(conn)conn_args={"host":conn.host,"user":conn.login,"password":conn.password,"dbname":self.databaseorconn.schema,"port":conn.port,}raw_cursor=conn.extra_dejson.get("cursor",False)ifraw_cursor:conn_args["cursor_factory"]=self._get_cursor(raw_cursor)ifself.options:conn_args["options"]=self.optionsforarg_name,arg_valinconn.extra_dejson.items():ifarg_namenotin["iam","redshift","cursor","cluster-identifier","aws_conn_id",]:conn_args[arg_name]=arg_valself.conn=psycopg2.connect(**conn_args)returnself.conn
[docs]defcopy_expert(self,sql:str,filename:str)->None:""" Execute SQL using psycopg2's ``copy_expert`` method. Necessary to execute COPY command without access to a superuser. Note: if this method is called with a "COPY FROM" statement and the specified input file does not exist, it creates an empty file and no data is loaded, but the operation succeeds. So if users want to be aware when the input file does not exist, they have to check its existence by themselves. """self.log.info("Running copy expert: %s, filename: %s",sql,filename)ifnotos.path.isfile(filename):withopen(filename,"w"):passwithopen(filename,"r+")asfile,closing(self.get_conn())asconn,closing(conn.cursor())ascur:cur.copy_expert(sql,file)file.truncate(file.tell())conn.commit()
[docs]defget_uri(self)->str:""" Extract the URI from the connection. :return: the extracted URI in Sqlalchemy URI format. """returnself.sqlalchemy_url.render_as_string(hide_password=False)
[docs]defbulk_load(self,table:str,tmp_file:str)->None:"""Load a tab-delimited file into a database table."""self.copy_expert(f"COPY {table} FROM STDIN",tmp_file)
[docs]defbulk_dump(self,table:str,tmp_file:str)->None:"""Dump a database table into a tab-delimited file."""self.copy_expert(f"COPY {table} TO STDOUT",tmp_file)
@staticmethoddef_serialize_cell(cell:object,conn:connection|None=None)->Any:""" Serialize a cell. PostgreSQL adapts all arguments to the ``execute()`` method internally, hence we return the cell without any conversion. See http://initd.org/psycopg/docs/advanced.html#adapting-new-types for more information. :param cell: The cell to insert into the table :param conn: The database connection :return: The cell """returncell
[docs]defget_iam_token(self,conn:Connection)->tuple[str,str,int]:""" Get the IAM token. This uses AWSHook to retrieve a temporary password to connect to Postgres or Redshift. Port is required. If none is provided, the default 5432 is used. """try:fromairflow.providers.amazon.aws.hooks.base_awsimportAwsBaseHookexceptImportError:fromairflow.exceptionsimportAirflowExceptionraiseAirflowException("apache-airflow-providers-amazon not installed, run: ""pip install 'apache-airflow-providers-postgres[amazon]'.")aws_conn_id=conn.extra_dejson.get("aws_conn_id","aws_default")login=conn.loginifconn.extra_dejson.get("redshift",False):port=conn.portor5439# Pull the custer-identifier from the beginning of the Redshift URL# ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-clustercluster_identifier=conn.extra_dejson.get("cluster-identifier",conn.host.split(".")[0])redshift_client=AwsBaseHook(aws_conn_id=aws_conn_id,client_type="redshift").conn# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentialscluster_creds=redshift_client.get_cluster_credentials(DbUser=login,DbName=self.databaseorconn.schema,ClusterIdentifier=cluster_identifier,AutoCreate=False,)token=cluster_creds["DbPassword"]login=cluster_creds["DbUser"]else:port=conn.portor5432rds_client=AwsBaseHook(aws_conn_id=aws_conn_id,client_type="rds").conn# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.generate_db_auth_tokentoken=rds_client.generate_db_auth_token(conn.host,port,conn.login)returnlogin,token,port
[docs]defget_table_primary_key(self,table:str,schema:str|None="public")->list[str]|None:""" Get the table's primary key. :param table: Name of the target table :param schema: Name of the target schema, public by default :return: Primary key columns list """sql=""" select kcu.column_name from information_schema.table_constraints tco join information_schema.key_column_usage kcu on kcu.constraint_name = tco.constraint_name and kcu.constraint_schema = tco.constraint_schema and kcu.constraint_name = tco.constraint_name where tco.constraint_type = 'PRIMARY KEY' and kcu.table_schema = %s and kcu.table_name = %s """pk_columns=[row[0]forrowinself.get_records(sql,(schema,table))]returnpk_columnsorNone
def_generate_insert_sql(self,table:str,values:tuple[str,...],target_fields:Iterable[str],replace:bool,**kwargs)->str:""" Generate the INSERT SQL statement. The REPLACE variant is specific to the PostgreSQL syntax. :param table: Name of the target table :param values: The row to insert into the table :param target_fields: The names of the columns to fill in the table :param replace: Whether to replace instead of insert :param replace_index: the column or list of column names to act as index for the ON CONFLICT clause :return: The generated INSERT or REPLACE SQL statement """placeholders=[self.placeholder,]*len(values)replace_index=kwargs.get("replace_index")iftarget_fields:target_fields_fragment=", ".join(target_fields)target_fields_fragment=f"({target_fields_fragment})"else:target_fields_fragment=""sql=f"INSERT INTO {table}{target_fields_fragment} VALUES ({','.join(placeholders)})"ifreplace:ifnottarget_fields:raiseValueError("PostgreSQL ON CONFLICT upsert syntax requires column names")ifnotreplace_index:raiseValueError("PostgreSQL ON CONFLICT upsert syntax requires an unique index")ifisinstance(replace_index,str):replace_index=[replace_index]on_conflict_str=f" ON CONFLICT ({', '.join(replace_index)})"replace_target=[fforfintarget_fieldsiffnotinreplace_index]ifreplace_target:replace_target_str=", ".join(f"{col} = excluded.{col}"forcolinreplace_target)sql+=f"{on_conflict_str} DO UPDATE SET {replace_target_str}"else:sql+=f"{on_conflict_str} DO NOTHING"returnsql
[docs]defget_openlineage_database_info(self,connection)->DatabaseInfo:"""Return Postgres/Redshift specific information for OpenLineage."""fromairflow.providers.openlineage.sqlparserimportDatabaseInfois_redshift=connection.extra_dejson.get("redshift",False)ifis_redshift:authority=self._get_openlineage_redshift_authority_part(connection)else:authority=DbApiHook.get_openlineage_authority_part(# type: ignore[attr-defined]connection,default_port=5432)returnDatabaseInfo(scheme="postgres"ifnotis_redshiftelse"redshift",authority=authority,database=self.databaseorconnection.schema,)
def_get_openlineage_redshift_authority_part(self,connection)->str:try:fromairflow.providers.amazon.aws.hooks.base_awsimportAwsBaseHookexceptImportError:fromairflow.exceptionsimportAirflowExceptionraiseAirflowException("apache-airflow-providers-amazon not installed, run: ""pip install 'apache-airflow-providers-postgres[amazon]'.")aws_conn_id=connection.extra_dejson.get("aws_conn_id","aws_default")port=connection.portor5439cluster_identifier=connection.extra_dejson.get("cluster-identifier",connection.host.split(".")[0])region_name=AwsBaseHook(aws_conn_id=aws_conn_id).region_namereturnf"{cluster_identifier}.{region_name}:{port}"
[docs]defget_openlineage_default_schema(self)->str|None:"""Return current schema. This is usually changed with ``SEARCH_PATH`` parameter."""returnself.get_first("SELECT CURRENT_SCHEMA;")[0]