Source code for airflow.providers.apache.hive.operators.hive_stats

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import warnings
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.presto.hooks.presto import PrestoHook

if TYPE_CHECKING:
    from airflow.utils.context import Context


[docs]class HiveStatsCollectionOperator(BaseOperator): """ Gathers partition statistics using a dynamically generated Presto query, inserts the stats into a MySql table with this format. Stats overwrite themselves if you rerun the same date/partition. :: CREATE TABLE hive_stats ( ds VARCHAR(16), table_name VARCHAR(500), metric VARCHAR(200), value BIGINT ); :param metastore_conn_id: Reference to the :ref:`Hive Metastore connection id <howto/connection:hive_metastore>`. :type metastore_conn_id: str :param table: the source table, in the format ``database.table_name``. (templated) :type table: str :param partition: the source partition. (templated) :type partition: dict of {col:value} :param extra_exprs: dict of expression to run against the table where keys are metric names and values are Presto compatible expressions :type extra_exprs: dict :param excluded_columns: list of columns to exclude, consider excluding blobs, large json columns, ... :type excluded_columns: list :param assignment_func: a function that receives a column name and a type, and returns a dict of metric names and an Presto expressions. If None is returned, the global defaults are applied. If an empty dictionary is returned, no stats are computed for that column. :type assignment_func: function """
[docs] template_fields: Sequence[str] = ('table', 'partition', 'ds', 'dttm')
[docs] ui_color = '#aff7a6'
def __init__( self, *, table: str, partition: Any, extra_exprs: Optional[Dict[str, Any]] = None, excluded_columns: Optional[List[str]] = None, assignment_func: Optional[Callable[[str, str], Optional[Dict[Any, Any]]]] = None, metastore_conn_id: str = 'metastore_default', presto_conn_id: str = 'presto_default', mysql_conn_id: str = 'airflow_db', **kwargs: Any, ) -> None: if 'col_blacklist' in kwargs: warnings.warn( 'col_blacklist kwarg passed to {c} (task_id: {t}) is deprecated, please rename it to ' 'excluded_columns instead'.format(c=self.__class__.__name__, t=kwargs.get('task_id')), category=FutureWarning, stacklevel=2, ) excluded_columns = kwargs.pop('col_blacklist') super().__init__(**kwargs) self.table = table self.partition = partition self.extra_exprs = extra_exprs or {} self.excluded_columns = excluded_columns or [] # type: List[str] self.metastore_conn_id = metastore_conn_id self.presto_conn_id = presto_conn_id self.mysql_conn_id = mysql_conn_id self.assignment_func = assignment_func self.ds = '{{ ds }}' self.dttm = '{{ execution_date.isoformat() }}'
[docs] def get_default_exprs(self, col: str, col_type: str) -> Dict[Any, Any]: """Get default expressions""" if col in self.excluded_columns: return {} exp = {(col, 'non_null'): f"COUNT({col})"} if col_type in ['double', 'int', 'bigint', 'float']: exp[(col, 'sum')] = f'SUM({col})' exp[(col, 'min')] = f'MIN({col})' exp[(col, 'max')] = f'MAX({col})' exp[(col, 'avg')] = f'AVG({col})' elif col_type == 'boolean': exp[(col, 'true')] = f'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)' exp[(col, 'false')] = f'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)' elif col_type in ['string']: exp[(col, 'len')] = f'SUM(CAST(LENGTH({col}) AS BIGINT))' exp[(col, 'approx_distinct')] = f'APPROX_DISTINCT({col})' return exp
[docs] def execute(self, context: "Context") -> None: metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} exprs: Any = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): if self.assignment_func: assign_exprs = self.assignment_func(col, col_type) if assign_exprs is None: assign_exprs = self.get_default_exprs(col, col_type) else: assign_exprs = self.get_default_exprs(col, col_type) exprs.update(assign_exprs) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) exprs_str = ",\n ".join(v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()) where_clause_ = [f"{k} = '{v}'" for k, v in self.partition.items()] where_clause = " AND\n ".join(where_clause_) sql = f"SELECT {exprs_str} FROM {self.table} WHERE {where_clause};" presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info('Executing SQL check: %s', sql) row = presto.get_first(hql=sql) self.log.info("Record: %s", row) if not row: raise AirflowException("The query returned None") part_json = json.dumps(self.partition, sort_keys=True) self.log.info("Deleting rows from previous runs if they exist") mysql = MySqlHook(self.mysql_conn_id) sql = f""" SELECT 1 FROM hive_stats WHERE table_name='{self.table}' AND partition_repr='{part_json}' AND dttm='{self.dttm}' LIMIT 1; """ if mysql.get_records(sql): sql = f""" DELETE FROM hive_stats WHERE table_name='{self.table}' AND partition_repr='{part_json}' AND dttm='{self.dttm}'; """ mysql.run(sql) self.log.info("Pivoting and loading cells into the Airflow db") rows = [ (self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row) ] mysql.insert_rows( table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value',
], )

Was this entry helpful?