# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from builtins import bytes
from typing import Any
from sqlalchemy import Column, Integer, String, Text, Boolean
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import synonym
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet, InvalidFernetToken
from airflow.utils.db import provide_session
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.secrets import get_variable
[docs]class Variable(Base, LoggingMixin):
[docs] __tablename__ = "variable"
[docs] __NO_DEFAULT_SENTINEL = object()
[docs] id = Column(Integer, primary_key=True)
[docs] key = Column(String(ID_LEN), unique=True)
[docs] _val = Column('val', Text)
[docs] is_encrypted = Column(Boolean, unique=False, default=False)
[docs] def __repr__(self):
# Hiding the value
return '{} : {}'.format(self.key, self._val)
[docs] def get_val(self):
if self._val is not None and self.is_encrypted:
try:
fernet = get_fernet()
return fernet.decrypt(bytes(self._val, 'utf-8')).decode()
except InvalidFernetToken:
self.log.error("Can't decrypt _val for key={}, invalid token "
"or value".format(self.key))
return None
except Exception:
self.log.error("Can't decrypt _val for key={}, FERNET_KEY "
"configuration missing".format(self.key))
return None
else:
return self._val
[docs] def set_val(self, value):
if value is not None:
fernet = get_fernet()
self._val = fernet.encrypt(bytes(value, 'utf-8')).decode()
self.is_encrypted = fernet.is_encrypted
else:
self._val = None
self.is_encrypted = False
@declared_attr
[docs] def val(cls):
return synonym('_val',
descriptor=property(cls.get_val, cls.set_val))
@classmethod
[docs] def setdefault(cls, key, default, deserialize_json=False):
"""
Like a Python builtin dict object, setdefault returns the current value
for a key, and if it isn't there, stores the default value and returns it.
:param key: Dict key for this Variable
:type key: str
:param default: Default value to set and return if the variable
isn't already in the DB
:type default: Mixed
:param deserialize_json: Store this as a JSON encoded value in the DB
and un-encode it when retrieving a value
:return: Mixed
"""
obj = Variable.get(key, default_var=None,
deserialize_json=deserialize_json)
if obj is None:
if default is not None:
Variable.set(key, default, serialize_json=deserialize_json)
return default
else:
raise ValueError('Default Value must be set')
else:
return obj
@classmethod
[docs] def get(
cls,
key, # type: str
default_var=__NO_DEFAULT_SENTINEL, # type: Any
deserialize_json=False, # type: bool
session=None
):
var_val = get_variable(key=key)
if var_val is None:
if default_var is not cls.__NO_DEFAULT_SENTINEL:
return default_var
else:
raise KeyError('Variable {} does not exist'.format(key))
else:
if deserialize_json:
return json.loads(var_val)
else:
return var_val
@classmethod
@provide_session
[docs] def set(
cls,
key, # type: str
value, # type: Any
serialize_json=False, # type: bool
session=None
):
if serialize_json:
stored_value = json.dumps(value, indent=2, separators=(',', ': '))
else:
stored_value = str(value)
Variable.delete(key, session=session)
session.add(Variable(key=key, val=stored_value)) # type: ignore
session.flush()
@classmethod
@provide_session
[docs] def delete(cls, key, session=None):
session.query(cls).filter(cls.key == key).delete()
[docs] def rotate_fernet_key(self):
fernet = get_fernet()
if self._val and self.is_encrypted:
self._val = fernet.rotate(self._val.encode('utf-8')).decode()