2026-03-22 16:45:57 +08:00
|
|
|
|
"""Conversation repository — Conversation, turn log, and Segment data access."""
|
2026-03-18 17:18:23 +08:00
|
|
|
|
|
2026-04-08 16:50:53 +08:00
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
from sqlalchemy import func, select
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
2026-03-22 16:45:57 +08:00
|
|
|
|
from app.features.conversation.models import Conversation, ConversationMessage, Segment
|
2026-03-18 17:18:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
async def get_conversation(
|
|
|
|
|
|
conversation_id: str, db: AsyncSession
|
|
|
|
|
|
) -> Conversation | None:
|
2026-03-18 17:18:23 +08:00
|
|
|
|
return await db.get(Conversation, conversation_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-08 16:50:53 +08:00
|
|
|
|
async def set_playground_conversation_judge_json(
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
payload: dict[str, Any] | None,
|
|
|
|
|
|
) -> Conversation | None:
|
|
|
|
|
|
row = await get_conversation(conversation_id, db)
|
|
|
|
|
|
if row is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
row.playground_conversation_judge_json = payload
|
|
|
|
|
|
return row
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
async def get_user_conversations(user_id: str, db: AsyncSession) -> list[Conversation]:
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(Conversation)
|
2026-03-20 15:15:35 +08:00
|
|
|
|
.where(
|
|
|
|
|
|
Conversation.user_id == user_id,
|
|
|
|
|
|
Conversation.deleted_at.is_(None),
|
|
|
|
|
|
)
|
2026-03-19 14:36:14 +08:00
|
|
|
|
.order_by(
|
|
|
|
|
|
func.coalesce(Conversation.last_message_at, Conversation.started_at).desc()
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_conversation(conv: Conversation, db: AsyncSession) -> None:
|
|
|
|
|
|
db.add(conv)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-22 16:45:57 +08:00
|
|
|
|
def add_conversation_message(msg: ConversationMessage, db: AsyncSession) -> None:
|
|
|
|
|
|
db.add(msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_conversation_messages(
|
|
|
|
|
|
conversation_id: str, db: AsyncSession
|
|
|
|
|
|
) -> list[ConversationMessage]:
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(ConversationMessage)
|
|
|
|
|
|
.where(ConversationMessage.conversation_id == conversation_id)
|
|
|
|
|
|
.order_by(ConversationMessage.created_at)
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def set_latest_ai_message_tts_audio_urls(
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
*,
|
|
|
|
|
|
tts_audio_urls: list[str],
|
|
|
|
|
|
segment_id: str | None = None,
|
|
|
|
|
|
) -> ConversationMessage | None:
|
|
|
|
|
|
stmt = select(ConversationMessage).where(
|
|
|
|
|
|
ConversationMessage.conversation_id == conversation_id,
|
|
|
|
|
|
ConversationMessage.role == "ai",
|
|
|
|
|
|
)
|
|
|
|
|
|
if segment_id is not None:
|
|
|
|
|
|
stmt = stmt.where(ConversationMessage.segment_id == segment_id)
|
|
|
|
|
|
stmt = stmt.order_by(ConversationMessage.created_at.desc())
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
row = result.scalars().first()
|
|
|
|
|
|
if row is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
row.tts_audio_urls = list(tts_audio_urls)
|
|
|
|
|
|
return row
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
async def get_segments_for_conversation(
|
|
|
|
|
|
conversation_id: str, db: AsyncSession
|
|
|
|
|
|
) -> list[Segment]:
|
2026-03-18 17:18:23 +08:00
|
|
|
|
stmt = (
|
|
|
|
|
|
select(Segment)
|
|
|
|
|
|
.where(Segment.conversation_id == conversation_id)
|
|
|
|
|
|
.order_by(Segment.created_at)
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
async def get_segments_for_organize(
|
|
|
|
|
|
conversation_id: str, db: AsyncSession
|
|
|
|
|
|
) -> list[Segment]:
|
2026-04-02 16:37:14 +08:00
|
|
|
|
"""兼容旧语义:优先返回 Phase1 未完成的片段;若无则返回本会话全部片段。"""
|
|
|
|
|
|
pending = await get_segments_pending_phase1(conversation_id, db)
|
|
|
|
|
|
if pending:
|
|
|
|
|
|
return pending
|
|
|
|
|
|
return await get_segments_for_conversation(conversation_id, db)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_segments_pending_phase1(
|
|
|
|
|
|
conversation_id: str, db: AsyncSession
|
|
|
|
|
|
) -> list[Segment]:
|
|
|
|
|
|
"""尚未跑 Phase1 分类的 segments(topic_category 为空且未标记 narrated)。"""
|
2026-03-18 17:18:23 +08:00
|
|
|
|
stmt = (
|
|
|
|
|
|
select(Segment)
|
|
|
|
|
|
.where(
|
|
|
|
|
|
Segment.conversation_id == conversation_id,
|
2026-04-02 16:37:14 +08:00
|
|
|
|
Segment.topic_category.is_(None),
|
|
|
|
|
|
Segment.narrated.is_(False),
|
2026-03-18 17:18:23 +08:00
|
|
|
|
Segment.processed.is_(False),
|
|
|
|
|
|
)
|
|
|
|
|
|
.order_by(Segment.created_at)
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
2026-04-02 16:37:14 +08:00
|
|
|
|
return list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def conversation_has_pending_phase2(
|
|
|
|
|
|
conversation_id: str, db: AsyncSession
|
|
|
|
|
|
) -> bool:
|
|
|
|
|
|
"""Phase1 已完成但叙事未消费的片段(本会话范围内)。"""
|
|
|
|
|
|
stmt = select(func.count(Segment.id)).where(
|
|
|
|
|
|
Segment.conversation_id == conversation_id,
|
|
|
|
|
|
Segment.topic_category.isnot(None),
|
|
|
|
|
|
Segment.narrated.is_(False),
|
|
|
|
|
|
Segment.skip_narrative.is_(False),
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return int(result.scalar() or 0) > 0
|
2026-03-18 17:18:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def count_segments_for_user(user_id: str, db: AsyncSession) -> int:
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(func.count(Segment.id))
|
|
|
|
|
|
.select_from(Segment)
|
|
|
|
|
|
.join(Conversation, Segment.conversation_id == Conversation.id)
|
2026-03-20 15:15:35 +08:00
|
|
|
|
.where(
|
|
|
|
|
|
Conversation.user_id == user_id,
|
|
|
|
|
|
Conversation.deleted_at.is_(None),
|
|
|
|
|
|
)
|
2026-03-18 17:18:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return result.scalar() or 0
|