Files
life-echo/api/app/features/memory/repo.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

539 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Memory repository — MemorySource, MemoryChunk, and MemoryFact data access."""
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import cast, literal, or_, select, text, tuple_, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sqlalchemy.types import String as SqlString
from app.features.memory.models import (
MemoryChunk,
MemoryCurationAction,
MemoryFact,
MemorySource,
MemorySummary,
)
def _new_id() -> str:
return str(uuid.uuid4())
async def create_source(
db: AsyncSession,
*,
user_id: str,
source_type: str,
raw_text: str | None = None,
conversation_id: str | None = None,
segment_id: str | None = None,
captured_at: datetime | None = None,
lineage_json: dict | None = None,
primary_user_message_id: str | None = None,
) -> MemorySource:
"""Create a memory source. Caller must commit."""
source = MemorySource(
id=_new_id(),
user_id=user_id,
source_type=source_type,
raw_text=raw_text,
embedding_status="pending",
enrichment_status="pending",
conversation_id=conversation_id,
segment_id=segment_id,
lineage_json=lineage_json,
primary_user_message_id=primary_user_message_id,
captured_at=captured_at or datetime.now(timezone.utc),
)
db.add(source)
return source
async def get_transcript_source_by_segment_id(
db: AsyncSession,
*,
user_id: str,
segment_id: str,
) -> MemorySource | None:
stmt = select(MemorySource).where(
MemorySource.user_id == user_id,
MemorySource.segment_id == segment_id,
MemorySource.source_type == "transcript",
)
return (await db.execute(stmt)).scalar_one_or_none()
def get_transcript_source_by_segment_id_sync(
db: Session,
*,
user_id: str,
segment_id: str,
) -> MemorySource | None:
stmt = select(MemorySource).where(
MemorySource.user_id == user_id,
MemorySource.segment_id == segment_id,
MemorySource.source_type == "transcript",
)
return db.execute(stmt).scalar_one_or_none()
async def create_chunk(
db: AsyncSession,
*,
source_id: str,
user_id: str,
content: str,
chunk_index: int,
) -> MemoryChunk:
"""Create a memory chunk. Caller must commit."""
chunk = MemoryChunk(
id=_new_id(),
source_id=source_id,
user_id=user_id,
content=content,
chunk_index=chunk_index,
embedding_status="pending",
)
db.add(chunk)
return chunk
async def update_chunk_embedding(
db: AsyncSession, chunk_id: str, embedding: list[float]
) -> None:
"""Update chunk embedding. Caller must commit."""
chunk = await db.get(MemoryChunk, chunk_id)
if chunk:
chunk.embedding = embedding
chunk.embedding_status = "success"
chunk.embedding_error = None
async def set_chunk_embedding_status(
db: AsyncSession,
chunk_id: str,
*,
status: str,
error: str | None = None,
) -> bool:
chunk = await db.get(MemoryChunk, chunk_id)
if chunk is None:
return False
chunk.embedding_status = status
chunk.embedding_error = error
return True
async def set_source_embedding_status(
db: AsyncSession,
*,
source_id: str,
user_id: str,
status: str,
error: str | None = None,
) -> bool:
source = await db.get(MemorySource, source_id)
if source is None or source.user_id != user_id:
return False
source.embedding_status = status
source.embedding_error = error
return True
async def set_source_enrichment_status(
db: AsyncSession,
*,
source_id: str,
user_id: str,
status: str,
error: str | None = None,
) -> bool:
source = await db.get(MemorySource, source_id)
if source is None or source.user_id != user_id:
return False
source.enrichment_status = status
source.enrichment_error = error
return True
async def list_chunks_for_source(
db: AsyncSession,
*,
user_id: str,
source_id: str,
include_excluded: bool = True,
) -> list[MemoryChunk]:
stmt = (
select(MemoryChunk)
.where(MemoryChunk.user_id == user_id, MemoryChunk.source_id == source_id)
.order_by(MemoryChunk.chunk_index.asc(), MemoryChunk.id.asc())
)
if not include_excluded:
stmt = stmt.where(
or_(MemoryChunk.is_excluded.is_(False), MemoryChunk.is_excluded.is_(None))
)
result = await db.execute(stmt)
return list(result.unique().scalars().all())
async def get_chunks_by_ids(
db: AsyncSession, chunk_ids: list[str]
) -> list[MemoryChunk]:
"""Fetch chunks by IDs."""
if not chunk_ids:
return []
stmt = select(MemoryChunk).where(MemoryChunk.id.in_(chunk_ids))
result = await db.execute(stmt)
chunks = list(result.unique().scalars().all())
order = {cid: i for i, cid in enumerate(chunk_ids)}
return sorted(chunks, key=lambda c: order.get(c.id, 999))
async def get_facts_for_user(
db: AsyncSession, user_id: str, limit: int = 20
) -> list[MemoryFact]:
"""Fetch recent facts for user."""
stmt = (
select(MemoryFact)
.where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed")
.order_by(MemoryFact.created_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
return list(result.unique().scalars().all())
async def search_facts_for_user_async(
db: AsyncSession, user_id: str, query: str, limit: int = 20
) -> list[MemoryFact]:
q = (query or "").strip()
if not q:
return []
pat = f"%{q}%"
stmt = (
select(MemoryFact)
.where(
MemoryFact.user_id == user_id,
MemoryFact.status == "confirmed",
or_(
MemoryFact.subject.ilike(pat),
MemoryFact.predicate.ilike(pat),
cast(MemoryFact.object_json, SqlString).ilike(pat),
),
)
.order_by(MemoryFact.created_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
return list(result.unique().scalars().all())
async def mark_facts_stale_for_excluded_chunk(
db: AsyncSession, *, user_id: str, chunk_id: str
) -> int:
stmt = (
update(MemoryFact)
.where(
MemoryFact.user_id == user_id,
MemoryFact.source_chunk_id == chunk_id,
MemoryFact.status.in_(["confirmed", "candidate"]),
)
.values(status="stale")
)
res = await db.execute(stmt)
return int(res.rowcount or 0)
async def search_chunks_vector(
db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20
) -> list[dict]:
"""Vector similarity search. Returns list of {id, content, chunk_index, distance}."""
if not query_embedding:
return []
# pgvector cosine distance: 1 - cosine_similarity, lower is better
stmt = text("""
SELECT id, content, chunk_index,
(embedding <=> CAST(:emb AS vector)) AS distance
FROM memory_chunks
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
AND embedding IS NOT NULL
ORDER BY embedding <=> CAST(:emb2 AS vector)
LIMIT :lim
""")
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
result = await db.execute(
stmt,
{"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit},
)
rows = result.mappings().all()
return [
{
"id": r["id"],
"content": r["content"],
"chunk_index": r["chunk_index"],
"distance": float(r["distance"]),
}
for r in rows
]
async def list_users_with_recent_chunks(db: AsyncSession, *, hours: int) -> list[str]:
"""最近 N 小时内有新 chunk 的用户 idBeat compaction 扫描)。"""
if hours < 1:
hours = 1
cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
stmt = (
select(MemoryChunk.user_id).where(MemoryChunk.created_at >= cutoff).distinct()
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def list_storage_keys_for_conversation(
db: AsyncSession, conversation_id: str
) -> list[str]:
"""对话关联的 memory_sources 上记录的 COS object key若有"""
stmt = select(MemorySource.storage_key).where(
MemorySource.conversation_id == conversation_id,
MemorySource.storage_key.isnot(None),
)
result = await db.execute(stmt)
return sorted({r for r in result.scalars().all() if r})
async def create_memory_summary(
db: AsyncSession,
*,
user_id: str,
summary_type: str,
content: str,
source_chunk_ids: list[str] | None = None,
) -> MemorySummary:
row = MemorySummary(
id=_new_id(),
user_id=user_id,
summary_type=summary_type,
content=content,
source_chunk_ids=source_chunk_ids,
)
db.add(row)
return row
async def create_memory_fact(
db: AsyncSession,
*,
user_id: str,
fact_type: str,
subject: str | None,
predicate: str | None,
object_json: dict | None,
confidence: float,
source_chunk_id: str | None,
status: str = "confirmed",
lineage_json: dict | None = None,
) -> MemoryFact:
row = MemoryFact(
id=_new_id(),
user_id=user_id,
fact_type=fact_type,
subject=subject,
predicate=predicate,
object_json=object_json,
confidence=confidence,
source_chunk_id=source_chunk_id,
status=status,
lineage_json=lineage_json,
)
db.add(row)
return row
async def get_memory_fact_for_user(
db: AsyncSession, fact_id: str, user_id: str
) -> MemoryFact | None:
row = await db.get(MemoryFact, fact_id)
if row is None or row.user_id != user_id:
return None
return row
async def set_memory_fact_status(
db: AsyncSession, fact_id: str, user_id: str, status: str
) -> bool:
row = await get_memory_fact_for_user(db, fact_id, user_id)
if row is None:
return False
row.status = status
return True
async def create_curation_action(
db: AsyncSession,
*,
user_id: str,
action_type: str,
target_type: str,
target_id: str,
details: dict | None = None,
) -> MemoryCurationAction:
row = MemoryCurationAction(
id=_new_id(),
user_id=user_id,
action_type=action_type,
target_type=target_type,
target_id=target_id,
details=details,
)
db.add(row)
return row
async def get_memory_chunk_for_user(
db: AsyncSession, chunk_id: str, user_id: str
) -> MemoryChunk | None:
row = await db.get(MemoryChunk, chunk_id)
if row is None or row.user_id != user_id:
return None
return row
async def list_incremental_chunks_for_compaction(
db: AsyncSession,
*,
user_id: str,
after_cursor_ts: datetime,
after_chunk_id: str,
limit: int,
candidate_chunk_ids: list[str] | None = None,
candidate_source_ids: list[str] | None = None,
) -> list[MemoryChunk]:
stmt = (
select(MemoryChunk)
.where(
MemoryChunk.user_id == user_id,
tuple_(MemoryChunk.created_at, MemoryChunk.id)
> tuple_(literal(after_cursor_ts), literal(after_chunk_id)),
or_(MemoryChunk.is_excluded.is_(False), MemoryChunk.is_excluded.is_(None)),
)
.order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc())
.limit(limit)
)
if candidate_chunk_ids:
stmt = stmt.where(MemoryChunk.id.in_(candidate_chunk_ids))
if candidate_source_ids:
stmt = stmt.where(MemoryChunk.source_id.in_(candidate_source_ids))
result = await db.execute(stmt)
return list(result.unique().scalars().all())
async def get_first_chunk_after_cursor(
db: AsyncSession,
*,
user_id: str,
after_cursor_ts: datetime,
after_chunk_id: str,
) -> MemoryChunk | None:
stmt = (
select(MemoryChunk)
.where(
MemoryChunk.user_id == user_id,
tuple_(MemoryChunk.created_at, MemoryChunk.id)
> tuple_(literal(after_cursor_ts), literal(after_chunk_id)),
)
.order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc())
.limit(1)
)
result = await db.execute(stmt)
return result.scalars().first()
async def search_nearest_chunks_for_compaction(
db: AsyncSession,
*,
user_id: str,
chunk_id: str,
query_embedding: list[float],
limit: int,
) -> list[dict]:
if not query_embedding:
return []
stmt = text("""
SELECT mc.id, mc.content, mc.source_id, mc.event_year, mc.metadata_json,
ms.source_type, mc.created_at,
(mc.embedding <=> CAST(:emb AS vector)) AS distance
FROM memory_chunks mc
JOIN memory_sources ms ON ms.id = mc.source_id
WHERE mc.user_id = :user_id
AND (mc.is_excluded IS NOT TRUE OR mc.is_excluded = false)
AND mc.embedding IS NOT NULL
AND mc.id != :chunk_id
ORDER BY mc.embedding <=> CAST(:emb2 AS vector)
LIMIT :lim
""")
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
result = await db.execute(
stmt,
{
"user_id": user_id,
"chunk_id": chunk_id,
"emb": emb_str,
"emb2": emb_str,
"lim": limit,
},
)
return [
{
"id": r["id"],
"content": r["content"],
"source_id": r["source_id"],
"event_year": r["event_year"],
"metadata_json": r["metadata_json"],
"source_type": r["source_type"],
"created_at": r["created_at"],
"distance": float(r["distance"]),
}
for r in result.mappings().all()
]
async def set_chunk_excluded(
db: AsyncSession, chunk_id: str, user_id: str, excluded: bool
) -> bool:
row = await get_memory_chunk_for_user(db, chunk_id, user_id)
if row is None:
return False
row.is_excluded = excluded
return True
async def list_summaries_for_evidence_async(
db: AsyncSession, *, user_id: str, q: str, limit: int
) -> list[dict]:
if not (q or "").strip():
return []
pat = f"%{q.strip()}%"
stmt = (
select(MemorySummary)
.where(
MemorySummary.user_id == user_id,
MemorySummary.summary_type == "session",
MemorySummary.content.ilike(pat),
)
.order_by(MemorySummary.updated_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
rows = list(result.unique().scalars().all())
return [
{
"id": s.id,
"summary_type": s.summary_type,
"content": s.content,
"source_chunk_ids": s.source_chunk_ids,
}
for s in rows[:limit]
]