Files
life-echo/api/app/features/evaluation/eval_trace_repo.py
yangshilin e1341c6d18 feat:
1. 建立问题库大纲,对应每个人生阶段槽位
2. 鼓励使用更生活化的交流语言共情与总结
3. 降低评审模型可能发生截断的概率
4. 成稿质量维度强化情感表达和上下文连贯性
2026-04-09 15:32:35 +08:00

320 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""评测取证 repo所有查询显式携带 user_id避免跨租户串数据。"""
from __future__ import annotations
from sqlalchemy import or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.features.conversation.models import Conversation, ConversationMessage, Segment
from app.features.memoir.models import Chapter, ChapterStoryLink
from app.features.memory.models import (
MemoryChunk,
MemoryFact,
MemorySource,
MemorySummary,
TimelineEvent,
)
from app.features.story.models import Story, StoryEvidenceLink
def normalize_source_segment_ids(raw: object) -> list[str]:
"""chapter.source_segments历史为 list[str]。"""
if not raw:
return []
if isinstance(raw, list):
out: list[str] = []
for x in raw:
s = str(x).strip()
if s:
out.append(s)
# 保序去重
seen: set[str] = set()
deduped: list[str] = []
for s in out:
if s not in seen:
seen.add(s)
deduped.append(s)
return deduped
return []
async def get_chapter_for_eval_trace(
db: AsyncSession, *, user_id: str, chapter_id: str
) -> Chapter | None:
stmt = (
select(Chapter)
.where(Chapter.id == chapter_id, Chapter.user_id == user_id)
.options(joinedload(Chapter.current_evidence_snapshot))
)
result = await db.execute(stmt)
return result.unique().scalar_one_or_none()
async def get_story_for_eval_trace(
db: AsyncSession, *, user_id: str, story_id: str
) -> Story | None:
stmt = (
select(Story)
.where(Story.id == story_id, Story.user_id == user_id)
.options(joinedload(Story.evidence_links))
)
result = await db.execute(stmt)
return result.unique().scalar_one_or_none()
async def list_chapter_ids_for_story(
db: AsyncSession, *, user_id: str, story_id: str
) -> list[str]:
stmt = (
select(ChapterStoryLink.chapter_id)
.join(Chapter, ChapterStoryLink.chapter_id == Chapter.id)
.where(ChapterStoryLink.story_id == story_id, Chapter.user_id == user_id)
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def fetch_segments_for_user(
db: AsyncSession, *, user_id: str, segment_ids: list[str]
) -> list[Segment]:
if not segment_ids:
return []
stmt = (
select(Segment)
.join(Conversation, Segment.conversation_id == Conversation.id)
.where(
Segment.id.in_(segment_ids),
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
.order_by(Segment.created_at)
)
result = await db.execute(stmt)
rows = list(result.scalars().all())
order = {sid: i for i, sid in enumerate(segment_ids)}
return sorted(rows, key=lambda s: order.get(s.id, 9999))
async def fetch_turn_refs_for_segments(
db: AsyncSession, *, user_id: str, segment_ids: list[str]
) -> dict[str, dict[str, str | None]]:
"""
segment_id -> { user_message_id, assistant_message_id }(按 message created_at 取首条配对)。
"""
if not segment_ids:
return {}
human_stmt = (
select(ConversationMessage.segment_id, ConversationMessage.id)
.join(Conversation, ConversationMessage.conversation_id == Conversation.id)
.where(
ConversationMessage.segment_id.in_(segment_ids),
ConversationMessage.role == "human",
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
.order_by(ConversationMessage.created_at)
)
ai_stmt = (
select(ConversationMessage.segment_id, ConversationMessage.id)
.join(Conversation, ConversationMessage.conversation_id == Conversation.id)
.where(
ConversationMessage.segment_id.in_(segment_ids),
ConversationMessage.role == "ai",
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
.order_by(ConversationMessage.created_at)
)
h_result = await db.execute(human_stmt)
u_map: dict[str, str] = {}
for seg_id, mid in h_result.all():
if seg_id and str(seg_id) not in u_map:
u_map[str(seg_id)] = str(mid)
a_result = await db.execute(ai_stmt)
a_map: dict[str, str] = {}
for seg_id, mid in a_result.all():
if seg_id and str(seg_id) not in a_map:
a_map[str(seg_id)] = str(mid)
out: dict[str, dict[str, str | None]] = {}
for sid in segment_ids:
ss = str(sid)
out[ss] = {
"user_message_id": u_map.get(ss),
"assistant_message_id": a_map.get(ss),
}
return out
async def fetch_ai_messages_for_segments(
db: AsyncSession, *, user_id: str, segment_ids: list[str]
) -> dict[str, str]:
"""segment_id -> AI 消息正文(优先 durable log"""
if not segment_ids:
return {}
stmt = (
select(ConversationMessage.segment_id, ConversationMessage.content)
.join(Conversation, ConversationMessage.conversation_id == Conversation.id)
.where(
ConversationMessage.segment_id.in_(segment_ids),
ConversationMessage.role == "ai",
Conversation.user_id == user_id,
Conversation.deleted_at.is_(None),
)
.order_by(ConversationMessage.created_at)
)
result = await db.execute(stmt)
out: dict[str, str] = {}
for seg_id, content in result.all():
if seg_id and seg_id not in out:
out[str(seg_id)] = str(content or "")
return out
async def fetch_memory_closure_for_conversations(
db: AsyncSession, *, user_id: str, conversation_ids: list[str]
) -> tuple[list[str], list[str], list[str], list[str]]:
"""
返回 (chunk_ids, fact_ids, timeline_event_ids, summary_ids),均限定 user_id。
路径MemorySource(conversation_id) -> chunksfacts by source_chunk_id
timeline by memory_source_idsummaries 仅 rolling + 与会话 chunk 有交集的(轻量近似)。
"""
if not conversation_ids:
return [], [], [], []
conv_set = list({c for c in conversation_ids if c})
src_stmt = select(MemorySource).where(
MemorySource.user_id == user_id,
MemorySource.conversation_id.in_(conv_set),
)
src_result = await db.execute(src_stmt)
sources = list(src_result.scalars().all())
source_ids = [s.id for s in sources]
if not source_ids:
return [], [], [], []
ch_stmt = select(MemoryChunk).where(
MemoryChunk.user_id == user_id,
MemoryChunk.source_id.in_(source_ids),
MemoryChunk.is_excluded.is_(False),
)
ch_result = await db.execute(ch_stmt)
chunks = list(ch_result.scalars().all())
chunk_ids = [c.id for c in chunks]
if not chunk_ids:
fact_rows: list[MemoryFact] = []
else:
f_stmt = select(MemoryFact).where(
MemoryFact.user_id == user_id,
MemoryFact.source_chunk_id.in_(chunk_ids),
or_(MemoryFact.status.is_(None), MemoryFact.status != "stale"),
)
f_result = await db.execute(f_stmt)
fact_rows = list(f_result.scalars().all())
fact_ids = [f.id for f in fact_rows]
te_stmt = select(TimelineEvent).where(
TimelineEvent.user_id == user_id,
TimelineEvent.memory_source_id.in_(source_ids),
)
te_result = await db.execute(te_stmt)
ev_rows = list(te_result.scalars().all())
timeline_ids = [e.id for e in ev_rows]
sum_stmt = (
select(MemorySummary)
.where(MemorySummary.user_id == user_id)
.order_by(MemorySummary.updated_at.desc())
.limit(12)
)
sum_result = await db.execute(sum_stmt)
summaries = list(sum_result.scalars().all())
chunk_set = set(chunk_ids)
summary_ids: list[str] = []
for sm in summaries:
if sm.summary_type == "rolling":
summary_ids.append(sm.id)
continue
scids = sm.source_chunk_ids or []
if isinstance(scids, list) and chunk_set.intersection({str(x) for x in scids}):
summary_ids.append(sm.id)
return chunk_ids, fact_ids, timeline_ids, summary_ids
async def load_chunks_by_ids(
db: AsyncSession, *, user_id: str, chunk_ids: list[str]
) -> list[MemoryChunk]:
if not chunk_ids:
return []
stmt = select(MemoryChunk).where(
MemoryChunk.user_id == user_id,
MemoryChunk.id.in_(chunk_ids),
MemoryChunk.is_excluded.is_(False),
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def load_facts_by_ids(
db: AsyncSession, *, user_id: str, fact_ids: list[str]
) -> list[MemoryFact]:
if not fact_ids:
return []
stmt = select(MemoryFact).where(
MemoryFact.user_id == user_id,
MemoryFact.id.in_(fact_ids),
or_(MemoryFact.status.is_(None), MemoryFact.status != "stale"),
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def load_timeline_by_ids(
db: AsyncSession, *, user_id: str, event_ids: list[str]
) -> list[TimelineEvent]:
if not event_ids:
return []
stmt = select(TimelineEvent).where(
TimelineEvent.user_id == user_id,
TimelineEvent.id.in_(event_ids),
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def load_summaries_by_ids(
db: AsyncSession, *, user_id: str, summary_ids: list[str]
) -> list[MemorySummary]:
if not summary_ids:
return []
stmt = select(MemorySummary).where(
MemorySummary.user_id == user_id,
MemorySummary.id.in_(summary_ids),
)
result = await db.execute(stmt)
return list(result.scalars().all())
def story_link_ids_by_type(
links: list[StoryEvidenceLink],
) -> tuple[list[str], list[str], list[str], list[str]]:
chunks: list[str] = []
facts: list[str] = []
timelines: list[str] = []
summaries: list[str] = []
for ln in links:
et = (ln.evidence_type or "").strip()
eid = (ln.evidence_id or "").strip()
if not eid:
continue
if et == "chunk":
chunks.append(eid)
elif et == "fact":
facts.append(eid)
elif et == "timeline_event":
timelines.append(eid)
elif et == "summary":
summaries.append(eid)
return chunks, facts, timelines, summaries