Source code for airflow.providers.openlineage.utils.sql
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportloggingfromcollectionsimportdefaultdictfromcontextlibimportclosingfromenumimportIntEnumfromtypingimportTYPE_CHECKING,Dict,List,Optionalfromattrsimportdefinefromopenlineage.client.facetimportSchemaDatasetFacet,SchemaFieldfromopenlineage.client.runimportDatasetfromsqlalchemyimportColumn,MetaData,Table,and_,union_allifTYPE_CHECKING:fromsqlalchemy.engineimportEnginefromsqlalchemy.sqlimportClauseElementfromairflow.hooks.baseimportBaseHook
[docs]defto_dataset(self,namespace:str,database:str|None=None,schema:str|None=None)->Dataset:# Prefix the table name with database and schema name using# the format: {database_name}.{table_schema}.{table_name}.name=".".join(partforpartin[self.databaseordatabase,self.schemaorschema,self.table]ifpartisnotNone)returnDataset(namespace=namespace,name=name,facets={"schema":SchemaDatasetFacet(fields=self.fields)}iflen(self.fields)>0else{},)
[docs]defget_table_schemas(hook:BaseHook,namespace:str,schema:str|None,database:str|None,in_query:str|None,out_query:str|None,)->tuple[list[Dataset],list[Dataset]]:"""Query database for table schemas. Uses provided hook. Responsibility to provide queries for this function is on particular extractors. If query for input or output table isn't provided, the query is skipped. """# Do not query if we did not get both queriesifnotin_queryandnotout_query:return[],[]withclosing(hook.get_conn())asconn,closing(conn.cursor())ascursor:ifin_query:cursor.execute(in_query)in_datasets=[x.to_dataset(namespace,database,schema)forxinparse_query_result(cursor)]else:in_datasets=[]ifout_query:cursor.execute(out_query)out_datasets=[x.to_dataset(namespace,database,schema)forxinparse_query_result(cursor)]else:out_datasets=[]returnin_datasets,out_datasets
[docs]defparse_query_result(cursor)->list[TableSchema]:"""Fetch results from DB-API 2.0 cursor and creates list of table schemas. For each row it creates :class:`TableSchema`. """schemas:dict={}columns:dict=defaultdict(list)forrowincursor.fetchall():table_schema_name:str=row[ColumnIndex.SCHEMA]table_name:str=row[ColumnIndex.TABLE_NAME]table_column:SchemaField=SchemaField(name=row[ColumnIndex.COLUMN_NAME],type=row[ColumnIndex.UDT_NAME],description=None,)ordinal_position=row[ColumnIndex.ORDINAL_POSITION]try:table_database=row[ColumnIndex.DATABASE]exceptIndexError:table_database=None# Attempt to get table schematable_key=".".join(filter(None,[table_database,table_schema_name,table_name]))schemas[table_key]=TableSchema(table=table_name,schema=table_schema_name,database=table_database,fields=[])columns[table_key].append((ordinal_position,table_column))forschemainschemas.values():table_key=".".join(filter(None,[schema.database,schema.schema,schema.table]))schema.fields=[xfor_,xinsorted(columns[table_key])]returnlist(schemas.values())
[docs]defcreate_information_schema_query(columns:list[str],information_schema_table_name:str,tables_hierarchy:TablesHierarchy,uppercase_names:bool=False,sqlalchemy_engine:Engine|None=None,)->str:"""Creates query for getting table schemas from information schema."""metadata=MetaData(sqlalchemy_engine)select_statements=[]fordb,schema_mappingintables_hierarchy.items():schema,table_name=information_schema_table_name.split(".")ifdb:schema=f"{db}.{schema}"information_schema_table=Table(table_name,metadata,*[Column(column)forcolumnincolumns],schema=schema)filter_clauses=create_filter_clauses(schema_mapping,information_schema_table,uppercase_names)select_statements.append(information_schema_table.select().filter(*filter_clauses))returnstr(union_all(*select_statements).compile(sqlalchemy_engine,compile_kwargs={"literal_binds":True}))
[docs]defcreate_filter_clauses(schema_mapping:dict,information_schema_table:Table,uppercase_names:bool=False)->ClauseElement:""" Creates comprehensive filter clauses for all tables in one database. :param schema_mapping: a dictionary of schema names and list of tables in each :param information_schema_table: `sqlalchemy.Table` instance used to construct clauses For most SQL dbs it contains `table_name` and `table_schema` columns, therefore it is expected the table has them defined. :param uppercase_names: if True use schema and table names uppercase """filter_clauses=[]forschema,tablesinschema_mapping.items():filter_clause=information_schema_table.c.table_name.in_(name.upper()ifuppercase_nameselsenamefornameintables)ifschema:filter_clause=and_(information_schema_table.c.table_schema==schema,filter_clause)filter_clauses.append(filter_clause)returnfilter_clauses