393 lines
13 KiB
Python
393 lines
13 KiB
Python
"""Conversation management and natural language tuning."""
|
|
|
|
import uuid
|
|
from typing import Any, Literal
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from fusionagi._time import utc_now_iso
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
class ConversationStyle(BaseModel):
|
|
"""Configuration for conversation style and personality."""
|
|
|
|
formality: Literal["casual", "neutral", "formal"] = Field(
|
|
default="neutral",
|
|
description="Conversation formality level"
|
|
)
|
|
verbosity: Literal["concise", "balanced", "detailed"] = Field(
|
|
default="balanced",
|
|
description="Response length preference"
|
|
)
|
|
personality_traits: list[str] = Field(
|
|
default_factory=list,
|
|
description="Personality traits (e.g., friendly, professional, humorous)"
|
|
)
|
|
empathy_level: float = Field(
|
|
default=0.7,
|
|
ge=0.0,
|
|
le=1.0,
|
|
description="Emotional responsiveness (0=robotic, 1=highly empathetic)"
|
|
)
|
|
proactivity: float = Field(
|
|
default=0.5,
|
|
ge=0.0,
|
|
le=1.0,
|
|
description="Tendency to offer suggestions (0=reactive, 1=proactive)"
|
|
)
|
|
humor_level: float = Field(
|
|
default=0.3,
|
|
ge=0.0,
|
|
le=1.0,
|
|
description="Use of humor (0=serious, 1=playful)"
|
|
)
|
|
technical_depth: float = Field(
|
|
default=0.5,
|
|
ge=0.0,
|
|
le=1.0,
|
|
description="Technical detail level (0=simple, 1=expert)"
|
|
)
|
|
|
|
|
|
class ConversationContext(BaseModel):
|
|
"""Context for a conversation session."""
|
|
|
|
session_id: str = Field(default_factory=lambda: f"session_{uuid.uuid4().hex}")
|
|
user_id: str | None = Field(default=None)
|
|
style: ConversationStyle = Field(default_factory=ConversationStyle)
|
|
language: str = Field(default="en", description="Primary language code")
|
|
domain: str | None = Field(default=None, description="Domain/topic of conversation")
|
|
history_length: int = Field(default=10, description="Number of turns to maintain in context")
|
|
started_at: str = Field(default_factory=utc_now_iso)
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class ConversationTurn(BaseModel):
|
|
"""A single turn in a conversation."""
|
|
|
|
turn_id: str = Field(default_factory=lambda: f"turn_{uuid.uuid4().hex[:8]}")
|
|
session_id: str
|
|
speaker: Literal["user", "agent", "system"]
|
|
content: str
|
|
intent: str | None = Field(default=None, description="Detected intent")
|
|
sentiment: float | None = Field(
|
|
default=None,
|
|
ge=-1.0,
|
|
le=1.0,
|
|
description="Sentiment score (-1=negative, 0=neutral, 1=positive)"
|
|
)
|
|
confidence: float | None = Field(default=None, ge=0.0, le=1.0)
|
|
timestamp: str = Field(default_factory=utc_now_iso)
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class ConversationTuner:
|
|
"""
|
|
Conversation tuner for natural language interaction.
|
|
|
|
Allows admin to configure conversation style, personality, and behavior
|
|
for different contexts, users, or agents.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._styles: dict[str, ConversationStyle] = {}
|
|
self._default_style = ConversationStyle()
|
|
logger.info("ConversationTuner initialized")
|
|
|
|
def register_style(self, name: str, style: ConversationStyle) -> None:
|
|
"""
|
|
Register a named conversation style.
|
|
|
|
Args:
|
|
name: Style name (e.g., "customer_support", "technical_expert").
|
|
style: Conversation style configuration.
|
|
"""
|
|
self._styles[name] = style
|
|
logger.info("Conversation style registered", extra={"name": name})
|
|
|
|
def get_style(self, name: str) -> ConversationStyle | None:
|
|
"""Get a conversation style by name."""
|
|
return self._styles.get(name)
|
|
|
|
def list_styles(self) -> list[str]:
|
|
"""List all registered style names."""
|
|
return list(self._styles.keys())
|
|
|
|
def set_default_style(self, style: ConversationStyle) -> None:
|
|
"""Set the default conversation style."""
|
|
self._default_style = style
|
|
logger.info("Default conversation style updated")
|
|
|
|
def get_default_style(self) -> ConversationStyle:
|
|
"""Get the default conversation style."""
|
|
return self._default_style
|
|
|
|
def tune_for_context(
|
|
self,
|
|
base_style: ConversationStyle | None = None,
|
|
domain: str | None = None,
|
|
user_preferences: dict[str, Any] | None = None,
|
|
) -> ConversationStyle:
|
|
"""
|
|
Tune conversation style for a specific context.
|
|
|
|
Args:
|
|
base_style: Base style to start from (uses default if None).
|
|
domain: Domain/topic to optimize for.
|
|
user_preferences: User-specific preferences to apply.
|
|
|
|
Returns:
|
|
Tuned conversation style.
|
|
"""
|
|
style = base_style or self._default_style.model_copy(deep=True)
|
|
|
|
# Apply domain-specific tuning
|
|
if domain:
|
|
style = self._apply_domain_tuning(style, domain)
|
|
|
|
# Apply user preferences
|
|
if user_preferences:
|
|
for key, value in user_preferences.items():
|
|
if hasattr(style, key):
|
|
setattr(style, key, value)
|
|
|
|
logger.info(
|
|
"Conversation style tuned",
|
|
extra={"domain": domain, "has_user_prefs": bool(user_preferences)}
|
|
)
|
|
return style
|
|
|
|
def _apply_domain_tuning(self, style: ConversationStyle, domain: str) -> ConversationStyle:
|
|
"""
|
|
Apply domain-specific tuning to a conversation style.
|
|
|
|
Args:
|
|
style: Base conversation style.
|
|
domain: Domain to tune for.
|
|
|
|
Returns:
|
|
Tuned conversation style.
|
|
"""
|
|
# Domain-specific presets
|
|
domain_presets = {
|
|
"technical": {
|
|
"formality": "formal",
|
|
"technical_depth": 0.9,
|
|
"verbosity": "detailed",
|
|
"humor_level": 0.1,
|
|
},
|
|
"customer_support": {
|
|
"formality": "neutral",
|
|
"empathy_level": 0.9,
|
|
"proactivity": 0.8,
|
|
"verbosity": "balanced",
|
|
},
|
|
"casual_chat": {
|
|
"formality": "casual",
|
|
"humor_level": 0.7,
|
|
"empathy_level": 0.8,
|
|
"technical_depth": 0.3,
|
|
},
|
|
"education": {
|
|
"formality": "neutral",
|
|
"verbosity": "detailed",
|
|
"technical_depth": 0.6,
|
|
"proactivity": 0.7,
|
|
},
|
|
}
|
|
|
|
preset = domain_presets.get(domain.lower())
|
|
if preset:
|
|
for key, value in preset.items():
|
|
setattr(style, key, value)
|
|
|
|
return style
|
|
|
|
|
|
class ConversationManager:
|
|
"""
|
|
Conversation manager for maintaining conversation state and history.
|
|
|
|
Manages conversation sessions, tracks turns, and provides context for
|
|
natural language understanding and generation.
|
|
"""
|
|
|
|
def __init__(self, tuner: ConversationTuner | None = None) -> None:
|
|
"""
|
|
Initialize conversation manager.
|
|
|
|
Args:
|
|
tuner: Conversation tuner for style management.
|
|
"""
|
|
self.tuner = tuner or ConversationTuner()
|
|
self._sessions: dict[str, ConversationContext] = {}
|
|
self._history: dict[str, list[ConversationTurn]] = {}
|
|
logger.info("ConversationManager initialized")
|
|
|
|
def create_session(
|
|
self,
|
|
user_id: str | None = None,
|
|
style_name: str | None = None,
|
|
language: str = "en",
|
|
domain: str | None = None,
|
|
) -> str:
|
|
"""
|
|
Create a new conversation session.
|
|
|
|
Args:
|
|
user_id: Optional user identifier.
|
|
style_name: Optional style name (uses default if None).
|
|
language: Primary language code.
|
|
domain: Domain/topic of conversation.
|
|
|
|
Returns:
|
|
Session ID.
|
|
"""
|
|
style = self.tuner.get_style(style_name) if style_name else self.tuner.get_default_style()
|
|
|
|
context = ConversationContext(
|
|
user_id=user_id,
|
|
style=style,
|
|
language=language,
|
|
domain=domain,
|
|
)
|
|
|
|
self._sessions[context.session_id] = context
|
|
self._history[context.session_id] = []
|
|
|
|
logger.info(
|
|
"Conversation session created",
|
|
extra={
|
|
"session_id": context.session_id,
|
|
"user_id": user_id,
|
|
"domain": domain,
|
|
}
|
|
)
|
|
return context.session_id
|
|
|
|
def get_session(self, session_id: str) -> ConversationContext | None:
|
|
"""Get conversation context for a session."""
|
|
return self._sessions.get(session_id)
|
|
|
|
def add_turn(self, turn: ConversationTurn) -> None:
|
|
"""
|
|
Add a turn to conversation history.
|
|
|
|
Args:
|
|
turn: Conversation turn to add.
|
|
"""
|
|
if turn.session_id not in self._history:
|
|
logger.warning("Session not found", extra={"session_id": turn.session_id})
|
|
return
|
|
|
|
history = self._history[turn.session_id]
|
|
history.append(turn)
|
|
|
|
# Trim history to configured length
|
|
context = self._sessions.get(turn.session_id)
|
|
if context and len(history) > context.history_length:
|
|
self._history[turn.session_id] = history[-context.history_length:]
|
|
|
|
logger.debug(
|
|
"Turn added",
|
|
extra={
|
|
"session_id": turn.session_id,
|
|
"speaker": turn.speaker,
|
|
"content_length": len(turn.content),
|
|
}
|
|
)
|
|
|
|
def get_history(self, session_id: str, limit: int | None = None) -> list[ConversationTurn]:
|
|
"""
|
|
Get conversation history for a session.
|
|
|
|
Args:
|
|
session_id: Session identifier.
|
|
limit: Optional limit on number of turns to return.
|
|
|
|
Returns:
|
|
List of conversation turns (most recent last).
|
|
"""
|
|
history = self._history.get(session_id, [])
|
|
if limit:
|
|
return history[-limit:]
|
|
return history
|
|
|
|
def get_style_for_session(self, session_id: str) -> ConversationStyle | None:
|
|
"""
|
|
Get the conversation style for a session.
|
|
|
|
Args:
|
|
session_id: Session identifier.
|
|
|
|
Returns:
|
|
Conversation style for the session, or None if session not found.
|
|
"""
|
|
context = self._sessions.get(session_id)
|
|
return context.style if context else None
|
|
|
|
def update_style(self, session_id: str, style: ConversationStyle) -> bool:
|
|
"""
|
|
Update conversation style for a session.
|
|
|
|
Args:
|
|
session_id: Session identifier.
|
|
style: New conversation style.
|
|
|
|
Returns:
|
|
True if updated, False if session not found.
|
|
"""
|
|
context = self._sessions.get(session_id)
|
|
if context:
|
|
context.style = style
|
|
logger.info("Session style updated", extra={"session_id": session_id})
|
|
return True
|
|
return False
|
|
|
|
def end_session(self, session_id: str) -> bool:
|
|
"""
|
|
End a conversation session.
|
|
|
|
Args:
|
|
session_id: Session identifier.
|
|
|
|
Returns:
|
|
True if ended, False if not found.
|
|
"""
|
|
if session_id in self._sessions:
|
|
del self._sessions[session_id]
|
|
# Keep history for analytics but could be cleaned up later
|
|
logger.info("Session ended", extra={"session_id": session_id})
|
|
return True
|
|
return False
|
|
|
|
def get_context_summary(self, session_id: str) -> dict[str, Any]:
|
|
"""
|
|
Get a summary of conversation context for LLM prompting.
|
|
|
|
Args:
|
|
session_id: Session identifier.
|
|
|
|
Returns:
|
|
Dictionary with context summary.
|
|
"""
|
|
context = self._sessions.get(session_id)
|
|
history = self._history.get(session_id, [])
|
|
|
|
if not context:
|
|
return {}
|
|
|
|
return {
|
|
"session_id": session_id,
|
|
"user_id": context.user_id,
|
|
"language": context.language,
|
|
"domain": context.domain,
|
|
"style": context.style.model_dump(),
|
|
"turn_count": len(history),
|
|
"recent_turns": [
|
|
{"speaker": t.speaker, "content": t.content, "intent": t.intent}
|
|
for t in history[-5:] # Last 5 turns
|
|
],
|
|
}
|