数据库 - 新增迁移 0003:timeline_events.memory_source_id 外键 → memory_sources,便于按 ingest 源做时间线幂等 后端 - 记忆 - 新增 ingest 后 LLM 富化(摘要/事实/时间线),可配置开关与最大字符数 - 新增证据包组装:合并 chunk、摘要、事实、时间线、故事等检索结果;支持空 query 时是否仍带 rolling 等开关 - repo/retriever/service/router/schemas/summarizer/timeline/extractor 等扩展;文档 memory-retrieval.md 更新 后端 - 对话 WS - 增加 PING/PONG;分段 ASR 日志与空音频处理;转写失败与「无助手回复」错误提示更明确 - 助手多段回复持久化使用统一分隔符,与分段逻辑一致 后端 - Agent - reply_limits:按 [SPLIT] 与段落拆段,并保证非空 fallback,供 WS 与 TTS 多段下发 后端 - 回忆录任务 - transcript ingest 记录 source_id;任务成功结?
223 lines
6.3 KiB
Python
223 lines
6.3 KiB
Python
"""Story repository — Story, StoryVersion, StoryEvidenceLink data access."""
|
||
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
|
||
from sqlalchemy import delete, or_, select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.features.story.models import (
|
||
Story,
|
||
StoryEvidenceLink,
|
||
StoryImageIntent,
|
||
StoryVersion,
|
||
)
|
||
|
||
|
||
def _new_id() -> str:
|
||
return str(uuid.uuid4())
|
||
|
||
|
||
async def create_story(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
title: str,
|
||
stage: str | None = None,
|
||
story_type: str | None = None,
|
||
summary: str | None = None,
|
||
canonical_markdown: str | None = None,
|
||
) -> Story:
|
||
"""Create a story. Caller must commit."""
|
||
story = Story(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
title=title,
|
||
stage=stage,
|
||
story_type=story_type,
|
||
summary=summary,
|
||
canonical_markdown=canonical_markdown or "",
|
||
)
|
||
db.add(story)
|
||
return story
|
||
|
||
|
||
async def create_story_version(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: str,
|
||
version_no: int,
|
||
markdown_snapshot: str,
|
||
actor_type: str = "ai",
|
||
source_type: str = "generate",
|
||
parent_version_id: str | None = None,
|
||
prompt_meta: dict | None = None,
|
||
) -> StoryVersion:
|
||
"""Create a story version. Caller must commit."""
|
||
version = StoryVersion(
|
||
id=_new_id(),
|
||
story_id=story_id,
|
||
version_no=version_no,
|
||
markdown_snapshot=markdown_snapshot,
|
||
actor_type=actor_type,
|
||
source_type=source_type,
|
||
parent_version_id=parent_version_id,
|
||
prompt_meta=prompt_meta,
|
||
)
|
||
db.add(version)
|
||
return version
|
||
|
||
|
||
async def create_story_evidence_link(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: str,
|
||
evidence_type: str,
|
||
evidence_id: str,
|
||
role: str = "primary",
|
||
weight: float | None = None,
|
||
) -> StoryEvidenceLink:
|
||
"""Create story-evidence link. Caller must commit."""
|
||
link = StoryEvidenceLink(
|
||
id=_new_id(),
|
||
story_id=story_id,
|
||
evidence_type=evidence_type,
|
||
evidence_id=evidence_id,
|
||
role=role,
|
||
weight=weight,
|
||
)
|
||
db.add(link)
|
||
return link
|
||
|
||
|
||
async def get_story_by_id(db: AsyncSession, story_id: str) -> Story | None:
|
||
"""Fetch story by ID."""
|
||
return await db.get(Story, story_id)
|
||
|
||
|
||
async def get_stories_for_user(
|
||
db: AsyncSession, user_id: str, *, status: str | None = "active"
|
||
) -> list[Story]:
|
||
"""Fetch stories for user, optionally filtered by status."""
|
||
stmt = select(Story).where(Story.user_id == user_id)
|
||
if status:
|
||
stmt = stmt.where(Story.status == status)
|
||
stmt = stmt.order_by(Story.created_at.desc())
|
||
result = await db.execute(stmt)
|
||
return list(result.unique().scalars().all())
|
||
|
||
|
||
async def count_story_versions(db: AsyncSession, story_id: str) -> int:
|
||
"""Count versions for a story."""
|
||
from sqlalchemy import func
|
||
|
||
stmt = select(func.count(StoryVersion.id)).where(StoryVersion.story_id == story_id)
|
||
result = await db.execute(stmt)
|
||
return result.scalar() or 0
|
||
|
||
|
||
async def create_story_image_intent(
|
||
db: AsyncSession,
|
||
*,
|
||
story_id: str,
|
||
story_version_id: str | None,
|
||
caption: str,
|
||
prompt_brief: str,
|
||
style_profile: str | None = None,
|
||
) -> StoryImageIntent:
|
||
"""Create primary image intent for a story. Caller must commit."""
|
||
intent = StoryImageIntent(
|
||
id=_new_id(),
|
||
story_id=story_id,
|
||
story_version_id=story_version_id,
|
||
intent_role="primary",
|
||
caption=caption,
|
||
prompt_brief=prompt_brief,
|
||
style_profile=style_profile,
|
||
status="pending",
|
||
)
|
||
db.add(intent)
|
||
return intent
|
||
|
||
|
||
async def get_story_image_intent_by_story(
|
||
db: AsyncSession, story_id: str, *, role: str = "primary"
|
||
) -> StoryImageIntent | None:
|
||
"""Get primary image intent for a story."""
|
||
stmt = (
|
||
select(StoryImageIntent)
|
||
.where(StoryImageIntent.story_id == story_id)
|
||
.where(StoryImageIntent.intent_role == role)
|
||
)
|
||
result = await db.execute(stmt)
|
||
return result.unique().scalar_one_or_none()
|
||
|
||
|
||
async def delete_story_image_intents_by_story(
|
||
db: AsyncSession,
|
||
story_id: str,
|
||
*,
|
||
role: str = "primary",
|
||
statuses: list[str] | None = None,
|
||
) -> int:
|
||
"""
|
||
删除指定 story 的配图 intent。
|
||
statuses 为 None 时删除该 role 下全部;否则仅删除列出的状态(如仅清 pending/failed,避免打断 processing)。
|
||
"""
|
||
stmt = delete(StoryImageIntent).where(
|
||
StoryImageIntent.story_id == story_id,
|
||
StoryImageIntent.intent_role == role,
|
||
)
|
||
if statuses is not None:
|
||
stmt = stmt.where(StoryImageIntent.status.in_(statuses))
|
||
result = await db.execute(stmt)
|
||
return result.rowcount or 0
|
||
|
||
|
||
async def get_stories_by_ids(db: AsyncSession, story_ids: list[str]) -> list[Story]:
|
||
"""Fetch stories by IDs."""
|
||
if not story_ids:
|
||
return []
|
||
stmt = select(Story).where(Story.id.in_(story_ids))
|
||
result = await db.execute(stmt)
|
||
stories = list(result.unique().scalars().all())
|
||
order = {sid: i for i, sid in enumerate(story_ids)}
|
||
return sorted(stories, key=lambda s: order.get(s.id, 999))
|
||
|
||
|
||
async def list_recent_stories_for_evidence(
|
||
db: AsyncSession,
|
||
user_id: str,
|
||
*,
|
||
query: str | None = None,
|
||
limit: int = 5,
|
||
) -> list[Story]:
|
||
"""供 memory 检索:活跃故事,可选标题/摘要模糊匹配。"""
|
||
stmt = select(Story).where(Story.user_id == user_id).where(Story.status == "active")
|
||
q = (query or "").strip()
|
||
if q:
|
||
pat = f"%{q}%"
|
||
stmt = stmt.where(or_(Story.title.ilike(pat), Story.summary.ilike(pat)))
|
||
stmt = stmt.order_by(Story.updated_at.desc()).limit(limit)
|
||
result = await db.execute(stmt)
|
||
return list(result.unique().scalars().all())
|
||
|
||
|
||
def list_recent_stories_for_evidence_sync(
|
||
session: Session,
|
||
user_id: str,
|
||
*,
|
||
query: str | None = None,
|
||
limit: int = 5,
|
||
) -> list[Story]:
|
||
"""同步会话版 `list_recent_stories_for_evidence`(Celery / retrieve_evidence_sync)。"""
|
||
stmt = select(Story).where(Story.user_id == user_id).where(Story.status == "active")
|
||
q = (query or "").strip()
|
||
if q:
|
||
pat = f"%{q}%"
|
||
stmt = stmt.where(or_(Story.title.ilike(pat), Story.summary.ilike(pat)))
|
||
stmt = stmt.order_by(Story.updated_at.desc()).limit(limit)
|
||
result = session.execute(stmt)
|
||
return list(result.unique().scalars().all())
|