"""Tree-of-thought: multi-branch reasoning with evaluation and selection. Tree-of-Thought (ToT) extends Chain-of-Thought by exploring multiple reasoning paths and selecting the best one. This is useful for complex problems where the first reasoning path may not be optimal. Key concepts: - Branch: A single reasoning path - ThoughtNode: Hierarchical node with depth, unit_refs, children (Super Big Brain) - Evaluation: Scoring each branch for quality - Selection: Choosing the best branch based on evaluation - Pruning: Discarding low-quality branches early """ import json import uuid from dataclasses import dataclass, field from typing import Any from fusionagi.adapters.base import LLMAdapter from fusionagi.reasoning.cot import run_chain_of_thought, build_cot_messages from fusionagi._logger import logger @dataclass class ThoughtNode: """Hierarchical reasoning node: supports arbitrary depth, subtree independence, unit_refs.""" node_id: str = field(default_factory=lambda: f"node_{uuid.uuid4().hex[:12]}") parent_id: str | None = None thought: str = "" trace: list[str] = field(default_factory=list) score: float = 0.0 children: list["ThoughtNode"] = field(default_factory=list) depth: int = 0 unit_refs: list[str] = field(default_factory=list) metadata: dict[str, Any] = field(default_factory=dict) def expand_node( node: ThoughtNode, new_thought: str, unit_refs: list[str] | None = None, ) -> ThoughtNode: """Create a child node under node.""" child = ThoughtNode( parent_id=node.node_id, thought=new_thought, trace=node.trace + [new_thought], depth=node.depth + 1, unit_refs=unit_refs or list(node.unit_refs), ) node.children.append(child) return child def prune_subtree(node: ThoughtNode, prune_threshold: float) -> ThoughtNode: """Remove children below prune_threshold; return node.""" node.children = [c for c in node.children if c.score >= prune_threshold] for c in node.children: prune_subtree(c, prune_threshold) return node def merge_subtrees(nodes: list[ThoughtNode], threshold: float = 0.8) -> ThoughtNode | None: """Merge sibling nodes when they converge on same conclusion (similarity > threshold).""" if not nodes: return None if len(nodes) == 1: return nodes[0] best = max(nodes, key=lambda n: n.score) for n in nodes: if n is best: continue if n.score >= best.score * threshold: best.thought += "\n[Alternative] " + n.thought[:200] return best @dataclass class ThoughtBranch: """A single reasoning branch in the tree.""" branch_id: int thought: str trace: list[str] score: float = 0.0 is_terminal: bool = False children: list["ThoughtBranch"] = field(default_factory=list) metadata: dict[str, Any] = field(default_factory=dict) @dataclass class ToTResult: """Result of Tree-of-Thought reasoning.""" best_response: str best_trace: list[str] best_score: float all_branches: list[ThoughtBranch] total_llm_calls: int selection_reason: str # System prompts for ToT TOT_GENERATION_SYSTEM = """You are exploring different approaches to solve a problem. Generate a distinct reasoning approach. Be creative and consider different angles. State your thought process clearly step by step.""" TOT_EVALUATION_SYSTEM = """You are evaluating the quality of reasoning approaches. Score each approach from 0 to 1 based on: - Logical soundness (is the reasoning valid?) - Completeness (does it address all aspects?) - Clarity (is it easy to follow?) - Practicality (can it be implemented?) Output ONLY a JSON object: {"score": 0.X, "reason": "brief explanation"}""" def _generate_branch( adapter: LLMAdapter, query: str, context: str | None, branch_num: int, previous_branches: list[ThoughtBranch], **kwargs: Any, ) -> ThoughtBranch: """Generate a single reasoning branch.""" # Build prompt that encourages diverse thinking diversity_hint = "" if previous_branches: prev_summaries = [ f"Approach {b.branch_id}: {b.thought[:100]}..." for b in previous_branches ] diversity_hint = f"\n\nPrevious approaches tried:\n" + "\n".join(prev_summaries) diversity_hint += "\n\nGenerate a DIFFERENT approach." messages = [ {"role": "system", "content": TOT_GENERATION_SYSTEM}, { "role": "user", "content": f"Query: {query}{diversity_hint}" + (f"\n\nContext: {context}" if context else ""), }, ] response = adapter.complete(messages, **kwargs) return ThoughtBranch( branch_id=branch_num, thought=response, trace=[response], ) def _evaluate_branch( adapter: LLMAdapter, branch: ThoughtBranch, query: str, **kwargs: Any, ) -> float: """Evaluate a reasoning branch and return a score.""" messages = [ {"role": "system", "content": TOT_EVALUATION_SYSTEM}, { "role": "user", "content": f"Query: {query}\n\nReasoning approach:\n{branch.thought}\n\nScore this approach.", }, ] response = adapter.complete(messages, **kwargs) # Parse score from response try: # Try to extract JSON if "{" in response: json_start = response.index("{") json_end = response.rindex("}") + 1 json_str = response[json_start:json_end] result = json.loads(json_str) score = float(result.get("score", 0.5)) branch.metadata["evaluation_reason"] = result.get("reason", "") return max(0.0, min(1.0, score)) # Clamp to [0, 1] except (json.JSONDecodeError, ValueError, KeyError): pass # Fallback: try to extract a number import re numbers = re.findall(r"0?\.\d+|1\.0|[01]", response) if numbers: try: return max(0.0, min(1.0, float(numbers[0]))) except ValueError: pass return 0.5 # Default score if parsing fails def _select_best_branch(branches: list[ThoughtBranch]) -> tuple[ThoughtBranch, str]: """Select the best branch based on scores.""" if not branches: raise ValueError("No branches to select from") if len(branches) == 1: return branches[0], "Only one branch available" # Sort by score descending sorted_branches = sorted(branches, key=lambda b: b.score, reverse=True) best = sorted_branches[0] # Check if there's a clear winner if len(sorted_branches) > 1: score_diff = best.score - sorted_branches[1].score if score_diff > 0.2: reason = f"Clear winner with score {best.score:.2f} (next best: {sorted_branches[1].score:.2f})" else: reason = f"Selected highest score {best.score:.2f} among close alternatives" else: reason = f"Single branch with score {best.score:.2f}" return best, reason def run_tree_of_thought( adapter: LLMAdapter, query: str, context: str | None = None, max_branches: int = 3, depth: int = 1, prune_threshold: float = 0.3, **kwargs: Any, ) -> tuple[str, list[str]]: """ Run Tree-of-Thought reasoning with multiple branches. Args: adapter: LLM adapter for generation and evaluation. query: The question or problem to reason about. context: Optional context to include. max_branches: Maximum number of reasoning branches to explore. depth: Number of refinement iterations (1 = single pass, 2+ = iterative refinement). prune_threshold: Minimum score to keep a branch (branches below are pruned). **kwargs: Additional arguments passed to adapter.complete(). Returns: Tuple of (best_response, trace_list). """ if max_branches < 1: max_branches = 1 if max_branches == 1: # Fall back to simple CoT for single branch return run_chain_of_thought(adapter, query, context=context, **kwargs) logger.info( "Starting Tree-of-Thought", extra={"query_length": len(query), "max_branches": max_branches, "depth": depth}, ) total_llm_calls = 0 branches: list[ThoughtBranch] = [] # Generate initial branches for i in range(max_branches): branch = _generate_branch(adapter, query, context, i, branches, **kwargs) total_llm_calls += 1 branches.append(branch) # Evaluate all branches for branch in branches: branch.score = _evaluate_branch(adapter, branch, query, **kwargs) total_llm_calls += 1 # Prune low-quality branches branches = [b for b in branches if b.score >= prune_threshold] if not branches: # All branches pruned - fall back to CoT logger.warning("All ToT branches pruned, falling back to CoT") return run_chain_of_thought(adapter, query, context=context, **kwargs) # Iterative refinement for depth > 1 for d in range(1, depth): refined_branches = [] for branch in branches: # Generate a refined version refinement_prompt = f"""Original approach: {branch.thought} Score: {branch.score:.2f} Feedback: {branch.metadata.get('evaluation_reason', 'N/A')} Improve this approach based on the feedback. Make it more complete and rigorous.""" messages = [ {"role": "system", "content": TOT_GENERATION_SYSTEM}, {"role": "user", "content": f"Query: {query}\n\n{refinement_prompt}"}, ] refined_thought = adapter.complete(messages, **kwargs) total_llm_calls += 1 refined_branch = ThoughtBranch( branch_id=branch.branch_id, thought=refined_thought, trace=branch.trace + [f"[Refinement {d}] {refined_thought}"], ) refined_branch.score = _evaluate_branch(adapter, refined_branch, query, **kwargs) total_llm_calls += 1 # Keep the better version if refined_branch.score > branch.score: refined_branches.append(refined_branch) else: refined_branches.append(branch) branches = refined_branches # Select the best branch best_branch, selection_reason = _select_best_branch(branches) logger.info( "Tree-of-Thought completed", extra={ "best_score": best_branch.score, "total_branches": len(branches), "total_llm_calls": total_llm_calls, }, ) # Build comprehensive trace trace = [ f"[ToT Branch {best_branch.branch_id}] Score: {best_branch.score:.2f}", best_branch.thought, ] if best_branch.metadata.get("evaluation_reason"): trace.append(f"[Evaluation] {best_branch.metadata['evaluation_reason']}") trace.append(f"[Selection] {selection_reason}") return best_branch.thought, trace def run_tree_of_thought_detailed( adapter: LLMAdapter, query: str, context: str | None = None, max_branches: int = 3, depth: int = 1, prune_threshold: float = 0.3, **kwargs: Any, ) -> ToTResult: """ Run Tree-of-Thought and return detailed results including all branches. Same as run_tree_of_thought but returns a ToTResult with full information. """ if max_branches < 1: max_branches = 1 if max_branches == 1: response, trace = run_chain_of_thought(adapter, query, context=context, **kwargs) single_branch = ThoughtBranch(branch_id=0, thought=response, trace=trace, score=0.5) return ToTResult( best_response=response, best_trace=trace, best_score=0.5, all_branches=[single_branch], total_llm_calls=1, selection_reason="Single branch (CoT mode)", ) total_llm_calls = 0 branches: list[ThoughtBranch] = [] # Generate and evaluate branches for i in range(max_branches): branch = _generate_branch(adapter, query, context, i, branches, **kwargs) total_llm_calls += 1 branch.score = _evaluate_branch(adapter, branch, query, **kwargs) total_llm_calls += 1 branches.append(branch) all_branches = list(branches) # Keep all for result # Prune branches = [b for b in branches if b.score >= prune_threshold] if not branches: # Use best of all branches even if below threshold branches = sorted(all_branches, key=lambda b: b.score, reverse=True)[:1] # Select best best_branch, selection_reason = _select_best_branch(branches) return ToTResult( best_response=best_branch.thought, best_trace=best_branch.trace, best_score=best_branch.score, all_branches=all_branches, total_llm_calls=total_llm_calls, selection_reason=selection_reason, )