Chat 访谈 - 新增 persona 系统(default / warm_listener / curious_guide)与 background_voice 语气层 - 回复长度由 compute_reply_plan 统一决策(brief / standard / expanded),融合信息密度启发式 - 输入净稿(input_normalize):编排层可选 rules/llm 归一用户口语后再喂模型与记忆检索 - 记忆证据注入:按用户话检索 memory evidence 并注入 prompt Memoir 回忆录 - 口述归一(oral_normalize):segment 原文保留,story 管线取派生净稿作叙事输入 - segment 入队批次门闸:累计字数 + 最长等待秒数,减少零碎提交 - fidelity_check / prompts / narrative_agent 微调 - Alembic 0005:清理跨章节 story 外键 Infra - Dockerfile 加入 ffmpeg - pyproject.toml 新增依赖并同步 uv.lock - .env.example / .env.production 补全新配置项 Tests - 新增 test_background_voice、test_chat_input_normalize、test_experience_regressions - 扩展 test_interview_prompts、test_interview_reply_length、test_story_route_oral_invariant Made-with: Cursor
936 lines
26 KiB
Python
936 lines
26 KiB
Python
"""Memory repository — MemorySource, MemoryChunk, MemoryFact, TimelineEvent data access."""
|
||
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
|
||
from sqlalchemy import delete, literal, or_, select, text, tuple_
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.features.memory.models import (
|
||
MemoryChunk,
|
||
MemoryCurationAction,
|
||
MemoryFact,
|
||
MemorySource,
|
||
MemorySummary,
|
||
TimelineEvent,
|
||
)
|
||
|
||
|
||
def _new_id() -> str:
|
||
return str(uuid.uuid4())
|
||
|
||
|
||
def create_source_sync(
|
||
session: Session,
|
||
*,
|
||
user_id: str,
|
||
source_type: str,
|
||
raw_text: str | None = None,
|
||
conversation_id: str | None = None,
|
||
captured_at: datetime | None = None,
|
||
) -> MemorySource:
|
||
"""Create a memory source (sync). Caller must commit."""
|
||
source = MemorySource(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
source_type=source_type,
|
||
raw_text=raw_text,
|
||
conversation_id=conversation_id,
|
||
captured_at=captured_at or datetime.now(timezone.utc),
|
||
)
|
||
session.add(source)
|
||
return source
|
||
|
||
|
||
async def create_source(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
source_type: str,
|
||
raw_text: str | None = None,
|
||
conversation_id: str | None = None,
|
||
captured_at: datetime | None = None,
|
||
) -> MemorySource:
|
||
"""Create a memory source. Caller must commit."""
|
||
source = MemorySource(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
source_type=source_type,
|
||
raw_text=raw_text,
|
||
conversation_id=conversation_id,
|
||
captured_at=captured_at or datetime.now(timezone.utc),
|
||
)
|
||
db.add(source)
|
||
return source
|
||
|
||
|
||
def create_chunk_sync(
|
||
session: Session,
|
||
*,
|
||
source_id: str,
|
||
user_id: str,
|
||
content: str,
|
||
chunk_index: int,
|
||
) -> MemoryChunk:
|
||
"""Create a memory chunk (sync). Caller must commit."""
|
||
chunk = MemoryChunk(
|
||
id=_new_id(),
|
||
source_id=source_id,
|
||
user_id=user_id,
|
||
content=content,
|
||
chunk_index=chunk_index,
|
||
)
|
||
session.add(chunk)
|
||
return chunk
|
||
|
||
|
||
async def create_chunk(
|
||
db: AsyncSession,
|
||
*,
|
||
source_id: str,
|
||
user_id: str,
|
||
content: str,
|
||
chunk_index: int,
|
||
) -> MemoryChunk:
|
||
"""Create a memory chunk. Caller must commit."""
|
||
chunk = MemoryChunk(
|
||
id=_new_id(),
|
||
source_id=source_id,
|
||
user_id=user_id,
|
||
content=content,
|
||
chunk_index=chunk_index,
|
||
)
|
||
db.add(chunk)
|
||
return chunk
|
||
|
||
|
||
def update_chunk_fts_sync(session: Session, chunk_id: str) -> None:
|
||
"""Populate content_tsv for FTS (sync). Caller must commit."""
|
||
session.execute(
|
||
text(
|
||
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
|
||
),
|
||
{"id": chunk_id},
|
||
)
|
||
|
||
|
||
def update_chunk_embedding_sync(
|
||
session: Session, chunk_id: str, embedding: list[float]
|
||
) -> None:
|
||
"""Update chunk embedding (sync). Caller must commit."""
|
||
chunk = session.get(MemoryChunk, chunk_id)
|
||
if chunk:
|
||
chunk.embedding = embedding
|
||
|
||
|
||
async def update_chunk_embedding(
|
||
db: AsyncSession, chunk_id: str, embedding: list[float]
|
||
) -> None:
|
||
"""Update chunk embedding. Caller must commit."""
|
||
chunk = await db.get(MemoryChunk, chunk_id)
|
||
if chunk:
|
||
chunk.embedding = embedding
|
||
|
||
|
||
async def update_chunk_fts(db: AsyncSession, chunk_id: str) -> None:
|
||
"""Populate content_tsv for FTS. Caller must commit."""
|
||
await db.execute(
|
||
text(
|
||
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
|
||
),
|
||
{"id": chunk_id},
|
||
)
|
||
|
||
|
||
async def search_chunks_fts(
|
||
db: AsyncSession, user_id: str, query: str, limit: int = 20
|
||
) -> list[dict]:
|
||
"""FTS search on memory_chunks. Returns list of {id, content, chunk_index}."""
|
||
if not query or not query.strip():
|
||
return []
|
||
q = query.strip()
|
||
stmt = text("""
|
||
SELECT id, content, chunk_index
|
||
FROM memory_chunks
|
||
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
|
||
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
|
||
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
|
||
LIMIT :lim
|
||
""")
|
||
result = await db.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit})
|
||
rows = result.mappings().all()
|
||
return [
|
||
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
|
||
for r in rows
|
||
]
|
||
|
||
|
||
async def get_chunks_by_ids(
|
||
db: AsyncSession, chunk_ids: list[str]
|
||
) -> list[MemoryChunk]:
|
||
"""Fetch chunks by IDs."""
|
||
if not chunk_ids:
|
||
return []
|
||
stmt = select(MemoryChunk).where(MemoryChunk.id.in_(chunk_ids))
|
||
result = await db.execute(stmt)
|
||
chunks = list(result.unique().scalars().all())
|
||
order = {cid: i for i, cid in enumerate(chunk_ids)}
|
||
return sorted(chunks, key=lambda c: order.get(c.id, 999))
|
||
|
||
|
||
async def get_facts_for_user(
|
||
db: AsyncSession, user_id: str, limit: int = 20
|
||
) -> list[MemoryFact]:
|
||
"""Fetch recent facts for user."""
|
||
stmt = (
|
||
select(MemoryFact)
|
||
.where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed")
|
||
.order_by(MemoryFact.created_at.desc())
|
||
.limit(limit)
|
||
)
|
||
result = await db.execute(stmt)
|
||
return list(result.unique().scalars().all())
|
||
|
||
|
||
def get_facts_for_user_sync(
|
||
session: Session, user_id: str, limit: int = 20
|
||
) -> list[MemoryFact]:
|
||
stmt = (
|
||
select(MemoryFact)
|
||
.where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed")
|
||
.order_by(MemoryFact.created_at.desc())
|
||
.limit(limit)
|
||
)
|
||
return list(session.execute(stmt).unique().scalars().all())
|
||
|
||
|
||
def get_timeline_events_for_user_sync(
|
||
session: Session, user_id: str, limit: int = 20
|
||
) -> list[TimelineEvent]:
|
||
stmt = (
|
||
select(TimelineEvent)
|
||
.where(TimelineEvent.user_id == user_id)
|
||
.order_by(
|
||
TimelineEvent.event_year.desc().nullslast(), TimelineEvent.created_at.desc()
|
||
)
|
||
.limit(limit)
|
||
)
|
||
return list(session.execute(stmt).unique().scalars().all())
|
||
|
||
|
||
def search_chunks_fts_sync(
|
||
session: Session, user_id: str, query: str, limit: int = 20
|
||
) -> list[dict]:
|
||
"""FTS on memory_chunks(sync,Celery)。"""
|
||
if not query or not query.strip():
|
||
return []
|
||
q = query.strip()
|
||
stmt = text("""
|
||
SELECT id, content, chunk_index
|
||
FROM memory_chunks
|
||
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
|
||
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
|
||
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
|
||
LIMIT :lim
|
||
""")
|
||
result = session.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit})
|
||
rows = result.mappings().all()
|
||
return [
|
||
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
|
||
for r in rows
|
||
]
|
||
|
||
|
||
def search_facts_for_user_sync(
|
||
session: Session, user_id: str, query: str, limit: int = 20
|
||
) -> list[MemoryFact]:
|
||
q = (query or "").strip()
|
||
if not q:
|
||
return get_facts_for_user_sync(session, user_id, limit)
|
||
pat = f"%{q}%"
|
||
stmt = (
|
||
select(MemoryFact)
|
||
.where(
|
||
MemoryFact.user_id == user_id,
|
||
MemoryFact.status == "confirmed",
|
||
or_(MemoryFact.subject.ilike(pat), MemoryFact.predicate.ilike(pat)),
|
||
)
|
||
.order_by(MemoryFact.created_at.desc())
|
||
.limit(limit)
|
||
)
|
||
rows = list(session.execute(stmt).unique().scalars().all())
|
||
if rows:
|
||
return rows
|
||
return get_facts_for_user_sync(session, user_id, limit)
|
||
|
||
|
||
async def search_facts_for_user_async(
|
||
db: AsyncSession, user_id: str, query: str, limit: int = 20
|
||
) -> list[MemoryFact]:
|
||
q = (query or "").strip()
|
||
if not q:
|
||
return await get_facts_for_user(db, user_id=user_id, limit=limit)
|
||
pat = f"%{q}%"
|
||
stmt = (
|
||
select(MemoryFact)
|
||
.where(
|
||
MemoryFact.user_id == user_id,
|
||
MemoryFact.status == "confirmed",
|
||
or_(MemoryFact.subject.ilike(pat), MemoryFact.predicate.ilike(pat)),
|
||
)
|
||
.order_by(MemoryFact.created_at.desc())
|
||
.limit(limit)
|
||
)
|
||
result = await db.execute(stmt)
|
||
rows = list(result.unique().scalars().all())
|
||
if rows:
|
||
return rows
|
||
return await get_facts_for_user(db, user_id=user_id, limit=limit)
|
||
|
||
|
||
def search_timeline_events_for_user_sync(
|
||
session: Session, user_id: str, query: str, limit: int = 20
|
||
) -> list[TimelineEvent]:
|
||
q = (query or "").strip()
|
||
if not q:
|
||
return get_timeline_events_for_user_sync(session, user_id, limit)
|
||
pat = f"%{q}%"
|
||
stmt = (
|
||
select(TimelineEvent)
|
||
.where(
|
||
TimelineEvent.user_id == user_id,
|
||
or_(
|
||
TimelineEvent.title.ilike(pat),
|
||
TimelineEvent.description.ilike(pat),
|
||
),
|
||
)
|
||
.order_by(TimelineEvent.event_year.desc().nullslast())
|
||
.limit(limit)
|
||
)
|
||
rows = list(session.execute(stmt).unique().scalars().all())
|
||
if rows:
|
||
return rows
|
||
return get_timeline_events_for_user_sync(session, user_id, limit)
|
||
|
||
|
||
async def search_timeline_events_for_user_async(
|
||
db: AsyncSession, user_id: str, query: str, limit: int = 20
|
||
) -> list[TimelineEvent]:
|
||
q = (query or "").strip()
|
||
if not q:
|
||
return await get_timeline_events_for_user(db, user_id=user_id, limit=limit)
|
||
pat = f"%{q}%"
|
||
stmt = (
|
||
select(TimelineEvent)
|
||
.where(
|
||
TimelineEvent.user_id == user_id,
|
||
or_(
|
||
TimelineEvent.title.ilike(pat),
|
||
TimelineEvent.description.ilike(pat),
|
||
),
|
||
)
|
||
.order_by(TimelineEvent.event_year.desc().nullslast())
|
||
.limit(limit)
|
||
)
|
||
result = await db.execute(stmt)
|
||
rows = list(result.unique().scalars().all())
|
||
if rows:
|
||
return rows
|
||
return await get_timeline_events_for_user(db, user_id=user_id, limit=limit)
|
||
|
||
|
||
async def search_chunks_vector(
|
||
db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20
|
||
) -> list[dict]:
|
||
"""Vector similarity search. Returns list of {id, content, chunk_index, distance}."""
|
||
if not query_embedding:
|
||
return []
|
||
# pgvector cosine distance: 1 - cosine_similarity, lower is better
|
||
stmt = text("""
|
||
SELECT id, content, chunk_index,
|
||
(embedding <=> CAST(:emb AS vector)) AS distance
|
||
FROM memory_chunks
|
||
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
|
||
AND embedding IS NOT NULL
|
||
ORDER BY embedding <=> CAST(:emb2 AS vector)
|
||
LIMIT :lim
|
||
""")
|
||
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
||
result = await db.execute(
|
||
stmt,
|
||
{"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit},
|
||
)
|
||
rows = result.mappings().all()
|
||
return [
|
||
{
|
||
"id": r["id"],
|
||
"content": r["content"],
|
||
"chunk_index": r["chunk_index"],
|
||
"distance": float(r["distance"]),
|
||
}
|
||
for r in rows
|
||
]
|
||
|
||
|
||
def list_summaries_for_evidence_sync(
|
||
session: Session, *, user_id: str, q: str, limit: int
|
||
) -> list[dict]:
|
||
"""最新 rolling + 内容匹配 query 的摘要(ILIKE)。"""
|
||
pat = f"%{q}%"
|
||
rolling = (
|
||
session.execute(
|
||
select(MemorySummary)
|
||
.where(
|
||
MemorySummary.user_id == user_id,
|
||
MemorySummary.summary_type == "rolling",
|
||
)
|
||
.order_by(MemorySummary.updated_at.desc())
|
||
.limit(1)
|
||
)
|
||
.unique()
|
||
.scalar_one_or_none()
|
||
)
|
||
rows: list[MemorySummary] = []
|
||
seen: set[str] = set()
|
||
if rolling:
|
||
rows.append(rolling)
|
||
seen.add(rolling.id)
|
||
rest = limit - len(rows)
|
||
if rest > 0:
|
||
stmt = (
|
||
select(MemorySummary)
|
||
.where(
|
||
MemorySummary.user_id == user_id,
|
||
MemorySummary.content.ilike(pat),
|
||
)
|
||
.order_by(MemorySummary.updated_at.desc())
|
||
.limit(rest + len(seen))
|
||
)
|
||
for s in session.execute(stmt).unique().scalars().all():
|
||
if s.id not in seen:
|
||
rows.append(s)
|
||
seen.add(s.id)
|
||
if len(rows) >= limit:
|
||
break
|
||
return [
|
||
{
|
||
"id": s.id,
|
||
"summary_type": s.summary_type,
|
||
"content": s.content,
|
||
"source_chunk_ids": s.source_chunk_ids,
|
||
}
|
||
for s in rows[:limit]
|
||
]
|
||
|
||
|
||
def retrieve_evidence_sync(
|
||
session: Session, user_id: str, query: str, *, top_k: int = 10
|
||
) -> dict:
|
||
"""
|
||
Sync evidence retrieval for Celery tasks.
|
||
|
||
能力:**仅 FTS** 检索 chunks(与 `HybridRetriever` 的 FTS+向量 RRF 不同,见
|
||
`api/docs/memory-retrieval.md`);facts/timeline 按 query ILIKE;fallback 见 repo。
|
||
"""
|
||
from app.features.memory.evidence import retrieve_evidence_bundle_sync
|
||
|
||
return retrieve_evidence_bundle_sync(session, user_id, query, top_k=top_k)
|
||
|
||
|
||
async def get_timeline_events_for_user(
|
||
db: AsyncSession, user_id: str, limit: int = 20
|
||
) -> list[TimelineEvent]:
|
||
"""Fetch timeline events for user."""
|
||
stmt = (
|
||
select(TimelineEvent)
|
||
.where(TimelineEvent.user_id == user_id)
|
||
.order_by(
|
||
TimelineEvent.event_year.desc().nullslast(), TimelineEvent.created_at.desc()
|
||
)
|
||
.limit(limit)
|
||
)
|
||
result = await db.execute(stmt)
|
||
return list(result.unique().scalars().all())
|
||
|
||
|
||
async def list_storage_keys_for_conversation(
|
||
db: AsyncSession, conversation_id: str
|
||
) -> list[str]:
|
||
"""对话关联的 memory_sources 上记录的 COS object key(若有)。"""
|
||
stmt = select(MemorySource.storage_key).where(
|
||
MemorySource.conversation_id == conversation_id,
|
||
MemorySource.storage_key.isnot(None),
|
||
)
|
||
result = await db.execute(stmt)
|
||
return sorted({r for r in result.scalars().all() if r})
|
||
|
||
|
||
def list_chunks_for_source_sync(session: Session, source_id: str) -> list[MemoryChunk]:
|
||
stmt = (
|
||
select(MemoryChunk)
|
||
.where(MemoryChunk.source_id == source_id)
|
||
.order_by(MemoryChunk.chunk_index.asc())
|
||
)
|
||
return list(session.execute(stmt).unique().scalars().all())
|
||
|
||
|
||
def create_memory_summary_sync(
|
||
session: Session,
|
||
*,
|
||
user_id: str,
|
||
summary_type: str,
|
||
content: str,
|
||
source_chunk_ids: list[str] | None = None,
|
||
) -> MemorySummary:
|
||
row = MemorySummary(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
summary_type=summary_type,
|
||
content=content,
|
||
source_chunk_ids=source_chunk_ids,
|
||
)
|
||
session.add(row)
|
||
return row
|
||
|
||
|
||
async def create_memory_summary(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
summary_type: str,
|
||
content: str,
|
||
source_chunk_ids: list[str] | None = None,
|
||
) -> MemorySummary:
|
||
row = MemorySummary(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
summary_type=summary_type,
|
||
content=content,
|
||
source_chunk_ids=source_chunk_ids,
|
||
)
|
||
db.add(row)
|
||
return row
|
||
|
||
|
||
def get_latest_rolling_summary_sync(
|
||
session: Session, user_id: str
|
||
) -> MemorySummary | None:
|
||
stmt = (
|
||
select(MemorySummary)
|
||
.where(
|
||
MemorySummary.user_id == user_id,
|
||
MemorySummary.summary_type == "rolling",
|
||
)
|
||
.order_by(MemorySummary.updated_at.desc())
|
||
.limit(1)
|
||
)
|
||
return session.execute(stmt).unique().scalar_one_or_none()
|
||
|
||
|
||
def upsert_rolling_summary_sync(
|
||
session: Session,
|
||
*,
|
||
user_id: str,
|
||
content: str,
|
||
source_chunk_ids: list[str] | None = None,
|
||
) -> MemorySummary:
|
||
existing = get_latest_rolling_summary_sync(session, user_id)
|
||
if existing:
|
||
existing.content = content
|
||
if source_chunk_ids is not None:
|
||
existing.source_chunk_ids = source_chunk_ids
|
||
return existing
|
||
return create_memory_summary_sync(
|
||
session,
|
||
user_id=user_id,
|
||
summary_type="rolling",
|
||
content=content,
|
||
source_chunk_ids=source_chunk_ids,
|
||
)
|
||
|
||
|
||
def create_memory_fact_sync(
|
||
session: Session,
|
||
*,
|
||
user_id: str,
|
||
fact_type: str,
|
||
subject: str | None,
|
||
predicate: str | None,
|
||
object_json: dict | None,
|
||
confidence: float,
|
||
source_chunk_id: str | None,
|
||
status: str = "confirmed",
|
||
) -> MemoryFact:
|
||
row = MemoryFact(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
fact_type=fact_type,
|
||
subject=subject,
|
||
predicate=predicate,
|
||
object_json=object_json,
|
||
confidence=confidence,
|
||
source_chunk_id=source_chunk_id,
|
||
status=status,
|
||
)
|
||
session.add(row)
|
||
return row
|
||
|
||
|
||
async def create_memory_fact(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
fact_type: str,
|
||
subject: str | None,
|
||
predicate: str | None,
|
||
object_json: dict | None,
|
||
confidence: float,
|
||
source_chunk_id: str | None,
|
||
status: str = "confirmed",
|
||
) -> MemoryFact:
|
||
row = MemoryFact(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
fact_type=fact_type,
|
||
subject=subject,
|
||
predicate=predicate,
|
||
object_json=object_json,
|
||
confidence=confidence,
|
||
source_chunk_id=source_chunk_id,
|
||
status=status,
|
||
)
|
||
db.add(row)
|
||
return row
|
||
|
||
|
||
async def get_memory_fact_for_user(
|
||
db: AsyncSession, fact_id: str, user_id: str
|
||
) -> MemoryFact | None:
|
||
row = await db.get(MemoryFact, fact_id)
|
||
if row is None or row.user_id != user_id:
|
||
return None
|
||
return row
|
||
|
||
|
||
async def set_memory_fact_status(
|
||
db: AsyncSession, fact_id: str, user_id: str, status: str
|
||
) -> bool:
|
||
row = await get_memory_fact_for_user(db, fact_id, user_id)
|
||
if row is None:
|
||
return False
|
||
row.status = status
|
||
return True
|
||
|
||
|
||
def delete_timeline_events_by_memory_source_sync(
|
||
session: Session, *, user_id: str, memory_source_id: str
|
||
) -> int:
|
||
stmt = delete(TimelineEvent).where(
|
||
TimelineEvent.user_id == user_id,
|
||
TimelineEvent.memory_source_id == memory_source_id,
|
||
)
|
||
result = session.execute(stmt)
|
||
return result.rowcount or 0
|
||
|
||
|
||
async def delete_timeline_events_by_memory_source(
|
||
db: AsyncSession, *, user_id: str, memory_source_id: str
|
||
) -> int:
|
||
stmt = delete(TimelineEvent).where(
|
||
TimelineEvent.user_id == user_id,
|
||
TimelineEvent.memory_source_id == memory_source_id,
|
||
)
|
||
result = await db.execute(stmt)
|
||
return result.rowcount or 0
|
||
|
||
|
||
def create_timeline_event_sync(
|
||
session: Session,
|
||
*,
|
||
user_id: str,
|
||
event_year: int | None,
|
||
event_date: str | None,
|
||
title: str,
|
||
description: str | None,
|
||
person_refs: list | None = None,
|
||
source_fact_ids: list[str] | None = None,
|
||
memory_source_id: str | None = None,
|
||
) -> TimelineEvent:
|
||
row = TimelineEvent(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
memory_source_id=memory_source_id,
|
||
event_year=event_year,
|
||
event_date=event_date,
|
||
title=title,
|
||
description=description,
|
||
person_refs=person_refs,
|
||
source_fact_ids=source_fact_ids,
|
||
)
|
||
session.add(row)
|
||
return row
|
||
|
||
|
||
async def create_timeline_event(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
event_year: int | None,
|
||
event_date: str | None,
|
||
title: str,
|
||
description: str | None,
|
||
person_refs: list | None = None,
|
||
source_fact_ids: list[str] | None = None,
|
||
memory_source_id: str | None = None,
|
||
) -> TimelineEvent:
|
||
row = TimelineEvent(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
memory_source_id=memory_source_id,
|
||
event_year=event_year,
|
||
event_date=event_date,
|
||
title=title,
|
||
description=description,
|
||
person_refs=person_refs,
|
||
source_fact_ids=source_fact_ids,
|
||
)
|
||
db.add(row)
|
||
return row
|
||
|
||
|
||
def create_curation_action_sync(
|
||
session: Session,
|
||
*,
|
||
user_id: str,
|
||
action_type: str,
|
||
target_type: str,
|
||
target_id: str,
|
||
details: dict | None = None,
|
||
) -> MemoryCurationAction:
|
||
row = MemoryCurationAction(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
action_type=action_type,
|
||
target_type=target_type,
|
||
target_id=target_id,
|
||
details=details,
|
||
)
|
||
session.add(row)
|
||
return row
|
||
|
||
|
||
async def create_curation_action(
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: str,
|
||
action_type: str,
|
||
target_type: str,
|
||
target_id: str,
|
||
details: dict | None = None,
|
||
) -> MemoryCurationAction:
|
||
row = MemoryCurationAction(
|
||
id=_new_id(),
|
||
user_id=user_id,
|
||
action_type=action_type,
|
||
target_type=target_type,
|
||
target_id=target_id,
|
||
details=details,
|
||
)
|
||
db.add(row)
|
||
return row
|
||
|
||
|
||
async def get_memory_chunk_for_user(
|
||
db: AsyncSession, chunk_id: str, user_id: str
|
||
) -> MemoryChunk | None:
|
||
row = await db.get(MemoryChunk, chunk_id)
|
||
if row is None or row.user_id != user_id:
|
||
return None
|
||
return row
|
||
|
||
|
||
def get_memory_chunk_sync(
|
||
session: Session, chunk_id: str, user_id: str
|
||
) -> MemoryChunk | None:
|
||
row = session.get(MemoryChunk, chunk_id)
|
||
if row is None or row.user_id != user_id:
|
||
return None
|
||
return row
|
||
|
||
|
||
def set_chunk_excluded_sync(
|
||
session: Session, chunk_id: str, user_id: str, excluded: bool
|
||
) -> bool:
|
||
row = get_memory_chunk_sync(session, chunk_id, user_id)
|
||
if row is None:
|
||
return False
|
||
row.is_excluded = excluded
|
||
return True
|
||
|
||
|
||
def list_incremental_chunks_for_compaction_sync(
|
||
session: Session,
|
||
*,
|
||
user_id: str,
|
||
after_cursor_ts: datetime,
|
||
after_chunk_id: str,
|
||
limit: int,
|
||
candidate_chunk_ids: list[str] | None = None,
|
||
candidate_source_ids: list[str] | None = None,
|
||
) -> list[MemoryChunk]:
|
||
"""增量 chunk:(created_at, id) 字典序大于游标;可选与候选 id/source 求交。"""
|
||
stmt = (
|
||
select(MemoryChunk)
|
||
.where(
|
||
MemoryChunk.user_id == user_id,
|
||
tuple_(MemoryChunk.created_at, MemoryChunk.id)
|
||
> tuple_(literal(after_cursor_ts), literal(after_chunk_id)),
|
||
or_(MemoryChunk.is_excluded.is_(False), MemoryChunk.is_excluded.is_(None)),
|
||
)
|
||
.order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc())
|
||
.limit(limit)
|
||
)
|
||
if candidate_chunk_ids:
|
||
stmt = stmt.where(MemoryChunk.id.in_(candidate_chunk_ids))
|
||
if candidate_source_ids:
|
||
stmt = stmt.where(MemoryChunk.source_id.in_(candidate_source_ids))
|
||
rows = session.execute(stmt).unique().scalars().all()
|
||
return list(rows)
|
||
|
||
|
||
def get_first_chunk_after_cursor_sync(
|
||
session: Session,
|
||
*,
|
||
user_id: str,
|
||
after_cursor_ts: datetime,
|
||
after_chunk_id: str,
|
||
) -> MemoryChunk | None:
|
||
"""游标之后字典序第一条 chunk(含 excluded),用于空增量时推进游标。"""
|
||
stmt = (
|
||
select(MemoryChunk)
|
||
.where(
|
||
MemoryChunk.user_id == user_id,
|
||
tuple_(MemoryChunk.created_at, MemoryChunk.id)
|
||
> tuple_(literal(after_cursor_ts), literal(after_chunk_id)),
|
||
)
|
||
.order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc())
|
||
.limit(1)
|
||
)
|
||
return session.execute(stmt).scalars().first()
|
||
|
||
|
||
def search_nearest_chunks_for_compaction_sync(
|
||
session: Session,
|
||
*,
|
||
user_id: str,
|
||
chunk_id: str,
|
||
query_embedding: list[float],
|
||
limit: int,
|
||
) -> list[dict]:
|
||
"""
|
||
按余弦距离取 Top-K 近邻(不含自身)。pgvector `<=>` 为 cosine distance。
|
||
返回 dict: id, content, source_id, event_year, metadata_json, source_type,
|
||
distance, created_at
|
||
"""
|
||
if not query_embedding:
|
||
return []
|
||
stmt = text("""
|
||
SELECT mc.id, mc.content, mc.source_id, mc.event_year, mc.metadata_json,
|
||
ms.source_type, mc.created_at,
|
||
(mc.embedding <=> CAST(:emb AS vector)) AS distance
|
||
FROM memory_chunks mc
|
||
JOIN memory_sources ms ON ms.id = mc.source_id
|
||
WHERE mc.user_id = :user_id
|
||
AND (mc.is_excluded IS NOT TRUE OR mc.is_excluded = false)
|
||
AND mc.embedding IS NOT NULL
|
||
AND mc.id != :chunk_id
|
||
ORDER BY mc.embedding <=> CAST(:emb2 AS vector)
|
||
LIMIT :lim
|
||
""")
|
||
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
||
result = session.execute(
|
||
stmt,
|
||
{
|
||
"user_id": user_id,
|
||
"chunk_id": chunk_id,
|
||
"emb": emb_str,
|
||
"emb2": emb_str,
|
||
"lim": limit,
|
||
},
|
||
)
|
||
return [
|
||
{
|
||
"id": r["id"],
|
||
"content": r["content"],
|
||
"source_id": r["source_id"],
|
||
"event_year": r["event_year"],
|
||
"metadata_json": r["metadata_json"],
|
||
"source_type": r["source_type"],
|
||
"created_at": r["created_at"],
|
||
"distance": float(r["distance"]),
|
||
}
|
||
for r in result.mappings().all()
|
||
]
|
||
|
||
|
||
async def set_chunk_excluded(
|
||
db: AsyncSession, chunk_id: str, user_id: str, excluded: bool
|
||
) -> bool:
|
||
row = await get_memory_chunk_for_user(db, chunk_id, user_id)
|
||
if row is None:
|
||
return False
|
||
row.is_excluded = excluded
|
||
return True
|
||
|
||
|
||
async def list_summaries_for_evidence_async(
|
||
db: AsyncSession, *, user_id: str, q: str, limit: int
|
||
) -> list[dict]:
|
||
if not (q or "").strip():
|
||
return []
|
||
pat = f"%{q.strip()}%"
|
||
rolling_stmt = (
|
||
select(MemorySummary)
|
||
.where(
|
||
MemorySummary.user_id == user_id,
|
||
MemorySummary.summary_type == "rolling",
|
||
)
|
||
.order_by(MemorySummary.updated_at.desc())
|
||
.limit(1)
|
||
)
|
||
r_result = await db.execute(rolling_stmt)
|
||
rolling = r_result.unique().scalar_one_or_none()
|
||
rows: list[MemorySummary] = []
|
||
seen: set[str] = set()
|
||
if rolling:
|
||
rows.append(rolling)
|
||
seen.add(rolling.id)
|
||
rest = limit - len(rows)
|
||
if rest > 0:
|
||
stmt = (
|
||
select(MemorySummary)
|
||
.where(
|
||
MemorySummary.user_id == user_id,
|
||
MemorySummary.content.ilike(pat),
|
||
)
|
||
.order_by(MemorySummary.updated_at.desc())
|
||
.limit(rest + len(seen))
|
||
)
|
||
o_result = await db.execute(stmt)
|
||
for s in o_result.unique().scalars().all():
|
||
if s.id not in seen:
|
||
rows.append(s)
|
||
seen.add(s.id)
|
||
if len(rows) >= limit:
|
||
break
|
||
return [
|
||
{
|
||
"id": s.id,
|
||
"summary_type": s.summary_type,
|
||
"content": s.content,
|
||
"source_chunk_ids": s.source_chunk_ids,
|
||
}
|
||
for s in rows[:limit]
|
||
]
|