403 lines
13 KiB
Python
403 lines
13 KiB
Python
"""Tree-of-thought: multi-branch reasoning with evaluation and selection.
|
|
|
|
Tree-of-Thought (ToT) extends Chain-of-Thought by exploring multiple reasoning paths
|
|
and selecting the best one. This is useful for complex problems where the first
|
|
reasoning path may not be optimal.
|
|
|
|
Key concepts:
|
|
- Branch: A single reasoning path
|
|
- ThoughtNode: Hierarchical node with depth, unit_refs, children (Super Big Brain)
|
|
- Evaluation: Scoring each branch for quality
|
|
- Selection: Choosing the best branch based on evaluation
|
|
- Pruning: Discarding low-quality branches early
|
|
"""
|
|
|
|
import json
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from fusionagi.adapters.base import LLMAdapter
|
|
from fusionagi.reasoning.cot import run_chain_of_thought, build_cot_messages
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
@dataclass
|
|
class ThoughtNode:
|
|
"""Hierarchical reasoning node: supports arbitrary depth, subtree independence, unit_refs."""
|
|
|
|
node_id: str = field(default_factory=lambda: f"node_{uuid.uuid4().hex[:12]}")
|
|
parent_id: str | None = None
|
|
thought: str = ""
|
|
trace: list[str] = field(default_factory=list)
|
|
score: float = 0.0
|
|
children: list["ThoughtNode"] = field(default_factory=list)
|
|
depth: int = 0
|
|
unit_refs: list[str] = field(default_factory=list)
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
def expand_node(
|
|
node: ThoughtNode,
|
|
new_thought: str,
|
|
unit_refs: list[str] | None = None,
|
|
) -> ThoughtNode:
|
|
"""Create a child node under node."""
|
|
child = ThoughtNode(
|
|
parent_id=node.node_id,
|
|
thought=new_thought,
|
|
trace=node.trace + [new_thought],
|
|
depth=node.depth + 1,
|
|
unit_refs=unit_refs or list(node.unit_refs),
|
|
)
|
|
node.children.append(child)
|
|
return child
|
|
|
|
|
|
def prune_subtree(node: ThoughtNode, prune_threshold: float) -> ThoughtNode:
|
|
"""Remove children below prune_threshold; return node."""
|
|
node.children = [c for c in node.children if c.score >= prune_threshold]
|
|
for c in node.children:
|
|
prune_subtree(c, prune_threshold)
|
|
return node
|
|
|
|
|
|
def merge_subtrees(nodes: list[ThoughtNode], threshold: float = 0.8) -> ThoughtNode | None:
|
|
"""Merge sibling nodes when they converge on same conclusion (similarity > threshold)."""
|
|
if not nodes:
|
|
return None
|
|
if len(nodes) == 1:
|
|
return nodes[0]
|
|
best = max(nodes, key=lambda n: n.score)
|
|
for n in nodes:
|
|
if n is best:
|
|
continue
|
|
if n.score >= best.score * threshold:
|
|
best.thought += "\n[Alternative] " + n.thought[:200]
|
|
return best
|
|
|
|
|
|
@dataclass
|
|
class ThoughtBranch:
|
|
"""A single reasoning branch in the tree."""
|
|
|
|
branch_id: int
|
|
thought: str
|
|
trace: list[str]
|
|
score: float = 0.0
|
|
is_terminal: bool = False
|
|
children: list["ThoughtBranch"] = field(default_factory=list)
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class ToTResult:
|
|
"""Result of Tree-of-Thought reasoning."""
|
|
|
|
best_response: str
|
|
best_trace: list[str]
|
|
best_score: float
|
|
all_branches: list[ThoughtBranch]
|
|
total_llm_calls: int
|
|
selection_reason: str
|
|
|
|
|
|
# System prompts for ToT
|
|
TOT_GENERATION_SYSTEM = """You are exploring different approaches to solve a problem.
|
|
Generate a distinct reasoning approach. Be creative and consider different angles.
|
|
State your thought process clearly step by step."""
|
|
|
|
TOT_EVALUATION_SYSTEM = """You are evaluating the quality of reasoning approaches.
|
|
Score each approach from 0 to 1 based on:
|
|
- Logical soundness (is the reasoning valid?)
|
|
- Completeness (does it address all aspects?)
|
|
- Clarity (is it easy to follow?)
|
|
- Practicality (can it be implemented?)
|
|
Output ONLY a JSON object: {"score": 0.X, "reason": "brief explanation"}"""
|
|
|
|
|
|
def _generate_branch(
|
|
adapter: LLMAdapter,
|
|
query: str,
|
|
context: str | None,
|
|
branch_num: int,
|
|
previous_branches: list[ThoughtBranch],
|
|
**kwargs: Any,
|
|
) -> ThoughtBranch:
|
|
"""Generate a single reasoning branch."""
|
|
# Build prompt that encourages diverse thinking
|
|
diversity_hint = ""
|
|
if previous_branches:
|
|
prev_summaries = [
|
|
f"Approach {b.branch_id}: {b.thought[:100]}..."
|
|
for b in previous_branches
|
|
]
|
|
diversity_hint = f"\n\nPrevious approaches tried:\n" + "\n".join(prev_summaries)
|
|
diversity_hint += "\n\nGenerate a DIFFERENT approach."
|
|
|
|
messages = [
|
|
{"role": "system", "content": TOT_GENERATION_SYSTEM},
|
|
{
|
|
"role": "user",
|
|
"content": f"Query: {query}{diversity_hint}" + (f"\n\nContext: {context}" if context else ""),
|
|
},
|
|
]
|
|
|
|
response = adapter.complete(messages, **kwargs)
|
|
|
|
return ThoughtBranch(
|
|
branch_id=branch_num,
|
|
thought=response,
|
|
trace=[response],
|
|
)
|
|
|
|
|
|
def _evaluate_branch(
|
|
adapter: LLMAdapter,
|
|
branch: ThoughtBranch,
|
|
query: str,
|
|
**kwargs: Any,
|
|
) -> float:
|
|
"""Evaluate a reasoning branch and return a score."""
|
|
messages = [
|
|
{"role": "system", "content": TOT_EVALUATION_SYSTEM},
|
|
{
|
|
"role": "user",
|
|
"content": f"Query: {query}\n\nReasoning approach:\n{branch.thought}\n\nScore this approach.",
|
|
},
|
|
]
|
|
|
|
response = adapter.complete(messages, **kwargs)
|
|
|
|
# Parse score from response
|
|
try:
|
|
# Try to extract JSON
|
|
if "{" in response:
|
|
json_start = response.index("{")
|
|
json_end = response.rindex("}") + 1
|
|
json_str = response[json_start:json_end]
|
|
result = json.loads(json_str)
|
|
score = float(result.get("score", 0.5))
|
|
branch.metadata["evaluation_reason"] = result.get("reason", "")
|
|
return max(0.0, min(1.0, score)) # Clamp to [0, 1]
|
|
except (json.JSONDecodeError, ValueError, KeyError):
|
|
pass
|
|
|
|
# Fallback: try to extract a number
|
|
import re
|
|
numbers = re.findall(r"0?\.\d+|1\.0|[01]", response)
|
|
if numbers:
|
|
try:
|
|
return max(0.0, min(1.0, float(numbers[0])))
|
|
except ValueError:
|
|
pass
|
|
|
|
return 0.5 # Default score if parsing fails
|
|
|
|
|
|
def _select_best_branch(branches: list[ThoughtBranch]) -> tuple[ThoughtBranch, str]:
|
|
"""Select the best branch based on scores."""
|
|
if not branches:
|
|
raise ValueError("No branches to select from")
|
|
|
|
if len(branches) == 1:
|
|
return branches[0], "Only one branch available"
|
|
|
|
# Sort by score descending
|
|
sorted_branches = sorted(branches, key=lambda b: b.score, reverse=True)
|
|
best = sorted_branches[0]
|
|
|
|
# Check if there's a clear winner
|
|
if len(sorted_branches) > 1:
|
|
score_diff = best.score - sorted_branches[1].score
|
|
if score_diff > 0.2:
|
|
reason = f"Clear winner with score {best.score:.2f} (next best: {sorted_branches[1].score:.2f})"
|
|
else:
|
|
reason = f"Selected highest score {best.score:.2f} among close alternatives"
|
|
else:
|
|
reason = f"Single branch with score {best.score:.2f}"
|
|
|
|
return best, reason
|
|
|
|
|
|
def run_tree_of_thought(
|
|
adapter: LLMAdapter,
|
|
query: str,
|
|
context: str | None = None,
|
|
max_branches: int = 3,
|
|
depth: int = 1,
|
|
prune_threshold: float = 0.3,
|
|
**kwargs: Any,
|
|
) -> tuple[str, list[str]]:
|
|
"""
|
|
Run Tree-of-Thought reasoning with multiple branches.
|
|
|
|
Args:
|
|
adapter: LLM adapter for generation and evaluation.
|
|
query: The question or problem to reason about.
|
|
context: Optional context to include.
|
|
max_branches: Maximum number of reasoning branches to explore.
|
|
depth: Number of refinement iterations (1 = single pass, 2+ = iterative refinement).
|
|
prune_threshold: Minimum score to keep a branch (branches below are pruned).
|
|
**kwargs: Additional arguments passed to adapter.complete().
|
|
|
|
Returns:
|
|
Tuple of (best_response, trace_list).
|
|
"""
|
|
if max_branches < 1:
|
|
max_branches = 1
|
|
|
|
if max_branches == 1:
|
|
# Fall back to simple CoT for single branch
|
|
return run_chain_of_thought(adapter, query, context=context, **kwargs)
|
|
|
|
logger.info(
|
|
"Starting Tree-of-Thought",
|
|
extra={"query_length": len(query), "max_branches": max_branches, "depth": depth},
|
|
)
|
|
|
|
total_llm_calls = 0
|
|
branches: list[ThoughtBranch] = []
|
|
|
|
# Generate initial branches
|
|
for i in range(max_branches):
|
|
branch = _generate_branch(adapter, query, context, i, branches, **kwargs)
|
|
total_llm_calls += 1
|
|
branches.append(branch)
|
|
|
|
# Evaluate all branches
|
|
for branch in branches:
|
|
branch.score = _evaluate_branch(adapter, branch, query, **kwargs)
|
|
total_llm_calls += 1
|
|
|
|
# Prune low-quality branches
|
|
branches = [b for b in branches if b.score >= prune_threshold]
|
|
|
|
if not branches:
|
|
# All branches pruned - fall back to CoT
|
|
logger.warning("All ToT branches pruned, falling back to CoT")
|
|
return run_chain_of_thought(adapter, query, context=context, **kwargs)
|
|
|
|
# Iterative refinement for depth > 1
|
|
for d in range(1, depth):
|
|
refined_branches = []
|
|
for branch in branches:
|
|
# Generate a refined version
|
|
refinement_prompt = f"""Original approach:
|
|
{branch.thought}
|
|
|
|
Score: {branch.score:.2f}
|
|
Feedback: {branch.metadata.get('evaluation_reason', 'N/A')}
|
|
|
|
Improve this approach based on the feedback. Make it more complete and rigorous."""
|
|
|
|
messages = [
|
|
{"role": "system", "content": TOT_GENERATION_SYSTEM},
|
|
{"role": "user", "content": f"Query: {query}\n\n{refinement_prompt}"},
|
|
]
|
|
|
|
refined_thought = adapter.complete(messages, **kwargs)
|
|
total_llm_calls += 1
|
|
|
|
refined_branch = ThoughtBranch(
|
|
branch_id=branch.branch_id,
|
|
thought=refined_thought,
|
|
trace=branch.trace + [f"[Refinement {d}] {refined_thought}"],
|
|
)
|
|
|
|
refined_branch.score = _evaluate_branch(adapter, refined_branch, query, **kwargs)
|
|
total_llm_calls += 1
|
|
|
|
# Keep the better version
|
|
if refined_branch.score > branch.score:
|
|
refined_branches.append(refined_branch)
|
|
else:
|
|
refined_branches.append(branch)
|
|
|
|
branches = refined_branches
|
|
|
|
# Select the best branch
|
|
best_branch, selection_reason = _select_best_branch(branches)
|
|
|
|
logger.info(
|
|
"Tree-of-Thought completed",
|
|
extra={
|
|
"best_score": best_branch.score,
|
|
"total_branches": len(branches),
|
|
"total_llm_calls": total_llm_calls,
|
|
},
|
|
)
|
|
|
|
# Build comprehensive trace
|
|
trace = [
|
|
f"[ToT Branch {best_branch.branch_id}] Score: {best_branch.score:.2f}",
|
|
best_branch.thought,
|
|
]
|
|
if best_branch.metadata.get("evaluation_reason"):
|
|
trace.append(f"[Evaluation] {best_branch.metadata['evaluation_reason']}")
|
|
trace.append(f"[Selection] {selection_reason}")
|
|
|
|
return best_branch.thought, trace
|
|
|
|
|
|
def run_tree_of_thought_detailed(
|
|
adapter: LLMAdapter,
|
|
query: str,
|
|
context: str | None = None,
|
|
max_branches: int = 3,
|
|
depth: int = 1,
|
|
prune_threshold: float = 0.3,
|
|
**kwargs: Any,
|
|
) -> ToTResult:
|
|
"""
|
|
Run Tree-of-Thought and return detailed results including all branches.
|
|
|
|
Same as run_tree_of_thought but returns a ToTResult with full information.
|
|
"""
|
|
if max_branches < 1:
|
|
max_branches = 1
|
|
|
|
if max_branches == 1:
|
|
response, trace = run_chain_of_thought(adapter, query, context=context, **kwargs)
|
|
single_branch = ThoughtBranch(branch_id=0, thought=response, trace=trace, score=0.5)
|
|
return ToTResult(
|
|
best_response=response,
|
|
best_trace=trace,
|
|
best_score=0.5,
|
|
all_branches=[single_branch],
|
|
total_llm_calls=1,
|
|
selection_reason="Single branch (CoT mode)",
|
|
)
|
|
|
|
total_llm_calls = 0
|
|
branches: list[ThoughtBranch] = []
|
|
|
|
# Generate and evaluate branches
|
|
for i in range(max_branches):
|
|
branch = _generate_branch(adapter, query, context, i, branches, **kwargs)
|
|
total_llm_calls += 1
|
|
branch.score = _evaluate_branch(adapter, branch, query, **kwargs)
|
|
total_llm_calls += 1
|
|
branches.append(branch)
|
|
|
|
all_branches = list(branches) # Keep all for result
|
|
|
|
# Prune
|
|
branches = [b for b in branches if b.score >= prune_threshold]
|
|
|
|
if not branches:
|
|
# Use best of all branches even if below threshold
|
|
branches = sorted(all_branches, key=lambda b: b.score, reverse=True)[:1]
|
|
|
|
# Select best
|
|
best_branch, selection_reason = _select_best_branch(branches)
|
|
|
|
return ToTResult(
|
|
best_response=best_branch.thought,
|
|
best_trace=best_branch.trace,
|
|
best_score=best_branch.score,
|
|
all_branches=all_branches,
|
|
total_llm_calls=total_llm_calls,
|
|
selection_reason=selection_reason,
|
|
)
|