Source code for airflow.providers.apache.hive.hooks.hive
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportcontextlibimportosimportreimportsocketimportsubprocessimporttimefromcollections.abcimportIterable,MappingfromtempfileimportNamedTemporaryFile,TemporaryDirectoryfromtypingimportTYPE_CHECKING,AnyifTYPE_CHECKING:importpandasaspdimportcsvfromairflow.configurationimportconffromairflow.exceptionsimportAirflowExceptionfromairflow.hooks.baseimportBaseHookfromairflow.providers.common.sql.hooks.sqlimportDbApiHookfromairflow.securityimportutilsfromairflow.utils.helpersimportas_flattened_listfromairflow.utils.operator_helpersimportAIRFLOW_VAR_NAME_FORMAT_MAPPING
[docs]defget_context_from_env_var()->dict[Any,Any]:""" Extract context from env variable, (dag_id, task_id, etc) for use in BashOperator and PythonOperator. :return: The context of interest. """return{format_map["default"]:os.environ.get(format_map["env_var_format"],"")forformat_mapinAIRFLOW_VAR_NAME_FORMAT_MAPPING.values()}
[docs]classHiveCliHook(BaseHook):""" Simple wrapper around the hive CLI. It also supports the ``beeline`` a lighter CLI that runs JDBC and is replacing the heavier traditional CLI. To enable ``beeline``, set the use_beeline param in the extra field of your connection as in ``{ "use_beeline": true }`` Note that you can also set default hive CLI parameters by passing ``hive_cli_params`` space separated list of parameters to add to the hive command. The extra connection parameter ``auth`` gets passed as in the ``jdbc`` connection string as is. :param hive_cli_conn_id: Reference to the :ref:`Hive CLI connection id <howto/connection:hive_cli>`. :param mapred_queue: queue used by the Hadoop Scheduler (Capacity or Fair) :param mapred_queue_priority: priority within the job queue. Possible settings include: VERY_HIGH, HIGH, NORMAL, LOW, VERY_LOW :param mapred_job_name: This name will appear in the jobtracker. This can make monitoring easier. :param hive_cli_params: Space separated list of hive command parameters to add to the hive command. :param proxy_user: Run HQL code as this user. """
[docs]defget_connection_form_widgets(cls)->dict[str,Any]:"""Return connection widgets to add to Hive Client Wrapper connection form."""fromflask_appbuilder.fieldwidgetsimportBS3TextFieldWidgetfromflask_babelimportlazy_gettextfromwtformsimportBooleanField,StringFieldreturn{"use_beeline":BooleanField(lazy_gettext("Use Beeline"),default=True),"proxy_user":StringField(lazy_gettext("Proxy User"),widget=BS3TextFieldWidget(),default=""),"principal":StringField(lazy_gettext("Principal"),widget=BS3TextFieldWidget(),default="hive/_HOST@EXAMPLE.COM"),"high_availability":BooleanField(lazy_gettext("High Availability mode"),default=False),}
@classmethod
[docs]defget_ui_field_behaviour(cls)->dict[str,Any]:"""Return custom UI field behaviour for Hive Client Wrapper connection."""return{"hidden_fields":["extra"],"relabeling":{},}
def_get_proxy_user(self)->str:"""Set the proper proxy_user value in case the user overwrite the default."""conn=self.connifself.proxy_userisnotNone:returnf"hive.server2.proxy.user={self.proxy_user}"proxy_user_value:str=conn.extra_dejson.get("proxy_user","")ifproxy_user_value!="":returnf"hive.server2.proxy.user={proxy_user_value}"return""def_prepare_cli_cmd(self)->list[Any]:"""Create the command list from available information."""conn=self.connhive_bin="hive"cmd_extra=[]ifself.use_beeline:hive_bin="beeline"self._validate_beeline_parameters(conn)ifself.high_availability:jdbc_url=f"jdbc:hive2://{conn.host}/{conn.schema}"self.log.info("High Availability selected, setting JDBC url as %s",jdbc_url)else:jdbc_url=f"jdbc:hive2://{conn.host}:{conn.port}/{conn.schema}"self.log.info("High Availability not selected, setting JDBC url as %s",jdbc_url)ifconf.get("core","security")=="kerberos":template=conn.extra_dejson.get("principal","hive/_HOST@EXAMPLE.COM")if"_HOST"intemplate:template=utils.replace_hostname_pattern(utils.get_components(template))proxy_user=self._get_proxy_user()if";"intemplate:raiseRuntimeError("The principal should not contain the ';' character")if";"inproxy_user:raiseRuntimeError("The proxy_user should not contain the ';' character")jdbc_url+=f";principal={template};{proxy_user}"ifself.high_availability:ifnotjdbc_url.endswith(";"):jdbc_url+=";"jdbc_url+="serviceDiscoveryMode=zooKeeper;ssl=true;zooKeeperNamespace=hiveserver2"elifself.auth:jdbc_url+=";auth="+self.authjdbc_url=f'"{jdbc_url}"'cmd_extra+=["-u",jdbc_url]ifconn.login:cmd_extra+=["-n",conn.login]ifconn.password:cmd_extra+=["-p",conn.password]hive_params_list=self.hive_cli_params.split()return[hive_bin,*cmd_extra,*hive_params_list]def_validate_beeline_parameters(self,conn):ifself.high_availability:if";"inconn.schema:raiseValueError(f"The schema used in beeline command ({conn.schema}) should not contain ';' character)")returnelif":"inconn.hostor"/"inconn.hostor";"inconn.host:raiseValueError(f"The host used in beeline command ({conn.host}) should not contain ':/;' characters)")try:int_port=int(conn.port)ifnot0<int_port<=65535:raiseValueError(f"The port used in beeline command ({conn.port}) should be in range 0-65535)")except(ValueError,TypeError)ase:raiseValueError(f"The port used in beeline command ({conn.port}) should be a valid integer: {e})")if";"inconn.schema:raiseValueError(f"The schema used in beeline command ({conn.schema}) should not contain ';' character)")@staticmethoddef_prepare_hiveconf(d:dict[Any,Any])->list[Any]:""" Prepare a list of hiveconf params from a dictionary of key value pairs. :param d: >>> hh = HiveCliHook() >>> hive_conf = {"hive.exec.dynamic.partition": "true", ... "hive.exec.dynamic.partition.mode": "nonstrict"} >>> hh._prepare_hiveconf(hive_conf) ["-hiveconf", "hive.exec.dynamic.partition=true",\ "-hiveconf", "hive.exec.dynamic.partition.mode=nonstrict"] """ifnotd:return[]returnas_flattened_list(zip(["-hiveconf"]*len(d),[f"{k}={v}"fork,vind.items()]))
[docs]defrun_cli(self,hql:str,schema:str|None=None,verbose:bool=True,hive_conf:dict[Any,Any]|None=None,)->Any:""" Run an hql statement using the hive cli. If hive_conf is specified it should be a dict and the entries will be set as key/value pairs in HiveConf. :param hql: an hql (hive query language) statement to run with hive cli :param schema: Name of hive schema (database) to use :param verbose: Provides additional logging. Defaults to True. :param hive_conf: if specified these key value pairs will be passed to hive as ``-hiveconf "key"="value"``. Note that they will be passed after the ``hive_cli_params`` and thus will override whatever values are specified in the database. >>> hh = HiveCliHook() >>> result = hh.run_cli("USE airflow;") >>> ("OK" in result) True """conn=self.connschema=schemaorconn.schemainvalid_chars_list=re.findall(r"[^a-z0-9_]",schema)ifinvalid_chars_list:invalid_chars="".join(invalid_chars_list)raiseRuntimeError(f"The schema `{schema}` contains invalid characters: {invalid_chars}")ifschema:hql=f"USE {schema};\n{hql}"withTemporaryDirectory(prefix="airflow_hiveop_")astmp_dir,NamedTemporaryFile(dir=tmp_dir)asf:hql+="\n"f.write(hql.encode("UTF-8"))f.flush()hive_cmd=self._prepare_cli_cmd()env_context=get_context_from_env_var()# Only extend the hive_conf if it is defined.ifhive_conf:env_context.update(hive_conf)hive_conf_params=self._prepare_hiveconf(env_context)ifself.mapred_queue:hive_conf_params.extend(["-hiveconf",f"mapreduce.job.queuename={self.mapred_queue}","-hiveconf",f"mapred.job.queue.name={self.mapred_queue}","-hiveconf",f"tez.queue.name={self.mapred_queue}",])ifself.mapred_queue_priority:hive_conf_params.extend(["-hiveconf",f"mapreduce.job.priority={self.mapred_queue_priority}"])ifself.mapred_job_name:hive_conf_params.extend(["-hiveconf",f"mapred.job.name={self.mapred_job_name}"])hive_cmd.extend(hive_conf_params)hive_cmd.extend(["-f",f.name])ifverbose:self.log.info("%s"," ".join(hive_cmd))sub_process:Any=subprocess.Popen(hive_cmd,stdout=subprocess.PIPE,stderr=subprocess.STDOUT,cwd=tmp_dir,close_fds=True)self.sub_process=sub_processstdout=""forlineiniter(sub_process.stdout.readline,b""):line=line.decode()stdout+=lineifverbose:self.log.info(line.strip())sub_process.wait()ifsub_process.returncode:raiseAirflowException(stdout)returnstdout
[docs]deftest_hql(self,hql:str)->None:"""Test an hql statement using the hive cli and EXPLAIN."""create,insert,other=[],[],[]forqueryinhql.split(";"):# naivequery_original=queryquery=query.lower().strip()ifquery.startswith("create table"):create.append(query_original)elifquery.startswith(("set ","add jar ","create temporary function")):other.append(query_original)elifquery.startswith("insert"):insert.append(query_original)other_=";".join(other)forquery_setin[create,insert]:forqueryinquery_set:query_preview=" ".join(query.split())[:50]self.log.info("Testing HQL [%s (...)]",query_preview)ifquery_set==insert:query=other_+"; explain "+queryelse:query="explain "+querytry:self.run_cli(query,verbose=False)exceptAirflowExceptionase:message=e.args[0].splitlines()[-2]self.log.info(message)error_loc=re.search(r"(\d+):(\d+)",message)iferror_loc:lst=int(error_loc.group(1))begin=max(lst-2,0)end=min(lst+3,len(query.splitlines()))context="\n".join(query.splitlines()[begin:end])self.log.info("Context :\n%s",context)else:self.log.info("SUCCESS")
[docs]defload_df(self,df:pd.DataFrame,table:str,field_dict:dict[Any,Any]|None=None,delimiter:str=",",encoding:str="utf8",pandas_kwargs:Any=None,**kwargs:Any,)->None:""" Load a pandas DataFrame into hive. Hive data types will be inferred if not passed but column names will not be sanitized. :param df: DataFrame to load into a Hive table :param table: target Hive table, use dot notation to target a specific database :param field_dict: mapping from column name to hive data type. Note that Python dict is ordered so it keeps columns' order. :param delimiter: field delimiter in the file :param encoding: str encoding to use when writing DataFrame to file :param pandas_kwargs: passed to DataFrame.to_csv :param kwargs: passed to self.load_file """def_infer_field_types_from_df(df:pd.DataFrame)->dict[Any,Any]:dtype_kind_hive_type={"b":"BOOLEAN",# boolean"i":"BIGINT",# signed integer"u":"BIGINT",# unsigned integer"f":"DOUBLE",# floating-point"c":"STRING",# complex floating-point"M":"TIMESTAMP",# datetime"O":"STRING",# object"S":"STRING",# (byte-)string"U":"STRING",# Unicode"V":"STRING",# void}order_type={}forcol,dtypeindf.dtypes.items():order_type[col]=dtype_kind_hive_type[dtype.kind]returnorder_typeifpandas_kwargsisNone:pandas_kwargs={}with(TemporaryDirectory(prefix="airflow_hiveop_")astmp_dir,NamedTemporaryFile(dir=tmp_dir,mode="w")asf,):iffield_dictisNone:field_dict=_infer_field_types_from_df(df)df.to_csv(path_or_buf=f,sep=delimiter,header=False,index=False,encoding=encoding,date_format="%Y-%m-%d %H:%M:%S",**pandas_kwargs,)f.flush()returnself.load_file(filepath=f.name,table=table,delimiter=delimiter,field_dict=field_dict,**kwargs)
[docs]defload_file(self,filepath:str,table:str,delimiter:str=",",field_dict:dict[Any,Any]|None=None,create:bool=True,overwrite:bool=True,partition:dict[str,Any]|None=None,recreate:bool=False,tblproperties:dict[str,Any]|None=None,)->None:""" Load a local file into Hive. Note that the table generated in Hive uses ``STORED AS textfile`` which isn't the most efficient serialization format. If a large amount of data is loaded and/or if the tables gets queried considerably, you may want to use this operator only to stage the data into a temporary table before loading it into its final destination using a ``HiveOperator``. :param filepath: local filepath of the file to load :param table: target Hive table, use dot notation to target a specific database :param delimiter: field delimiter in the file :param field_dict: A dictionary of the fields name in the file as keys and their Hive types as values. Note that Python dict is ordered so it keeps columns' order. :param create: whether to create the table if it doesn't exist :param overwrite: whether to overwrite the data in table or partition :param partition: target partition as a dict of partition columns and values :param recreate: whether to drop and recreate the table at every execution :param tblproperties: TBLPROPERTIES of the hive table being created """hql=""ifrecreate:hql+=f"DROP TABLE IF EXISTS {table};\n"ifcreateorrecreate:iffield_dictisNone:raiseValueError("Must provide a field dict when creating a table")fields=",\n ".join(f"`{k.strip('`')}` {v}"fork,vinfield_dict.items())hql+=f"CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n"ifpartition:pfields=",\n ".join(p+" STRING"forpinpartition)hql+=f"PARTITIONED BY ({pfields})\n"hql+="ROW FORMAT DELIMITED\n"hql+=f"FIELDS TERMINATED BY '{delimiter}'\n"hql+="STORED AS textfile\n"iftblpropertiesisnotNone:tprops=", ".join(f"'{k}'='{v}'"fork,vintblproperties.items())hql+=f"TBLPROPERTIES({tprops})\n"hql+=";"self.log.info(hql)self.run_cli(hql)hql=f"LOAD DATA LOCAL INPATH '{filepath}' "ifoverwrite:hql+="OVERWRITE "hql+=f"INTO TABLE {table} "ifpartition:pvals=", ".join(f"{k}='{v}'"fork,vinpartition.items())hql+=f"PARTITION ({pvals})"# Add a newline character as a workaround for https://issues.apache.org/jira/browse/HIVE-10541,hql+=";\n"self.log.info(hql)self.run_cli(hql)
[docs]defkill(self)->None:"""Kill Hive cli command."""ifhasattr(self,"sub_process"):ifself.sub_process.poll()isNone:print("Killing the Hive job")self.sub_process.terminate()time.sleep(60)self.sub_process.kill()
[docs]classHiveMetastoreHook(BaseHook):""" Wrapper to interact with the Hive Metastore. :param metastore_conn_id: reference to the :ref: `metastore thrift service connection id <howto/connection:hive_metastore>`. """# java short max val
[docs]def__getstate__(self)->dict[str,Any]:# This is for pickling to work despite the thrift hive client not# being picklable"""Serialize object and omit non-serializable attributes."""state=dict(self.__dict__)delstate["metastore"]returnstate
[docs]def__setstate__(self,d:dict[str,Any])->None:"""Deserialize object and restore non-serializable attributes."""self.__dict__.update(d)self.__dict__["metastore"]=self.get_metastore_client()
[docs]defget_metastore_client(self)->Any:"""Return a Hive thrift client."""importhmsclientfromthrift.protocolimportTBinaryProtocolfromthrift.transportimportTSocket,TTransporthost=self._find_valid_host()conn=self.connifnothost:raiseAirflowException("Failed to locate the valid server.")auth_mechanism=conn.extra_dejson.get("auth_mechanism","NOSASL")ifconf.get("core","security")=="kerberos":auth_mechanism=conn.extra_dejson.get("auth_mechanism","GSSAPI")kerberos_service_name=conn.extra_dejson.get("kerberos_service_name","hive")conn_socket=TSocket.TSocket(host,conn.port)ifconf.get("core","security")=="kerberos"andauth_mechanism=="GSSAPI":try:importsaslwrapperassaslexceptImportError:importsasldefsasl_factory()->sasl.Client:sasl_client=sasl.Client()sasl_client.setAttr("host",host)sasl_client.setAttr("service",kerberos_service_name)sasl_client.init()returnsasl_clientfromthrift_saslimportTSaslClientTransporttransport=TSaslClientTransport(sasl_factory,"GSSAPI",conn_socket)else:transport=TTransport.TBufferedTransport(conn_socket)protocol=TBinaryProtocol.TBinaryProtocol(transport)returnhmsclient.HMSClient(iprot=protocol)
def_find_valid_host(self)->Any:conn=self.connhosts=conn.host.split(",")forhostinhosts:host_socket=socket.socket(socket.AF_INET,socket.SOCK_STREAM)self.log.info("Trying to connect to %s:%s",host,conn.port)ifhost_socket.connect_ex((host,conn.port))==0:self.log.info("Connected to %s:%s",host,conn.port)host_socket.close()returnhostelse:self.log.error("Could not connect to %s:%s",host,conn.port)returnNone
[docs]defcheck_for_partition(self,schema:str,table:str,partition:str)->bool:""" Check whether a partition exists. :param schema: Name of hive schema (database) @table belongs to :param table: Name of hive table @partition belongs to :param partition: Expression that matches the partitions to check for (e.g. `a = 'b' AND c = 'd'`) >>> hh = HiveMetastoreHook() >>> t = "static_babynames_partitioned" >>> hh.check_for_partition("airflow", t, "ds='2015-01-01'") True """withself.metastoreasclient:partitions=client.get_partitions_by_filter(schema,table,partition,HiveMetastoreHook.MAX_PART_COUNT)returnbool(partitions)
[docs]defcheck_for_named_partition(self,schema:str,table:str,partition_name:str)->Any:""" Check whether a partition with a given name exists. :param schema: Name of hive schema (database) @table belongs to :param table: Name of hive table @partition belongs to :param partition_name: Name of the partitions to check for (eg `a=b/c=d`) >>> hh = HiveMetastoreHook() >>> t = "static_babynames_partitioned" >>> hh.check_for_named_partition("airflow", t, "ds=2015-01-01") True >>> hh.check_for_named_partition("airflow", t, "ds=xxx") False """withself.metastoreasclient:returnclient.check_for_named_partition(schema,table,partition_name)
[docs]defget_table(self,table_name:str,db:str="default")->Any:""" Get a metastore table object. >>> hh = HiveMetastoreHook() >>> t = hh.get_table(db="airflow", table_name="static_babynames") >>> t.tableName 'static_babynames' >>> [col.name for col in t.sd.cols] ['state', 'year', 'name', 'gender', 'num'] """ifdb=="default"and"."intable_name:db,table_name=table_name.split(".")[:2]withself.metastoreasclient:returnclient.get_table(dbname=db,tbl_name=table_name)
[docs]defget_tables(self,db:str,pattern:str="*")->Any:"""Get a metastore table object."""withself.metastoreasclient:tables=client.get_tables(db_name=db,pattern=pattern)returnclient.get_table_objects_by_name(db,tables)
[docs]defget_databases(self,pattern:str="*")->Any:"""Get a metastore table object."""withself.metastoreasclient:returnclient.get_databases(pattern)
[docs]defget_partitions(self,schema:str,table_name:str,partition_filter:str|None=None)->list[Any]:""" Return a list of all partitions in a table. Works only for tables with less than 32767 (java short max val). For subpartitioned table, the number might easily exceed this. >>> hh = HiveMetastoreHook() >>> t = "static_babynames_partitioned" >>> parts = hh.get_partitions(schema="airflow", table_name=t) >>> len(parts) 1 >>> parts [{'ds': '2015-01-01'}] """withself.metastoreasclient:table=client.get_table(dbname=schema,tbl_name=table_name)iftable.partitionKeys:ifpartition_filter:parts=client.get_partitions_by_filter(db_name=schema,tbl_name=table_name,filter=partition_filter,max_parts=HiveMetastoreHook.MAX_PART_COUNT,)else:parts=client.get_partitions(db_name=schema,tbl_name=table_name,max_parts=HiveMetastoreHook.MAX_PART_COUNT)pnames=[p.nameforpintable.partitionKeys]return[dict(zip(pnames,p.values))forpinparts]else:raiseAirflowException("The table isn't partitioned")
@staticmethoddef_get_max_partition_from_part_specs(part_specs:list[Any],partition_key:str|None,filter_map:dict[str,Any]|None)->Any:""" Get max partition of partitions with partition_key from part specs. key:value pair in filter_map will be used to filter out partitions. :param part_specs: list of partition specs. :param partition_key: partition key name. :param filter_map: partition_key:partition_value map used for partition filtering, e.g. {'key1': 'value1', 'key2': 'value2'}. Only partitions matching all partition_key:partition_value pairs will be considered as candidates of max partition. :return: Max partition or None if part_specs is empty. """ifnotpart_specs:returnNone# Assuming all specs have the same keys.ifpartition_keynotinpart_specs[0].keys():raiseAirflowException(f"Provided partition_key {partition_key} is not in part_specs.")iffilter_mapandnotset(filter_map).issubset(part_specs[0]):raiseAirflowException(f"Keys in provided filter_map {', '.join(filter_map.keys())} "f"are not subset of part_spec keys: {', '.join(part_specs[0].keys())}")returnmax((p_dict[partition_key]forp_dictinpart_specsiffilter_mapisNoneorall(iteminp_dict.items()foriteminfilter_map.items())),default=None,)
[docs]defmax_partition(self,schema:str,table_name:str,field:str|None=None,filter_map:dict[Any,Any]|None=None,)->Any:""" Return the maximum value for all partitions with given field in a table. If only one partition key exist in the table, the key will be used as field. filter_map should be a partition_key:partition_value map and will be used to filter out partitions. :param schema: schema name. :param table_name: table name. :param field: partition key to get max partition from. :param filter_map: partition_key:partition_value map used for partition filtering. >>> hh = HiveMetastoreHook() >>> filter_map = {'ds': '2015-01-01'} >>> t = 'static_babynames_partitioned' >>> hh.max_partition(schema='airflow',\ ... table_name=t, field='ds', filter_map=filter_map) '2015-01-01' """withself.metastoreasclient:table=client.get_table(dbname=schema,tbl_name=table_name)key_name_set={key.nameforkeyintable.partitionKeys}iflen(table.partitionKeys)==1:field=table.partitionKeys[0].nameelifnotfield:raiseAirflowException("Please specify the field you want the max value for.")eliffieldnotinkey_name_set:raiseAirflowException("Provided field is not a partition key.")iffilter_mapandnotset(filter_map.keys()).issubset(key_name_set):raiseAirflowException("Provided filter_map contains keys that are not partition key.")part_names=client.get_partition_names(schema,table_name,max_parts=HiveMetastoreHook.MAX_PART_COUNT)part_specs=[client.partition_name_to_spec(part_name)forpart_nameinpart_names]returnHiveMetastoreHook._get_max_partition_from_part_specs(part_specs,field,filter_map)
[docs]defdrop_partitions(self,table_name,part_vals,delete_data=False,db="default"):""" Drop partitions from the given table matching the part_vals input. :param table_name: table name. :param part_vals: list of partition specs. :param delete_data: Setting to control if underlying data have to deleted in addition to dropping partitions. :param db: Name of hive schema (database) @table belongs to >>> hh = HiveMetastoreHook() >>> hh.drop_partitions(db='airflow', table_name='static_babynames', part_vals="['2020-05-01']") True """ifself.table_exists(table_name,db):withself.metastoreasclient:self.log.info("Dropping partition of table %s.%s matching the spec: %s",db,table_name,part_vals)returnclient.drop_partition(db,table_name,part_vals,delete_data)else:self.log.info("Table %s.%s does not exist!",db,table_name)returnFalse
[docs]classHiveServer2Hook(DbApiHook):""" Wrapper around the pyhive library. Notes: * the default auth_mechanism is PLAIN, to override it you can specify it in the ``extra`` of your connection in the UI * the default for run_set_variable_statements is true, if you are using impala you may need to set it to false in the ``extra`` of your connection in the UI :param hiveserver2_conn_id: Reference to the :ref: `Hive Server2 thrift service connection id <howto/connection:hiveserver2>`. :param schema: Hive database name. """
[docs]defget_conn(self,schema:str|None=None)->Any:"""Return a Hive connection object."""username:str|None=Nonepassword:str|None=Nonedb=self.get_connection(self.hiveserver2_conn_id)# type: ignoreauth_mechanism=db.extra_dejson.get("auth_mechanism","NONE")ifauth_mechanism=="NONE"anddb.loginisNone:# we need to give a usernameusername="airflow"kerberos_service_name=Noneifconf.get("core","security")=="kerberos":auth_mechanism=db.extra_dejson.get("auth_mechanism","KERBEROS")kerberos_service_name=db.extra_dejson.get("kerberos_service_name","hive")# Password should be set if and only if in LDAP or CUSTOM modeifauth_mechanismin("LDAP","CUSTOM"):password=db.passwordfrompyhive.hiveimportconnectreturnconnect(host=db.host,port=db.port,auth=auth_mechanism,kerberos_service_name=kerberos_service_name,username=db.loginorusername,password=password,database=schemaordb.schemaor"default",)
def_get_results(self,sql:str|list[str],schema:str="default",fetch_size:int|None=None,hive_conf:Iterable|Mapping|None=None,)->Any:frompyhive.excimportProgrammingErrorifisinstance(sql,str):sql=[sql]previous_description=Nonewithcontextlib.closing(self.get_conn(schema))asconn,contextlib.closing(conn.cursor())ascur:cur.arraysize=fetch_sizeor1000db=self.get_connection(self.hiveserver2_conn_id)# type: ignore# Not all query services (e.g. impala) support the set commandifdb.extra_dejson.get("run_set_variable_statements",True):env_context=get_context_from_env_var()ifhive_conf:env_context.update(hive_conf)fork,vinenv_context.items():cur.execute(f"set {k}={v}")forstatementinsql:cur.execute(statement)# we only get results of statements that returnslowered_statement=statement.lower().strip()iflowered_statement.startswith(("select","with","show"))or(lowered_statement.startswith("set")and"="notinlowered_statement):description=cur.descriptionifprevious_descriptionandprevious_description!=description:message=f"""The statements are producing different descriptions: Current: {description!r} Previous: {previous_description!r}"""raiseValueError(message)elifnotprevious_description:previous_description=descriptionyielddescriptiontry:# DB API 2 raises when no results are returned# we're silencing here as some statements in the list# may be `SET` or DDLyield fromcurexceptProgrammingError:self.log.debug("get_results returned no records")
[docs]defget_results(self,sql:str|list[str],schema:str="default",fetch_size:int|None=None,hive_conf:Iterable|Mapping|None=None,)->dict[str,Any]:""" Get results of the provided hql in target schema. :param sql: hql to be executed. :param schema: target schema, default to 'default'. :param fetch_size: max size of result to fetch. :param hive_conf: hive_conf to execute alone with the hql. :return: results of hql execution, dict with data (list of results) and header """results_iter=self._get_results(sql,schema,fetch_size=fetch_size,hive_conf=hive_conf)header=next(results_iter)results={"data":list(results_iter),"header":header}returnresults
[docs]defto_csv(self,sql:str,csv_filepath:str,schema:str="default",delimiter:str=",",lineterminator:str="\r\n",output_header:bool=True,fetch_size:int=1000,hive_conf:dict[Any,Any]|None=None,)->None:""" Execute hql in target schema and write results to a csv file. :param sql: hql to be executed. :param csv_filepath: filepath of csv to write results into. :param schema: target schema, default to 'default'. :param delimiter: delimiter of the csv file, default to ','. :param lineterminator: lineterminator of the csv file. :param output_header: header of the csv file, default to True. :param fetch_size: number of result rows to write into the csv file, default to 1000. :param hive_conf: hive_conf to execute alone with the hql. """results_iter=self._get_results(sql,schema,fetch_size=fetch_size,hive_conf=hive_conf)header=next(results_iter)message=Nonei=0withopen(csv_filepath,"w",encoding="utf-8")asfile:writer=csv.writer(file,delimiter=delimiter,lineterminator=lineterminator)try:ifoutput_header:self.log.debug("Cursor description is %s",header)writer.writerow([c[0]forcinheader])fori,rowinenumerate(results_iter,1):writer.writerow(row)ifi%fetch_size==0:self.log.info("Written %s rows so far.",i)exceptValueErrorasexception:message=str(exception)ifmessage:# need to clean up the file firstos.remove(csv_filepath)raiseValueError(message)self.log.info("Done. Loaded a total of %s rows.",i)
[docs]defget_records(self,sql:str|list[str],parameters:Iterable|Mapping[str,Any]|None=None,**kwargs)->Any:""" Get a set of records from a Hive query; optionally pass a 'schema' kwarg to specify target schema. :param sql: hql to be executed. :param parameters: optional configuration passed to get_results :return: result of hive execution >>> hh = HiveServer2Hook() >>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100" >>> len(hh.get_records(sql)) 100 """schema=kwargs["schema"]if"schema"inkwargselse"default"returnself.get_results(sql,schema=schema,hive_conf=parameters)["data"]
[docs]defget_pandas_df(# type: ignoreself,sql:str,schema:str="default",hive_conf:dict[Any,Any]|None=None,**kwargs,)->pd.DataFrame:""" Get a pandas dataframe from a Hive query. :param sql: hql to be executed. :param schema: target schema, default to 'default'. :param hive_conf: hive_conf to execute alone with the hql. :param kwargs: (optional) passed into pandas.DataFrame constructor :return: result of hive execution >>> hh = HiveServer2Hook() >>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100" >>> df = hh.get_pandas_df(sql) >>> len(df.index) 100 :return: pandas.DateFrame """try:importpandasaspdexceptImportErrorase:fromairflow.exceptionsimportAirflowOptionalProviderFeatureExceptionraiseAirflowOptionalProviderFeatureException(e)res=self.get_results(sql,schema=schema,hive_conf=hive_conf)df=pd.DataFrame(res["data"],columns=[c[0]forcinres["header"]],**kwargs)returndf