339 lines
12 KiB
Python
339 lines
12 KiB
Python
"""Voice interface: speech-to-text, text-to-speech, voice library management."""
|
|
|
|
import uuid
|
|
from typing import Any, Literal, Protocol, runtime_checkable
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from fusionagi._time import utc_now_iso
|
|
from fusionagi.interfaces.base import InterfaceAdapter, InterfaceCapabilities, InterfaceMessage, ModalityType
|
|
from fusionagi._logger import logger
|
|
|
|
|
|
@runtime_checkable
|
|
class TTSAdapter(Protocol):
|
|
"""Protocol for TTS providers (ElevenLabs, Azure, system, etc.). Integrate by injecting an implementation."""
|
|
|
|
async def synthesize(self, text: str, voice_id: str | None = None, **kwargs: Any) -> bytes | None:
|
|
"""Synthesize text to audio. Returns raw audio bytes or None if not available."""
|
|
...
|
|
|
|
|
|
@runtime_checkable
|
|
class STTAdapter(Protocol):
|
|
"""Protocol for STT providers (Whisper, Azure, Google, etc.). Integrate by injecting an implementation."""
|
|
|
|
async def transcribe(self, audio_data: bytes | None = None, timeout_seconds: float | None = None, **kwargs: Any) -> str | None:
|
|
"""Transcribe audio to text. Returns transcribed text or None if timeout/unavailable."""
|
|
...
|
|
|
|
|
|
class VoiceProfile(BaseModel):
|
|
"""Voice profile for text-to-speech synthesis."""
|
|
|
|
id: str = Field(default_factory=lambda: f"voice_{uuid.uuid4().hex[:8]}")
|
|
name: str = Field(description="Human-readable voice name")
|
|
language: str = Field(default="en-US", description="Language code (e.g., en-US, es-ES)")
|
|
gender: Literal["male", "female", "neutral"] | None = Field(default=None)
|
|
age_range: Literal["child", "young_adult", "adult", "senior"] | None = Field(default=None)
|
|
style: str | None = Field(default=None, description="Voice style (e.g., friendly, professional, calm)")
|
|
pitch: float = Field(default=1.0, ge=0.5, le=2.0, description="Pitch multiplier")
|
|
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed multiplier")
|
|
provider: str = Field(default="system", description="TTS provider (e.g., system, elevenlabs, azure)")
|
|
provider_voice_id: str | None = Field(default=None, description="Provider-specific voice ID")
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
created_at: str = Field(default_factory=utc_now_iso)
|
|
|
|
|
|
class VoiceLibrary:
|
|
"""
|
|
Voice library for managing TTS voice profiles.
|
|
|
|
Allows admin to add, configure, and organize voice profiles for different
|
|
agents, contexts, or user preferences.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._voices: dict[str, VoiceProfile] = {}
|
|
self._default_voice_id: str | None = None
|
|
logger.info("VoiceLibrary initialized")
|
|
|
|
def add_voice(self, profile: VoiceProfile) -> str:
|
|
"""
|
|
Add a voice profile to the library.
|
|
|
|
Args:
|
|
profile: Voice profile to add.
|
|
|
|
Returns:
|
|
Voice ID.
|
|
"""
|
|
self._voices[profile.id] = profile
|
|
if self._default_voice_id is None:
|
|
self._default_voice_id = profile.id
|
|
logger.info("Voice added", extra={"voice_id": profile.id, "name": profile.name})
|
|
return profile.id
|
|
|
|
def remove_voice(self, voice_id: str) -> bool:
|
|
"""
|
|
Remove a voice profile from the library.
|
|
|
|
Args:
|
|
voice_id: ID of voice to remove.
|
|
|
|
Returns:
|
|
True if removed, False if not found.
|
|
"""
|
|
if voice_id in self._voices:
|
|
del self._voices[voice_id]
|
|
if self._default_voice_id == voice_id:
|
|
self._default_voice_id = next(iter(self._voices.keys()), None)
|
|
logger.info("Voice removed", extra={"voice_id": voice_id})
|
|
return True
|
|
return False
|
|
|
|
def get_voice(self, voice_id: str) -> VoiceProfile | None:
|
|
"""Get a voice profile by ID."""
|
|
return self._voices.get(voice_id)
|
|
|
|
def list_voices(
|
|
self,
|
|
language: str | None = None,
|
|
gender: str | None = None,
|
|
style: str | None = None,
|
|
) -> list[VoiceProfile]:
|
|
"""
|
|
List voice profiles with optional filtering.
|
|
|
|
Args:
|
|
language: Filter by language code.
|
|
gender: Filter by gender.
|
|
style: Filter by style.
|
|
|
|
Returns:
|
|
List of matching voice profiles.
|
|
"""
|
|
voices = list(self._voices.values())
|
|
|
|
if language:
|
|
voices = [v for v in voices if v.language == language]
|
|
if gender:
|
|
voices = [v for v in voices if v.gender == gender]
|
|
if style:
|
|
voices = [v for v in voices if v.style == style]
|
|
|
|
return voices
|
|
|
|
def set_default_voice(self, voice_id: str) -> bool:
|
|
"""
|
|
Set the default voice for the library.
|
|
|
|
Args:
|
|
voice_id: ID of voice to set as default.
|
|
|
|
Returns:
|
|
True if set, False if voice not found.
|
|
"""
|
|
if voice_id in self._voices:
|
|
self._default_voice_id = voice_id
|
|
logger.info("Default voice set", extra={"voice_id": voice_id})
|
|
return True
|
|
return False
|
|
|
|
def get_default_voice(self) -> VoiceProfile | None:
|
|
"""Get the default voice profile."""
|
|
if self._default_voice_id:
|
|
return self._voices.get(self._default_voice_id)
|
|
return None
|
|
|
|
def update_voice(self, voice_id: str, updates: dict[str, Any]) -> bool:
|
|
"""
|
|
Update a voice profile.
|
|
|
|
Args:
|
|
voice_id: ID of voice to update.
|
|
updates: Dictionary of fields to update.
|
|
|
|
Returns:
|
|
True if updated, False if not found.
|
|
"""
|
|
if voice_id not in self._voices:
|
|
return False
|
|
|
|
voice = self._voices[voice_id]
|
|
for key, value in updates.items():
|
|
if hasattr(voice, key):
|
|
setattr(voice, key, value)
|
|
|
|
logger.info("Voice updated", extra={"voice_id": voice_id, "updates": list(updates.keys())})
|
|
return True
|
|
|
|
|
|
class VoiceInterface(InterfaceAdapter):
|
|
"""
|
|
Voice interface adapter for speech interaction.
|
|
|
|
Handles:
|
|
- Speech-to-text (STT) for user input
|
|
- Text-to-speech (TTS) for system output
|
|
- Voice activity detection
|
|
- Noise cancellation
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str = "voice",
|
|
voice_library: VoiceLibrary | None = None,
|
|
stt_provider: str = "whisper",
|
|
tts_provider: str = "system",
|
|
tts_adapter: TTSAdapter | None = None,
|
|
stt_adapter: STTAdapter | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize voice interface.
|
|
|
|
Args:
|
|
name: Interface name.
|
|
voice_library: Voice library for TTS profiles.
|
|
stt_provider: Speech-to-text provider (whisper, azure, google, etc.).
|
|
tts_provider: Text-to-speech provider (system, elevenlabs, azure, etc.).
|
|
tts_adapter: Optional TTS adapter for synthesis (inject to integrate ElevenLabs, Azure, etc.).
|
|
stt_adapter: Optional STT adapter for transcription (inject to integrate Whisper, Azure, etc.).
|
|
"""
|
|
super().__init__(name)
|
|
self.voice_library = voice_library or VoiceLibrary()
|
|
self.stt_provider = stt_provider
|
|
self.tts_provider = tts_provider
|
|
self._tts_adapter = tts_adapter
|
|
self._stt_adapter = stt_adapter
|
|
self._active_voice_id: str | None = None
|
|
logger.info(
|
|
"VoiceInterface initialized",
|
|
extra={"stt_provider": stt_provider, "tts_provider": tts_provider}
|
|
)
|
|
|
|
def capabilities(self) -> InterfaceCapabilities:
|
|
"""Return voice interface capabilities."""
|
|
return InterfaceCapabilities(
|
|
supported_modalities=[ModalityType.VOICE],
|
|
supports_streaming=True,
|
|
supports_interruption=True,
|
|
supports_multimodal=False,
|
|
latency_ms=200.0, # Typical voice latency
|
|
max_concurrent_sessions=10,
|
|
)
|
|
|
|
async def send(self, message: InterfaceMessage) -> None:
|
|
"""
|
|
Send voice output (text-to-speech).
|
|
|
|
Args:
|
|
message: Message with text content to synthesize.
|
|
"""
|
|
if not self.validate_message(message):
|
|
logger.warning("Invalid message for voice interface", extra={"modality": message.modality})
|
|
return
|
|
|
|
# Get voice profile
|
|
voice_id = message.metadata.get("voice_id", self._active_voice_id)
|
|
voice = None
|
|
if voice_id:
|
|
voice = self.voice_library.get_voice(voice_id)
|
|
if not voice:
|
|
voice = self.voice_library.get_default_voice()
|
|
|
|
text = message.content if isinstance(message.content, str) else str(message.content)
|
|
voice_id = voice.id if voice else None
|
|
if self._tts_adapter is not None:
|
|
try:
|
|
audio_data = await self._tts_adapter.synthesize(text, voice_id=voice_id)
|
|
if audio_data:
|
|
logger.info(
|
|
"TTS synthesis (adapter)",
|
|
extra={"text_length": len(text), "voice_id": voice_id, "bytes": len(audio_data)},
|
|
)
|
|
# Inject: await self._play_audio(audio_data)
|
|
except Exception as e:
|
|
logger.exception("TTS adapter failed", extra={"error": str(e)})
|
|
else:
|
|
logger.info(
|
|
"TTS synthesis (stub; inject tts_adapter for ElevenLabs, Azure, etc.)",
|
|
extra={"text_length": len(text), "voice_id": voice_id, "provider": self.tts_provider},
|
|
)
|
|
|
|
async def receive(self, timeout_seconds: float | None = None) -> InterfaceMessage | None:
|
|
"""
|
|
Receive voice input (speech-to-text).
|
|
|
|
Args:
|
|
timeout_seconds: Optional timeout for listening.
|
|
|
|
Returns:
|
|
Message with transcribed text or None if timeout.
|
|
"""
|
|
logger.info("STT listening", extra={"timeout": timeout_seconds, "provider": self.stt_provider})
|
|
if self._stt_adapter is not None:
|
|
try:
|
|
text = await self._stt_adapter.transcribe(audio_data=None, timeout_seconds=timeout_seconds)
|
|
if text:
|
|
return InterfaceMessage(
|
|
id=f"stt_{uuid.uuid4().hex[:8]}",
|
|
modality=ModalityType.VOICE,
|
|
content=text,
|
|
metadata={"provider": self.stt_provider},
|
|
)
|
|
except Exception as e:
|
|
logger.exception("STT adapter failed", extra={"error": str(e)})
|
|
return None
|
|
|
|
def set_active_voice(self, voice_id: str) -> bool:
|
|
"""
|
|
Set the active voice for this interface session.
|
|
|
|
Args:
|
|
voice_id: ID of voice to use.
|
|
|
|
Returns:
|
|
True if voice exists, False otherwise.
|
|
"""
|
|
if self.voice_library.get_voice(voice_id):
|
|
self._active_voice_id = voice_id
|
|
logger.info("Active voice set", extra={"voice_id": voice_id})
|
|
return True
|
|
return False
|
|
|
|
async def _synthesize_speech(self, text: str, voice: VoiceProfile | None) -> bytes:
|
|
"""
|
|
Synthesize speech from text (to be implemented with actual provider).
|
|
|
|
Args:
|
|
text: Text to synthesize.
|
|
voice: Voice profile to use.
|
|
|
|
Returns:
|
|
Audio data as bytes.
|
|
"""
|
|
# Integrate with TTS provider based on self.tts_provider
|
|
# - system: Use OS TTS (pyttsx3, etc.)
|
|
# - elevenlabs: Use ElevenLabs API
|
|
# - azure: Use Azure Cognitive Services
|
|
# - google: Use Google Cloud TTS
|
|
raise NotImplementedError("TTS provider integration required")
|
|
|
|
async def _transcribe_speech(self, audio_data: bytes) -> str:
|
|
"""
|
|
Transcribe speech to text (to be implemented with actual provider).
|
|
|
|
Args:
|
|
audio_data: Audio data to transcribe.
|
|
|
|
Returns:
|
|
Transcribed text.
|
|
"""
|
|
# Integrate with STT provider based on self.stt_provider
|
|
# - whisper: Use OpenAI Whisper (local or API)
|
|
# - azure: Use Azure Cognitive Services
|
|
# - google: Use Google Cloud Speech-to-Text
|
|
# - deepgram: Use Deepgram API
|
|
raise NotImplementedError("STT provider integration required")
|