"""评测取证 repo:所有查询显式携带 user_id,避免跨租户串数据。""" from __future__ import annotations from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from app.features.conversation.models import Conversation, ConversationMessage, Segment from app.features.memoir.models import Chapter, ChapterStoryLink from app.features.memory.models import ( MemoryChunk, MemoryFact, MemorySummary, TimelineEvent, ) from app.features.story.models import Story, StoryEvidenceLink def normalize_source_segment_ids(raw: object) -> list[str]: """chapter.source_segments:历史为 list[str]。""" if not raw: return [] if isinstance(raw, list): out: list[str] = [] for x in raw: s = str(x).strip() if s: out.append(s) # 保序去重 seen: set[str] = set() deduped: list[str] = [] for s in out: if s not in seen: seen.add(s) deduped.append(s) return deduped return [] async def get_chapter_for_eval_trace( db: AsyncSession, *, user_id: str, chapter_id: str ) -> Chapter | None: stmt = ( select(Chapter) .where(Chapter.id == chapter_id, Chapter.user_id == user_id) .options(joinedload(Chapter.current_evidence_snapshot)) ) result = await db.execute(stmt) return result.unique().scalar_one_or_none() async def get_story_for_eval_trace( db: AsyncSession, *, user_id: str, story_id: str ) -> Story | None: stmt = ( select(Story) .where(Story.id == story_id, Story.user_id == user_id) .options(joinedload(Story.evidence_links)) ) result = await db.execute(stmt) return result.unique().scalar_one_or_none() async def list_chapter_ids_for_story( db: AsyncSession, *, user_id: str, story_id: str ) -> list[str]: stmt = ( select(ChapterStoryLink.chapter_id) .join(Chapter, ChapterStoryLink.chapter_id == Chapter.id) .where(ChapterStoryLink.story_id == story_id, Chapter.user_id == user_id) ) result = await db.execute(stmt) return list(result.scalars().all()) async def fetch_segments_for_user( db: AsyncSession, *, user_id: str, segment_ids: list[str] ) -> list[Segment]: if not segment_ids: return [] stmt = ( select(Segment) .join(Conversation, Segment.conversation_id == Conversation.id) .where( Segment.id.in_(segment_ids), Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) .order_by(Segment.created_at) ) result = await db.execute(stmt) rows = list(result.scalars().all()) order = {sid: i for i, sid in enumerate(segment_ids)} return sorted(rows, key=lambda s: order.get(s.id, 9999)) async def fetch_turn_refs_for_segments( db: AsyncSession, *, user_id: str, segment_ids: list[str] ) -> dict[str, dict[str, str | None]]: """ segment_id -> { user_message_id, assistant_message_id }(按 message created_at 取首条配对)。 """ if not segment_ids: return {} human_stmt = ( select(ConversationMessage.segment_id, ConversationMessage.id) .join(Conversation, ConversationMessage.conversation_id == Conversation.id) .where( ConversationMessage.segment_id.in_(segment_ids), ConversationMessage.role == "human", Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) .order_by(ConversationMessage.created_at) ) ai_stmt = ( select(ConversationMessage.segment_id, ConversationMessage.id) .join(Conversation, ConversationMessage.conversation_id == Conversation.id) .where( ConversationMessage.segment_id.in_(segment_ids), ConversationMessage.role == "ai", Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) .order_by(ConversationMessage.created_at) ) h_result = await db.execute(human_stmt) u_map: dict[str, str] = {} for seg_id, mid in h_result.all(): if seg_id and str(seg_id) not in u_map: u_map[str(seg_id)] = str(mid) a_result = await db.execute(ai_stmt) a_map: dict[str, str] = {} for seg_id, mid in a_result.all(): if seg_id and str(seg_id) not in a_map: a_map[str(seg_id)] = str(mid) out: dict[str, dict[str, str | None]] = {} for sid in segment_ids: ss = str(sid) out[ss] = { "user_message_id": u_map.get(ss), "assistant_message_id": a_map.get(ss), } return out async def fetch_ai_messages_for_segments( db: AsyncSession, *, user_id: str, segment_ids: list[str] ) -> dict[str, str]: """segment_id -> AI 消息正文(优先 durable log)。""" if not segment_ids: return {} stmt = ( select(ConversationMessage.segment_id, ConversationMessage.content) .join(Conversation, ConversationMessage.conversation_id == Conversation.id) .where( ConversationMessage.segment_id.in_(segment_ids), ConversationMessage.role == "ai", Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) .order_by(ConversationMessage.created_at) ) result = await db.execute(stmt) out: dict[str, str] = {} for seg_id, content in result.all(): if seg_id and seg_id not in out: out[str(seg_id)] = str(content or "") return out async def load_chunks_by_ids( db: AsyncSession, *, user_id: str, chunk_ids: list[str] ) -> list[MemoryChunk]: if not chunk_ids: return [] stmt = select(MemoryChunk).where( MemoryChunk.user_id == user_id, MemoryChunk.id.in_(chunk_ids), MemoryChunk.is_excluded.is_(False), ) result = await db.execute(stmt) return list(result.scalars().all()) async def load_facts_by_ids( db: AsyncSession, *, user_id: str, fact_ids: list[str] ) -> list[MemoryFact]: if not fact_ids: return [] stmt = select(MemoryFact).where( MemoryFact.user_id == user_id, MemoryFact.id.in_(fact_ids), or_(MemoryFact.status.is_(None), MemoryFact.status != "stale"), ) result = await db.execute(stmt) return list(result.scalars().all()) async def load_timeline_by_ids( db: AsyncSession, *, user_id: str, event_ids: list[str] ) -> list[TimelineEvent]: if not event_ids: return [] stmt = select(TimelineEvent).where( TimelineEvent.user_id == user_id, TimelineEvent.id.in_(event_ids), ) result = await db.execute(stmt) return list(result.scalars().all()) async def load_summaries_by_ids( db: AsyncSession, *, user_id: str, summary_ids: list[str] ) -> list[MemorySummary]: if not summary_ids: return [] stmt = select(MemorySummary).where( MemorySummary.user_id == user_id, MemorySummary.id.in_(summary_ids), ) result = await db.execute(stmt) return list(result.scalars().all()) def story_link_ids_by_type( links: list[StoryEvidenceLink], ) -> tuple[list[str], list[str], list[str], list[str]]: chunks: list[str] = [] facts: list[str] = [] timelines: list[str] = [] summaries: list[str] = [] for ln in links: et = (ln.evidence_type or "").strip() eid = (ln.evidence_id or "").strip() if not eid: continue if et == "chunk": chunks.append(eid) elif et == "fact": facts.append(eid) elif et == "timeline_event": timelines.append(eid) elif et == "summary": summaries.append(eid) return chunks, facts, timelines, summaries