Source code for airflow.providers.amazon.aws.executors.batch.utils
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsimportdatetimefromcollectionsimportdefaultdictfromdataclassesimportdataclassfromtypingimportTYPE_CHECKING,Any,Dict,Listfromairflow.providers.amazon.aws.executors.utils.base_config_keysimportBaseConfigKeysfromairflow.utils.stateimportStateifTYPE_CHECKING:fromairflow.models.taskinstanceimportTaskInstanceKey
[docs]defget_job_state(self)->str:"""Return the state of the job."""returnself.STATE_MAPPINGS.get(self.status,State.QUEUED)
[docs]def__repr__(self):"""Return a visual representation of the Job status."""returnf"({self.job_id} -> {self.status}, {self.get_job_state()})"
[docs]classBatchJobCollection:"""A collection to manage running Batch Jobs."""def__init__(self):self.key_to_id:dict[TaskInstanceKey,str]={}self.id_to_key:dict[str,TaskInstanceKey]={}self.id_to_failure_counts:dict[str,int]=defaultdict(int)self.id_to_job_info:dict[str,BatchJobInfo]={}
[docs]defadd_job(self,job_id:str,airflow_task_key:TaskInstanceKey,airflow_cmd:CommandType,queue:str,exec_config:ExecutorConfigType,attempt_number:int,):"""Add a job to the collection."""self.key_to_id[airflow_task_key]=job_idself.id_to_key[job_id]=airflow_task_keyself.id_to_failure_counts[job_id]=attempt_numberself.id_to_job_info[job_id]=BatchJobInfo(cmd=airflow_cmd,queue=queue,config=exec_config)
[docs]defpop_by_id(self,job_id:str)->TaskInstanceKey:"""Delete job from collection based off of Batch Job ID."""task_key=self.id_to_key[job_id]delself.key_to_id[task_key]delself.id_to_key[job_id]delself.id_to_failure_counts[job_id]returntask_key
[docs]deffailure_count_by_id(self,job_id:str)->int:"""Get the number of times a job has failed given a Batch Job Id."""returnself.id_to_failure_counts[job_id]
[docs]defincrement_failure_count(self,job_id:str):"""Increment the failure counter given a Batch Job Id."""self.id_to_failure_counts[job_id]+=1
[docs]defget_all_jobs(self)->list[str]:"""Get all AWS ARNs in collection."""returnlist(self.id_to_key.keys())
[docs]def__len__(self):"""Return the number of jobs in collection."""returnlen(self.key_to_id)
[docs]classBatchSubmitJobKwargsConfigKeys(BaseConfigKeys):"""Keys loaded into the config which are valid Batch submit_job kwargs."""