Files
life-echo/api/app/features/memory/repo.py
2026-03-20 15:15:35 +08:00

320 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Memory repository — MemorySource, MemoryChunk, MemoryFact, TimelineEvent data access."""
import uuid
from datetime import datetime, timezone
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.features.memory.models import (
MemoryChunk,
MemoryFact,
MemorySource,
TimelineEvent,
)
def _new_id() -> str:
return str(uuid.uuid4())
def create_source_sync(
session: Session,
*,
user_id: str,
source_type: str,
raw_text: str | None = None,
conversation_id: str | None = None,
captured_at: datetime | None = None,
) -> MemorySource:
"""Create a memory source (sync). Caller must commit."""
source = MemorySource(
id=_new_id(),
user_id=user_id,
source_type=source_type,
raw_text=raw_text,
conversation_id=conversation_id,
captured_at=captured_at or datetime.now(timezone.utc),
)
session.add(source)
return source
async def create_source(
db: AsyncSession,
*,
user_id: str,
source_type: str,
raw_text: str | None = None,
conversation_id: str | None = None,
captured_at: datetime | None = None,
) -> MemorySource:
"""Create a memory source. Caller must commit."""
source = MemorySource(
id=_new_id(),
user_id=user_id,
source_type=source_type,
raw_text=raw_text,
conversation_id=conversation_id,
captured_at=captured_at or datetime.now(timezone.utc),
)
db.add(source)
return source
def create_chunk_sync(
session: Session,
*,
source_id: str,
user_id: str,
content: str,
chunk_index: int,
) -> MemoryChunk:
"""Create a memory chunk (sync). Caller must commit."""
chunk = MemoryChunk(
id=_new_id(),
source_id=source_id,
user_id=user_id,
content=content,
chunk_index=chunk_index,
)
session.add(chunk)
return chunk
async def create_chunk(
db: AsyncSession,
*,
source_id: str,
user_id: str,
content: str,
chunk_index: int,
) -> MemoryChunk:
"""Create a memory chunk. Caller must commit."""
chunk = MemoryChunk(
id=_new_id(),
source_id=source_id,
user_id=user_id,
content=content,
chunk_index=chunk_index,
)
db.add(chunk)
return chunk
def update_chunk_fts_sync(session: Session, chunk_id: str) -> None:
"""Populate content_tsv for FTS (sync). Caller must commit."""
session.execute(
text(
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
),
{"id": chunk_id},
)
async def update_chunk_embedding(
db: AsyncSession, chunk_id: str, embedding: list[float]
) -> None:
"""Update chunk embedding. Caller must commit."""
chunk = await db.get(MemoryChunk, chunk_id)
if chunk:
chunk.embedding = embedding
async def update_chunk_fts(db: AsyncSession, chunk_id: str) -> None:
"""Populate content_tsv for FTS. Caller must commit."""
await db.execute(
text(
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
),
{"id": chunk_id},
)
async def search_chunks_fts(
db: AsyncSession, user_id: str, query: str, limit: int = 20
) -> list[dict]:
"""FTS search on memory_chunks. Returns list of {id, content, chunk_index}."""
if not query or not query.strip():
return []
q = query.strip()
stmt = text("""
SELECT id, content, chunk_index
FROM memory_chunks
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
LIMIT :lim
""")
result = await db.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit})
rows = result.mappings().all()
return [
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
for r in rows
]
async def get_chunks_by_ids(
db: AsyncSession, chunk_ids: list[str]
) -> list[MemoryChunk]:
"""Fetch chunks by IDs."""
if not chunk_ids:
return []
stmt = select(MemoryChunk).where(MemoryChunk.id.in_(chunk_ids))
result = await db.execute(stmt)
chunks = list(result.unique().scalars().all())
order = {cid: i for i, cid in enumerate(chunk_ids)}
return sorted(chunks, key=lambda c: order.get(c.id, 999))
async def get_facts_for_user(
db: AsyncSession, user_id: str, limit: int = 20
) -> list[MemoryFact]:
"""Fetch recent facts for user."""
stmt = (
select(MemoryFact)
.where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed")
.order_by(MemoryFact.created_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
return list(result.unique().scalars().all())
async def search_chunks_vector(
db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20
) -> list[dict]:
"""Vector similarity search. Returns list of {id, content, chunk_index, distance}."""
if not query_embedding:
return []
# pgvector cosine distance: 1 - cosine_similarity, lower is better
stmt = text("""
SELECT id, content, chunk_index,
(embedding <=> :emb::vector) AS distance
FROM memory_chunks
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
AND embedding IS NOT NULL
ORDER BY embedding <=> :emb2::vector
LIMIT :lim
""")
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
result = await db.execute(
stmt,
{"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit},
)
rows = result.mappings().all()
return [
{
"id": r["id"],
"content": r["content"],
"chunk_index": r["chunk_index"],
"distance": float(r["distance"]),
}
for r in rows
]
def retrieve_evidence_sync(
session: Session, user_id: str, query: str, *, top_k: int = 10
) -> dict:
"""
Sync evidence retrieval for Celery tasks.
FTS only (no vector), returns evidence bundle.
"""
if not query or not query.strip():
return {
"relevant_chunks": [],
"relevant_summaries": [],
"relevant_facts": [],
"timeline_hints": [],
"relevant_stories": [],
}
q = query.strip()
# FTS chunks
stmt = text("""
SELECT id, content, chunk_index
FROM memory_chunks
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
LIMIT :lim
""")
result = session.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": top_k})
rows = result.mappings().all()
relevant_chunks = [
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
for r in rows
]
# Facts
facts_stmt = (
select(MemoryFact)
.where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed")
.order_by(MemoryFact.created_at.desc())
.limit(top_k)
)
facts = list(session.execute(facts_stmt).unique().scalars().all())
relevant_facts = [
{
"id": f.id,
"fact_type": f.fact_type,
"subject": f.subject,
"predicate": f.predicate,
"object_json": f.object_json,
}
for f in facts
]
# Timeline
events_stmt = (
select(TimelineEvent)
.where(TimelineEvent.user_id == user_id)
.order_by(TimelineEvent.event_year.desc().nullslast())
.limit(top_k)
)
events = list(session.execute(events_stmt).unique().scalars().all())
timeline_hints = [
{
"id": e.id,
"event_year": e.event_year,
"event_date": e.event_date,
"title": e.title,
"description": e.description,
}
for e in events
]
return {
"relevant_chunks": relevant_chunks,
"relevant_summaries": [],
"relevant_facts": relevant_facts,
"timeline_hints": timeline_hints,
"relevant_stories": [],
}
async def get_timeline_events_for_user(
db: AsyncSession, user_id: str, limit: int = 20
) -> list[TimelineEvent]:
"""Fetch timeline events for user."""
stmt = (
select(TimelineEvent)
.where(TimelineEvent.user_id == user_id)
.order_by(
TimelineEvent.event_year.desc().nullslast(), TimelineEvent.created_at.desc()
)
.limit(limit)
)
result = await db.execute(stmt)
return list(result.unique().scalars().all())
async def list_storage_keys_for_conversation(
db: AsyncSession, conversation_id: str
) -> list[str]:
"""对话关联的 memory_sources 上记录的 COS object key若有"""
stmt = select(MemorySource.storage_key).where(
MemorySource.conversation_id == conversation_id,
MemorySource.storage_key.isnot(None),
)
result = await db.execute(stmt)
return sorted({r for r in result.scalars().all() if r})