445 lines
17 KiB
Python
445 lines
17 KiB
Python
"""Graph Orchestrator - LangGraph-based multi-agent workflow orchestration."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import TypedDict
|
|
|
|
from langgraph.graph import END, START, StateGraph
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AppFactoryState(TypedDict):
|
|
"""Global state passed through the orchestration graph."""
|
|
|
|
user_input: str
|
|
prd: str
|
|
tasks: list # All tasks from task-master
|
|
active_tasks: dict # task_id -> {status, container_id, worktree_path}
|
|
completed_tasks: list # List of completed task_ids
|
|
blocked_tasks: dict # task_id -> reason
|
|
clarification_requests: list # Pending clarification dicts
|
|
global_architecture: str # Architecture summary for dev agents
|
|
iteration_count: int # Safety counter to prevent infinite loops
|
|
max_iterations: int # Max loop iterations (default 50)
|
|
errors: list # Error log
|
|
|
|
|
|
class AppFactoryOrchestrator:
|
|
"""Main LangGraph state machine for the App Factory."""
|
|
|
|
def __init__(
|
|
self,
|
|
pm_agent=None,
|
|
task_agent=None,
|
|
dev_manager=None,
|
|
qa_agent=None,
|
|
workspace_manager=None,
|
|
observability=None,
|
|
):
|
|
self.pm_agent = pm_agent
|
|
self.task_agent = task_agent
|
|
self.dev_manager = dev_manager
|
|
self.qa_agent = qa_agent
|
|
self.workspace_manager = workspace_manager
|
|
self.observability = observability
|
|
|
|
def build_graph(self) -> StateGraph:
|
|
"""Build and compile the LangGraph StateGraph with nodes and edges."""
|
|
graph = StateGraph(AppFactoryState)
|
|
|
|
graph.add_node("pm_node", self._pm_node)
|
|
graph.add_node("task_node", self._task_node)
|
|
graph.add_node("dev_dispatch_node", self._dev_dispatch_node)
|
|
graph.add_node("qa_node", self._qa_node)
|
|
graph.add_node("clarification_node", self._clarification_node)
|
|
|
|
graph.add_edge(START, "pm_node")
|
|
graph.add_conditional_edges(
|
|
"pm_node",
|
|
self._should_continue_after_pm,
|
|
{
|
|
"task_node": "task_node",
|
|
"end": END,
|
|
},
|
|
)
|
|
graph.add_conditional_edges(
|
|
"task_node",
|
|
self._should_continue_after_tasks,
|
|
{
|
|
"dev_dispatch": "dev_dispatch_node",
|
|
"end": END,
|
|
"clarification": "clarification_node",
|
|
},
|
|
)
|
|
graph.add_edge("dev_dispatch_node", "qa_node")
|
|
graph.add_conditional_edges(
|
|
"qa_node",
|
|
self._should_continue_after_qa,
|
|
{
|
|
"task_node": "task_node",
|
|
"clarification": "clarification_node",
|
|
"end": END,
|
|
},
|
|
)
|
|
graph.add_edge("clarification_node", "task_node")
|
|
|
|
return graph.compile()
|
|
|
|
def _should_continue_after_pm(self, state: dict) -> str:
|
|
"""Routing function after pm_node: 'task_node' | 'end'."""
|
|
prd = state.get("prd", "")
|
|
if prd and prd.strip():
|
|
return "task_node"
|
|
|
|
# PM failure (or empty prompt) yields no PRD and should terminate cleanly.
|
|
return "end"
|
|
|
|
def _should_continue_after_tasks(self, state: dict) -> str:
|
|
"""Routing function after task_node: 'dev_dispatch' | 'end' | 'clarification'."""
|
|
if state.get("iteration_count", 0) >= state.get("max_iterations", 50):
|
|
return "end"
|
|
|
|
tasks = state.get("tasks", [])
|
|
completed = set(state.get("completed_tasks", []))
|
|
all_task_ids = {str(t.get("id", "")) for t in tasks}
|
|
|
|
# Check if all tasks are done
|
|
if all_task_ids and all_task_ids <= completed:
|
|
return "end"
|
|
|
|
# Check for unblocked tasks (pending tasks with all deps done)
|
|
unblocked = []
|
|
for t in tasks:
|
|
if str(t.get("id", "")) in completed:
|
|
continue
|
|
if t.get("status") == "done":
|
|
continue
|
|
deps = [str(d) for d in t.get("dependencies", [])]
|
|
if all(d in completed for d in deps):
|
|
unblocked.append(t)
|
|
|
|
if unblocked:
|
|
return "dev_dispatch"
|
|
|
|
# No unblocked tasks - if there are blocked ones, try clarification
|
|
if state.get("blocked_tasks") or state.get("clarification_requests"):
|
|
return "clarification"
|
|
|
|
# No tasks at all or nothing left to do
|
|
return "end"
|
|
|
|
def _should_continue_after_qa(self, state: dict) -> str:
|
|
"""Routing function after qa_node: 'task_node' | 'clarification' | 'end'."""
|
|
if state.get("iteration_count", 0) >= state.get("max_iterations", 50):
|
|
return "end"
|
|
|
|
if state.get("clarification_requests"):
|
|
return "clarification"
|
|
|
|
# Loop back to check for newly unblocked tasks
|
|
return "task_node"
|
|
|
|
async def _pm_node(self, state: dict) -> dict:
|
|
"""Call PM agent to expand user input into a PRD."""
|
|
if self.observability:
|
|
self.observability.log_state_transition("start", "pm_node")
|
|
|
|
user_input = state.get("user_input", "")
|
|
if not user_input:
|
|
return {"prd": "", "errors": state.get("errors", []) + ["No user input provided"]}
|
|
|
|
if self.pm_agent is None:
|
|
return {"prd": f"Mock PRD for: {user_input}"}
|
|
|
|
try:
|
|
prd = await self.pm_agent.expand_prompt_to_prd(user_input)
|
|
return {"prd": prd}
|
|
except Exception as e:
|
|
logger.error("PM agent failed: %s", e)
|
|
return {"prd": "", "errors": state.get("errors", []) + [f"PM agent error: {e}"]}
|
|
|
|
async def _task_node(self, state: dict) -> dict:
|
|
"""Parse PRD into tasks or get unblocked tasks. Increments iteration_count."""
|
|
if self.observability:
|
|
self.observability.log_state_transition("pm_node/qa_node/clarification_node", "task_node")
|
|
|
|
iteration_count = state.get("iteration_count", 0) + 1
|
|
updates = {"iteration_count": iteration_count}
|
|
|
|
if iteration_count >= state.get("max_iterations", 50):
|
|
updates["errors"] = state.get("errors", []) + ["Max iterations reached"]
|
|
return updates
|
|
|
|
if self.task_agent is None:
|
|
return updates
|
|
|
|
try:
|
|
existing_tasks = state.get("tasks", [])
|
|
if not existing_tasks:
|
|
# First pass - parse the PRD
|
|
prd = state.get("prd", "")
|
|
if prd:
|
|
await self.task_agent.parse_prd(prd)
|
|
unblocked = await self.task_agent.get_unblocked_tasks()
|
|
updates["tasks"] = unblocked
|
|
else:
|
|
# Subsequent passes - refresh unblocked tasks
|
|
unblocked = await self.task_agent.get_unblocked_tasks()
|
|
updates["tasks"] = unblocked
|
|
except Exception as e:
|
|
logger.error("Task agent failed: %s", e)
|
|
updates["errors"] = state.get("errors", []) + [f"Task agent error: {e}"]
|
|
|
|
return updates
|
|
|
|
async def _dev_dispatch_node(self, state: dict) -> dict:
|
|
"""Dispatch dev agents concurrently for unblocked tasks."""
|
|
if self.observability:
|
|
self.observability.log_state_transition("task_node", "dev_dispatch_node")
|
|
|
|
tasks = state.get("tasks", [])
|
|
completed = set(state.get("completed_tasks", []))
|
|
active_tasks = dict(state.get("active_tasks", {}))
|
|
errors = list(state.get("errors", []))
|
|
clarification_requests = list(state.get("clarification_requests", []))
|
|
global_arch = state.get("global_architecture", "")
|
|
|
|
# Filter to unblocked, not-yet-completed tasks
|
|
to_execute = []
|
|
for t in tasks:
|
|
tid = str(t.get("id", ""))
|
|
if tid in completed or tid in active_tasks:
|
|
continue
|
|
deps = [str(d) for d in t.get("dependencies", [])]
|
|
if all(d in completed for d in deps):
|
|
to_execute.append(t)
|
|
|
|
if not to_execute:
|
|
return {}
|
|
|
|
if self.dev_manager is None or self.workspace_manager is None:
|
|
# Mock execution for testing
|
|
new_completed = list(completed)
|
|
for t in to_execute:
|
|
tid = str(t.get("id", ""))
|
|
active_tasks[tid] = {"status": "success", "container_id": "mock", "worktree_path": "/mock"}
|
|
new_completed.append(tid)
|
|
return {"active_tasks": active_tasks, "completed_tasks": new_completed}
|
|
|
|
async def _execute_single(task):
|
|
tid = str(task.get("id", ""))
|
|
worktree_path = None
|
|
container = None
|
|
try:
|
|
worktree_path = await self.workspace_manager.create_worktree(tid)
|
|
container = await self.workspace_manager.spin_up_clean_room(worktree_path, tid)
|
|
container_id = container.id
|
|
|
|
if self.task_agent:
|
|
await self.task_agent.update_task_status(tid, "in-progress")
|
|
|
|
result = await self.dev_manager.execute_with_retry(
|
|
task, container_id, worktree_path, global_arch
|
|
)
|
|
return tid, result, worktree_path
|
|
except Exception as e:
|
|
logger.error("Dev dispatch failed for task %s: %s", tid, e)
|
|
return tid, {"status": "failed", "output": str(e), "files_changed": [], "exit_code": -1}, worktree_path
|
|
|
|
# Execute concurrently
|
|
results = await asyncio.gather(*[_execute_single(t) for t in to_execute], return_exceptions=True)
|
|
|
|
new_completed = list(completed)
|
|
for item in results:
|
|
if isinstance(item, Exception):
|
|
errors.append(f"Dev dispatch exception: {item}")
|
|
continue
|
|
|
|
tid, result, worktree_path = item
|
|
status = result.get("status", "failed")
|
|
active_tasks[tid] = {
|
|
"status": status,
|
|
"container_id": result.get("container_id", ""),
|
|
"worktree_path": worktree_path or "",
|
|
}
|
|
|
|
if status == "success":
|
|
new_completed.append(tid)
|
|
elif status == "needs_clarification":
|
|
clarification_requests.append({
|
|
"requesting_agent": "dev_agent",
|
|
"task_id": tid,
|
|
"question": f"Task {tid} failed after retries. Output: {result.get('output', '')[:500]}",
|
|
"context": result.get("output", "")[:1000],
|
|
})
|
|
|
|
return {
|
|
"active_tasks": active_tasks,
|
|
"completed_tasks": new_completed,
|
|
"errors": errors,
|
|
"clarification_requests": clarification_requests,
|
|
}
|
|
|
|
async def _qa_node(self, state: dict) -> dict:
|
|
"""Run QA on completed dev tasks."""
|
|
if self.observability:
|
|
self.observability.log_state_transition("dev_dispatch_node", "qa_node")
|
|
|
|
active_tasks = dict(state.get("active_tasks", {}))
|
|
completed = list(state.get("completed_tasks", []))
|
|
errors = list(state.get("errors", []))
|
|
clarification_requests = list(state.get("clarification_requests", []))
|
|
blocked_tasks = dict(state.get("blocked_tasks", {}))
|
|
|
|
# Find tasks that were successfully completed by dev and need QA
|
|
tasks_for_qa = []
|
|
for tid, info in active_tasks.items():
|
|
if info.get("status") == "success" and tid in completed:
|
|
tasks_for_qa.append((tid, info))
|
|
|
|
if not tasks_for_qa or self.qa_agent is None:
|
|
return {}
|
|
|
|
for tid, info in tasks_for_qa:
|
|
worktree_path = info.get("worktree_path", "")
|
|
if not worktree_path:
|
|
continue
|
|
|
|
try:
|
|
# Find the task dict for context
|
|
task_dict = None
|
|
for t in state.get("tasks", []):
|
|
if str(t.get("id", "")) == tid:
|
|
task_dict = t
|
|
break
|
|
|
|
qa_result = await self.qa_agent.review_and_merge(tid, worktree_path, task=task_dict)
|
|
qa_status = qa_result.get("status", "")
|
|
|
|
if qa_status == "merged":
|
|
# Successfully merged - update task status
|
|
if self.task_agent:
|
|
await self.task_agent.update_task_status(tid, "done")
|
|
active_tasks[tid]["status"] = "merged"
|
|
else:
|
|
# QA failed - may need clarification or retry
|
|
retry_count = qa_result.get("retry_count", 0)
|
|
if retry_count >= (self.qa_agent.max_retries if self.qa_agent else 3):
|
|
clarification_requests.append({
|
|
"requesting_agent": "qa_agent",
|
|
"task_id": tid,
|
|
"question": f"QA failed for task {tid} with status '{qa_status}'",
|
|
"context": str(qa_result),
|
|
})
|
|
else:
|
|
blocked_tasks[tid] = f"QA {qa_status}: {qa_result}"
|
|
# Remove from completed so it can be retried
|
|
if tid in completed:
|
|
completed.remove(tid)
|
|
active_tasks[tid]["status"] = qa_status
|
|
|
|
# Cleanup workspace after QA
|
|
if self.workspace_manager:
|
|
try:
|
|
await self.workspace_manager.cleanup_workspace(tid)
|
|
except Exception as e:
|
|
logger.warning("Workspace cleanup failed for task %s: %s", tid, e)
|
|
|
|
except Exception as e:
|
|
logger.error("QA failed for task %s: %s", tid, e)
|
|
errors.append(f"QA error for task {tid}: {e}")
|
|
|
|
return {
|
|
"active_tasks": active_tasks,
|
|
"completed_tasks": completed,
|
|
"errors": errors,
|
|
"clarification_requests": clarification_requests,
|
|
"blocked_tasks": blocked_tasks,
|
|
}
|
|
|
|
async def _clarification_node(self, state: dict) -> dict:
|
|
"""Handle clarification requests via PM agent."""
|
|
if self.observability:
|
|
self.observability.log_state_transition("task_node/qa_node", "clarification_node")
|
|
|
|
requests = list(state.get("clarification_requests", []))
|
|
blocked_tasks = dict(state.get("blocked_tasks", {}))
|
|
errors = list(state.get("errors", []))
|
|
|
|
if not requests:
|
|
return {"clarification_requests": []}
|
|
|
|
if self.pm_agent is None:
|
|
# Clear requests without processing for testing
|
|
return {"clarification_requests": [], "blocked_tasks": {}}
|
|
|
|
resolved = []
|
|
remaining = []
|
|
|
|
for req in requests:
|
|
try:
|
|
answer = await self.pm_agent.handle_clarification_request(req)
|
|
tid = req.get("task_id", "")
|
|
if tid and tid in blocked_tasks:
|
|
del blocked_tasks[tid]
|
|
resolved.append({"request": req, "answer": answer})
|
|
except Exception as e:
|
|
logger.error("Clarification failed: %s", e)
|
|
errors.append(f"Clarification error: {e}")
|
|
remaining.append(req)
|
|
|
|
return {
|
|
"clarification_requests": remaining,
|
|
"blocked_tasks": blocked_tasks,
|
|
"errors": errors,
|
|
}
|
|
|
|
async def run(self, user_input: str) -> dict:
|
|
"""Build graph and execute with initial state."""
|
|
compiled = self.build_graph()
|
|
|
|
initial_state = {
|
|
"user_input": user_input,
|
|
"prd": "",
|
|
"tasks": [],
|
|
"active_tasks": {},
|
|
"completed_tasks": [],
|
|
"blocked_tasks": {},
|
|
"clarification_requests": [],
|
|
"global_architecture": "",
|
|
"iteration_count": 0,
|
|
"max_iterations": 50,
|
|
"errors": [],
|
|
}
|
|
|
|
if self.observability:
|
|
self.observability.log_state_transition("init", "run")
|
|
|
|
result = await compiled.ainvoke(initial_state)
|
|
|
|
self.save_state(result)
|
|
return result
|
|
|
|
def save_state(self, state: dict, path: str = "app_factory/data/state.json"):
|
|
"""Persist state to disk."""
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
# Convert to JSON-serializable form
|
|
serializable = {}
|
|
for k, v in state.items():
|
|
try:
|
|
json.dumps(v)
|
|
serializable[k] = v
|
|
except (TypeError, ValueError):
|
|
serializable[k] = str(v)
|
|
|
|
with open(path, "w") as f:
|
|
json.dump(serializable, f, indent=2)
|
|
|
|
def load_state(self, path: str = "app_factory/data/state.json") -> dict:
|
|
"""Load state from disk."""
|
|
with open(path) as f:
|
|
return json.load(f)
|