# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Any, Sequence
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger
if TYPE_CHECKING:
    from airflow.utils.context import Context
[docs]class RedshiftCreateClusterOperator(BaseOperator):
    """Creates a new cluster with the specified parameters.
    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:RedshiftCreateClusterOperator`
    :param cluster_identifier:  A unique identifier for the cluster.
    :param node_type: The node type to be provisioned for the cluster.
            Valid Values: ``ds2.xlarge``, ``ds2.8xlarge``, ``dc1.large``,
            ``dc1.8xlarge``, ``dc2.large``, ``dc2.8xlarge``, ``ra3.xlplus``,
            ``ra3.4xlarge``, and ``ra3.16xlarge``.
    :param master_username: The username associated with the admin user account for
        the cluster that is being created.
    :param master_user_password: The password associated with the admin user account for
        the cluster that is being created.
    :param cluster_type: The type of the cluster ``single-node`` or ``multi-node``.
        The default value is ``multi-node``.
    :param db_name: The name of the first database to be created when the cluster is created.
    :param number_of_nodes: The number of compute nodes in the cluster.
        This param require when ``cluster_type`` is ``multi-node``.
    :param cluster_security_groups: A list of security groups to be associated with this cluster.
    :param vpc_security_group_ids: A list of  VPC security groups to be associated with the cluster.
    :param cluster_subnet_group_name: The name of a cluster subnet group to be associated with this cluster.
    :param availability_zone: The EC2 Availability Zone (AZ).
    :param preferred_maintenance_window: The time range (in UTC) during which automated cluster
        maintenance can occur.
    :param cluster_parameter_group_name: The name of the parameter group to be associated with this cluster.
    :param automated_snapshot_retention_period: The number of days that automated snapshots are retained.
        The default value is ``1``.
    :param manual_snapshot_retention_period: The default number of days to retain a manual snapshot.
    :param port: The port number on which the cluster accepts incoming connections.
        The Default value is ``5439``.
    :param cluster_version: The version of a Redshift engine software that you want to deploy on the cluster.
    :param allow_version_upgrade: Whether major version upgrades can be applied during the maintenance window.
        The Default value is ``True``.
    :parma publicly_accessible: Whether cluster can be accessed from a public network.
    :parma encrypted: Whether data in the cluster is encrypted at rest.
        The default value is ``False``.
    :parma hsm_client_certificate_identifier: Name of the HSM client certificate
        the Amazon Redshift cluster uses to retrieve the data.
    :parma hsm_configuration_identifier: Name of the HSM configuration
    :parma elastic_ip: The Elastic IP (EIP) address for the cluster.
    :parma tags: A list of tag instances
    :parma kms_key_id: KMS key id of encryption key.
    :param enhanced_vpc_routing: Whether to create the cluster with enhanced VPC routing enabled
        Default value is ``False``.
    :param additional_info: Reserved
    :param iam_roles: A list of IAM roles that can be used by the cluster to access other AWS services.
    :param maintenance_track_name: Name of the maintenance track for the cluster.
    :param snapshot_schedule_identifier: A  unique identifier for the snapshot schedule.
    :param availability_zone_relocation: Enable relocation for a Redshift cluster
        between Availability Zones after the cluster is created.
    :param aqua_configuration_status: The cluster is configured to use AQUA .
    :param default_iam_role_arn: ARN for the IAM role.
    :param aws_conn_id: str = The Airflow connection used for AWS credentials.
        The default connection id is ``aws_default``.
    :param wait_for_completion: Whether wait for the cluster to be in ``available`` state
    :param max_attempt: The maximum number of attempts to be made. Default: 5
    :param poll_interval: The amount of time in seconds to wait between attempts. Default: 60
    """
