2026-04-30 16:22:55 +08:00
|
|
|
|
"""Memory repository — MemorySource, MemoryChunk, and MemoryFact data access."""
|
2026-03-20 10:30:07 +08:00
|
|
|
|
|
|
|
|
|
|
import uuid
|
2026-04-03 11:43:16 +08:00
|
|
|
|
from datetime import datetime, timedelta, timezone
|
2026-03-20 10:30:07 +08:00
|
|
|
|
|
2026-04-30 16:22:55 +08:00
|
|
|
|
from sqlalchemy import cast, literal, or_, select, text, tuple_, update
|
2026-03-20 10:30:07 +08:00
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2026-05-22 13:44:50 +08:00
|
|
|
|
from sqlalchemy.orm import Session
|
2026-04-30 16:22:55 +08:00
|
|
|
|
from sqlalchemy.types import String as SqlString
|
2026-03-20 10:30:07 +08:00
|
|
|
|
|
|
|
|
|
|
from app.features.memory.models import (
|
|
|
|
|
|
MemoryChunk,
|
2026-03-27 16:01:28 +08:00
|
|
|
|
MemoryCurationAction,
|
2026-03-20 10:30:07 +08:00
|
|
|
|
MemoryFact,
|
|
|
|
|
|
MemorySource,
|
2026-03-27 16:01:28 +08:00
|
|
|
|
MemorySummary,
|
2026-03-20 10:30:07 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _new_id() -> str:
|
|
|
|
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_source(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
*,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
source_type: str,
|
|
|
|
|
|
raw_text: str | None = None,
|
|
|
|
|
|
conversation_id: str | None = None,
|
2026-05-22 13:44:50 +08:00
|
|
|
|
segment_id: str | None = None,
|
2026-03-20 10:30:07 +08:00
|
|
|
|
captured_at: datetime | None = None,
|
2026-04-08 15:37:09 +08:00
|
|
|
|
lineage_json: dict | None = None,
|
|
|
|
|
|
primary_user_message_id: str | None = None,
|
2026-03-20 10:30:07 +08:00
|
|
|
|
) -> MemorySource:
|
|
|
|
|
|
"""Create a memory source. Caller must commit."""
|
|
|
|
|
|
source = MemorySource(
|
|
|
|
|
|
id=_new_id(),
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
source_type=source_type,
|
|
|
|
|
|
raw_text=raw_text,
|
2026-04-30 16:22:55 +08:00
|
|
|
|
embedding_status="pending",
|
|
|
|
|
|
enrichment_status="pending",
|
2026-03-20 10:30:07 +08:00
|
|
|
|
conversation_id=conversation_id,
|
2026-05-22 13:44:50 +08:00
|
|
|
|
segment_id=segment_id,
|
2026-04-08 15:37:09 +08:00
|
|
|
|
lineage_json=lineage_json,
|
|
|
|
|
|
primary_user_message_id=primary_user_message_id,
|
2026-03-20 10:30:07 +08:00
|
|
|
|
captured_at=captured_at or datetime.now(timezone.utc),
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(source)
|
|
|
|
|
|
return source
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-05-22 13:44:50 +08:00
|
|
|
|
async def get_transcript_source_by_segment_id(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
*,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
segment_id: str,
|
|
|
|
|
|
) -> MemorySource | None:
|
|
|
|
|
|
stmt = select(MemorySource).where(
|
|
|
|
|
|
MemorySource.user_id == user_id,
|
|
|
|
|
|
MemorySource.segment_id == segment_id,
|
|
|
|
|
|
MemorySource.source_type == "transcript",
|
|
|
|
|
|
)
|
|
|
|
|
|
return (await db.execute(stmt)).scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_transcript_source_by_segment_id_sync(
|
|
|
|
|
|
db: Session,
|
|
|
|
|
|
*,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
segment_id: str,
|
|
|
|
|
|
) -> MemorySource | None:
|
|
|
|
|
|
stmt = select(MemorySource).where(
|
|
|
|
|
|
MemorySource.user_id == user_id,
|
|
|
|
|
|
MemorySource.segment_id == segment_id,
|
|
|
|
|
|
MemorySource.source_type == "transcript",
|
|
|
|
|
|
)
|
|
|
|
|
|
return db.execute(stmt).scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-20 10:30:07 +08:00
|
|
|
|
async def create_chunk(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
*,
|
|
|
|
|
|
source_id: str,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
content: str,
|
|
|
|
|
|
chunk_index: int,
|
|
|
|
|
|
) -> MemoryChunk:
|
|
|
|
|
|
"""Create a memory chunk. Caller must commit."""
|
|
|
|
|
|
chunk = MemoryChunk(
|
|
|
|
|
|
id=_new_id(),
|
|
|
|
|
|
source_id=source_id,
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
content=content,
|
|
|
|
|
|
chunk_index=chunk_index,
|
2026-04-30 16:22:55 +08:00
|
|
|
|
embedding_status="pending",
|
2026-03-20 10:30:07 +08:00
|
|
|
|
)
|
|
|
|
|
|
db.add(chunk)
|
|
|
|
|
|
return chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def update_chunk_embedding(
|
|
|
|
|
|
db: AsyncSession, chunk_id: str, embedding: list[float]
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
"""Update chunk embedding. Caller must commit."""
|
|
|
|
|
|
chunk = await db.get(MemoryChunk, chunk_id)
|
|
|
|
|
|
if chunk:
|
|
|
|
|
|
chunk.embedding = embedding
|
2026-04-30 16:22:55 +08:00
|
|
|
|
chunk.embedding_status = "success"
|
|
|
|
|
|
chunk.embedding_error = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def set_chunk_embedding_status(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
chunk_id: str,
|
|
|
|
|
|
*,
|
|
|
|
|
|
status: str,
|
|
|
|
|
|
error: str | None = None,
|
|
|
|
|
|
) -> bool:
|
|
|
|
|
|
chunk = await db.get(MemoryChunk, chunk_id)
|
|
|
|
|
|
if chunk is None:
|
|
|
|
|
|
return False
|
|
|
|
|
|
chunk.embedding_status = status
|
|
|
|
|
|
chunk.embedding_error = error
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def set_source_embedding_status(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
*,
|
|
|
|
|
|
source_id: str,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
status: str,
|
|
|
|
|
|
error: str | None = None,
|
|
|
|
|
|
) -> bool:
|
|
|
|
|
|
source = await db.get(MemorySource, source_id)
|
|
|
|
|
|
if source is None or source.user_id != user_id:
|
|
|
|
|
|
return False
|
|
|
|
|
|
source.embedding_status = status
|
|
|
|
|
|
source.embedding_error = error
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def set_source_enrichment_status(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
*,
|
|
|
|
|
|
source_id: str,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
status: str,
|
|
|
|
|
|
error: str | None = None,
|
|
|
|
|
|
) -> bool:
|
|
|
|
|
|
source = await db.get(MemorySource, source_id)
|
|
|
|
|
|
if source is None or source.user_id != user_id:
|
|
|
|
|
|
return False
|
|
|
|
|
|
source.enrichment_status = status
|
|
|
|
|
|
source.enrichment_error = error
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def list_chunks_for_source(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
*,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
source_id: str,
|
|
|
|
|
|
include_excluded: bool = True,
|
|
|
|
|
|
) -> list[MemoryChunk]:
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(MemoryChunk)
|
|
|
|
|
|
.where(MemoryChunk.user_id == user_id, MemoryChunk.source_id == source_id)
|
|
|
|
|
|
.order_by(MemoryChunk.chunk_index.asc(), MemoryChunk.id.asc())
|
|
|
|
|
|
)
|
|
|
|
|
|
if not include_excluded:
|
|
|
|
|
|
stmt = stmt.where(
|
|
|
|
|
|
or_(MemoryChunk.is_excluded.is_(False), MemoryChunk.is_excluded.is_(None))
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return list(result.unique().scalars().all())
|
2026-03-20 10:30:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_chunks_by_ids(
|
|
|
|
|
|
db: AsyncSession, chunk_ids: list[str]
|
|
|
|
|
|
) -> list[MemoryChunk]:
|
|
|
|
|
|
"""Fetch chunks by IDs."""
|
|
|
|
|
|
if not chunk_ids:
|
|
|
|
|
|
return []
|
|
|
|
|
|
stmt = select(MemoryChunk).where(MemoryChunk.id.in_(chunk_ids))
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
chunks = list(result.unique().scalars().all())
|
|
|
|
|
|
order = {cid: i for i, cid in enumerate(chunk_ids)}
|
|
|
|
|
|
return sorted(chunks, key=lambda c: order.get(c.id, 999))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_facts_for_user(
|
|
|
|
|
|
db: AsyncSession, user_id: str, limit: int = 20
|
|
|
|
|
|
) -> list[MemoryFact]:
|
|
|
|
|
|
"""Fetch recent facts for user."""
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(MemoryFact)
|
|
|
|
|
|
.where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed")
|
|
|
|
|
|
.order_by(MemoryFact.created_at.desc())
|
|
|
|
|
|
.limit(limit)
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return list(result.unique().scalars().all())
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-27 16:01:28 +08:00
|
|
|
|
async def search_facts_for_user_async(
|
|
|
|
|
|
db: AsyncSession, user_id: str, query: str, limit: int = 20
|
|
|
|
|
|
) -> list[MemoryFact]:
|
|
|
|
|
|
q = (query or "").strip()
|
|
|
|
|
|
if not q:
|
2026-04-30 14:11:46 +08:00
|
|
|
|
return []
|
2026-03-27 16:01:28 +08:00
|
|
|
|
pat = f"%{q}%"
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(MemoryFact)
|
|
|
|
|
|
.where(
|
|
|
|
|
|
MemoryFact.user_id == user_id,
|
|
|
|
|
|
MemoryFact.status == "confirmed",
|
2026-04-30 16:22:55 +08:00
|
|
|
|
or_(
|
|
|
|
|
|
MemoryFact.subject.ilike(pat),
|
|
|
|
|
|
MemoryFact.predicate.ilike(pat),
|
|
|
|
|
|
cast(MemoryFact.object_json, SqlString).ilike(pat),
|
|
|
|
|
|
),
|
2026-03-27 16:01:28 +08:00
|
|
|
|
)
|
|
|
|
|
|
.order_by(MemoryFact.created_at.desc())
|
|
|
|
|
|
.limit(limit)
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
2026-04-30 14:11:46 +08:00
|
|
|
|
return list(result.unique().scalars().all())
|
2026-04-03 10:12:59 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-04-30 14:11:46 +08:00
|
|
|
|
async def mark_facts_stale_for_excluded_chunk(
|
|
|
|
|
|
db: AsyncSession, *, user_id: str, chunk_id: str
|
2026-04-03 10:12:59 +08:00
|
|
|
|
) -> int:
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
update(MemoryFact)
|
|
|
|
|
|
.where(
|
|
|
|
|
|
MemoryFact.user_id == user_id,
|
|
|
|
|
|
MemoryFact.source_chunk_id == chunk_id,
|
|
|
|
|
|
MemoryFact.status.in_(["confirmed", "candidate"]),
|
|
|
|
|
|
)
|
|
|
|
|
|
.values(status="stale")
|
|
|
|
|
|
)
|
2026-04-30 14:11:46 +08:00
|
|
|
|
res = await db.execute(stmt)
|
2026-04-03 10:12:59 +08:00
|
|
|
|
return int(res.rowcount or 0)
|
2026-03-27 16:01:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-03-20 10:30:07 +08:00
|
|
|
|
async def search_chunks_vector(
|
|
|
|
|
|
db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20
|
|
|
|
|
|
) -> list[dict]:
|
|
|
|
|
|
"""Vector similarity search. Returns list of {id, content, chunk_index, distance}."""
|
|
|
|
|
|
if not query_embedding:
|
|
|
|
|
|
return []
|
|
|
|
|
|
# pgvector cosine distance: 1 - cosine_similarity, lower is better
|
|
|
|
|
|
stmt = text("""
|
|
|
|
|
|
SELECT id, content, chunk_index,
|
2026-03-31 23:55:26 +08:00
|
|
|
|
(embedding <=> CAST(:emb AS vector)) AS distance
|
2026-03-20 10:30:07 +08:00
|
|
|
|
FROM memory_chunks
|
|
|
|
|
|
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
|
|
|
|
|
|
AND embedding IS NOT NULL
|
2026-03-31 23:55:26 +08:00
|
|
|
|
ORDER BY embedding <=> CAST(:emb2 AS vector)
|
2026-03-20 10:30:07 +08:00
|
|
|
|
LIMIT :lim
|
|
|
|
|
|
""")
|
|
|
|
|
|
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
|
|
|
|
|
result = await db.execute(
|
|
|
|
|
|
stmt,
|
|
|
|
|
|
{"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit},
|
|
|
|
|
|
)
|
|
|
|
|
|
rows = result.mappings().all()
|
|
|
|
|
|
return [
|
|
|
|
|
|
{
|
|
|
|
|
|
"id": r["id"],
|
|
|
|
|
|
"content": r["content"],
|
|
|
|
|
|
"chunk_index": r["chunk_index"],
|
|
|
|
|
|
"distance": float(r["distance"]),
|
|
|
|
|
|
}
|
|
|
|
|
|
for r in rows
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-30 14:11:46 +08:00
|
|
|
|
async def list_users_with_recent_chunks(db: AsyncSession, *, hours: int) -> list[str]:
|
2026-04-03 11:43:16 +08:00
|
|
|
|
"""最近 N 小时内有新 chunk 的用户 id(Beat compaction 扫描)。"""
|
|
|
|
|
|
if hours < 1:
|
|
|
|
|
|
hours = 1
|
|
|
|
|
|
cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(MemoryChunk.user_id).where(MemoryChunk.created_at >= cutoff).distinct()
|
|
|
|
|
|
)
|
2026-04-30 14:11:46 +08:00
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return list(result.scalars().all())
|
2026-03-20 10:30:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-03-20 15:15:35 +08:00
|
|
|
|
async def list_storage_keys_for_conversation(
|
|
|
|
|
|
db: AsyncSession, conversation_id: str
|
|
|
|
|
|
) -> list[str]:
|
|
|
|
|
|
"""对话关联的 memory_sources 上记录的 COS object key(若有)。"""
|
|
|
|
|
|
stmt = select(MemorySource.storage_key).where(
|
|
|
|
|
|
MemorySource.conversation_id == conversation_id,
|
|
|
|
|
|
MemorySource.storage_key.isnot(None),
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return sorted({r for r in result.scalars().all() if r})
|
2026-03-27 16:01:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_memory_summary(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
*,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
summary_type: str,
|
|
|
|
|
|
content: str,
|
|
|
|
|
|
source_chunk_ids: list[str] | None = None,
|
|
|
|
|
|
) -> MemorySummary:
|
|
|
|
|
|
row = MemorySummary(
|
|
|
|
|
|
id=_new_id(),
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
summary_type=summary_type,
|
|
|
|
|
|
content=content,
|
|
|
|
|
|
source_chunk_ids=source_chunk_ids,
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(row)
|
|
|
|
|
|
return row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_memory_fact(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
*,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
fact_type: str,
|
|
|
|
|
|
subject: str | None,
|
|
|
|
|
|
predicate: str | None,
|
|
|
|
|
|
object_json: dict | None,
|
|
|
|
|
|
confidence: float,
|
|
|
|
|
|
source_chunk_id: str | None,
|
|
|
|
|
|
status: str = "confirmed",
|
2026-04-08 15:37:09 +08:00
|
|
|
|
lineage_json: dict | None = None,
|
2026-03-27 16:01:28 +08:00
|
|
|
|
) -> MemoryFact:
|
|
|
|
|
|
row = MemoryFact(
|
|
|
|
|
|
id=_new_id(),
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
fact_type=fact_type,
|
|
|
|
|
|
subject=subject,
|
|
|
|
|
|
predicate=predicate,
|
|
|
|
|
|
object_json=object_json,
|
|
|
|
|
|
confidence=confidence,
|
|
|
|
|
|
source_chunk_id=source_chunk_id,
|
|
|
|
|
|
status=status,
|
2026-04-08 15:37:09 +08:00
|
|
|
|
lineage_json=lineage_json,
|
2026-03-27 16:01:28 +08:00
|
|
|
|
)
|
|
|
|
|
|
db.add(row)
|
|
|
|
|
|
return row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_memory_fact_for_user(
|
|
|
|
|
|
db: AsyncSession, fact_id: str, user_id: str
|
|
|
|
|
|
) -> MemoryFact | None:
|
|
|
|
|
|
row = await db.get(MemoryFact, fact_id)
|
|
|
|
|
|
if row is None or row.user_id != user_id:
|
|
|
|
|
|
return None
|
|
|
|
|
|
return row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def set_memory_fact_status(
|
|
|
|
|
|
db: AsyncSession, fact_id: str, user_id: str, status: str
|
|
|
|
|
|
) -> bool:
|
|
|
|
|
|
row = await get_memory_fact_for_user(db, fact_id, user_id)
|
|
|
|
|
|
if row is None:
|
|
|
|
|
|
return False
|
|
|
|
|
|
row.status = status
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def create_curation_action(
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
*,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
action_type: str,
|
|
|
|
|
|
target_type: str,
|
|
|
|
|
|
target_id: str,
|
|
|
|
|
|
details: dict | None = None,
|
|
|
|
|
|
) -> MemoryCurationAction:
|
|
|
|
|
|
row = MemoryCurationAction(
|
|
|
|
|
|
id=_new_id(),
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
action_type=action_type,
|
|
|
|
|
|
target_type=target_type,
|
|
|
|
|
|
target_id=target_id,
|
|
|
|
|
|
details=details,
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(row)
|
|
|
|
|
|
return row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_memory_chunk_for_user(
|
|
|
|
|
|
db: AsyncSession, chunk_id: str, user_id: str
|
|
|
|
|
|
) -> MemoryChunk | None:
|
|
|
|
|
|
row = await db.get(MemoryChunk, chunk_id)
|
|
|
|
|
|
if row is None or row.user_id != user_id:
|
|
|
|
|
|
return None
|
|
|
|
|
|
return row
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-30 14:11:46 +08:00
|
|
|
|
async def list_incremental_chunks_for_compaction(
|
|
|
|
|
|
db: AsyncSession,
|
2026-03-30 10:46:35 +08:00
|
|
|
|
*,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
after_cursor_ts: datetime,
|
|
|
|
|
|
after_chunk_id: str,
|
|
|
|
|
|
limit: int,
|
|
|
|
|
|
candidate_chunk_ids: list[str] | None = None,
|
|
|
|
|
|
candidate_source_ids: list[str] | None = None,
|
|
|
|
|
|
) -> list[MemoryChunk]:
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(MemoryChunk)
|
|
|
|
|
|
.where(
|
|
|
|
|
|
MemoryChunk.user_id == user_id,
|
|
|
|
|
|
tuple_(MemoryChunk.created_at, MemoryChunk.id)
|
|
|
|
|
|
> tuple_(literal(after_cursor_ts), literal(after_chunk_id)),
|
|
|
|
|
|
or_(MemoryChunk.is_excluded.is_(False), MemoryChunk.is_excluded.is_(None)),
|
|
|
|
|
|
)
|
|
|
|
|
|
.order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc())
|
|
|
|
|
|
.limit(limit)
|
|
|
|
|
|
)
|
|
|
|
|
|
if candidate_chunk_ids:
|
|
|
|
|
|
stmt = stmt.where(MemoryChunk.id.in_(candidate_chunk_ids))
|
|
|
|
|
|
if candidate_source_ids:
|
|
|
|
|
|
stmt = stmt.where(MemoryChunk.source_id.in_(candidate_source_ids))
|
2026-04-30 14:11:46 +08:00
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return list(result.unique().scalars().all())
|
2026-03-30 10:46:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-04-30 14:11:46 +08:00
|
|
|
|
async def get_first_chunk_after_cursor(
|
|
|
|
|
|
db: AsyncSession,
|
2026-03-30 10:46:35 +08:00
|
|
|
|
*,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
after_cursor_ts: datetime,
|
|
|
|
|
|
after_chunk_id: str,
|
|
|
|
|
|
) -> MemoryChunk | None:
|
|
|
|
|
|
stmt = (
|
|
|
|
|
|
select(MemoryChunk)
|
|
|
|
|
|
.where(
|
|
|
|
|
|
MemoryChunk.user_id == user_id,
|
|
|
|
|
|
tuple_(MemoryChunk.created_at, MemoryChunk.id)
|
|
|
|
|
|
> tuple_(literal(after_cursor_ts), literal(after_chunk_id)),
|
|
|
|
|
|
)
|
|
|
|
|
|
.order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc())
|
|
|
|
|
|
.limit(1)
|
|
|
|
|
|
)
|
2026-04-30 14:11:46 +08:00
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
return result.scalars().first()
|
2026-03-30 10:46:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-04-30 14:11:46 +08:00
|
|
|
|
async def search_nearest_chunks_for_compaction(
|
|
|
|
|
|
db: AsyncSession,
|
2026-03-30 10:46:35 +08:00
|
|
|
|
*,
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
chunk_id: str,
|
|
|
|
|
|
query_embedding: list[float],
|
|
|
|
|
|
limit: int,
|
|
|
|
|
|
) -> list[dict]:
|
|
|
|
|
|
if not query_embedding:
|
|
|
|
|
|
return []
|
|
|
|
|
|
stmt = text("""
|
|
|
|
|
|
SELECT mc.id, mc.content, mc.source_id, mc.event_year, mc.metadata_json,
|
|
|
|
|
|
ms.source_type, mc.created_at,
|
2026-03-31 23:55:26 +08:00
|
|
|
|
(mc.embedding <=> CAST(:emb AS vector)) AS distance
|
2026-03-30 10:46:35 +08:00
|
|
|
|
FROM memory_chunks mc
|
|
|
|
|
|
JOIN memory_sources ms ON ms.id = mc.source_id
|
|
|
|
|
|
WHERE mc.user_id = :user_id
|
|
|
|
|
|
AND (mc.is_excluded IS NOT TRUE OR mc.is_excluded = false)
|
|
|
|
|
|
AND mc.embedding IS NOT NULL
|
|
|
|
|
|
AND mc.id != :chunk_id
|
2026-03-31 23:55:26 +08:00
|
|
|
|
ORDER BY mc.embedding <=> CAST(:emb2 AS vector)
|
2026-03-30 10:46:35 +08:00
|
|
|
|
LIMIT :lim
|
|
|
|
|
|
""")
|
|
|
|
|
|
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
2026-04-30 14:11:46 +08:00
|
|
|
|
result = await db.execute(
|
2026-03-30 10:46:35 +08:00
|
|
|
|
stmt,
|
|
|
|
|
|
{
|
|
|
|
|
|
"user_id": user_id,
|
|
|
|
|
|
"chunk_id": chunk_id,
|
|
|
|
|
|
"emb": emb_str,
|
|
|
|
|
|
"emb2": emb_str,
|
|
|
|
|
|
"lim": limit,
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
return [
|
|
|
|
|
|
{
|
|
|
|
|
|
"id": r["id"],
|
|
|
|
|
|
"content": r["content"],
|
|
|
|
|
|
"source_id": r["source_id"],
|
|
|
|
|
|
"event_year": r["event_year"],
|
|
|
|
|
|
"metadata_json": r["metadata_json"],
|
|
|
|
|
|
"source_type": r["source_type"],
|
|
|
|
|
|
"created_at": r["created_at"],
|
|
|
|
|
|
"distance": float(r["distance"]),
|
|
|
|
|
|
}
|
|
|
|
|
|
for r in result.mappings().all()
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-27 16:01:28 +08:00
|
|
|
|
async def set_chunk_excluded(
|
|
|
|
|
|
db: AsyncSession, chunk_id: str, user_id: str, excluded: bool
|
|
|
|
|
|
) -> bool:
|
|
|
|
|
|
row = await get_memory_chunk_for_user(db, chunk_id, user_id)
|
|
|
|
|
|
if row is None:
|
|
|
|
|
|
return False
|
|
|
|
|
|
row.is_excluded = excluded
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def list_summaries_for_evidence_async(
|
|
|
|
|
|
db: AsyncSession, *, user_id: str, q: str, limit: int
|
|
|
|
|
|
) -> list[dict]:
|
|
|
|
|
|
if not (q or "").strip():
|
|
|
|
|
|
return []
|
|
|
|
|
|
pat = f"%{q.strip()}%"
|
2026-04-30 14:11:46 +08:00
|
|
|
|
stmt = (
|
2026-03-27 16:01:28 +08:00
|
|
|
|
select(MemorySummary)
|
|
|
|
|
|
.where(
|
|
|
|
|
|
MemorySummary.user_id == user_id,
|
2026-04-30 14:11:46 +08:00
|
|
|
|
MemorySummary.summary_type == "session",
|
|
|
|
|
|
MemorySummary.content.ilike(pat),
|
2026-03-27 16:01:28 +08:00
|
|
|
|
)
|
|
|
|
|
|
.order_by(MemorySummary.updated_at.desc())
|
2026-04-30 14:11:46 +08:00
|
|
|
|
.limit(limit)
|
2026-03-27 16:01:28 +08:00
|
|
|
|
)
|
2026-04-30 14:11:46 +08:00
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
rows = list(result.unique().scalars().all())
|
2026-03-27 16:01:28 +08:00
|
|
|
|
return [
|
|
|
|
|
|
{
|
|
|
|
|
|
"id": s.id,
|
|
|
|
|
|
"summary_type": s.summary_type,
|
|
|
|
|
|
"content": s.content,
|
|
|
|
|
|
"source_chunk_ids": s.source_chunk_ids,
|
|
|
|
|
|
}
|
|
|
|
|
|
for s in rows[:limit]
|
|
|
|
|
|
]
|