Files
life-echo/api/app/features/memory/repo.py
Kevin a3f61fcc0f feat(api+app): 对话阶段化、回忆录流水线与客户端会话体验
- DB: segments 用户输入文本(Alembic 0002)
- Chat: 阶段检测/阶段提示/回复限制,编排与访谈/画像 prompts 调整
- Memoir: 忠实度检查 agent,叙事与分类等链路更新
- Core: agent 日志、Alembic 启动、LangChain/日志/配置等
- Story: time_hints;Memory 检索与相关测试
- Expo: 助手头像、会话页与消息拆分、实时会话与文案/i18n
- Docs/scripts/tests: 迁移脚本、LLM JSON/记忆检索文档、新增单测
2026-03-26 12:13:36 +08:00

323 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Memory repository — MemorySource, MemoryChunk, MemoryFact, TimelineEvent data access."""
import uuid
from datetime import datetime, timezone
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.features.memory.models import (
MemoryChunk,
MemoryFact,
MemorySource,
TimelineEvent,
)
def _new_id() -> str:
return str(uuid.uuid4())
def create_source_sync(
session: Session,
*,
user_id: str,
source_type: str,
raw_text: str | None = None,
conversation_id: str | None = None,
captured_at: datetime | None = None,
) -> MemorySource:
"""Create a memory source (sync). Caller must commit."""
source = MemorySource(
id=_new_id(),
user_id=user_id,
source_type=source_type,
raw_text=raw_text,
conversation_id=conversation_id,
captured_at=captured_at or datetime.now(timezone.utc),
)
session.add(source)
return source
async def create_source(
db: AsyncSession,
*,
user_id: str,
source_type: str,
raw_text: str | None = None,
conversation_id: str | None = None,
captured_at: datetime | None = None,
) -> MemorySource:
"""Create a memory source. Caller must commit."""
source = MemorySource(
id=_new_id(),
user_id=user_id,
source_type=source_type,
raw_text=raw_text,
conversation_id=conversation_id,
captured_at=captured_at or datetime.now(timezone.utc),
)
db.add(source)
return source
def create_chunk_sync(
session: Session,
*,
source_id: str,
user_id: str,
content: str,
chunk_index: int,
) -> MemoryChunk:
"""Create a memory chunk (sync). Caller must commit."""
chunk = MemoryChunk(
id=_new_id(),
source_id=source_id,
user_id=user_id,
content=content,
chunk_index=chunk_index,
)
session.add(chunk)
return chunk
async def create_chunk(
db: AsyncSession,
*,
source_id: str,
user_id: str,
content: str,
chunk_index: int,
) -> MemoryChunk:
"""Create a memory chunk. Caller must commit."""
chunk = MemoryChunk(
id=_new_id(),
source_id=source_id,
user_id=user_id,
content=content,
chunk_index=chunk_index,
)
db.add(chunk)
return chunk
def update_chunk_fts_sync(session: Session, chunk_id: str) -> None:
"""Populate content_tsv for FTS (sync). Caller must commit."""
session.execute(
text(
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
),
{"id": chunk_id},
)
async def update_chunk_embedding(
db: AsyncSession, chunk_id: str, embedding: list[float]
) -> None:
"""Update chunk embedding. Caller must commit."""
chunk = await db.get(MemoryChunk, chunk_id)
if chunk:
chunk.embedding = embedding
async def update_chunk_fts(db: AsyncSession, chunk_id: str) -> None:
"""Populate content_tsv for FTS. Caller must commit."""
await db.execute(
text(
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
),
{"id": chunk_id},
)
async def search_chunks_fts(
db: AsyncSession, user_id: str, query: str, limit: int = 20
) -> list[dict]:
"""FTS search on memory_chunks. Returns list of {id, content, chunk_index}."""
if not query or not query.strip():
return []
q = query.strip()
stmt = text("""
SELECT id, content, chunk_index
FROM memory_chunks
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
LIMIT :lim
""")
result = await db.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit})
rows = result.mappings().all()
return [
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
for r in rows
]
async def get_chunks_by_ids(
db: AsyncSession, chunk_ids: list[str]
) -> list[MemoryChunk]:
"""Fetch chunks by IDs."""
if not chunk_ids:
return []
stmt = select(MemoryChunk).where(MemoryChunk.id.in_(chunk_ids))
result = await db.execute(stmt)
chunks = list(result.unique().scalars().all())
order = {cid: i for i, cid in enumerate(chunk_ids)}
return sorted(chunks, key=lambda c: order.get(c.id, 999))
async def get_facts_for_user(
db: AsyncSession, user_id: str, limit: int = 20
) -> list[MemoryFact]:
"""Fetch recent facts for user."""
stmt = (
select(MemoryFact)
.where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed")
.order_by(MemoryFact.created_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
return list(result.unique().scalars().all())
async def search_chunks_vector(
db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20
) -> list[dict]:
"""Vector similarity search. Returns list of {id, content, chunk_index, distance}."""
if not query_embedding:
return []
# pgvector cosine distance: 1 - cosine_similarity, lower is better
stmt = text("""
SELECT id, content, chunk_index,
(embedding <=> :emb::vector) AS distance
FROM memory_chunks
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
AND embedding IS NOT NULL
ORDER BY embedding <=> :emb2::vector
LIMIT :lim
""")
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
result = await db.execute(
stmt,
{"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit},
)
rows = result.mappings().all()
return [
{
"id": r["id"],
"content": r["content"],
"chunk_index": r["chunk_index"],
"distance": float(r["distance"]),
}
for r in rows
]
def retrieve_evidence_sync(
session: Session, user_id: str, query: str, *, top_k: int = 10
) -> dict:
"""
Sync evidence retrieval for Celery tasks.
能力:**仅 FTS** 检索 chunks与 `HybridRetriever` 的 FTS+向量 RRF 不同,见
`api/docs/memory-retrieval.md`confirmed factstimeline。
ingest 在 Celery 任务内先于 story 流水线执行并已 commit故本检索可见刚写入的 chunk。
"""
if not query or not query.strip():
return {
"relevant_chunks": [],
"relevant_summaries": [],
"relevant_facts": [],
"timeline_hints": [],
"relevant_stories": [],
}
q = query.strip()
# FTS chunks
stmt = text("""
SELECT id, content, chunk_index
FROM memory_chunks
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
LIMIT :lim
""")
result = session.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": top_k})
rows = result.mappings().all()
relevant_chunks = [
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
for r in rows
]
# Facts
facts_stmt = (
select(MemoryFact)
.where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed")
.order_by(MemoryFact.created_at.desc())
.limit(top_k)
)
facts = list(session.execute(facts_stmt).unique().scalars().all())
relevant_facts = [
{
"id": f.id,
"fact_type": f.fact_type,
"subject": f.subject,
"predicate": f.predicate,
"object_json": f.object_json,
}
for f in facts
]
# Timeline
events_stmt = (
select(TimelineEvent)
.where(TimelineEvent.user_id == user_id)
.order_by(TimelineEvent.event_year.desc().nullslast())
.limit(top_k)
)
events = list(session.execute(events_stmt).unique().scalars().all())
timeline_hints = [
{
"id": e.id,
"event_year": e.event_year,
"event_date": e.event_date,
"title": e.title,
"description": e.description,
}
for e in events
]
return {
"relevant_chunks": relevant_chunks,
"relevant_summaries": [],
"relevant_facts": relevant_facts,
"timeline_hints": timeline_hints,
"relevant_stories": [],
}
async def get_timeline_events_for_user(
db: AsyncSession, user_id: str, limit: int = 20
) -> list[TimelineEvent]:
"""Fetch timeline events for user."""
stmt = (
select(TimelineEvent)
.where(TimelineEvent.user_id == user_id)
.order_by(
TimelineEvent.event_year.desc().nullslast(), TimelineEvent.created_at.desc()
)
.limit(limit)
)
result = await db.execute(stmt)
return list(result.unique().scalars().all())
async def list_storage_keys_for_conversation(
db: AsyncSession, conversation_id: str
) -> list[str]:
"""对话关联的 memory_sources 上记录的 COS object key若有"""
stmt = select(MemorySource.storage_key).where(
MemorySource.conversation_id == conversation_id,
MemorySource.storage_key.isnot(None),
)
result = await db.execute(stmt)
return sorted({r for r in result.scalars().all() if r})