Route all memory ingest/retrieve/enrichment/compaction through async MemoryService. Remove legacy sync memory implementations (ingest/retrieve/compaction); Celery and memoir Phase2 call asyncio.run into MemoryService-backed helpers. Memoir Phase1 batch ingest uses MemoryService.ingest_transcripts_batch; drop chapters. evidence_bundle_json mirror (Alembic 0015). Evaluation uses snapshot/link-only bundles; raise EvidenceClosureMissing instead of partial/fallback lineage tiers. Split memoir state into NarrativeCoverageState and InterviewControlState; delete the _interview_meta_store adapter layer. Remove rolling-query and recent-fact fallback settings from config and evidence assembly. Update judges, docs, tests, and PlaygroundPage alignment. Made-with: Cursor
248 lines
7.7 KiB
Python
248 lines
7.7 KiB
Python
"""评测取证 repo:所有查询显式携带 user_id,避免跨租户串数据。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from sqlalchemy import or_, select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import joinedload
|
||
|
||
from app.features.conversation.models import Conversation, ConversationMessage, Segment
|
||
from app.features.memoir.models import Chapter, ChapterStoryLink
|
||
from app.features.memory.models import (
|
||
MemoryChunk,
|
||
MemoryFact,
|
||
MemorySummary,
|
||
TimelineEvent,
|
||
)
|
||
from app.features.story.models import Story, StoryEvidenceLink
|
||
|
||
|
||
def normalize_source_segment_ids(raw: object) -> list[str]:
|
||
"""chapter.source_segments:历史为 list[str]。"""
|
||
if not raw:
|
||
return []
|
||
if isinstance(raw, list):
|
||
out: list[str] = []
|
||
for x in raw:
|
||
s = str(x).strip()
|
||
if s:
|
||
out.append(s)
|
||
# 保序去重
|
||
seen: set[str] = set()
|
||
deduped: list[str] = []
|
||
for s in out:
|
||
if s not in seen:
|
||
seen.add(s)
|
||
deduped.append(s)
|
||
return deduped
|
||
return []
|
||
|
||
|
||
async def get_chapter_for_eval_trace(
|
||
db: AsyncSession, *, user_id: str, chapter_id: str
|
||
) -> Chapter | None:
|
||
stmt = (
|
||
select(Chapter)
|
||
.where(Chapter.id == chapter_id, Chapter.user_id == user_id)
|
||
.options(joinedload(Chapter.current_evidence_snapshot))
|
||
)
|
||
result = await db.execute(stmt)
|
||
return result.unique().scalar_one_or_none()
|
||
|
||
|
||
async def get_story_for_eval_trace(
|
||
db: AsyncSession, *, user_id: str, story_id: str
|
||
) -> Story | None:
|
||
stmt = (
|
||
select(Story)
|
||
.where(Story.id == story_id, Story.user_id == user_id)
|
||
.options(joinedload(Story.evidence_links))
|
||
)
|
||
result = await db.execute(stmt)
|
||
return result.unique().scalar_one_or_none()
|
||
|
||
|
||
async def list_chapter_ids_for_story(
|
||
db: AsyncSession, *, user_id: str, story_id: str
|
||
) -> list[str]:
|
||
stmt = (
|
||
select(ChapterStoryLink.chapter_id)
|
||
.join(Chapter, ChapterStoryLink.chapter_id == Chapter.id)
|
||
.where(ChapterStoryLink.story_id == story_id, Chapter.user_id == user_id)
|
||
)
|
||
result = await db.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
|
||
async def fetch_segments_for_user(
|
||
db: AsyncSession, *, user_id: str, segment_ids: list[str]
|
||
) -> list[Segment]:
|
||
if not segment_ids:
|
||
return []
|
||
stmt = (
|
||
select(Segment)
|
||
.join(Conversation, Segment.conversation_id == Conversation.id)
|
||
.where(
|
||
Segment.id.in_(segment_ids),
|
||
Conversation.user_id == user_id,
|
||
Conversation.deleted_at.is_(None),
|
||
)
|
||
.order_by(Segment.created_at)
|
||
)
|
||
result = await db.execute(stmt)
|
||
rows = list(result.scalars().all())
|
||
order = {sid: i for i, sid in enumerate(segment_ids)}
|
||
return sorted(rows, key=lambda s: order.get(s.id, 9999))
|
||
|
||
|
||
async def fetch_turn_refs_for_segments(
|
||
db: AsyncSession, *, user_id: str, segment_ids: list[str]
|
||
) -> dict[str, dict[str, str | None]]:
|
||
"""
|
||
segment_id -> { user_message_id, assistant_message_id }(按 message created_at 取首条配对)。
|
||
"""
|
||
if not segment_ids:
|
||
return {}
|
||
human_stmt = (
|
||
select(ConversationMessage.segment_id, ConversationMessage.id)
|
||
.join(Conversation, ConversationMessage.conversation_id == Conversation.id)
|
||
.where(
|
||
ConversationMessage.segment_id.in_(segment_ids),
|
||
ConversationMessage.role == "human",
|
||
Conversation.user_id == user_id,
|
||
Conversation.deleted_at.is_(None),
|
||
)
|
||
.order_by(ConversationMessage.created_at)
|
||
)
|
||
ai_stmt = (
|
||
select(ConversationMessage.segment_id, ConversationMessage.id)
|
||
.join(Conversation, ConversationMessage.conversation_id == Conversation.id)
|
||
.where(
|
||
ConversationMessage.segment_id.in_(segment_ids),
|
||
ConversationMessage.role == "ai",
|
||
Conversation.user_id == user_id,
|
||
Conversation.deleted_at.is_(None),
|
||
)
|
||
.order_by(ConversationMessage.created_at)
|
||
)
|
||
h_result = await db.execute(human_stmt)
|
||
u_map: dict[str, str] = {}
|
||
for seg_id, mid in h_result.all():
|
||
if seg_id and str(seg_id) not in u_map:
|
||
u_map[str(seg_id)] = str(mid)
|
||
a_result = await db.execute(ai_stmt)
|
||
a_map: dict[str, str] = {}
|
||
for seg_id, mid in a_result.all():
|
||
if seg_id and str(seg_id) not in a_map:
|
||
a_map[str(seg_id)] = str(mid)
|
||
out: dict[str, dict[str, str | None]] = {}
|
||
for sid in segment_ids:
|
||
ss = str(sid)
|
||
out[ss] = {
|
||
"user_message_id": u_map.get(ss),
|
||
"assistant_message_id": a_map.get(ss),
|
||
}
|
||
return out
|
||
|
||
|
||
async def fetch_ai_messages_for_segments(
|
||
db: AsyncSession, *, user_id: str, segment_ids: list[str]
|
||
) -> dict[str, str]:
|
||
"""segment_id -> AI 消息正文(优先 durable log)。"""
|
||
if not segment_ids:
|
||
return {}
|
||
stmt = (
|
||
select(ConversationMessage.segment_id, ConversationMessage.content)
|
||
.join(Conversation, ConversationMessage.conversation_id == Conversation.id)
|
||
.where(
|
||
ConversationMessage.segment_id.in_(segment_ids),
|
||
ConversationMessage.role == "ai",
|
||
Conversation.user_id == user_id,
|
||
Conversation.deleted_at.is_(None),
|
||
)
|
||
.order_by(ConversationMessage.created_at)
|
||
)
|
||
result = await db.execute(stmt)
|
||
out: dict[str, str] = {}
|
||
for seg_id, content in result.all():
|
||
if seg_id and seg_id not in out:
|
||
out[str(seg_id)] = str(content or "")
|
||
return out
|
||
|
||
|
||
async def load_chunks_by_ids(
|
||
db: AsyncSession, *, user_id: str, chunk_ids: list[str]
|
||
) -> list[MemoryChunk]:
|
||
if not chunk_ids:
|
||
return []
|
||
stmt = select(MemoryChunk).where(
|
||
MemoryChunk.user_id == user_id,
|
||
MemoryChunk.id.in_(chunk_ids),
|
||
MemoryChunk.is_excluded.is_(False),
|
||
)
|
||
result = await db.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
|
||
async def load_facts_by_ids(
|
||
db: AsyncSession, *, user_id: str, fact_ids: list[str]
|
||
) -> list[MemoryFact]:
|
||
if not fact_ids:
|
||
return []
|
||
stmt = select(MemoryFact).where(
|
||
MemoryFact.user_id == user_id,
|
||
MemoryFact.id.in_(fact_ids),
|
||
or_(MemoryFact.status.is_(None), MemoryFact.status != "stale"),
|
||
)
|
||
result = await db.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
|
||
async def load_timeline_by_ids(
|
||
db: AsyncSession, *, user_id: str, event_ids: list[str]
|
||
) -> list[TimelineEvent]:
|
||
if not event_ids:
|
||
return []
|
||
stmt = select(TimelineEvent).where(
|
||
TimelineEvent.user_id == user_id,
|
||
TimelineEvent.id.in_(event_ids),
|
||
)
|
||
result = await db.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
|
||
async def load_summaries_by_ids(
|
||
db: AsyncSession, *, user_id: str, summary_ids: list[str]
|
||
) -> list[MemorySummary]:
|
||
if not summary_ids:
|
||
return []
|
||
stmt = select(MemorySummary).where(
|
||
MemorySummary.user_id == user_id,
|
||
MemorySummary.id.in_(summary_ids),
|
||
)
|
||
result = await db.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
|
||
def story_link_ids_by_type(
|
||
links: list[StoryEvidenceLink],
|
||
) -> tuple[list[str], list[str], list[str], list[str]]:
|
||
chunks: list[str] = []
|
||
facts: list[str] = []
|
||
timelines: list[str] = []
|
||
summaries: list[str] = []
|
||
for ln in links:
|
||
et = (ln.evidence_type or "").strip()
|
||
eid = (ln.evidence_id or "").strip()
|
||
if not eid:
|
||
continue
|
||
if et == "chunk":
|
||
chunks.append(eid)
|
||
elif et == "fact":
|
||
facts.append(eid)
|
||
elif et == "timeline_event":
|
||
timelines.append(eid)
|
||
elif et == "summary":
|
||
summaries.append(eid)
|
||
return chunks, facts, timelines, summaries
|