Source code for airflow.providers.common.sql.operators.generic_transfer
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsfromcollections.abcimportSequencefromfunctoolsimportcached_propertyfromtypingimportTYPE_CHECKING,Anyfromairflow.exceptionsimportAirflowExceptionfromairflow.hooks.baseimportBaseHookfromairflow.modelsimportBaseOperatorfromairflow.providers.common.sql.hooks.sqlimportDbApiHookfromairflow.providers.common.sql.triggers.sqlimportSQLExecuteQueryTriggerifTYPE_CHECKING:importjinja2try:fromairflow.sdk.definitions.contextimportContextexceptImportError:# TODO: Remove once provider drops support for Airflow 2fromairflow.utils.contextimportContext
[docs]classGenericTransfer(BaseOperator):""" Moves data from a connection to another. Assuming that they both provide the required methods in their respective hooks. The source hook needs to expose a `get_records` method, and the destination a `insert_rows` method. This is meant to be used on small-ish datasets that fit in memory. :param sql: SQL query to execute against the source database. (templated) :param destination_table: target table. (templated) :param source_conn_id: source connection. (templated) :param source_hook_params: source hook parameters. :param destination_conn_id: destination connection. (templated) :param destination_hook_params: destination hook parameters. :param preoperator: sql statement or list of statements to be executed prior to loading the data. (templated) :param insert_args: extra params for `insert_rows` method. :param page_size: number of records to be read in paginated mode (optional). """
[docs]defget_hook(cls,conn_id:str,hook_params:dict|None=None)->DbApiHook:""" Return DbApiHook for this connection id. :param conn_id: connection id :param hook_params: hook parameters :return: DbApiHook for this connection """connection=BaseHook.get_connection(conn_id)hook=connection.get_hook(hook_params=hook_params)ifnotisinstance(hook,DbApiHook):raiseRuntimeError(f"Hook for connection {conn_id!r} must be of type {DbApiHook.__name__}")returnhook
[docs]defget_paginated_sql(self,offset:int)->str:"""Format the paginated SQL statement using the current format."""returnself._paginated_sql_statement_format.format(self.sql,self.page_size,offset)
[docs]defrender_template_fields(self,context:Context,jinja_env:jinja2.Environment|None=None,)->None:super().render_template_fields(context=context,jinja_env=jinja_env)# Make sure string are converted to integersifisinstance(self.page_size,str):self.page_size=int(self.page_size)commit_every=self.insert_args.get("commit_every")ifisinstance(commit_every,str):self.insert_args["commit_every"]=int(commit_every)
[docs]defexecute(self,context:Context):ifself.preoperator:self.log.info("Running preoperator")self.log.info(self.preoperator)self.destination_hook.run(self.preoperator)ifself.page_sizeandisinstance(self.sql,str):self.defer(trigger=SQLExecuteQueryTrigger(conn_id=self.source_conn_id,hook_params=self.source_hook_params,sql=self.get_paginated_sql(0),),method_name=self.execute_complete.__name__,)else:self.log.info("Extracting data from %s",self.source_conn_id)self.log.info("Executing: \n%s",self.sql)results=self.destination_hook.get_records(self.sql)self.log.info("Inserting rows into %s",self.destination_conn_id)self.destination_hook.insert_rows(table=self.destination_table,rows=results,**self.insert_args)
[docs]defexecute_complete(self,context:Context,event:dict[Any,Any]|None=None,)->Any:ifevent:ifevent.get("status")=="failure":raiseAirflowException(event.get("message"))results=event.get("results")ifresults:map_index=context["ti"].map_indexoffset=(context["ti"].xcom_pull(key="offset",task_ids=self.task_id,dag_id=self.dag_id,map_indexes=map_index,default=0,)+self.page_size)self.log.info("Offset increased to %d",offset)self.xcom_push(context=context,key="offset",value=offset)self.log.info("Inserting %d rows into %s",len(results),self.destination_conn_id)self.destination_hook.insert_rows(table=self.destination_table,rows=results,**self.insert_args)self.log.info("Inserting %d rows into %s done!",len(results),self.destination_conn_id,)self.defer(trigger=SQLExecuteQueryTrigger(conn_id=self.source_conn_id,hook_params=self.source_hook_params,sql=self.get_paginated_sql(offset),),method_name=self.execute_complete.__name__,)else:self.log.info("No more rows to fetch into %s; ending transfer.",self.destination_table,)