"""评测取证 repo:所有查询显式携带 user_id,避免跨租户串数据。""" from __future__ import annotations from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from app.features.conversation.models import Conversation, ConversationMessage, Segment from app.features.memoir.models import Chapter, ChapterStoryLink from app.features.memory.models import ( MemoryChunk, MemoryFact, MemorySource, MemorySummary, TimelineEvent, ) from app.features.story.models import Story, StoryEvidenceLink def normalize_source_segment_ids(raw: object) -> list[str]: """chapter.source_segments:历史为 list[str]。""" if not raw: return [] if isinstance(raw, list): out: list[str] = [] for x in raw: s = str(x).strip() if s: out.append(s) # 保序去重 seen: set[str] = set() deduped: list[str] = [] for s in out: if s not in seen: seen.add(s) deduped.append(s) return deduped return [] async def get_chapter_for_eval_trace( db: AsyncSession, *, user_id: str, chapter_id: str ) -> Chapter | None: stmt = ( select(Chapter) .where(Chapter.id == chapter_id, Chapter.user_id == user_id) .options(joinedload(Chapter.current_evidence_snapshot)) ) result = await db.execute(stmt) return result.unique().scalar_one_or_none() async def get_story_for_eval_trace( db: AsyncSession, *, user_id: str, story_id: str ) -> Story | None: stmt = ( select(Story) .where(Story.id == story_id, Story.user_id == user_id) .options(joinedload(Story.evidence_links)) ) result = await db.execute(stmt) return result.unique().scalar_one_or_none() async def list_chapter_ids_for_story( db: AsyncSession, *, user_id: str, story_id: str ) -> list[str]: stmt = ( select(ChapterStoryLink.chapter_id) .join(Chapter, ChapterStoryLink.chapter_id == Chapter.id) .where(ChapterStoryLink.story_id == story_id, Chapter.user_id == user_id) ) result = await db.execute(stmt) return list(result.scalars().all()) async def fetch_segments_for_user( db: AsyncSession, *, user_id: str, segment_ids: list[str] ) -> list[Segment]: if not segment_ids: return [] stmt = ( select(Segment) .join(Conversation, Segment.conversation_id == Conversation.id) .where( Segment.id.in_(segment_ids), Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) .order_by(Segment.created_at) ) result = await db.execute(stmt) rows = list(result.scalars().all()) order = {sid: i for i, sid in enumerate(segment_ids)} return sorted(rows, key=lambda s: order.get(s.id, 9999)) async def fetch_turn_refs_for_segments( db: AsyncSession, *, user_id: str, segment_ids: list[str] ) -> dict[str, dict[str, str | None]]: """ segment_id -> { user_message_id, assistant_message_id }(按 message created_at 取首条配对)。 """ if not segment_ids: return {} human_stmt = ( select(ConversationMessage.segment_id, ConversationMessage.id) .join(Conversation, ConversationMessage.conversation_id == Conversation.id) .where( ConversationMessage.segment_id.in_(segment_ids), ConversationMessage.role == "human", Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) .order_by(ConversationMessage.created_at) ) ai_stmt = ( select(ConversationMessage.segment_id, ConversationMessage.id) .join(Conversation, ConversationMessage.conversation_id == Conversation.id) .where( ConversationMessage.segment_id.in_(segment_ids), ConversationMessage.role == "ai", Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) .order_by(ConversationMessage.created_at) ) h_result = await db.execute(human_stmt) u_map: dict[str, str] = {} for seg_id, mid in h_result.all(): if seg_id and str(seg_id) not in u_map: u_map[str(seg_id)] = str(mid) a_result = await db.execute(ai_stmt) a_map: dict[str, str] = {} for seg_id, mid in a_result.all(): if seg_id and str(seg_id) not in a_map: a_map[str(seg_id)] = str(mid) out: dict[str, dict[str, str | None]] = {} for sid in segment_ids: ss = str(sid) out[ss] = { "user_message_id": u_map.get(ss), "assistant_message_id": a_map.get(ss), } return out async def fetch_ai_messages_for_segments( db: AsyncSession, *, user_id: str, segment_ids: list[str] ) -> dict[str, str]: """segment_id -> AI 消息正文(优先 durable log)。""" if not segment_ids: return {} stmt = ( select(ConversationMessage.segment_id, ConversationMessage.content) .join(Conversation, ConversationMessage.conversation_id == Conversation.id) .where( ConversationMessage.segment_id.in_(segment_ids), ConversationMessage.role == "ai", Conversation.user_id == user_id, Conversation.deleted_at.is_(None), ) .order_by(ConversationMessage.created_at) ) result = await db.execute(stmt) out: dict[str, str] = {} for seg_id, content in result.all(): if seg_id and seg_id not in out: out[str(seg_id)] = str(content or "") return out async def fetch_memory_closure_for_conversations( db: AsyncSession, *, user_id: str, conversation_ids: list[str] ) -> tuple[list[str], list[str], list[str], list[str]]: """ 返回 (chunk_ids, fact_ids, timeline_event_ids, summary_ids),均限定 user_id。 路径:MemorySource(conversation_id) -> chunks;facts by source_chunk_id; timeline by memory_source_id;summaries 仅 rolling + 与会话 chunk 有交集的(轻量近似)。 """ if not conversation_ids: return [], [], [], [] conv_set = list({c for c in conversation_ids if c}) src_stmt = select(MemorySource).where( MemorySource.user_id == user_id, MemorySource.conversation_id.in_(conv_set), ) src_result = await db.execute(src_stmt) sources = list(src_result.scalars().all()) source_ids = [s.id for s in sources] if not source_ids: return [], [], [], [] ch_stmt = select(MemoryChunk).where( MemoryChunk.user_id == user_id, MemoryChunk.source_id.in_(source_ids), MemoryChunk.is_excluded.is_(False), ) ch_result = await db.execute(ch_stmt) chunks = list(ch_result.scalars().all()) chunk_ids = [c.id for c in chunks] if not chunk_ids: fact_rows: list[MemoryFact] = [] else: f_stmt = select(MemoryFact).where( MemoryFact.user_id == user_id, MemoryFact.source_chunk_id.in_(chunk_ids), or_(MemoryFact.status.is_(None), MemoryFact.status != "stale"), ) f_result = await db.execute(f_stmt) fact_rows = list(f_result.scalars().all()) fact_ids = [f.id for f in fact_rows] te_stmt = select(TimelineEvent).where( TimelineEvent.user_id == user_id, TimelineEvent.memory_source_id.in_(source_ids), ) te_result = await db.execute(te_stmt) ev_rows = list(te_result.scalars().all()) timeline_ids = [e.id for e in ev_rows] sum_stmt = ( select(MemorySummary) .where(MemorySummary.user_id == user_id) .order_by(MemorySummary.updated_at.desc()) .limit(12) ) sum_result = await db.execute(sum_stmt) summaries = list(sum_result.scalars().all()) chunk_set = set(chunk_ids) summary_ids: list[str] = [] for sm in summaries: if sm.summary_type == "rolling": summary_ids.append(sm.id) continue scids = sm.source_chunk_ids or [] if isinstance(scids, list) and chunk_set.intersection({str(x) for x in scids}): summary_ids.append(sm.id) return chunk_ids, fact_ids, timeline_ids, summary_ids async def load_chunks_by_ids( db: AsyncSession, *, user_id: str, chunk_ids: list[str] ) -> list[MemoryChunk]: if not chunk_ids: return [] stmt = select(MemoryChunk).where( MemoryChunk.user_id == user_id, MemoryChunk.id.in_(chunk_ids), MemoryChunk.is_excluded.is_(False), ) result = await db.execute(stmt) return list(result.scalars().all()) async def load_facts_by_ids( db: AsyncSession, *, user_id: str, fact_ids: list[str] ) -> list[MemoryFact]: if not fact_ids: return [] stmt = select(MemoryFact).where( MemoryFact.user_id == user_id, MemoryFact.id.in_(fact_ids), or_(MemoryFact.status.is_(None), MemoryFact.status != "stale"), ) result = await db.execute(stmt) return list(result.scalars().all()) async def load_timeline_by_ids( db: AsyncSession, *, user_id: str, event_ids: list[str] ) -> list[TimelineEvent]: if not event_ids: return [] stmt = select(TimelineEvent).where( TimelineEvent.user_id == user_id, TimelineEvent.id.in_(event_ids), ) result = await db.execute(stmt) return list(result.scalars().all()) async def load_summaries_by_ids( db: AsyncSession, *, user_id: str, summary_ids: list[str] ) -> list[MemorySummary]: if not summary_ids: return [] stmt = select(MemorySummary).where( MemorySummary.user_id == user_id, MemorySummary.id.in_(summary_ids), ) result = await db.execute(stmt) return list(result.scalars().all()) def story_link_ids_by_type(links: list[StoryEvidenceLink]) -> tuple[ list[str], list[str], list[str], list[str] ]: chunks: list[str] = [] facts: list[str] = [] timelines: list[str] = [] summaries: list[str] = [] for ln in links: et = (ln.evidence_type or "").strip() eid = (ln.evidence_id or "").strip() if not eid: continue if et == "chunk": chunks.append(eid) elif et == "fact": facts.append(eid) elif et == "timeline_event": timelines.append(eid) elif et == "summary": summaries.append(eid) return chunks, facts, timelines, summaries