Source code for airflow.providers.google.cloud.hooks.vision
## Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""This module contains a Google Cloud Vision Hook."""from__future__importannotationsfromcopyimportdeepcopyfromfunctoolsimportcached_propertyfromtypingimportTYPE_CHECKING,Any,Callable,Sequencefromgoogle.api_core.gapic_v1.methodimportDEFAULT,_MethodDefaultfromgoogle.cloud.vision_v1import(AnnotateImageRequest,Image,ImageAnnotatorClient,Product,ProductSearchClient,ProductSet,ReferenceImage,)fromgoogle.protobuf.json_formatimportMessageToDictfromairflow.exceptionsimportAirflowExceptionfromairflow.providers.google.common.constsimportCLIENT_INFOfromairflow.providers.google.common.hooks.base_googleimportPROVIDE_PROJECT_ID,GoogleBaseHookifTYPE_CHECKING:fromgoogle.api_core.retryimportRetryfromgoogle.protobufimportfield_mask_pb2
[docs]ERR_DIFF_NAMES="""The {label} name provided in the object ({explicit_name}) is different than the name created from the input parameters ({constructed_name}). Please either: 1) Remove the {label} name, 2) Remove the location and {id_label} parameters, 3) Unify the {label} name and input parameters. """
[docs]ERR_UNABLE_TO_CREATE="""Unable to determine the {label} name. Please either set the name directly in the {label} object or provide the `location` and `{id_label}` parameters. """
[docs]classNameDeterminer:"""Helper class to determine entity name."""def__init__(self,label:str,id_label:str,get_path:Callable[[str,str,str],str])->None:self.label=labelself.id_label=id_labelself.get_path=get_path
[docs]defget_entity_with_name(self,entity:Any,entity_id:str|None,location:str|None,project_id:str)->Any:""" Check if entity has the `name` attribute set. * If so, no action is taken. * If not, and the name can be constructed from other parameters provided, it is created and filled in the entity. * If both the entity's 'name' attribute is set and the name can be constructed from other parameters provided: * If they are the same - no action is taken * if they are different - an exception is thrown. :param entity: Entity :param entity_id: Entity id :param location: Location :param project_id: The id of Google Cloud Vision project. :return: The same entity or entity with new name :raises: AirflowException """entity=deepcopy(entity)explicit_name=getattr(entity,"name")iflocationandentity_id:# Necessary parameters to construct the name are present. Checking for conflict with explicit nameconstructed_name=self.get_path(project_id,location,entity_id)ifnotexplicit_name:entity.name=constructed_namereturnentityifexplicit_name!=constructed_name:raiseAirflowException(ERR_DIFF_NAMES.format(label=self.label,explicit_name=explicit_name,constructed_name=constructed_name,id_label=self.id_label,))# Not enough parameters to construct the name. Trying to use the name from Product / ProductSet.ifexplicit_name:returnentityelse:raiseAirflowException(ERR_UNABLE_TO_CREATE.format(label=self.label,id_label=self.id_label))
[docs]classCloudVisionHook(GoogleBaseHook):""" Hook for Google Cloud Vision APIs. All the methods in the hook where project_id is used must be called with keyword arguments rather than positional. """_client:ProductSearchClient|None
def__init__(self,gcp_conn_id:str="google_cloud_default",impersonation_chain:str|Sequence[str]|None=None,**kwargs,)->None:ifkwargs.get("delegate_to")isnotNone:raiseRuntimeError("The `delegate_to` parameter has been deprecated before and finally removed in this version"" of Google Provider. You MUST convert it to `impersonate_chain`")super().__init__(gcp_conn_id=gcp_conn_id,impersonation_chain=impersonation_chain,)self._client=None
[docs]defget_conn(self)->ProductSearchClient:""" Retrieves connection to Cloud Vision. :return: Google Cloud Vision client object. """ifnotself._client:self._client=ProductSearchClient(credentials=self.get_credentials(),client_info=CLIENT_INFO)returnself._client
@cached_property
[docs]defannotator_client(self)->ImageAnnotatorClient:""" Creates ImageAnnotatorClient. :return: Google Image Annotator client object. """returnImageAnnotatorClient(credentials=self.get_credentials())
[docs]defcreate_product_set(self,location:str,product_set:ProductSet|None,project_id:str=PROVIDE_PROJECT_ID,product_set_id:str|None=None,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->str:""" Create product set. For the documentation see: :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator`. """client=self.get_conn()parent=f"projects/{project_id}/locations/{location}"self.log.info("Creating a new ProductSet under the parent: %s",parent)response=client.create_product_set(parent=parent,product_set=product_set,product_set_id=product_set_id,retry=retry,timeout=timeout,metadata=metadata,)self.log.info("ProductSet created: %s",response.nameifresponseelse"")self.log.debug("ProductSet created:\n%s",response)ifnotproduct_set_id:# Product set id was generated by the APIproduct_set_id=self._get_autogenerated_id(response)self.log.info("Extracted autogenerated ProductSet ID from the response: %s",product_set_id)returnproduct_set_id
@GoogleBaseHook.fallback_to_default_project_id
[docs]defget_product_set(self,location:str,product_set_id:str,project_id:str=PROVIDE_PROJECT_ID,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->dict:""" Get product set. For the documentation see: :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator`. """client=self.get_conn()name=ProductSearchClient.product_set_path(project_id,location,product_set_id)self.log.info("Retrieving ProductSet: %s",name)response=client.get_product_set(name=name,retry=retry,timeout=timeout,metadata=metadata)self.log.info("ProductSet retrieved.")self.log.debug("ProductSet retrieved:\n%s",response)returnMessageToDict(response._pb)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defupdate_product_set(self,product_set:dict|ProductSet,project_id:str=PROVIDE_PROJECT_ID,location:str|None=None,product_set_id:str|None=None,update_mask:dict|field_mask_pb2.FieldMask|None=None,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->dict:""" Update product set. For the documentation see: :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator`. """client=self.get_conn()product_set=self.product_set_name_determiner.get_entity_with_name(product_set,product_set_id,location,project_id)ifisinstance(product_set,dict):product_set=ProductSet(product_set)self.log.info("Updating ProductSet: %s",product_set.name)response=client.update_product_set(product_set=product_set,update_mask=update_mask,# type: ignoreretry=retry,timeout=timeout,metadata=metadata,)self.log.info("ProductSet updated: %s",response.nameifresponseelse"")self.log.debug("ProductSet updated:\n%s",response)returnMessageToDict(response._pb)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defdelete_product_set(self,location:str,product_set_id:str,project_id:str=PROVIDE_PROJECT_ID,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->None:""" Delete product set. For the documentation see: :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator`. """client=self.get_conn()name=ProductSearchClient.product_set_path(project_id,location,product_set_id)self.log.info("Deleting ProductSet: %s",name)client.delete_product_set(name=name,retry=retry,timeout=timeout,metadata=metadata)self.log.info("ProductSet with the name [%s] deleted.",name)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defcreate_product(self,location:str,product:dict|Product,project_id:str=PROVIDE_PROJECT_ID,product_id:str|None=None,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),):""" Create product. For the documentation see: :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator`. """client=self.get_conn()parent=f"projects/{project_id}/locations/{location}"self.log.info("Creating a new Product under the parent: %s",parent)ifisinstance(product,dict):product=Product(product)response=client.create_product(parent=parent,product=product,product_id=product_id,retry=retry,timeout=timeout,metadata=metadata,)self.log.info("Product created: %s",response.nameifresponseelse"")self.log.debug("Product created:\n%s",response)ifnotproduct_id:# Product id was generated by the APIproduct_id=self._get_autogenerated_id(response)self.log.info("Extracted autogenerated Product ID from the response: %s",product_id)returnproduct_id
@GoogleBaseHook.fallback_to_default_project_id
[docs]defget_product(self,location:str,product_id:str,project_id:str=PROVIDE_PROJECT_ID,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),):""" Get product. For the documentation see: :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator`. """client=self.get_conn()name=ProductSearchClient.product_path(project_id,location,product_id)self.log.info("Retrieving Product: %s",name)response=client.get_product(name=name,retry=retry,timeout=timeout,metadata=metadata)self.log.info("Product retrieved.")self.log.debug("Product retrieved:\n%s",response)returnMessageToDict(response._pb)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defupdate_product(self,product:dict|Product,project_id:str=PROVIDE_PROJECT_ID,location:str|None=None,product_id:str|None=None,update_mask:dict|field_mask_pb2.FieldMask|None=None,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),):""" Update product. For the documentation see: :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator`. """client=self.get_conn()product=self.product_name_determiner.get_entity_with_name(product,product_id,location,project_id)ifisinstance(product,dict):product=Product(product)self.log.info("Updating ProductSet: %s",product.name)response=client.update_product(product=product,update_mask=update_mask,# type: ignoreretry=retry,timeout=timeout,metadata=metadata,)self.log.info("Product updated: %s",response.nameifresponseelse"")self.log.debug("Product updated:\n%s",response)returnMessageToDict(response._pb)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defdelete_product(self,location:str,product_id:str,project_id:str=PROVIDE_PROJECT_ID,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->None:""" Delete product. For the documentation see: :class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator`. """client=self.get_conn()name=ProductSearchClient.product_path(project_id,location,product_id)self.log.info("Deleting ProductSet: %s",name)client.delete_product(name=name,retry=retry,timeout=timeout,metadata=metadata)self.log.info("Product with the name [%s] deleted:",name)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defcreate_reference_image(self,location:str,product_id:str,reference_image:dict|ReferenceImage,project_id:str,reference_image_id:str|None=None,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->str:""" Create reference image. For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator`. """client=self.get_conn()self.log.info("Creating ReferenceImage")parent=ProductSearchClient.product_path(project=project_id,location=location,product=product_id)ifisinstance(reference_image,dict):reference_image=ReferenceImage(reference_image)response=client.create_reference_image(parent=parent,reference_image=reference_image,reference_image_id=reference_image_id,retry=retry,timeout=timeout,metadata=metadata,)self.log.info("ReferenceImage created: %s",response.nameifresponseelse"")self.log.debug("ReferenceImage created:\n%s",response)ifnotreference_image_id:# Reference image id was generated by the APIreference_image_id=self._get_autogenerated_id(response)self.log.info("Extracted autogenerated ReferenceImage ID from the response: %s",reference_image_id)returnreference_image_id
@GoogleBaseHook.fallback_to_default_project_id
[docs]defdelete_reference_image(self,location:str,product_id:str,reference_image_id:str,project_id:str,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->None:""" Delete reference image. For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDeleteReferenceImageOperator`. """client=self.get_conn()self.log.info("Deleting ReferenceImage")name=ProductSearchClient.reference_image_path(project=project_id,location=location,product=product_id,reference_image=reference_image_id)client.delete_reference_image(name=name,retry=retry,timeout=timeout,metadata=metadata,)self.log.info("ReferenceImage with the name [%s] deleted.",name)
@GoogleBaseHook.fallback_to_default_project_id
[docs]defadd_product_to_product_set(self,product_set_id:str,product_id:str,project_id:str,location:str,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->None:""" Add product to product set. For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionAddProductToProductSetOperator`. """client=self.get_conn()product_name=ProductSearchClient.product_path(project_id,location,product_id)product_set_name=ProductSearchClient.product_set_path(project_id,location,product_set_id)self.log.info("Add Product[name=%s] to Product Set[name=%s]",product_name,product_set_name)client.add_product_to_product_set(name=product_set_name,product=product_name,retry=retry,timeout=timeout,metadata=metadata)self.log.info("Product added to Product Set")
@GoogleBaseHook.fallback_to_default_project_id
[docs]defremove_product_from_product_set(self,product_set_id:str,product_id:str,project_id:str,location:str,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,metadata:Sequence[tuple[str,str]]=(),)->None:""" Remove product from product set. For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionRemoveProductFromProductSetOperator`. """client=self.get_conn()product_name=ProductSearchClient.product_path(project_id,location,product_id)product_set_name=ProductSearchClient.product_set_path(project_id,location,product_set_id)self.log.info("Remove Product[name=%s] from Product Set[name=%s]",product_name,product_set_name)client.remove_product_from_product_set(name=product_set_name,product=product_name,retry=retry,timeout=timeout,metadata=metadata)self.log.info("Product removed from Product Set")
[docs]defannotate_image(self,request:dict|AnnotateImageRequest,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,)->dict:""" Annotate image. For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator`. """client=self.annotator_clientself.log.info("Annotating image")response=client.annotate_image(request=request,retry=retry,timeout=timeout)self.log.info("Image annotated")returnMessageToDict(response._pb)
@GoogleBaseHook.quota_retry()
[docs]defbatch_annotate_images(self,requests:list[dict]|list[AnnotateImageRequest],retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,)->dict:""" Batch annotate images. For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator`. """client=self.annotator_clientself.log.info("Annotating images")requests=list(map(AnnotateImageRequest,requests))response=client.batch_annotate_images(requests=requests,retry=retry,timeout=timeout)self.log.info("Images annotated")returnMessageToDict(response._pb)
@GoogleBaseHook.quota_retry()
[docs]deftext_detection(self,image:dict|Image,max_results:int|None=None,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,additional_properties:dict|None=None,)->dict:""" Text detection. For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDetectTextOperator`. """client=self.annotator_clientself.log.info("Detecting text")ifadditional_propertiesisNone:additional_properties={}response=client.text_detection(image=image,max_results=max_results,retry=retry,timeout=timeout,**additional_properties)response=MessageToDict(response._pb)self._check_for_error(response)self.log.info("Text detection finished")returnresponse
@GoogleBaseHook.quota_retry()
[docs]defdocument_text_detection(self,image:dict|Image,max_results:int|None=None,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,additional_properties:dict|None=None,)->dict:""" Document text detection. For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator`. """client=self.annotator_clientself.log.info("Detecting document text")ifadditional_propertiesisNone:additional_properties={}response=client.document_text_detection(image=image,max_results=max_results,retry=retry,timeout=timeout,**additional_properties)response=MessageToDict(response._pb)self._check_for_error(response)self.log.info("Document text detection finished")returnresponse
@GoogleBaseHook.quota_retry()
[docs]deflabel_detection(self,image:dict|Image,max_results:int|None=None,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,additional_properties:dict|None=None,)->dict:""" Label detection. For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDetectImageLabelsOperator`. """client=self.annotator_clientself.log.info("Detecting labels")ifadditional_propertiesisNone:additional_properties={}response=client.label_detection(image=image,max_results=max_results,retry=retry,timeout=timeout,**additional_properties)response=MessageToDict(response._pb)self._check_for_error(response)self.log.info("Labels detection finished")returnresponse
@GoogleBaseHook.quota_retry()
[docs]defsafe_search_detection(self,image:dict|Image,max_results:int|None=None,retry:Retry|_MethodDefault=DEFAULT,timeout:float|None=None,additional_properties:dict|None=None,)->dict:""" Safe search detection. For the documentation see: :py:class:`~airflow.providers.google.cloud.operators.vision.CloudVisionDetectImageSafeSearchOperator`. """client=self.annotator_clientself.log.info("Detecting safe search")ifadditional_propertiesisNone:additional_properties={}response=client.safe_search_detection(image=image,max_results=max_results,retry=retry,timeout=timeout,**additional_properties)response=MessageToDict(response._pb)self._check_for_error(response)self.log.info("Safe search detection finished")returnresponse
@staticmethoddef_get_autogenerated_id(response)->str:try:name=response.nameexceptAttributeErrorase:raiseAirflowException(f"Unable to get name from response... [{response}]\n{e}")if"/"notinname:raiseAirflowException(f"Unable to get id from name... [{name}]")returnname.rsplit("/",1)[1]