Source code for tests.system.openai.example_trigger_batch_operator
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Any, Literal
from airflow.decorators import dag, task
[docs]OPENAI_CONN_ID = "openai_default"
[docs]POKEMONS = [
"pikachu",
"charmander",
"bulbasaur",
]
@dag(
schedule=None,
catchup=False,
)
[docs]def openai_batch_chat_completions():
@task
def generate_messages(pokemon, **context) -> list[dict[str, Any]]:
return [{"role": "user", "content": f"Describe the info about {pokemon}?"}]
@task
def batch_upload(messages_batch, **context) -> str:
import tempfile
import uuid
from pydantic import BaseModel, Field
from airflow.providers.openai.hooks.openai import OpenAIHook
class RequestBody(BaseModel):
model: str
messages: list[dict[str, Any]]
max_tokens: int = Field(default=1000)
class BatchModel(BaseModel):
custom_id: str
method: Literal["POST"]
url: Literal["/v1/chat/completions"]
body: RequestBody
model = "gpt-4o-mini"
max_tokens = 1000
hook = OpenAIHook(conn_id=OPENAI_CONN_ID)
with tempfile.NamedTemporaryFile(mode="w", delete=False) as file:
for messages in messages_batch:
file.write(
BatchModel(
custom_id=str(uuid.uuid4()),
method="POST",
url="/v1/chat/completions",
body=RequestBody(
model=model,
max_tokens=max_tokens,
messages=messages,
),
).model_dump_json()
+ "\n"
)
batch_file = hook.upload_file(file.name, purpose="batch")
return batch_file.id
@task
def cleanup_batch_output_file(batch_id, **context):
from airflow.providers.openai.hooks.openai import OpenAIHook
hook = OpenAIHook(conn_id=OPENAI_CONN_ID)
batch = hook.get_batch(batch_id)
if batch.output_file_id:
hook.delete_file(batch.output_file_id)
messages = generate_messages.expand(pokemon=POKEMONS)
batch_file_id = batch_upload(messages_batch=messages)
# [START howto_operator_openai_trigger_operator]
from airflow.providers.openai.operators.openai import OpenAITriggerBatchOperator
batch_id = OpenAITriggerBatchOperator(
task_id="batch_operator_deferred",
conn_id=OPENAI_CONN_ID,
file_id=batch_file_id,
endpoint="/v1/chat/completions",
deferrable=True,
)
# [END howto_operator_openai_trigger_operator]
cleanup_batch_output = cleanup_batch_output_file(
batch_id="{{ ti.xcom_pull(task_ids='batch_operator_deferred', key='return_value') }}"
)
batch_id >> cleanup_batch_output
openai_batch_chat_completions()
from tests_common.test_utils.system_tests import get_test_run # noqa: E402
# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
[docs]test_run = get_test_run(dag)