#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Optional, Union
from airflow.exceptions import AirflowException
from airflow.operators.sql import BaseSQLOperator
[docs]def parse_boolean(val: str) -> Union[str, bool]:
"""Try to parse a string into boolean.
Raises ValueError if the input is not a valid true- or false-like string value.
"""
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return True
if val in ('n', 'no', 'f', 'false', 'off', '0'):
return False
raise ValueError(f"{val!r} is not a boolean-like string value")
def _get_failed_tests(checks):
return [
f"\tCheck: {check},\n\tCheck Values: {check_values}\n"
for check, check_values in checks.items()
if not check_values["success"]
]
[docs]class SQLColumnCheckOperator(BaseSQLOperator):
"""
Performs one or more of the templated checks in the column_checks dictionary.
Checks are performed on a per-column basis specified by the column_mapping.
Each check can take one or more of the following options:
- equal_to: an exact value to equal, cannot be used with other comparison options
- greater_than: value that result should be strictly greater than
- less_than: value that results should be strictly less than
- geq_to: value that results should be greater than or equal to
- leq_to: value that results should be less than or equal to
- tolerance: the percentage that the result may be off from the expected value
:param table: the table to run checks on
:param column_mapping: the dictionary of columns and their associated checks, e.g.
.. code-block:: python
{
"col_name": {
"null_check": {
"equal_to": 0,
},
"min": {
"greater_than": 5,
"leq_to": 10,
"tolerance": 0.2,
},
"max": {"less_than": 1000, "geq_to": 10, "tolerance": 0.01},
}
}
:param conn_id: the connection ID used to connect to the database
:param database: name of database which overwrite the defined one in connection
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:SQLColumnCheckOperator`
"""
[docs] column_checks = {
"null_check": "SUM(CASE WHEN column IS NULL THEN 1 ELSE 0 END) AS column_null_check",
"distinct_check": "COUNT(DISTINCT(column)) AS column_distinct_check",
"unique_check": "COUNT(column) - COUNT(DISTINCT(column)) AS column_unique_check",
"min": "MIN(column) AS column_min",
"max": "MAX(column) AS column_max",
}
def __init__(
self,
*,
table: str,
column_mapping: Dict[str, Dict[str, Any]],
conn_id: Optional[str] = None,
database: Optional[str] = None,
**kwargs,
):
super().__init__(conn_id=conn_id, database=database, **kwargs)
for checks in column_mapping.values():
for check, check_values in checks.items():
self._column_mapping_validation(check, check_values)
self.table = table
self.column_mapping = column_mapping
# OpenLineage needs a valid SQL query with the input/output table(s) to parse
self.sql = f"SELECT * FROM {self.table};"
[docs] def execute(self, context=None):
hook = self.get_db_hook()
failed_tests = []
for column in self.column_mapping:
checks = [*self.column_mapping[column]]
checks_sql = ",".join([self.column_checks[check].replace("column", column) for check in checks])
self.sql = f"SELECT {checks_sql} FROM {self.table};"
records = hook.get_first(self.sql)
if not records:
raise AirflowException(f"The following query returned zero rows: {self.sql}")
self.log.info("Record: %s", records)
for idx, result in enumerate(records):
tolerance = self.column_mapping[column][checks[idx]].get("tolerance")
self.column_mapping[column][checks[idx]]["result"] = result
self.column_mapping[column][checks[idx]]["success"] = self._get_match(
self.column_mapping[column][checks[idx]], result, tolerance
)
failed_tests.extend(_get_failed_tests(self.column_mapping[column]))
if failed_tests:
raise AirflowException(
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n"
"The following tests have failed:"
f"\n{''.join(failed_tests)}"
)
self.log.info("All tests have passed")
def _get_match(self, check_values, record, tolerance=None) -> bool:
match_boolean = True
if "geq_to" in check_values:
if tolerance is not None:
match_boolean = record >= check_values["geq_to"] * (1 - tolerance)
else:
match_boolean = record >= check_values["geq_to"]
elif "greater_than" in check_values:
if tolerance is not None:
match_boolean = record > check_values["greater_than"] * (1 - tolerance)
else:
match_boolean = record > check_values["greater_than"]
if "leq_to" in check_values:
if tolerance is not None:
match_boolean = record <= check_values["leq_to"] * (1 + tolerance) and match_boolean
else:
match_boolean = record <= check_values["leq_to"] and match_boolean
elif "less_than" in check_values:
if tolerance is not None:
match_boolean = record < check_values["less_than"] * (1 + tolerance) and match_boolean
else:
match_boolean = record < check_values["less_than"] and match_boolean
if "equal_to" in check_values:
if tolerance is not None:
match_boolean = (
check_values["equal_to"] * (1 - tolerance)
<= record
<= check_values["equal_to"] * (1 + tolerance)
) and match_boolean
else:
match_boolean = record == check_values["equal_to"] and match_boolean
return match_boolean
def _column_mapping_validation(self, check, check_values):
if check not in self.column_checks:
raise AirflowException(f"Invalid column check: {check}.")
if (
"greater_than" not in check_values
and "geq_to" not in check_values
and "less_than" not in check_values
and "leq_to" not in check_values
and "equal_to" not in check_values
):
raise ValueError(
"Please provide one or more of: less_than, leq_to, "
"greater_than, geq_to, or equal_to in the check's dict."
)
if "greater_than" in check_values and "less_than" in check_values:
if check_values["greater_than"] >= check_values["less_than"]:
raise ValueError(
"greater_than should be strictly less than "
"less_than. Use geq_to or leq_to for "
"overlapping equality."
)
if "greater_than" in check_values and "leq_to" in check_values:
if check_values["greater_than"] >= check_values["leq_to"]:
raise ValueError(
"greater_than must be strictly less than leq_to. "
"Use geq_to with leq_to for overlapping equality."
)
if "geq_to" in check_values and "less_than" in check_values:
if check_values["geq_to"] >= check_values["less_than"]:
raise ValueError(
"geq_to should be strictly less than less_than. "
"Use leq_to with geq_to for overlapping equality."
)
if "geq_to" in check_values and "leq_to" in check_values:
if check_values["geq_to"] > check_values["leq_to"]:
raise ValueError("geq_to should be less than or equal to leq_to.")
if "greater_than" in check_values and "geq_to" in check_values:
raise ValueError("Only supply one of greater_than or geq_to.")
if "less_than" in check_values and "leq_to" in check_values:
raise ValueError("Only supply one of less_than or leq_to.")
if (
"greater_than" in check_values
or "geq_to" in check_values
or "less_than" in check_values
or "leq_to" in check_values
) and "equal_to" in check_values:
raise ValueError(
"equal_to cannot be passed with a greater or less than "
"function. To specify 'greater than or equal to' or "
"'less than or equal to', use geq_to or leq_to."
)
[docs]class SQLTableCheckOperator(BaseSQLOperator):
"""
Performs one or more of the checks provided in the checks dictionary.
Checks should be written to return a boolean result.
:param table: the table to run checks on
:param checks: the dictionary of checks, e.g.:
.. code-block:: python
{
"row_count_check": {"check_statement": "COUNT(*) = 1000"},
"column_sum_check": {"check_statement": "col_a + col_b < col_c"},
}
:param conn_id: the connection ID used to connect to the database
:param database: name of database which overwrite the defined one in connection
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:SQLTableCheckOperator`
"""
[docs] sql_check_template = "CASE WHEN check_statement THEN 1 ELSE 0 END AS check_name"
[docs] sql_min_template = "MIN(check_name)"
def __init__(
self,
*,
table: str,
checks: Dict[str, Dict[str, Any]],
conn_id: Optional[str] = None,
database: Optional[str] = None,
**kwargs,
):
super().__init__(conn_id=conn_id, database=database, **kwargs)
self.table = table
self.checks = checks
# OpenLineage needs a valid SQL query with the input/output table(s) to parse
self.sql = f"SELECT * FROM {self.table};"
[docs] def execute(self, context=None):
hook = self.get_db_hook()
check_names = [*self.checks]
check_mins_sql = ",".join(
self.sql_min_template.replace("check_name", check_name) for check_name in check_names
)
checks_sql = ",".join(
[
self.sql_check_template.replace("check_statement", value["check_statement"]).replace(
"check_name", check_name
)
for check_name, value in self.checks.items()
]
)
self.sql = f"SELECT {check_mins_sql} FROM (SELECT {checks_sql} FROM {self.table});"
records = hook.get_first(self.sql)
if not records:
raise AirflowException(f"The following query returned zero rows: {self.sql}")
self.log.info("Record: %s", records)
for check in self.checks.keys():
for result in records:
self.checks[check]["success"] = parse_boolean(str(result))
failed_tests = _get_failed_tests(self.checks)
if failed_tests:
raise AirflowException(
f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n"
"The following tests have failed:"
f"\n{', '.join(failed_tests)}"
)
self.log.info("All tests have passed")