"""Graph Orchestrator - LangGraph-based multi-agent workflow orchestration.""" import asyncio import json import logging import os from typing import TypedDict from langgraph.graph import END, START, StateGraph logger = logging.getLogger(__name__) class AppFactoryState(TypedDict): """Global state passed through the orchestration graph.""" user_input: str prd: str tasks: list # All tasks from task-master active_tasks: dict # task_id -> {status, container_id, worktree_path} completed_tasks: list # List of completed task_ids blocked_tasks: dict # task_id -> reason clarification_requests: list # Pending clarification dicts global_architecture: str # Architecture summary for dev agents iteration_count: int # Safety counter to prevent infinite loops max_iterations: int # Max loop iterations (default 50) errors: list # Error log class AppFactoryOrchestrator: """Main LangGraph state machine for the App Factory.""" def __init__( self, pm_agent=None, task_agent=None, dev_manager=None, qa_agent=None, workspace_manager=None, observability=None, ): self.pm_agent = pm_agent self.task_agent = task_agent self.dev_manager = dev_manager self.qa_agent = qa_agent self.workspace_manager = workspace_manager self.observability = observability def build_graph(self) -> StateGraph: """Build and compile the LangGraph StateGraph with nodes and edges.""" graph = StateGraph(AppFactoryState) graph.add_node("pm_node", self._pm_node) graph.add_node("task_node", self._task_node) graph.add_node("dev_dispatch_node", self._dev_dispatch_node) graph.add_node("qa_node", self._qa_node) graph.add_node("clarification_node", self._clarification_node) graph.add_edge(START, "pm_node") graph.add_conditional_edges( "pm_node", self._should_continue_after_pm, { "task_node": "task_node", "end": END, }, ) graph.add_conditional_edges( "task_node", self._should_continue_after_tasks, { "dev_dispatch": "dev_dispatch_node", "end": END, "clarification": "clarification_node", }, ) graph.add_edge("dev_dispatch_node", "qa_node") graph.add_conditional_edges( "qa_node", self._should_continue_after_qa, { "task_node": "task_node", "clarification": "clarification_node", "end": END, }, ) graph.add_edge("clarification_node", "task_node") return graph.compile() def _should_continue_after_pm(self, state: dict) -> str: """Routing function after pm_node: 'task_node' | 'end'.""" prd = state.get("prd", "") if prd and prd.strip(): return "task_node" # PM failure (or empty prompt) yields no PRD and should terminate cleanly. return "end" def _should_continue_after_tasks(self, state: dict) -> str: """Routing function after task_node: 'dev_dispatch' | 'end' | 'clarification'.""" if state.get("iteration_count", 0) >= state.get("max_iterations", 50): return "end" tasks = state.get("tasks", []) completed = set(state.get("completed_tasks", [])) all_task_ids = {str(t.get("id", "")) for t in tasks} # Check if all tasks are done if all_task_ids and all_task_ids <= completed: return "end" # Check for unblocked tasks (pending tasks with all deps done) unblocked = [] for t in tasks: if str(t.get("id", "")) in completed: continue if t.get("status") == "done": continue deps = [str(d) for d in t.get("dependencies", [])] if all(d in completed for d in deps): unblocked.append(t) if unblocked: return "dev_dispatch" # No unblocked tasks - if there are blocked ones, try clarification if state.get("blocked_tasks") or state.get("clarification_requests"): return "clarification" # No tasks at all or nothing left to do return "end" def _should_continue_after_qa(self, state: dict) -> str: """Routing function after qa_node: 'task_node' | 'clarification' | 'end'.""" if state.get("iteration_count", 0) >= state.get("max_iterations", 50): return "end" if state.get("clarification_requests"): return "clarification" # Loop back to check for newly unblocked tasks return "task_node" async def _pm_node(self, state: dict) -> dict: """Call PM agent to expand user input into a PRD.""" if self.observability: self.observability.log_state_transition("start", "pm_node") user_input = state.get("user_input", "") if not user_input: return {"prd": "", "errors": state.get("errors", []) + ["No user input provided"]} if self.pm_agent is None: return {"prd": f"Mock PRD for: {user_input}"} try: prd = await self.pm_agent.expand_prompt_to_prd(user_input) return {"prd": prd} except Exception as e: logger.error("PM agent failed: %s", e) return {"prd": "", "errors": state.get("errors", []) + [f"PM agent error: {e}"]} async def _task_node(self, state: dict) -> dict: """Parse PRD into tasks or get unblocked tasks. Increments iteration_count.""" if self.observability: self.observability.log_state_transition("pm_node/qa_node/clarification_node", "task_node") iteration_count = state.get("iteration_count", 0) + 1 updates = {"iteration_count": iteration_count} if iteration_count >= state.get("max_iterations", 50): updates["errors"] = state.get("errors", []) + ["Max iterations reached"] return updates if self.task_agent is None: return updates try: existing_tasks = state.get("tasks", []) if not existing_tasks: # First pass - parse the PRD prd = state.get("prd", "") if prd: await self.task_agent.parse_prd(prd) unblocked = await self.task_agent.get_unblocked_tasks() updates["tasks"] = unblocked else: # Subsequent passes - refresh unblocked tasks unblocked = await self.task_agent.get_unblocked_tasks() updates["tasks"] = unblocked except Exception as e: logger.error("Task agent failed: %s", e) updates["errors"] = state.get("errors", []) + [f"Task agent error: {e}"] return updates async def _dev_dispatch_node(self, state: dict) -> dict: """Dispatch dev agents concurrently for unblocked tasks.""" if self.observability: self.observability.log_state_transition("task_node", "dev_dispatch_node") tasks = state.get("tasks", []) completed = set(state.get("completed_tasks", [])) active_tasks = dict(state.get("active_tasks", {})) errors = list(state.get("errors", [])) clarification_requests = list(state.get("clarification_requests", [])) global_arch = state.get("global_architecture", "") # Filter to unblocked, not-yet-completed tasks to_execute = [] for t in tasks: tid = str(t.get("id", "")) if tid in completed or tid in active_tasks: continue deps = [str(d) for d in t.get("dependencies", [])] if all(d in completed for d in deps): to_execute.append(t) if not to_execute: return {} if self.dev_manager is None or self.workspace_manager is None: # Mock execution for testing new_completed = list(completed) for t in to_execute: tid = str(t.get("id", "")) active_tasks[tid] = {"status": "success", "container_id": "mock", "worktree_path": "/mock"} new_completed.append(tid) return {"active_tasks": active_tasks, "completed_tasks": new_completed} async def _execute_single(task): tid = str(task.get("id", "")) worktree_path = None container = None try: worktree_path = await self.workspace_manager.create_worktree(tid) container = await self.workspace_manager.spin_up_clean_room(worktree_path, tid) container_id = container.id if self.task_agent: await self.task_agent.update_task_status(tid, "in-progress") result = await self.dev_manager.execute_with_retry( task, container_id, worktree_path, global_arch ) return tid, result, worktree_path except Exception as e: logger.error("Dev dispatch failed for task %s: %s", tid, e) return tid, {"status": "failed", "output": str(e), "files_changed": [], "exit_code": -1}, worktree_path # Execute concurrently results = await asyncio.gather(*[_execute_single(t) for t in to_execute], return_exceptions=True) new_completed = list(completed) for item in results: if isinstance(item, Exception): errors.append(f"Dev dispatch exception: {item}") continue tid, result, worktree_path = item status = result.get("status", "failed") active_tasks[tid] = { "status": status, "container_id": result.get("container_id", ""), "worktree_path": worktree_path or "", } if status == "success": new_completed.append(tid) elif status == "needs_clarification": clarification_requests.append({ "requesting_agent": "dev_agent", "task_id": tid, "question": f"Task {tid} failed after retries. Output: {result.get('output', '')[:500]}", "context": result.get("output", "")[:1000], }) return { "active_tasks": active_tasks, "completed_tasks": new_completed, "errors": errors, "clarification_requests": clarification_requests, } async def _qa_node(self, state: dict) -> dict: """Run QA on completed dev tasks.""" if self.observability: self.observability.log_state_transition("dev_dispatch_node", "qa_node") active_tasks = dict(state.get("active_tasks", {})) completed = list(state.get("completed_tasks", [])) errors = list(state.get("errors", [])) clarification_requests = list(state.get("clarification_requests", [])) blocked_tasks = dict(state.get("blocked_tasks", {})) # Find tasks that were successfully completed by dev and need QA tasks_for_qa = [] for tid, info in active_tasks.items(): if info.get("status") == "success" and tid in completed: tasks_for_qa.append((tid, info)) if not tasks_for_qa or self.qa_agent is None: return {} for tid, info in tasks_for_qa: worktree_path = info.get("worktree_path", "") if not worktree_path: continue try: # Find the task dict for context task_dict = None for t in state.get("tasks", []): if str(t.get("id", "")) == tid: task_dict = t break qa_result = await self.qa_agent.review_and_merge(tid, worktree_path, task=task_dict) qa_status = qa_result.get("status", "") if qa_status == "merged": # Successfully merged - update task status if self.task_agent: await self.task_agent.update_task_status(tid, "done") active_tasks[tid]["status"] = "merged" else: # QA failed - may need clarification or retry retry_count = qa_result.get("retry_count", 0) if retry_count >= (self.qa_agent.max_retries if self.qa_agent else 3): clarification_requests.append({ "requesting_agent": "qa_agent", "task_id": tid, "question": f"QA failed for task {tid} with status '{qa_status}'", "context": str(qa_result), }) else: blocked_tasks[tid] = f"QA {qa_status}: {qa_result}" # Remove from completed so it can be retried if tid in completed: completed.remove(tid) active_tasks[tid]["status"] = qa_status # Cleanup workspace after QA if self.workspace_manager: try: await self.workspace_manager.cleanup_workspace(tid) except Exception as e: logger.warning("Workspace cleanup failed for task %s: %s", tid, e) except Exception as e: logger.error("QA failed for task %s: %s", tid, e) errors.append(f"QA error for task {tid}: {e}") return { "active_tasks": active_tasks, "completed_tasks": completed, "errors": errors, "clarification_requests": clarification_requests, "blocked_tasks": blocked_tasks, } async def _clarification_node(self, state: dict) -> dict: """Handle clarification requests via PM agent.""" if self.observability: self.observability.log_state_transition("task_node/qa_node", "clarification_node") requests = list(state.get("clarification_requests", [])) blocked_tasks = dict(state.get("blocked_tasks", {})) errors = list(state.get("errors", [])) if not requests: return {"clarification_requests": []} if self.pm_agent is None: # Clear requests without processing for testing return {"clarification_requests": [], "blocked_tasks": {}} resolved = [] remaining = [] for req in requests: try: answer = await self.pm_agent.handle_clarification_request(req) tid = req.get("task_id", "") if tid and tid in blocked_tasks: del blocked_tasks[tid] resolved.append({"request": req, "answer": answer}) except Exception as e: logger.error("Clarification failed: %s", e) errors.append(f"Clarification error: {e}") remaining.append(req) return { "clarification_requests": remaining, "blocked_tasks": blocked_tasks, "errors": errors, } async def run(self, user_input: str) -> dict: """Build graph and execute with initial state.""" compiled = self.build_graph() initial_state = { "user_input": user_input, "prd": "", "tasks": [], "active_tasks": {}, "completed_tasks": [], "blocked_tasks": {}, "clarification_requests": [], "global_architecture": "", "iteration_count": 0, "max_iterations": 50, "errors": [], } if self.observability: self.observability.log_state_transition("init", "run") result = await compiled.ainvoke(initial_state) self.save_state(result) return result def save_state(self, state: dict, path: str = "app_factory/data/state.json"): """Persist state to disk.""" os.makedirs(os.path.dirname(path), exist_ok=True) # Convert to JSON-serializable form serializable = {} for k, v in state.items(): try: json.dumps(v) serializable[k] = v except (TypeError, ValueError): serializable[k] = str(v) with open(path, "w") as f: json.dump(serializable, f, indent=2) def load_state(self, path: str = "app_factory/data/state.json") -> dict: """Load state from disk.""" with open(path) as f: return json.load(f)