# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ReAct tool-calling agent with LangChain -- research and report pipeline.
Demonstrates the "agent as a task" pattern using a LangChain ReAct agent
that autonomously decides which tools to call, composed with common.ai's
:class:`~airflow.providers.common.ai.operators.llm.LLMOperator` for report
formatting and AIP-90 HITL operators for human review.
Unlike RAG examples (fixed pipeline: retrieve then synthesize), this
agent's tool-call sequence is determined by the LLM at runtime. The agent
might call zero tools or ten tools depending on the question. This is the
canonical "agent as a task" pattern: Airflow handles scheduling, retry,
connections, and the surrounding workflow; the LangChain agent handles
internal reasoning.
``example_langchain_tool_agent`` (manual trigger):
.. code-block:: text
prompt_review (HITLEntryOperator)
-> prepare_tools (@task)
-> run_research_agent (@task)
-> format_report (LLMOperator)
-> report_approval (ApprovalOperator)
**What this makes visible that running an agent alone hides:**
* The question goes through human review before the agent runs.
* The agent's raw findings are a visible XCom value between tasks.
* Report formatting is a separate, independently retryable LLM call.
* The formatted report requires human approval before delivery.
**Contrast with AIP-99's AgentOperator:**
AIP-99's ``AgentOperator`` uses PydanticAI for agent execution. This
example uses LangChain's ``create_agent`` with LangChain-native ``@tool``
definitions. Users with existing LangChain tools (700+ integrations)
can use them directly without rewriting as PydanticAI tools.
Before running:
1. Install LangChain packages::
pip install langchain langchain-openai langchain-text-splitters \\
langchain-community faiss-cpu
2. Create an LLM connection of type ``langchain`` named ``langchain_default``
(or the value of ``LLM_CONN_ID`` below) for your chosen model provider.
Set ``password`` to your API key; the ``host`` field is optional and only
needed for custom endpoints / Ollama.
3. Optionally place a knowledge base directory at ``DOCS_PATH`` and a
survey CSV at ``SURVEY_CSV_PATH``. If ``DOCS_PATH`` is empty, sample
documents about Apache Airflow are created automatically.
"""
from __future__ import annotations
import datetime
import os
from airflow.providers.common.ai.operators.llm import LLMOperator
from airflow.providers.common.compat.sdk import dag, task
from airflow.providers.standard.operators.hitl import ApprovalOperator, HITLEntryOperator
from airflow.sdk import Param
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
[docs]
LLM_CONN_ID = "langchain_default"
[docs]
LLM_MODEL = os.environ.get("LLM_MODEL", "openai:gpt-4o")
[docs]
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "openai:text-embedding-3-small")
[docs]
DOCS_PATH = os.environ.get("DOCS_PATH", "/opt/airflow/data/rag_documents")
[docs]
SURVEY_CSV_PATH = os.environ.get(
"SURVEY_CSV_PATH",
"/opt/airflow/data/airflow-user-survey-2025.csv",
)
[docs]
INDEX_PERSIST_DIR = os.environ.get("INDEX_PERSIST_DIR", "/opt/airflow/data/langchain_faiss_index")
[docs]
DEFAULT_QUESTION = (
"What percentage of Airflow users are on Kubernetes? "
"Also check what the documentation says about the KubernetesExecutor."
)
[docs]
SAMPLE_DOCUMENTS = {
"apache_airflow_overview.txt": (
"Apache Airflow is an open-source platform for programmatically authoring, "
"scheduling, and monitoring workflows. Originally created at Airbnb in 2014, "
"it graduated from the Apache Incubator in 2019. Airflow uses directed acyclic "
"graphs (DAGs) to define workflows as Python code, making pipelines versionable, "
"testable, and collaborative. The scheduler executes tasks on workers following "
"the defined dependencies. Airflow is widely used for ETL/ELT pipelines, ML model "
"training orchestration, and data warehouse management. As of Airflow 3.0, workers "
"communicate exclusively through the Execution API and never access the metadata "
"database directly, strengthening security and enabling horizontal scaling."
),
"kubernetes_executor.txt": (
"The KubernetesExecutor runs each Airflow task as a separate Kubernetes pod. "
"This provides strong isolation between tasks, dynamic resource allocation, and "
"the ability to use different Docker images per task. When a task is scheduled, "
"the executor creates a pod spec, submits it to the Kubernetes API, and monitors "
"the pod until completion. Resource requests and limits can be set per task via "
"executor_config. The KubernetesExecutor is recommended for heterogeneous "
"workloads where tasks have different resource requirements or dependencies. "
"It scales to zero when no tasks are running, reducing infrastructure costs. "
"In Airflow 3.0, pod specs are submitted via the Execution API."
),
"operators_and_hooks.txt": (
"Operators are the building blocks of Airflow tasks. Each operator defines a "
"single unit of work: BashOperator runs shell commands, PythonOperator executes "
"Python callables, and provider-specific operators interact with external systems "
"(S3, BigQuery, Spark, etc.). Hooks are the connection layer between operators "
"and external services. A hook manages authentication and provides methods to "
"interact with a specific service. For example, S3Hook provides methods to read "
"and write S3 objects, while PostgresHook connects to PostgreSQL databases."
),
"connections_and_variables.txt": (
"Connections store credentials and endpoint information for external services. "
"Each connection has a type (e.g., postgres, aws, http), login, password, host, "
"port, schema, and an extras JSON field for additional parameters. In Airflow 3.0, "
"workers access connections through the Execution API using short-lived JWT tokens "
"scoped to the running task instance. Variables are key-value pairs for storing "
"configuration that may change between environments."
),
"ai_operators.txt": (
"Airflow's common.ai provider (AIP-99) adds first-class AI/LLM support. "
"LLMOperator sends a prompt to any supported LLM and returns text or structured "
"output via Pydantic models. AgentOperator runs multi-turn reasoning with tools "
"(SQL, HTTP, MCP servers). LLMBranchOperator uses an LLM to choose downstream "
"task branches. All operators support human-in-the-loop review, durable execution "
"for long-running agents, usage limits for cost control, and connect to 20+ model "
"providers through Airflow connections."
),
}
[docs]
REPORT_SYSTEM_PROMPT = (
"You are a technical report writer. Format the research findings into a "
"clear, well-structured report with sections and bullet points. Cite "
"sources when available. Be concise but thorough."
)
# ---------------------------------------------------------------------------
# Helper: build or load the knowledge base FAISS index
# ---------------------------------------------------------------------------
def _ensure_knowledge_base(hook) -> str:
"""Build a FAISS index from sample docs if it does not already exist.
Returns the persist directory path.
"""
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
if os.path.exists(os.path.join(INDEX_PERSIST_DIR, "index.faiss")):
return INDEX_PERSIST_DIR
os.makedirs(DOCS_PATH, exist_ok=True)
for filename, content in SAMPLE_DOCUMENTS.items():
filepath = os.path.join(DOCS_PATH, filename)
if not os.path.exists(filepath):
with open(filepath, "w", encoding="utf-8") as f:
f.write(content)
docs = []
for filename in sorted(os.listdir(DOCS_PATH)):
if not filename.endswith((".txt", ".md")):
continue
with open(os.path.join(DOCS_PATH, filename), encoding="utf-8") as f:
docs.append(Document(page_content=f.read(), metadata={"source": filename}))
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
chunks = splitter.split_documents(docs)
vectorstore = FAISS.from_documents(chunks, hook.get_embedding_model())
os.makedirs(INDEX_PERSIST_DIR, exist_ok=True)
vectorstore.save_local(INDEX_PERSIST_DIR)
print(f"Built FAISS index: {len(chunks)} chunks in {INDEX_PERSIST_DIR}")
return INDEX_PERSIST_DIR
# ---------------------------------------------------------------------------
# Tool definitions (LangChain @tool decorator)
# ---------------------------------------------------------------------------
def _build_tools(hook, index_dir: str, survey_csv_path: str) -> list:
"""Construct the agent's tool set."""
from langchain.tools import tool
from langchain_community.vectorstores import FAISS
# Build the vector store once and close over it -- the agent may invoke
# search_knowledge_base many times, and reloading the FAISS index plus
# re-initialising the embedding model on every call would be wasteful.
vectorstore = FAISS.load_local(
index_dir, hook.get_embedding_model(), allow_dangerous_deserialization=True
)
# -- Tool 1: Knowledge base search (vector retrieval) ------------------
@tool
def search_knowledge_base(query: str) -> str:
"""Search the internal knowledge base for relevant documentation.
Use this for questions about Airflow features, architecture,
operators, executors, connections, or best practices.
"""
results = vectorstore.similarity_search(query, k=3)
if not results:
return "No relevant documents found in the knowledge base."
formatted = []
for i, doc in enumerate(results, 1):
source = doc.metadata.get("source", "unknown")
formatted.append(f"[{i}] Source: {source}\n{doc.page_content}")
return "\n\n".join(formatted)
# -- Tool 2: Survey data query ----------------------------------------
@tool
def query_survey_data(question: str) -> str:
"""Query the Airflow user survey dataset to answer questions about
Airflow adoption, usage patterns, executor choices, deployment
methods, cloud providers, and user demographics.
Pass a natural language question. The tool converts it to SQL
and executes it against the survey data.
"""
import csv
if not os.path.exists(survey_csv_path):
return (
"Survey data not available. The CSV file was not found at "
f"{survey_csv_path}. Continuing with other tools."
)
with open(survey_csv_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
rows = list(reader)
if not rows:
return "Survey data is empty."
columns = list(rows[0].keys())
total = len(rows)
summary_parts = [f"Survey has {total} responses with columns:"]
summary_parts.append(", ".join(columns[:15]))
if len(columns) > 15:
summary_parts.append(f"... and {len(columns) - 15} more columns")
q_lower = question.lower()
if "kubernetes" in q_lower or "k8s" in q_lower:
k8s_col = next(
(c for c in columns if "kubernetes" in c.lower()),
None,
)
if k8s_col:
k8s_users = sum(1 for r in rows if r.get(k8s_col, "").strip())
pct = round(100 * k8s_users / total, 1) if total else 0
return (
f"KubernetesExecutor usage: {k8s_users} of {total} "
f"respondents ({pct}%) indicated they use KubernetesExecutor."
)
if "celery" in q_lower:
celery_col = next(
(c for c in columns if "celery" in c.lower()),
None,
)
if celery_col:
celery_users = sum(1 for r in rows if r.get(celery_col, "").strip())
pct = round(100 * celery_users / total, 1) if total else 0
return (
f"CeleryExecutor usage: {celery_users} of {total} "
f"respondents ({pct}%) indicated they use CeleryExecutor."
)
if "version" in q_lower:
version_col = next(
(c for c in columns if "version" in c.lower() and "airflow" in c.lower()),
None,
)
if version_col:
from collections import Counter
counts = Counter(r.get(version_col, "unknown") for r in rows)
top5 = counts.most_common(5)
lines = [f" {v}: {c} ({round(100 * c / total, 1)}%)" for v, c in top5]
return "Airflow version distribution (top 5):\n" + "\n".join(lines)
return (
f"Survey dataset has {total} responses across {len(columns)} columns. "
"Available topics: executor usage (Kubernetes, Celery, Local), "
"Airflow versions, deployment methods, cloud providers, company "
"size, industries, AI tool usage. Ask a more specific question."
)
# -- Tool 3: Web search (simulated) ------------------------------------
@tool
def search_web(query: str) -> str:
"""Search the web for current information, news, or context.
Use this for questions that need up-to-date external information
not available in the knowledge base or survey data.
"""
responses = {
"kubernetes airflow": (
"Recent blog posts indicate KubernetesExecutor adoption has grown "
"significantly since Airflow 2.0, with many large-scale deployments "
"migrating from CeleryExecutor. Key advantages cited: pod-level "
"isolation, dynamic scaling, and per-task resource configuration. "
"Source: Airflow blog, Astronomer blog (2025-2026)."
),
"airflow 3": (
"Airflow 3.0 shipped in early 2026 with major architectural changes: "
"Execution API (workers never access metadata DB directly), multi-team "
"isolation, improved UI, and the common.ai provider for AI/LLM support. "
"Source: airflow.apache.org release notes."
),
"airflow adoption": (
"The 2025 Airflow User Survey showed continued growth: 2,000+ responses, "
"40% of respondents at companies with 1,000+ employees, 35% using cloud-managed "
"Airflow (Astronomer, MWAA, Cloud Composer). Source: Airflow blog."
),
}
for keyword, response in responses.items():
if any(w in query.lower() for w in keyword.split()):
return response
return (
f"Web search for '{query}' returned general results. "
"For this demo, web search is simulated with canned responses. "
"In production, use Tavily, Serper, or another search API "
"configured via an Airflow connection."
)
# -- Tool 4: Current UTC time -----------------------------------------
@tool
def get_current_utc_time() -> str:
"""Return the current UTC date and time in ISO-8601 format.
Use when the question depends on a current timestamp (e.g. "is the
merge freeze active right now", "how recent is the survey data").
LLMs cannot reliably know the wall-clock time on their own.
"""
return datetime.datetime.now(datetime.timezone.utc).isoformat(timespec="seconds")
return [search_knowledge_base, query_survey_data, search_web, get_current_utc_time]
# ---------------------------------------------------------------------------
# DAG: ReAct tool-calling agent with human review
# ---------------------------------------------------------------------------
# [START example_langchain_tool_agent]
@dag(tags=["example"])
# [END example_langchain_tool_agent]
example_langchain_tool_agent()