Files
life-echo/api/app/features/conversation/repo.py

138 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Conversation repository — Conversation, turn log, and Segment data access."""
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.features.conversation.models import Conversation, ConversationMessage, Segment
async def get_conversation(
conversation_id: str, db: AsyncSession
) -> Conversation | None:
return await db.get(Conversation, conversation_id)
async def get_user_conversations(user_id: str, db: AsyncSession) -> list[Conversation]:
stmt = (
select(Conversation)
.where(
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
.order_by(
func.coalesce(Conversation.last_message_at, Conversation.started_at).desc()
)
)
result = await db.execute(stmt)
return list(result.scalars().all())
def add_conversation(conv: Conversation, db: AsyncSession) -> None:
db.add(conv)
def add_conversation_message(msg: ConversationMessage, db: AsyncSession) -> None:
db.add(msg)
async def get_conversation_messages(
conversation_id: str, db: AsyncSession
) -> list[ConversationMessage]:
stmt = (
select(ConversationMessage)
.where(ConversationMessage.conversation_id == conversation_id)
.order_by(ConversationMessage.created_at)
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def set_latest_ai_message_tts_audio_urls(
conversation_id: str,
db: AsyncSession,
*,
tts_audio_urls: list[str],
segment_id: str | None = None,
) -> ConversationMessage | None:
stmt = select(ConversationMessage).where(
ConversationMessage.conversation_id == conversation_id,
ConversationMessage.role == "ai",
)
if segment_id is not None:
stmt = stmt.where(ConversationMessage.segment_id == segment_id)
stmt = stmt.order_by(ConversationMessage.created_at.desc())
result = await db.execute(stmt)
row = result.scalars().first()
if row is None:
return None
row.tts_audio_urls = list(tts_audio_urls)
return row
async def get_segments_for_conversation(
conversation_id: str, db: AsyncSession
) -> list[Segment]:
stmt = (
select(Segment)
.where(Segment.conversation_id == conversation_id)
.order_by(Segment.created_at)
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def get_segments_for_organize(
conversation_id: str, db: AsyncSession
) -> list[Segment]:
"""兼容旧语义:优先返回 Phase1 未完成的片段;若无则返回本会话全部片段。"""
pending = await get_segments_pending_phase1(conversation_id, db)
if pending:
return pending
return await get_segments_for_conversation(conversation_id, db)
async def get_segments_pending_phase1(
conversation_id: str, db: AsyncSession
) -> list[Segment]:
"""尚未跑 Phase1 分类的 segmentstopic_category 为空且未标记 narrated"""
stmt = (
select(Segment)
.where(
Segment.conversation_id == conversation_id,
Segment.topic_category.is_(None),
Segment.narrated.is_(False),
Segment.processed.is_(False),
)
.order_by(Segment.created_at)
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def conversation_has_pending_phase2(
conversation_id: str, db: AsyncSession
) -> bool:
"""Phase1 已完成但叙事未消费的片段(本会话范围内)。"""
stmt = select(func.count(Segment.id)).where(
Segment.conversation_id == conversation_id,
Segment.topic_category.isnot(None),
Segment.narrated.is_(False),
Segment.skip_narrative.is_(False),
)
result = await db.execute(stmt)
return int(result.scalar() or 0) > 0
async def count_segments_for_user(user_id: str, db: AsyncSession) -> int:
stmt = (
select(func.count(Segment.id))
.select_from(Segment)
.join(Conversation, Segment.conversation_id == Conversation.id)
.where(
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
)
result = await db.execute(stmt)
return result.scalar() or 0