Source code for tests.system.openai.example_trigger_batch_operator
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.from__future__importannotationsfromtypingimportAny,Literalfromairflow.decoratorsimportdag,task
[docs]defopenai_batch_chat_completions():@taskdefgenerate_messages(pokemon,**context)->list[dict[str,Any]]:return[{"role":"user","content":f"Describe the info about {pokemon}?"}]@taskdefbatch_upload(messages_batch,**context)->str:importtempfileimportuuidfrompydanticimportBaseModel,Fieldfromairflow.providers.openai.hooks.openaiimportOpenAIHookclassRequestBody(BaseModel):model:strmessages:list[dict[str,Any]]max_tokens:int=Field(default=1000)classBatchModel(BaseModel):custom_id:strmethod:Literal["POST"]url:Literal["/v1/chat/completions"]body:RequestBodymodel="gpt-4o-mini"max_tokens=1000hook=OpenAIHook(conn_id=OPENAI_CONN_ID)withtempfile.NamedTemporaryFile(mode="w",delete=False)asfile:formessagesinmessages_batch:file.write(BatchModel(custom_id=str(uuid.uuid4()),method="POST",url="/v1/chat/completions",body=RequestBody(model=model,max_tokens=max_tokens,messages=messages,),).model_dump_json()+"\n")batch_file=hook.upload_file(file.name,purpose="batch")returnbatch_file.id@taskdefcleanup_batch_output_file(batch_id,**context):fromairflow.providers.openai.hooks.openaiimportOpenAIHookhook=OpenAIHook(conn_id=OPENAI_CONN_ID)batch=hook.get_batch(batch_id)ifbatch.output_file_id:hook.delete_file(batch.output_file_id)messages=generate_messages.expand(pokemon=POKEMONS)batch_file_id=batch_upload(messages_batch=messages)# [START howto_operator_openai_trigger_operator]fromairflow.providers.openai.operators.openaiimportOpenAITriggerBatchOperatorbatch_id=OpenAITriggerBatchOperator(task_id="batch_operator_deferred",conn_id=OPENAI_CONN_ID,file_id=batch_file_id,endpoint="/v1/chat/completions",deferrable=True,)# [END howto_operator_openai_trigger_operator]cleanup_batch_output=cleanup_batch_output_file(batch_id="{{ ti.xcom_pull(task_ids='batch_operator_deferred', key='return_value') }}")batch_id>>cleanup_batch_output
openai_batch_chat_completions()fromtests_common.test_utils.system_testsimportget_test_run# noqa: E402# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)