first commit
This commit is contained in:
444
app_factory/core/graph.py
Normal file
444
app_factory/core/graph.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""Graph Orchestrator - LangGraph-based multi-agent workflow orchestration."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TypedDict
|
||||
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppFactoryState(TypedDict):
|
||||
"""Global state passed through the orchestration graph."""
|
||||
|
||||
user_input: str
|
||||
prd: str
|
||||
tasks: list # All tasks from task-master
|
||||
active_tasks: dict # task_id -> {status, container_id, worktree_path}
|
||||
completed_tasks: list # List of completed task_ids
|
||||
blocked_tasks: dict # task_id -> reason
|
||||
clarification_requests: list # Pending clarification dicts
|
||||
global_architecture: str # Architecture summary for dev agents
|
||||
iteration_count: int # Safety counter to prevent infinite loops
|
||||
max_iterations: int # Max loop iterations (default 50)
|
||||
errors: list # Error log
|
||||
|
||||
|
||||
class AppFactoryOrchestrator:
|
||||
"""Main LangGraph state machine for the App Factory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pm_agent=None,
|
||||
task_agent=None,
|
||||
dev_manager=None,
|
||||
qa_agent=None,
|
||||
workspace_manager=None,
|
||||
observability=None,
|
||||
):
|
||||
self.pm_agent = pm_agent
|
||||
self.task_agent = task_agent
|
||||
self.dev_manager = dev_manager
|
||||
self.qa_agent = qa_agent
|
||||
self.workspace_manager = workspace_manager
|
||||
self.observability = observability
|
||||
|
||||
def build_graph(self) -> StateGraph:
|
||||
"""Build and compile the LangGraph StateGraph with nodes and edges."""
|
||||
graph = StateGraph(AppFactoryState)
|
||||
|
||||
graph.add_node("pm_node", self._pm_node)
|
||||
graph.add_node("task_node", self._task_node)
|
||||
graph.add_node("dev_dispatch_node", self._dev_dispatch_node)
|
||||
graph.add_node("qa_node", self._qa_node)
|
||||
graph.add_node("clarification_node", self._clarification_node)
|
||||
|
||||
graph.add_edge(START, "pm_node")
|
||||
graph.add_conditional_edges(
|
||||
"pm_node",
|
||||
self._should_continue_after_pm,
|
||||
{
|
||||
"task_node": "task_node",
|
||||
"end": END,
|
||||
},
|
||||
)
|
||||
graph.add_conditional_edges(
|
||||
"task_node",
|
||||
self._should_continue_after_tasks,
|
||||
{
|
||||
"dev_dispatch": "dev_dispatch_node",
|
||||
"end": END,
|
||||
"clarification": "clarification_node",
|
||||
},
|
||||
)
|
||||
graph.add_edge("dev_dispatch_node", "qa_node")
|
||||
graph.add_conditional_edges(
|
||||
"qa_node",
|
||||
self._should_continue_after_qa,
|
||||
{
|
||||
"task_node": "task_node",
|
||||
"clarification": "clarification_node",
|
||||
"end": END,
|
||||
},
|
||||
)
|
||||
graph.add_edge("clarification_node", "task_node")
|
||||
|
||||
return graph.compile()
|
||||
|
||||
def _should_continue_after_pm(self, state: dict) -> str:
|
||||
"""Routing function after pm_node: 'task_node' | 'end'."""
|
||||
prd = state.get("prd", "")
|
||||
if prd and prd.strip():
|
||||
return "task_node"
|
||||
|
||||
# PM failure (or empty prompt) yields no PRD and should terminate cleanly.
|
||||
return "end"
|
||||
|
||||
def _should_continue_after_tasks(self, state: dict) -> str:
|
||||
"""Routing function after task_node: 'dev_dispatch' | 'end' | 'clarification'."""
|
||||
if state.get("iteration_count", 0) >= state.get("max_iterations", 50):
|
||||
return "end"
|
||||
|
||||
tasks = state.get("tasks", [])
|
||||
completed = set(state.get("completed_tasks", []))
|
||||
all_task_ids = {str(t.get("id", "")) for t in tasks}
|
||||
|
||||
# Check if all tasks are done
|
||||
if all_task_ids and all_task_ids <= completed:
|
||||
return "end"
|
||||
|
||||
# Check for unblocked tasks (pending tasks with all deps done)
|
||||
unblocked = []
|
||||
for t in tasks:
|
||||
if str(t.get("id", "")) in completed:
|
||||
continue
|
||||
if t.get("status") == "done":
|
||||
continue
|
||||
deps = [str(d) for d in t.get("dependencies", [])]
|
||||
if all(d in completed for d in deps):
|
||||
unblocked.append(t)
|
||||
|
||||
if unblocked:
|
||||
return "dev_dispatch"
|
||||
|
||||
# No unblocked tasks - if there are blocked ones, try clarification
|
||||
if state.get("blocked_tasks") or state.get("clarification_requests"):
|
||||
return "clarification"
|
||||
|
||||
# No tasks at all or nothing left to do
|
||||
return "end"
|
||||
|
||||
def _should_continue_after_qa(self, state: dict) -> str:
|
||||
"""Routing function after qa_node: 'task_node' | 'clarification' | 'end'."""
|
||||
if state.get("iteration_count", 0) >= state.get("max_iterations", 50):
|
||||
return "end"
|
||||
|
||||
if state.get("clarification_requests"):
|
||||
return "clarification"
|
||||
|
||||
# Loop back to check for newly unblocked tasks
|
||||
return "task_node"
|
||||
|
||||
async def _pm_node(self, state: dict) -> dict:
|
||||
"""Call PM agent to expand user input into a PRD."""
|
||||
if self.observability:
|
||||
self.observability.log_state_transition("start", "pm_node")
|
||||
|
||||
user_input = state.get("user_input", "")
|
||||
if not user_input:
|
||||
return {"prd": "", "errors": state.get("errors", []) + ["No user input provided"]}
|
||||
|
||||
if self.pm_agent is None:
|
||||
return {"prd": f"Mock PRD for: {user_input}"}
|
||||
|
||||
try:
|
||||
prd = await self.pm_agent.expand_prompt_to_prd(user_input)
|
||||
return {"prd": prd}
|
||||
except Exception as e:
|
||||
logger.error("PM agent failed: %s", e)
|
||||
return {"prd": "", "errors": state.get("errors", []) + [f"PM agent error: {e}"]}
|
||||
|
||||
async def _task_node(self, state: dict) -> dict:
|
||||
"""Parse PRD into tasks or get unblocked tasks. Increments iteration_count."""
|
||||
if self.observability:
|
||||
self.observability.log_state_transition("pm_node/qa_node/clarification_node", "task_node")
|
||||
|
||||
iteration_count = state.get("iteration_count", 0) + 1
|
||||
updates = {"iteration_count": iteration_count}
|
||||
|
||||
if iteration_count >= state.get("max_iterations", 50):
|
||||
updates["errors"] = state.get("errors", []) + ["Max iterations reached"]
|
||||
return updates
|
||||
|
||||
if self.task_agent is None:
|
||||
return updates
|
||||
|
||||
try:
|
||||
existing_tasks = state.get("tasks", [])
|
||||
if not existing_tasks:
|
||||
# First pass - parse the PRD
|
||||
prd = state.get("prd", "")
|
||||
if prd:
|
||||
await self.task_agent.parse_prd(prd)
|
||||
unblocked = await self.task_agent.get_unblocked_tasks()
|
||||
updates["tasks"] = unblocked
|
||||
else:
|
||||
# Subsequent passes - refresh unblocked tasks
|
||||
unblocked = await self.task_agent.get_unblocked_tasks()
|
||||
updates["tasks"] = unblocked
|
||||
except Exception as e:
|
||||
logger.error("Task agent failed: %s", e)
|
||||
updates["errors"] = state.get("errors", []) + [f"Task agent error: {e}"]
|
||||
|
||||
return updates
|
||||
|
||||
async def _dev_dispatch_node(self, state: dict) -> dict:
|
||||
"""Dispatch dev agents concurrently for unblocked tasks."""
|
||||
if self.observability:
|
||||
self.observability.log_state_transition("task_node", "dev_dispatch_node")
|
||||
|
||||
tasks = state.get("tasks", [])
|
||||
completed = set(state.get("completed_tasks", []))
|
||||
active_tasks = dict(state.get("active_tasks", {}))
|
||||
errors = list(state.get("errors", []))
|
||||
clarification_requests = list(state.get("clarification_requests", []))
|
||||
global_arch = state.get("global_architecture", "")
|
||||
|
||||
# Filter to unblocked, not-yet-completed tasks
|
||||
to_execute = []
|
||||
for t in tasks:
|
||||
tid = str(t.get("id", ""))
|
||||
if tid in completed or tid in active_tasks:
|
||||
continue
|
||||
deps = [str(d) for d in t.get("dependencies", [])]
|
||||
if all(d in completed for d in deps):
|
||||
to_execute.append(t)
|
||||
|
||||
if not to_execute:
|
||||
return {}
|
||||
|
||||
if self.dev_manager is None or self.workspace_manager is None:
|
||||
# Mock execution for testing
|
||||
new_completed = list(completed)
|
||||
for t in to_execute:
|
||||
tid = str(t.get("id", ""))
|
||||
active_tasks[tid] = {"status": "success", "container_id": "mock", "worktree_path": "/mock"}
|
||||
new_completed.append(tid)
|
||||
return {"active_tasks": active_tasks, "completed_tasks": new_completed}
|
||||
|
||||
async def _execute_single(task):
|
||||
tid = str(task.get("id", ""))
|
||||
worktree_path = None
|
||||
container = None
|
||||
try:
|
||||
worktree_path = await self.workspace_manager.create_worktree(tid)
|
||||
container = await self.workspace_manager.spin_up_clean_room(worktree_path, tid)
|
||||
container_id = container.id
|
||||
|
||||
if self.task_agent:
|
||||
await self.task_agent.update_task_status(tid, "in-progress")
|
||||
|
||||
result = await self.dev_manager.execute_with_retry(
|
||||
task, container_id, worktree_path, global_arch
|
||||
)
|
||||
return tid, result, worktree_path
|
||||
except Exception as e:
|
||||
logger.error("Dev dispatch failed for task %s: %s", tid, e)
|
||||
return tid, {"status": "failed", "output": str(e), "files_changed": [], "exit_code": -1}, worktree_path
|
||||
|
||||
# Execute concurrently
|
||||
results = await asyncio.gather(*[_execute_single(t) for t in to_execute], return_exceptions=True)
|
||||
|
||||
new_completed = list(completed)
|
||||
for item in results:
|
||||
if isinstance(item, Exception):
|
||||
errors.append(f"Dev dispatch exception: {item}")
|
||||
continue
|
||||
|
||||
tid, result, worktree_path = item
|
||||
status = result.get("status", "failed")
|
||||
active_tasks[tid] = {
|
||||
"status": status,
|
||||
"container_id": result.get("container_id", ""),
|
||||
"worktree_path": worktree_path or "",
|
||||
}
|
||||
|
||||
if status == "success":
|
||||
new_completed.append(tid)
|
||||
elif status == "needs_clarification":
|
||||
clarification_requests.append({
|
||||
"requesting_agent": "dev_agent",
|
||||
"task_id": tid,
|
||||
"question": f"Task {tid} failed after retries. Output: {result.get('output', '')[:500]}",
|
||||
"context": result.get("output", "")[:1000],
|
||||
})
|
||||
|
||||
return {
|
||||
"active_tasks": active_tasks,
|
||||
"completed_tasks": new_completed,
|
||||
"errors": errors,
|
||||
"clarification_requests": clarification_requests,
|
||||
}
|
||||
|
||||
async def _qa_node(self, state: dict) -> dict:
|
||||
"""Run QA on completed dev tasks."""
|
||||
if self.observability:
|
||||
self.observability.log_state_transition("dev_dispatch_node", "qa_node")
|
||||
|
||||
active_tasks = dict(state.get("active_tasks", {}))
|
||||
completed = list(state.get("completed_tasks", []))
|
||||
errors = list(state.get("errors", []))
|
||||
clarification_requests = list(state.get("clarification_requests", []))
|
||||
blocked_tasks = dict(state.get("blocked_tasks", {}))
|
||||
|
||||
# Find tasks that were successfully completed by dev and need QA
|
||||
tasks_for_qa = []
|
||||
for tid, info in active_tasks.items():
|
||||
if info.get("status") == "success" and tid in completed:
|
||||
tasks_for_qa.append((tid, info))
|
||||
|
||||
if not tasks_for_qa or self.qa_agent is None:
|
||||
return {}
|
||||
|
||||
for tid, info in tasks_for_qa:
|
||||
worktree_path = info.get("worktree_path", "")
|
||||
if not worktree_path:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Find the task dict for context
|
||||
task_dict = None
|
||||
for t in state.get("tasks", []):
|
||||
if str(t.get("id", "")) == tid:
|
||||
task_dict = t
|
||||
break
|
||||
|
||||
qa_result = await self.qa_agent.review_and_merge(tid, worktree_path, task=task_dict)
|
||||
qa_status = qa_result.get("status", "")
|
||||
|
||||
if qa_status == "merged":
|
||||
# Successfully merged - update task status
|
||||
if self.task_agent:
|
||||
await self.task_agent.update_task_status(tid, "done")
|
||||
active_tasks[tid]["status"] = "merged"
|
||||
else:
|
||||
# QA failed - may need clarification or retry
|
||||
retry_count = qa_result.get("retry_count", 0)
|
||||
if retry_count >= (self.qa_agent.max_retries if self.qa_agent else 3):
|
||||
clarification_requests.append({
|
||||
"requesting_agent": "qa_agent",
|
||||
"task_id": tid,
|
||||
"question": f"QA failed for task {tid} with status '{qa_status}'",
|
||||
"context": str(qa_result),
|
||||
})
|
||||
else:
|
||||
blocked_tasks[tid] = f"QA {qa_status}: {qa_result}"
|
||||
# Remove from completed so it can be retried
|
||||
if tid in completed:
|
||||
completed.remove(tid)
|
||||
active_tasks[tid]["status"] = qa_status
|
||||
|
||||
# Cleanup workspace after QA
|
||||
if self.workspace_manager:
|
||||
try:
|
||||
await self.workspace_manager.cleanup_workspace(tid)
|
||||
except Exception as e:
|
||||
logger.warning("Workspace cleanup failed for task %s: %s", tid, e)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("QA failed for task %s: %s", tid, e)
|
||||
errors.append(f"QA error for task {tid}: {e}")
|
||||
|
||||
return {
|
||||
"active_tasks": active_tasks,
|
||||
"completed_tasks": completed,
|
||||
"errors": errors,
|
||||
"clarification_requests": clarification_requests,
|
||||
"blocked_tasks": blocked_tasks,
|
||||
}
|
||||
|
||||
async def _clarification_node(self, state: dict) -> dict:
|
||||
"""Handle clarification requests via PM agent."""
|
||||
if self.observability:
|
||||
self.observability.log_state_transition("task_node/qa_node", "clarification_node")
|
||||
|
||||
requests = list(state.get("clarification_requests", []))
|
||||
blocked_tasks = dict(state.get("blocked_tasks", {}))
|
||||
errors = list(state.get("errors", []))
|
||||
|
||||
if not requests:
|
||||
return {"clarification_requests": []}
|
||||
|
||||
if self.pm_agent is None:
|
||||
# Clear requests without processing for testing
|
||||
return {"clarification_requests": [], "blocked_tasks": {}}
|
||||
|
||||
resolved = []
|
||||
remaining = []
|
||||
|
||||
for req in requests:
|
||||
try:
|
||||
answer = await self.pm_agent.handle_clarification_request(req)
|
||||
tid = req.get("task_id", "")
|
||||
if tid and tid in blocked_tasks:
|
||||
del blocked_tasks[tid]
|
||||
resolved.append({"request": req, "answer": answer})
|
||||
except Exception as e:
|
||||
logger.error("Clarification failed: %s", e)
|
||||
errors.append(f"Clarification error: {e}")
|
||||
remaining.append(req)
|
||||
|
||||
return {
|
||||
"clarification_requests": remaining,
|
||||
"blocked_tasks": blocked_tasks,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
async def run(self, user_input: str) -> dict:
|
||||
"""Build graph and execute with initial state."""
|
||||
compiled = self.build_graph()
|
||||
|
||||
initial_state = {
|
||||
"user_input": user_input,
|
||||
"prd": "",
|
||||
"tasks": [],
|
||||
"active_tasks": {},
|
||||
"completed_tasks": [],
|
||||
"blocked_tasks": {},
|
||||
"clarification_requests": [],
|
||||
"global_architecture": "",
|
||||
"iteration_count": 0,
|
||||
"max_iterations": 50,
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
if self.observability:
|
||||
self.observability.log_state_transition("init", "run")
|
||||
|
||||
result = await compiled.ainvoke(initial_state)
|
||||
|
||||
self.save_state(result)
|
||||
return result
|
||||
|
||||
def save_state(self, state: dict, path: str = "app_factory/data/state.json"):
|
||||
"""Persist state to disk."""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
# Convert to JSON-serializable form
|
||||
serializable = {}
|
||||
for k, v in state.items():
|
||||
try:
|
||||
json.dumps(v)
|
||||
serializable[k] = v
|
||||
except (TypeError, ValueError):
|
||||
serializable[k] = str(v)
|
||||
|
||||
with open(path, "w") as f:
|
||||
json.dump(serializable, f, indent=2)
|
||||
|
||||
def load_state(self, path: str = "app_factory/data/state.json") -> dict:
|
||||
"""Load state from disk."""
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
Reference in New Issue
Block a user