Source code for airflow.providers.apache.spark.hooks.spark_jdbc_script

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import argparse
from typing import Any

from pyspark.sql import SparkSession

[docs]SPARK_WRITE_TO_JDBC: str = "spark_to_jdbc"
[docs]SPARK_READ_FROM_JDBC: str = "jdbc_to_spark"
[docs]def set_common_options( spark_source: Any, url: str = "localhost:5432", jdbc_table: str = "default.default", user: str = "root", password: str = "root", driver: str = "driver", ) -> Any: """ Get Spark source from JDBC connection :param spark_source: Spark source, here is Spark reader or writer :param url: JDBC resource url :param jdbc_table: JDBC resource table name :param user: JDBC resource user name :param password: JDBC resource password :param driver: JDBC resource driver """ spark_source = ( spark_source.format("jdbc") .option("url", url) .option("dbtable", jdbc_table) .option("user", user) .option("password", password) .option("driver", driver) ) return spark_source
[docs]def spark_write_to_jdbc( spark_session: SparkSession, url: str, user: str, password: str, metastore_table: str, jdbc_table: str, driver: Any, truncate: bool, save_mode: str, batch_size: int, num_partitions: int, create_table_column_types: str, ) -> None: """Transfer data from Spark to JDBC source""" writer = spark_session.table(metastore_table).write # first set common options writer = set_common_options(writer, url, jdbc_table, user, password, driver) # now set write-specific options if truncate: writer = writer.option("truncate", truncate) if batch_size: writer = writer.option("batchsize", batch_size) if num_partitions: writer = writer.option("numPartitions", num_partitions) if create_table_column_types: writer = writer.option("createTableColumnTypes", create_table_column_types) writer.save(mode=save_mode)
[docs]def spark_read_from_jdbc( spark_session: SparkSession, url: str, user: str, password: str, metastore_table: str, jdbc_table: str, driver: Any, save_mode: str, save_format: str, fetch_size: int, num_partitions: int, partition_column: str, lower_bound: str, upper_bound: str, ) -> None: """Transfer data from JDBC source to Spark""" # first set common options reader = set_common_options(spark_session.read, url, jdbc_table, user, password, driver) # now set specific read options if fetch_size: reader = reader.option("fetchsize", fetch_size) if num_partitions: reader = reader.option("numPartitions", num_partitions) if partition_column and lower_bound and upper_bound: reader = ( reader.option("partitionColumn", partition_column) .option("lowerBound", lower_bound) .option("upperBound", upper_bound) ) reader.load().write.saveAsTable(metastore_table, format=save_format, mode=save_mode)
def _parse_arguments(args: list[str] | None = None) -> Any: parser = argparse.ArgumentParser(description="Spark-JDBC") parser.add_argument("-cmdType", dest="cmd_type", action="store") parser.add_argument("-url", dest="url", action="store") parser.add_argument("-user", dest="user", action="store") parser.add_argument("-password", dest="password", action="store") parser.add_argument("-metastoreTable", dest="metastore_table", action="store") parser.add_argument("-jdbcTable", dest="jdbc_table", action="store") parser.add_argument("-jdbcDriver", dest="jdbc_driver", action="store") parser.add_argument("-jdbcTruncate", dest="truncate", action="store") parser.add_argument("-saveMode", dest="save_mode", action="store") parser.add_argument("-saveFormat", dest="save_format", action="store") parser.add_argument("-batchsize", dest="batch_size", action="store") parser.add_argument("-fetchsize", dest="fetch_size", action="store") parser.add_argument("-name", dest="name", action="store") parser.add_argument("-numPartitions", dest="num_partitions", action="store") parser.add_argument("-partitionColumn", dest="partition_column", action="store") parser.add_argument("-lowerBound", dest="lower_bound", action="store") parser.add_argument("-upperBound", dest="upper_bound", action="store") parser.add_argument("-createTableColumnTypes", dest="create_table_column_types", action="store") return parser.parse_args(args=args) def _create_spark_session(arguments: Any) -> SparkSession: return SparkSession.builder.appName(arguments.name).enableHiveSupport().getOrCreate() def _run_spark(arguments: Any) -> None: # Disable dynamic allocation by default to allow num_executors to take effect. spark = _create_spark_session(arguments) if arguments.cmd_type == SPARK_WRITE_TO_JDBC: spark_write_to_jdbc( spark, arguments.url, arguments.user, arguments.password, arguments.metastore_table, arguments.jdbc_table, arguments.jdbc_driver, arguments.truncate, arguments.save_mode, arguments.batch_size, arguments.num_partitions, arguments.create_table_column_types, ) elif arguments.cmd_type == SPARK_READ_FROM_JDBC: spark_read_from_jdbc( spark, arguments.url, arguments.user, arguments.password, arguments.metastore_table, arguments.jdbc_table, arguments.jdbc_driver, arguments.save_mode, arguments.save_format, arguments.fetch_size, arguments.num_partitions, arguments.partition_column, arguments.lower_bound, arguments.upper_bound, ) if __name__ == "__main__": # pragma: no cover _run_spark(arguments=_parse_arguments())

Was this entry helpful?