Files
life-echo/api/app/features/conversation/service.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

392 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Conversation service — 对话编排(列表、创建、结束、删除、消息、整理)。"""
import asyncio
import uuid
from datetime import datetime, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat.personas import agent_name
from app.core.cos_url_keys import (
collect_cos_keys_from_conversation_history,
collect_cos_keys_from_tts_url_list,
extract_cos_object_key_if_owned,
)
from app.core.db import transactional
from app.core.errors import (
AuthorizationError,
BadRequestError,
NotFoundError,
QuotaExceededError,
)
from app.core.logging import get_logger
from app.core.redis import redis_service
from app.core.storage_purge import delete_object_storage_keys_best_effort
from app.features.conversation import repo
from app.features.conversation.history_store import ConversationHistoryStore
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.session_history import (
conversation_messages_to_redis_history,
)
from app.features.conversation.tts_delivery import apply_presigned_tts_urls_to_messages
from app.features.memory import repo as memory_repo
from app.features.quota.service import QuotaService
from app.features.user.models import User
from app.ports.storage import ObjectStorage
from app.tasks.memoir_tasks import (
dispatch_pending_memoir_phase2_for_user,
process_memoir_phase1,
)
logger = get_logger(__name__)
def _datetime_to_timestamp_ms(value: datetime | None) -> int:
if value is None:
return int(datetime.now(timezone.utc).timestamp() * 1000)
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return int(value.timestamp() * 1000)
def _message_timestamp_ms(msg: dict, fallback: datetime | None) -> int:
raw_timestamp = msg.get("timestamp")
if isinstance(raw_timestamp, (int, float)):
return int(raw_timestamp)
if isinstance(raw_timestamp, str):
try:
return int(
datetime.fromisoformat(raw_timestamp.replace("Z", "+00:00")).timestamp()
* 1000
)
except ValueError:
pass
return _datetime_to_timestamp_ms(fallback)
def _latest_message_time_ms(conversation: Conversation, history: list[dict]) -> int:
last_at = getattr(conversation, "last_message_at", None)
if last_at:
return _datetime_to_timestamp_ms(last_at)
if history:
return _message_timestamp_ms(history[-1], conversation.started_at)
return _datetime_to_timestamp_ms(conversation.started_at)
def _build_messages_from_history(
conversation_id: str,
history: list[dict],
fallback_timestamp: datetime | None,
) -> list[dict]:
messages: list[dict] = []
seen_audio_sessions: set[str] = set()
for idx, msg in enumerate(history):
role = msg.get("role")
message_type = msg.get("messageType", "text")
voice_session_id = msg.get("voiceSessionId")
if role == "human" and message_type == "audio" and voice_session_id:
if voice_session_id in seen_audio_sessions:
continue
seen_audio_sessions.add(voice_session_id)
item: dict = {
"id": f"{conversation_id}_msg_{idx}",
"conversationId": conversation_id,
"content": msg.get("content", ""),
"senderType": "user" if role == "human" else "assistant",
"timestamp": _message_timestamp_ms(msg, fallback_timestamp),
"messageType": message_type,
}
if voice_session_id and role == "human":
item["voiceSessionId"] = voice_session_id
ds = msg.get("durationSeconds")
if isinstance(ds, (int, float)) and ds > 0:
item["durationSeconds"] = int(ds)
if role == "ai":
tts = msg.get("ttsAudioUrls")
if isinstance(tts, list) and tts:
item["ttsAudioUrls"] = [x for x in tts if isinstance(x, str)]
dm = msg.get("durableMessageId")
if isinstance(dm, str) and dm:
item["durableMessageId"] = dm
messages.append(item)
return messages
class ConversationService:
def __init__(
self,
db: AsyncSession,
quota_service: QuotaService,
*,
object_storage: ObjectStorage | None = None,
):
self._db = db
self._quota = quota_service
self._object_storage = object_storage
async def ensure_ws_connection(
self, conversation_id: str, user_id: str
) -> tuple[Conversation | None, str]:
"""
WebSocket加载或创建对话。返回 (conversation, err)。
err 为空表示成功;否则为 forbidden | deleted。
"""
conv = await self._db.get(Conversation, conversation_id)
if not conv:
conv = Conversation(
id=conversation_id,
user_id=user_id,
started_at=datetime.now(timezone.utc),
status="active",
)
async with transactional(self._db):
self._db.add(conv)
await self._db.refresh(conv)
return conv, ""
if conv.user_id != user_id:
return None, "forbidden"
if conv.deleted_at is not None:
return None, "deleted"
return conv, ""
async def create_user_segment(
self,
conversation: Conversation,
user_id: str,
text: str,
*,
audio_url: str | None = None,
audio_duration_seconds: int | None = None,
) -> Segment:
if conversation.user_id != user_id:
raise AuthorizationError("无权访问此对话")
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation.id,
user_input_text=text,
audio_url=audio_url,
audio_duration_seconds=audio_duration_seconds,
processed=False,
)
async with transactional(self._db):
self._db.add(segment)
conversation.last_message_at = datetime.now(timezone.utc)
await self._db.refresh(segment)
return segment
async def _clear_history(self, conversation_id: str) -> None:
try:
await redis_service.clear_conversation_history(conversation_id)
except Exception as e:
logger.debug("清空会话历史失败: {}", e)
async def ensure_redis_history_from_db(self, conversation_id: str) -> list[dict]:
"""
供 WS 与 get_messages 使用:优先 Redis若为空则用 DB conversation_messages 重建并写回。
"""
try:
history = await redis_service.get_conversation_history(conversation_id)
except Exception as exc:
logger.warning("conversation history cache read skipped: {}", exc)
history = []
if history:
try:
await redis_service.extend_session_ttl(conversation_id)
except Exception as exc:
logger.debug("conversation history ttl extend skipped: {}", exc)
return history
rows = await repo.get_conversation_messages(conversation_id, self._db)
if rows:
rebuilt = conversation_messages_to_redis_history(rows)
try:
await redis_service.set_conversation_history(conversation_id, rebuilt)
except Exception as exc:
logger.warning("conversation history cache write skipped: {}", exc)
return rebuilt
return []
async def record_ai_only_turn(
self, conversation_id: str, texts: list[str]
) -> str | None:
return await ConversationHistoryStore(self._db).record_ai_only_turn(
conversation_id, texts
)
async def list_for_user(self, user_id: str) -> list[dict]:
conversations = await repo.get_user_conversations(user_id, self._db)
# Fetch language once for fallback title localization (no per-row N+1).
user_obj = await self._db.get(User, user_id)
raw_lang = getattr(user_obj, "language_preference", "zh") if user_obj else "zh"
lang = str(raw_lang or "zh").strip().lower()
fallback_title = agent_name(lang)
result = []
for conv in conversations:
history: list[dict] = []
try:
history = await self.ensure_redis_history_from_db(conv.id)
except Exception:
pass
latest_message = history[-1].get("content", "")[:50] if history else None
has_user_message = any((msg.get("role") == "human") for msg in history)
result.append(
{
"id": conv.id,
"title": (conv.summary or "")[:30] or fallback_title,
"avatarUrl": None,
"latestMessagePreview": latest_message or conv.summary,
"latestMessageTime": _latest_message_time_ms(conv, history),
# 对话「初次创建」时间ms供客户端按日历日区分「打个招呼 / 继续对话」
"startedAt": _datetime_to_timestamp_ms(conv.started_at),
"unreadCount": 0,
"isDefaultAssistant": conv.summary is None,
"hasUserMessage": has_user_message,
}
)
return result
async def create(self, user_id: str) -> dict:
conv = Conversation(
id=str(uuid.uuid4()),
user_id=user_id,
started_at=datetime.now(timezone.utc),
status="active",
)
async with transactional(self._db):
repo.add_conversation(conv, self._db)
await self._db.refresh(conv)
return {
"id": conv.id,
"user_id": conv.user_id,
"started_at": conv.started_at.isoformat(),
"status": conv.status,
}
async def get_or_404(self, conversation_id: str, user_id: str) -> Conversation:
conv = await repo.get_conversation(conversation_id, self._db)
if not conv or conv.user_id != user_id or conv.deleted_at is not None:
raise NotFoundError("Conversation not found")
return conv
async def get_one(self, conversation_id: str, user_id: str) -> dict:
conv = await self.get_or_404(conversation_id, user_id)
return {
"id": conv.id,
"user_id": conv.user_id,
"started_at": conv.started_at.isoformat(),
"ended_at": conv.ended_at.isoformat() if conv.ended_at else None,
"duration_seconds": conv.duration_seconds,
"summary": conv.summary,
"status": conv.status,
"current_topic": conv.current_topic,
"conversation_stage": conv.conversation_stage,
}
async def end(self, conversation_id: str, user_id: str) -> dict:
conv = await self.get_or_404(conversation_id, user_id)
async with transactional(self._db):
conv.status = "ended"
conv.ended_at = datetime.now(timezone.utc)
if conv.started_at:
conv.duration_seconds = int(
(conv.ended_at - conv.started_at).total_seconds()
)
return {
"id": conv.id,
"status": conv.status,
"ended_at": conv.ended_at.isoformat(),
"duration_seconds": conv.duration_seconds,
}
async def delete(self, conversation_id: str, user_id: str) -> None:
conv = await self.get_or_404(conversation_id, user_id)
cos_keys: set[str] = set(
await memory_repo.list_storage_keys_for_conversation(
self._db, conversation_id
)
)
try:
hist = await redis_service.get_conversation_history(conversation_id)
cos_keys |= collect_cos_keys_from_conversation_history(hist)
except Exception:
pass
segments = await repo.get_segments_for_conversation(conversation_id, self._db)
for seg in segments:
k = extract_cos_object_key_if_owned(seg.audio_url)
if k:
cos_keys.add(k)
raw_tts = getattr(seg, "tts_audio_urls", None)
if isinstance(raw_tts, list):
cos_keys |= collect_cos_keys_from_tts_url_list(
[str(x) for x in raw_tts if isinstance(x, str)]
)
await self._clear_history(conversation_id)
async with transactional(self._db):
conv.deleted_at = datetime.now(timezone.utc)
delete_object_storage_keys_best_effort(
self._object_storage,
sorted(cos_keys),
log_prefix=f"conversation_soft_delete id={conversation_id}",
)
async def get_messages(self, conversation_id: str, user_id: str) -> list[dict]:
conv = await self.get_or_404(conversation_id, user_id)
try:
history = await self.ensure_redis_history_from_db(conversation_id)
messages = _build_messages_from_history(
conversation_id=conversation_id,
history=history,
fallback_timestamp=conv.started_at,
)
apply_presigned_tts_urls_to_messages(messages, self._object_storage)
return messages
except Exception:
return []
async def align_conversation_stage_from_memoir(
self, conversation: Conversation, memoir_stage: str
) -> None:
"""Align conversation_stage with memoir state without regressing stage order."""
from app.agents.stage_constants import STAGE_TO_ORDER
ms = (memoir_stage or "").strip()
if not ms:
return
cs = (conversation.conversation_stage or "").strip()
async with transactional(self._db):
if not cs:
conversation.conversation_stage = ms
elif STAGE_TO_ORDER.get(ms, -1) >= STAGE_TO_ORDER.get(cs, -1):
conversation.conversation_stage = ms
async def organize(
self, conversation_id: str, user_id: str, subscription_type: str
) -> dict:
conv = await self.get_or_404(conversation_id, user_id)
pending_p1 = await repo.get_segments_pending_phase1(conversation_id, self._db)
has_p2 = await repo.conversation_has_pending_phase2(conversation_id, self._db)
if not pending_p1 and not has_p2:
raise BadRequestError("该对话没有可整理的内容")
can_submit, quota_message = await self._quota.check_can_submit_organize(
user_id, subscription_type
)
if not can_submit:
raise QuotaExceededError(quota_message)
if pending_p1:
segment_ids = [s.id for s in pending_p1]
process_memoir_phase1.delay(conv.user_id, segment_ids)
logger.info(
"手动触发 Phase1: conversation_id={}, segments={}",
conversation_id,
len(segment_ids),
)
await asyncio.to_thread(dispatch_pending_memoir_phase2_for_user, conv.user_id)
return {
"message": "对话整理任务已提交",
"conversation_id": conversation_id,
"segments_count": len(pending_p1),
}