[docs]    template_fields: Sequence[str] = (
        "cluster_identifier",
        "cluster_type",
        "node_type",
        "number_of_nodes",
        "vpc_security_group_ids", 
    )
    def __init__(
        self,
        *,
        cluster_identifier: str,
        node_type: str,
        master_username: str,
        master_user_password: str,
        cluster_type: str = "multi-node",
        db_name: str = "dev",
        number_of_nodes: int = 1,
        cluster_security_groups: list[str] | None = None,
        vpc_security_group_ids: list[str] | None = None,
        cluster_subnet_group_name: str | None = None,
        availability_zone: str | None = None,
        preferred_maintenance_window: str | None = None,
        cluster_parameter_group_name: str | None = None,
        automated_snapshot_retention_period: int = 1,
        manual_snapshot_retention_period: int | None = None,
        port: int = 5439,
        cluster_version: str = "1.0",
        allow_version_upgrade: bool = True,
        publicly_accessible: bool = True,
        encrypted: bool = False,
        hsm_client_certificate_identifier: str | None = None,
        hsm_configuration_identifier: str | None = None,
        elastic_ip: str | None = None,
        tags: list[Any] | None = None,
        kms_key_id: str | None = None,
        enhanced_vpc_routing: bool = False,
        additional_info: str | None = None,
        iam_roles: list[str] | None = None,
        maintenance_track_name: str | None = None,
        snapshot_schedule_identifier: str | None = None,
        availability_zone_relocation: bool | None = None,
        aqua_configuration_status: str | None = None,
        default_iam_role_arn: str | None = None,
        aws_conn_id: str = "aws_default",
        wait_for_completion: bool = False,
        max_attempt: int = 5,
        poll_interval: int = 60,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.cluster_identifier = cluster_identifier
        self.node_type = node_type
        self.master_username = master_username
        self.master_user_password = master_user_password
        self.cluster_type = cluster_type
        self.db_name = db_name
        self.number_of_nodes = number_of_nodes
        self.cluster_security_groups = cluster_security_groups
        self.vpc_security_group_ids = vpc_security_group_ids
        self.cluster_subnet_group_name = cluster_subnet_group_name
        self.availability_zone = availability_zone
        self.preferred_maintenance_window = preferred_maintenance_window
        self.cluster_parameter_group_name = cluster_parameter_group_name
        self.automated_snapshot_retention_period = automated_snapshot_retention_period
        self.manual_snapshot_retention_period = manual_snapshot_retention_period
        self.port = port
        self.cluster_version = cluster_version
        self.allow_version_upgrade = allow_version_upgrade
        self.publicly_accessible = publicly_accessible
        self.encrypted = encrypted
        self.hsm_client_certificate_identifier = hsm_client_certificate_identifier
        self.hsm_configuration_identifier = hsm_configuration_identifier
        self.elastic_ip = elastic_ip
        self.tags = tags
        self.kms_key_id = kms_key_id
        self.enhanced_vpc_routing = enhanced_vpc_routing
        self.additional_info = additional_info
        self.iam_roles = iam_roles
        self.maintenance_track_name = maintenance_track_name
        self.snapshot_schedule_identifier = snapshot_schedule_identifier
        self.availability_zone_relocation = availability_zone_relocation
        self.aqua_configuration_status = aqua_configuration_status
        self.default_iam_role_arn = default_iam_role_arn
        self.aws_conn_id = aws_conn_id
        self.wait_for_completion = wait_for_completion
        self.max_attempt = max_attempt
        self.poll_interval = poll_interval
        self.kwargs = kwargs
[docs]    def execute(self, context: Context):
        redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
        self.log.info("Creating Redshift cluster %s", self.cluster_identifier)
        params: dict[str, Any] = {}
        if self.db_name:
            params["DBName"] = self.db_name
        if self.cluster_type:
            params["ClusterType"] = self.cluster_type
            if self.cluster_type == "multi-node":
                params["NumberOfNodes"] = self.number_of_nodes
        if self.cluster_security_groups:
            params["ClusterSecurityGroups"] = self.cluster_security_groups
        if self.vpc_security_group_ids:
            params["VpcSecurityGroupIds"] = self.vpc_security_group_ids
        if self.cluster_subnet_group_name:
            params["ClusterSubnetGroupName"] = self.cluster_subnet_group_name
        if self.availability_zone:
            params["AvailabilityZone"] = self.availability_zone
        if self.preferred_maintenance_window:
            params["PreferredMaintenanceWindow"] = self.preferred_maintenance_window
        if self.cluster_parameter_group_name:
            params["ClusterParameterGroupName"] = self.cluster_parameter_group_name
        if self.automated_snapshot_retention_period:
            params["AutomatedSnapshotRetentionPeriod"] = self.automated_snapshot_retention_period
        if self.manual_snapshot_retention_period:
            params["ManualSnapshotRetentionPeriod"] = self.manual_snapshot_retention_period
        if self.port:
            params["Port"] = self.port
        if self.cluster_version:
            params["ClusterVersion"] = self.cluster_version
        if self.allow_version_upgrade:
            params["AllowVersionUpgrade"] = self.allow_version_upgrade
        if self.publicly_accessible:
            params["PubliclyAccessible"] = self.publicly_accessible
        if self.encrypted:
            params["Encrypted"] = self.encrypted
        if self.hsm_client_certificate_identifier:
            params["HsmClientCertificateIdentifier"] = self.hsm_client_certificate_identifier
        if self.hsm_configuration_identifier:
            params["HsmConfigurationIdentifier"] = self.hsm_configuration_identifier
        if self.elastic_ip:
            params["ElasticIp"] = self.elastic_ip
        if self.tags:
            params["Tags"] = self.tags
        if self.kms_key_id:
            params["KmsKeyId"] = self.kms_key_id
        if self.enhanced_vpc_routing:
            params["EnhancedVpcRouting"] = self.enhanced_vpc_routing
        if self.additional_info:
            params["AdditionalInfo"] = self.additional_info
        if self.iam_roles:
            params["IamRoles"] = self.iam_roles
        if self.maintenance_track_name:
            params["MaintenanceTrackName"] = self.maintenance_track_name
        if self.snapshot_schedule_identifier:
            params["SnapshotScheduleIdentifier"] = self.snapshot_schedule_identifier
        if self.availability_zone_relocation:
            params["AvailabilityZoneRelocation"] = self.availability_zone_relocation
        if self.aqua_configuration_status:
            params["AquaConfigurationStatus"] = self.aqua_configuration_status
        if self.default_iam_role_arn:
            params["DefaultIamRoleArn"] = self.default_iam_role_arn
        cluster = redshift_hook.create_cluster(
            self.cluster_identifier,
            self.node_type,
            self.master_username,
            self.master_user_password,
            params,
        )
        if self.wait_for_completion:
            redshift_hook.get_conn().get_waiter("cluster_available").wait(
                ClusterIdentifier=self.cluster_identifier,
                WaiterConfig={
                    "Delay": self.poll_interval,
                    "MaxAttempts": self.max_attempt,
                },
            )
        self.log.info("Created Redshift cluster %s", self.cluster_identifier)
        self.log.info(cluster)  
[docs]class RedshiftCreateClusterSnapshotOperator(BaseOperator):
    """
    Creates a manual snapshot of the specified cluster. The cluster must be in the available state
    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:RedshiftCreateClusterSnapshotOperator`
    :param snapshot_identifier: A unique identifier for the snapshot that you are requesting
    :param cluster_identifier: The cluster identifier for which you want a snapshot
    :param retention_period: The number of days that a manual snapshot is retained.
        If the value is -1, the manual snapshot is retained indefinitely.
    :param wait_for_completion: Whether wait for the cluster snapshot to be in ``available`` state
    :param poll_interval: Time (in seconds) to wait between two consecutive calls to check state
    :param max_attempt: The maximum number of attempts to be made to check the state
    :param aws_conn_id: The Airflow connection used for AWS credentials.
        The default connection id is ``aws_default``
    """
[docs]    template_fields: Sequence[str] = (
        "cluster_identifier",
        "snapshot_identifier", 
    )
    def __init__(
        self,
        *,
        snapshot_identifier: str,
        cluster_identifier: str,
        retention_period: int = -1,
        wait_for_completion: bool = False,
        poll_interval: int = 15,
        max_attempt: int = 20,
        aws_conn_id: str = "aws_default",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.snapshot_identifier = snapshot_identifier
        self.cluster_identifier = cluster_identifier
        self.retention_period = retention_period
        self.wait_for_completion = wait_for_completion
        self.poll_interval = poll_interval
        self.max_attempt = max_attempt
        self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id)
[docs]    def execute(self, context: Context) -> Any:
        cluster_state = self.redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
        if cluster_state != "available":
            raise AirflowException(
                "Redshift cluster must be in available state. "
                f"Redshift cluster current state is {cluster_state}"
            )
        self.redshift_hook.create_cluster_snapshot(
            cluster_identifier=self.cluster_identifier,
            snapshot_identifier=self.snapshot_identifier,
            retention_period=self.retention_period,
        )
        if self.wait_for_completion:
            self.redshift_hook.get_conn().get_waiter("snapshot_available").wait(
                ClusterIdentifier=self.cluster_identifier,
                WaiterConfig={
                    "Delay": self.poll_interval,
                    "MaxAttempts": self.max_attempt,  
                },
            )
[docs]class RedshiftDeleteClusterSnapshotOperator(BaseOperator):
    """
    Deletes the specified manual snapshot
    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:RedshiftDeleteClusterSnapshotOperator`
    :param snapshot_identifier: A unique identifier for the snapshot that you are requesting
    :param cluster_identifier: The unique identifier of the cluster the snapshot was created from
    :param wait_for_completion: Whether wait for cluster deletion or not
        The default value is ``True``
    :param aws_conn_id: The Airflow connection used for AWS credentials.
        The default connection id is ``aws_default``
    :param poll_interval: Time (in seconds) to wait between two consecutive calls to check snapshot state
    """
[docs]    template_fields: Sequence[str] = (
        "cluster_identifier",
        "snapshot_identifier", 
    )
    def __init__(
        self,
        *,
        snapshot_identifier: str,
        cluster_identifier: str,
        wait_for_completion: bool = True,
        aws_conn_id: str = "aws_default",
        poll_interval: int = 10,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.snapshot_identifier = snapshot_identifier
        self.cluster_identifier = cluster_identifier
        self.wait_for_completion = wait_for_completion
        self.poll_interval = poll_interval
        self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id)
[docs]    def execute(self, context: Context) -> Any:
        self.redshift_hook.get_conn().delete_cluster_snapshot(
            SnapshotClusterIdentifier=self.cluster_identifier,
            SnapshotIdentifier=self.snapshot_identifier,
        )
        if self.wait_for_completion:
            while self.get_status() is not None:
                time.sleep(self.poll_interval) 
[docs]    def get_status(self) -> str:
        return self.redshift_hook.get_cluster_snapshot_status(
            snapshot_identifier=self.snapshot_identifier,  
        )
[docs]class RedshiftResumeClusterOperator(BaseOperator):
    """
    Resume a paused AWS Redshift Cluster
    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:RedshiftResumeClusterOperator`
    :param cluster_identifier:  Unique identifier of the AWS Redshift cluster
    :param aws_conn_id: The Airflow connection used for AWS credentials.
        The default connection id is ``aws_default``
    :param deferrable: Run operator in deferrable mode
    :param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state
    """
[docs]    template_fields: Sequence[str] = ("cluster_identifier",) 
    def __init__(
        self,
        *,
        cluster_identifier: str,
        aws_conn_id: str = "aws_default",
        deferrable: bool = False,
        poll_interval: int = 10,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.cluster_identifier = cluster_identifier
        self.aws_conn_id = aws_conn_id
        self.deferrable = deferrable
        self.poll_interval = poll_interval
        # These parameters are added to address an issue with the boto3 API where the API
        # prematurely reports the cluster as available to receive requests. This causes the cluster
        # to reject initial attempts to resume the cluster despite reporting the correct state.
        self._attempts = 10
        self._attempt_interval = 15
[docs]    def execute(self, context: Context):
        redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
        if self.deferrable:
            self.defer(
                timeout=self.execution_timeout,
                trigger=RedshiftClusterTrigger(
                    task_id=self.task_id,
                    poll_interval=self.poll_interval,
                    aws_conn_id=self.aws_conn_id,
                    cluster_identifier=self.cluster_identifier,
                    attempts=self._attempts,
                    operation_type="pause_cluster",
                ),
                method_name="execute_complete",
            )
        else:
            while self._attempts >= 1:
                try:
                    redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
                    return
                except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
                    self._attempts = self._attempts - 1
                    if self._attempts > 0:
                        self.log.error("Unable to resume cluster. %d attempts remaining.", self._attempts)
                        time.sleep(self._attempt_interval)
                    else:
                        raise error 
[docs]    def execute_complete(self, context: Context, event: Any = None) -> None:
        """
        Callback for when the trigger fires - returns immediately.
        Relies on trigger to throw an exception, otherwise it assumes execution was
        successful.
        """
        if event:
            if "status" in event and event["status"] == "error":
                msg = f"{event['status']}: {event['message']}"
                raise AirflowException(msg)
            elif "status" in event and event["status"] == "success":
                self.log.info("%s completed successfully.", self.task_id)
                self.log.info("Paused cluster successfully")
        else:
            raise AirflowException("No event received from trigger")  
[docs]class RedshiftPauseClusterOperator(BaseOperator):
    """
    Pause an AWS Redshift Cluster if it has status `available`.
    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:RedshiftPauseClusterOperator`
    :param cluster_identifier: id of the AWS Redshift Cluster
    :param aws_conn_id: aws connection to use
    :param deferrable: Run operator in the deferrable mode. This mode requires an additional aiobotocore>=
    """
[docs]    template_fields: Sequence[str] = ("cluster_identifier",) 
    def __init__(
        self,
        *,
        cluster_identifier: str,
        aws_conn_id: str = "aws_default",
        deferrable: bool = False,
        poll_interval: int = 10,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.cluster_identifier = cluster_identifier
        self.aws_conn_id = aws_conn_id
        self.deferrable = deferrable
        self.poll_interval = poll_interval
        # These parameters are added to address an issue with the boto3 API where the API
        # prematurely reports the cluster as available to receive requests. This causes the cluster
        # to reject initial attempts to pause the cluster despite reporting the correct state.
        self._attempts = 10
        self._attempt_interval = 15
[docs]    def execute(self, context: Context):
        redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
        if self.deferrable:
            self.defer(
                timeout=self.execution_timeout,
                trigger=RedshiftClusterTrigger(
                    task_id=self.task_id,
                    poll_interval=self.poll_interval,
                    aws_conn_id=self.aws_conn_id,
                    cluster_identifier=self.cluster_identifier,
                    attempts=self._attempts,
                    operation_type="pause_cluster",
                ),
                method_name="execute_complete",
            )
        else:
            while self._attempts >= 1:
                try:
                    redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
                    return
                except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
                    self._attempts = self._attempts - 1
                    if self._attempts > 0:
                        self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
                        time.sleep(self._attempt_interval)
                    else:
                        raise error 
[docs]    def execute_complete(self, context: Context, event: Any = None) -> None:
        """
        Callback for when the trigger fires - returns immediately.
        Relies on trigger to throw an exception, otherwise it assumes execution was
        successful.
        """
        if event:
            if "status" in event and event["status"] == "error":
                msg = f"{event['status']}: {event['message']}"
                raise AirflowException(msg)
            elif "status" in event and event["status"] == "success":
                self.log.info("%s completed successfully.", self.task_id)
                self.log.info("Paused cluster successfully")
        else:
            raise AirflowException("No event received from trigger")  
[docs]class RedshiftDeleteClusterOperator(BaseOperator):
    """
    Delete an AWS Redshift cluster.
    .. seealso::
        For more information on how to use this operator, take a look at the guide:
        :ref:`howto/operator:RedshiftDeleteClusterOperator`
    :param cluster_identifier: unique identifier of a cluster
    :param skip_final_cluster_snapshot: determines cluster snapshot creation
    :param final_cluster_snapshot_identifier: name of final cluster snapshot
    :param wait_for_completion: Whether wait for cluster deletion or not
        The default value is ``True``
    :param aws_conn_id: aws connection to use
    :param poll_interval: Time (in seconds) to wait between two consecutive calls to check cluster state
    """
[docs]    template_fields: Sequence[str] = ("cluster_identifier",) 
    def __init__(
        self,
        *,
        cluster_identifier: str,
        skip_final_cluster_snapshot: bool = True,
        final_cluster_snapshot_identifier: str | None = None,
        wait_for_completion: bool = True,
        aws_conn_id: str = "aws_default",
        poll_interval: float = 30.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.cluster_identifier = cluster_identifier
        self.skip_final_cluster_snapshot = skip_final_cluster_snapshot
        self.final_cluster_snapshot_identifier = final_cluster_snapshot_identifier
        self.wait_for_completion = wait_for_completion
        self.poll_interval = poll_interval
        # These parameters are added to keep trying if there is a running operation in the cluster
        # If there is a running operation in the cluster while trying to delete it, a InvalidClusterStateFault
        # is thrown. In such case, retrying
        self._attempts = 10
        self._attempt_interval = 15
        self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id)
[docs]    def execute(self, context: Context):
        while self._attempts >= 1:
            try:
                self.redshift_hook.delete_cluster(
                    cluster_identifier=self.cluster_identifier,
                    skip_final_cluster_snapshot=self.skip_final_cluster_snapshot,
                    final_cluster_snapshot_identifier=self.final_cluster_snapshot_identifier,
                )
                break
            except self.redshift_hook.get_conn().exceptions.InvalidClusterStateFault:
                self._attempts = self._attempts - 1
                if self._attempts > 0:
                    self.log.error("Unable to delete cluster. %d attempts remaining.", self._attempts)
                    time.sleep(self._attempt_interval)
                else:
                    raise
        if self.wait_for_completion:
            waiter = self.redshift_hook.get_conn().get_waiter("cluster_deleted")
            waiter.wait(
                ClusterIdentifier=self.cluster_identifier,
                WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": 30},  
            )