Files
life-echo/api/app/features/conversation/chat_turn.py

117 lines
3.6 KiB
Python
Raw Normal View History

"""Conversation chat turn boundary.
This module gives the WebSocket pipeline a small, explicit contract for one
user turn. It deliberately keeps the existing ``ChatOrchestrator`` behavior
intact while making the runtime inputs/outputs visible and testable.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat import ChatOrchestrator
from app.agents.chat.agent_turn import AgentChatTurn
@dataclass(frozen=True)
class ChatTurnInput:
"""Transport-neutral input for a single user chat turn."""
conversation_id: str
user_message: str
is_from_voice: bool = False
voice_session_id: str | None = None
user_message_timestamp: datetime | None = None
audio_duration_seconds: int | None = None
force_skip_tts: bool = False
@dataclass(frozen=True)
class ChatTurnContext:
"""Runtime dependencies needed to execute a turn."""
db: AsyncSession
user: Any | None
conversation: Any | None
apply_extracted_profile_fn: Callable[..., Any]
get_missing_profile_fields_fn: Callable[[Any], list[str]]
get_filled_profile_fields_fn: Callable[[Any], dict[str, Any]]
@dataclass(frozen=True)
class ChatTurnDecision:
"""Observable decision metadata for the chat runtime boundary."""
engine: str = "ChatOrchestrator"
route_hint: str = "auto"
force_skip_tts: bool = False
@dataclass(frozen=True)
class ChatTurnResult:
"""Stable result shape consumed by conversation persistence and delivery."""
messages: list[str]
skip_tts: bool
memory_retrieval_trace: dict[str, Any] | None = None
interview_state_meta: dict[str, Any] | None = None
decision: ChatTurnDecision = ChatTurnDecision()
@classmethod
def from_agent_turn(
cls,
turn: AgentChatTurn,
*,
decision: ChatTurnDecision,
) -> "ChatTurnResult":
return cls(
messages=list(turn.messages or []),
skip_tts=bool(turn.skip_tts or decision.force_skip_tts),
memory_retrieval_trace=turn.memory_retrieval_trace,
interview_state_meta=turn.interview_state_meta,
decision=decision,
)
class ChatTurnService:
"""Executes one chat turn behind an explicit internal contract."""
def __init__(self, orchestrator: ChatOrchestrator | None = None) -> None:
self._orchestrator = orchestrator or ChatOrchestrator()
async def process_turn(
self,
turn_input: ChatTurnInput,
context: ChatTurnContext,
) -> ChatTurnResult:
decision = ChatTurnDecision(force_skip_tts=turn_input.force_skip_tts)
turn = await self._orchestrator.process_user_message(
conversation_id=turn_input.conversation_id,
user_message=turn_input.user_message,
user=context.user,
conversation=context.conversation,
is_from_voice=turn_input.is_from_voice,
voice_session_id=turn_input.voice_session_id,
db=context.db,
apply_extracted_profile_fn=context.apply_extracted_profile_fn,
get_missing_profile_fields_fn=context.get_missing_profile_fields_fn,
get_filled_profile_fields_fn=context.get_filled_profile_fields_fn,
user_message_timestamp=turn_input.user_message_timestamp,
audio_duration_seconds=turn_input.audio_duration_seconds,
)
return ChatTurnResult.from_agent_turn(turn, decision=decision)
__all__ = [
"ChatTurnContext",
"ChatTurnDecision",
"ChatTurnInput",
"ChatTurnResult",
"ChatTurnService",
]