Files
life-echo/api/app/features/conversation/repo.py
Kevin 78b61c076e feat(eval): Playground GLM 评分落库并可恢复
在 conversations 表增加 playground_conversation_judge_json,流式/非流式对话评审结束后写入最近一次快照(整体分、逐轮分、对比文案、错误与基线文件名等)。新增只读 GET 供前端按会话拉取;评测台 Playground 切换会话时自动恢复,并提示基线是否和当时一致。
2026-04-08 16:51:08 +08:00

152 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Conversation repository — Conversation, turn log, and Segment data access."""
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.features.conversation.models import Conversation, ConversationMessage, Segment
async def get_conversation(
conversation_id: str, db: AsyncSession
) -> Conversation | None:
return await db.get(Conversation, conversation_id)
async def set_playground_conversation_judge_json(
conversation_id: str,
db: AsyncSession,
payload: dict[str, Any] | None,
) -> Conversation | None:
row = await get_conversation(conversation_id, db)
if row is None:
return None
row.playground_conversation_judge_json = payload
return row
async def get_user_conversations(user_id: str, db: AsyncSession) -> list[Conversation]:
stmt = (
select(Conversation)
.where(
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
.order_by(
func.coalesce(Conversation.last_message_at, Conversation.started_at).desc()
)
)
result = await db.execute(stmt)
return list(result.scalars().all())
def add_conversation(conv: Conversation, db: AsyncSession) -> None:
db.add(conv)
def add_conversation_message(msg: ConversationMessage, db: AsyncSession) -> None:
db.add(msg)
async def get_conversation_messages(
conversation_id: str, db: AsyncSession
) -> list[ConversationMessage]:
stmt = (
select(ConversationMessage)
.where(ConversationMessage.conversation_id == conversation_id)
.order_by(ConversationMessage.created_at)
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def set_latest_ai_message_tts_audio_urls(
conversation_id: str,
db: AsyncSession,
*,
tts_audio_urls: list[str],
segment_id: str | None = None,
) -> ConversationMessage | None:
stmt = select(ConversationMessage).where(
ConversationMessage.conversation_id == conversation_id,
ConversationMessage.role == "ai",
)
if segment_id is not None:
stmt = stmt.where(ConversationMessage.segment_id == segment_id)
stmt = stmt.order_by(ConversationMessage.created_at.desc())
result = await db.execute(stmt)
row = result.scalars().first()
if row is None:
return None
row.tts_audio_urls = list(tts_audio_urls)
return row
async def get_segments_for_conversation(
conversation_id: str, db: AsyncSession
) -> list[Segment]:
stmt = (
select(Segment)
.where(Segment.conversation_id == conversation_id)
.order_by(Segment.created_at)
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def get_segments_for_organize(
conversation_id: str, db: AsyncSession
) -> list[Segment]:
"""兼容旧语义:优先返回 Phase1 未完成的片段;若无则返回本会话全部片段。"""
pending = await get_segments_pending_phase1(conversation_id, db)
if pending:
return pending
return await get_segments_for_conversation(conversation_id, db)
async def get_segments_pending_phase1(
conversation_id: str, db: AsyncSession
) -> list[Segment]:
"""尚未跑 Phase1 分类的 segmentstopic_category 为空且未标记 narrated"""
stmt = (
select(Segment)
.where(
Segment.conversation_id == conversation_id,
Segment.topic_category.is_(None),
Segment.narrated.is_(False),
Segment.processed.is_(False),
)
.order_by(Segment.created_at)
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def conversation_has_pending_phase2(
conversation_id: str, db: AsyncSession
) -> bool:
"""Phase1 已完成但叙事未消费的片段(本会话范围内)。"""
stmt = select(func.count(Segment.id)).where(
Segment.conversation_id == conversation_id,
Segment.topic_category.isnot(None),
Segment.narrated.is_(False),
Segment.skip_narrative.is_(False),
)
result = await db.execute(stmt)
return int(result.scalar() or 0) > 0
async def count_segments_for_user(user_id: str, db: AsyncSession) -> int:
stmt = (
select(func.count(Segment.id))
.select_from(Segment)
.join(Conversation, Segment.conversation_id == Conversation.id)
.where(
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
)
result = await db.execute(stmt)
return result.scalar() or 0