聊天和回忆录证据检索都走 pgvector,去掉 Postgres FTS/content_tsv,新迁移删掉 content_tsv 列(部署要先 alembic upgrade)。

Embedding 端口增加 is_available(),聊天和回忆录日志用统一方式表示向量是否真能调用。

记忆整理(compaction)支持 Beat 定期扫用户;

事实抽取提示与 subject 归一化,减少同一人多种称呼;
This commit is contained in:
Kevin
2026-04-03 11:43:16 +08:00
parent b853b986dd
commit 41518bda11
26 changed files with 543 additions and 222 deletions

View File

@@ -7,6 +7,9 @@ import asyncio
from zai import ZhipuAiClient
from app.core.embedding import MEMORY_EMBEDDING_DIMENSION
from app.core.logging import get_logger
_logger = get_logger(__name__)
# 单次请求最多 64 条文本(智谱 Embedding-3 文档)
_EMBED_BATCH_SIZE = 64
@@ -22,6 +25,9 @@ class ZhipuEmbeddingProvider:
) -> None:
self._model = model
if not api_key:
_logger.warning(
"ZhipuEmbeddingProvider: api_key 为空embedding 将不可用(记忆检索与 ingest 向量写入会降级)"
)
self._client = None
elif base_url:
self._client = ZhipuAiClient(
@@ -31,6 +37,9 @@ class ZhipuEmbeddingProvider:
else:
self._client = ZhipuAiClient(api_key=api_key)
def is_available(self) -> bool:
return self._client is not None
def _create_vectors_sync(self, texts: list[str]) -> list[list[float]]:
assert self._client is not None
resp = self._client.embeddings.create(
@@ -54,3 +63,16 @@ class ZhipuEmbeddingProvider:
part = await asyncio.to_thread(self._create_vectors_sync, batch)
out.extend(part)
return out
def embed_text_sync(self, text: str) -> list[float]:
vecs = self.embed_texts_sync([text])
return vecs[0] if vecs else []
def embed_texts_sync(self, texts: list[str]) -> list[list[float]]:
if not self._client or not texts:
return []
out: list[list[float]] = []
for i in range(0, len(texts), _EMBED_BATCH_SIZE):
batch = texts[i : i + _EMBED_BATCH_SIZE]
out.extend(self._create_vectors_sync(batch))
return out

View File

@@ -68,9 +68,22 @@ async def _fetch_interview_memory_evidence(
):
return ""
try:
ms = MemoryService(db, embedding_provider=get_embedding_provider())
emb = get_embedding_provider()
ms = MemoryService(db, embedding_provider=emb)
bundle = await ms.retrieve(user_id, msg, top_k=settings.chat_memory_top_k)
text = format_evidence_chunks_for_prompt(bundle.model_dump())
bd = bundle.model_dump()
vector_ok = emb.is_available()
logger.info(
"memory_evidence_retrieved user_id={} chunks={} facts={} summaries={} timeline={} stories={} vector_ok={}",
user_id,
len(bd.get("relevant_chunks") or []),
len(bd.get("relevant_facts") or []),
len(bd.get("relevant_summaries") or []),
len(bd.get("timeline_hints") or []),
len(bd.get("relevant_stories") or []),
vector_ok,
)
text = format_evidence_chunks_for_prompt(bd)
t = (text or "").strip()
if not t:
return ""

View File

@@ -267,16 +267,16 @@ class Settings(BaseSettings):
memoir_phase2_singleflight_immediate: bool = True
# ── Memory 检索与富化 ─────────────────────────────────────
# Truequery 为空时仍返回 rolling 摘要 + 最近事实/时间线(无 chunk FTS
# Truequery 为空时仍返回 rolling 摘要 + 最近事实/时间线(无 chunk 向量检索
memory_evidence_empty_query_include_rolling: bool = False
# False跳过 ingest 后 LLM 富化(摘要/事实/时间线)
memory_enrichment_enabled: bool = True
memory_enrichment_max_chars: int = Field(default=12000, ge=1000, le=100_000)
# True事实 FTS 未命中时退回「最近 confirmed 事实」(易引入无关/矛盾事实;默认关)
# True事实 ILIKE 未命中时退回「最近 confirmed 事实」(易引入无关/矛盾事实;默认关)
memory_fact_search_use_recent_fallback: bool = False
# ── Memory compaction近重复 chunk 软排除;事件触发 + Redis 防抖 + 用户锁)──
memory_compaction_enabled: bool = False
# ── Memory compaction近重复 chunk 软排除;事件触发 + Redis 防抖 + 用户锁;需 worker + Beat 跑 sweep)──
memory_compaction_enabled: bool = True
memory_compaction_debounce_seconds: int = Field(default=105, ge=10, le=3600)
memory_compaction_lock_ttl_seconds: int = Field(default=600, ge=60, le=7200)
memory_compaction_chunk_similarity_threshold: float = Field(
@@ -288,6 +288,8 @@ class Settings(BaseSettings):
memory_compaction_max_neighbors_per_chunk: int = Field(default=25, ge=5, le=100)
memory_compaction_text_jaccard_min: float = Field(default=0.55, ge=0.0, le=1.0)
memory_compaction_metadata_event_year_window: int = Field(default=1, ge=0, le=50)
# Beat sweep扫描最近 N 小时内有新 chunk 的用户并调度 compaction
memory_compaction_sweep_recent_hours: int = Field(default=24, ge=1, le=168)
# ── Liblib ───────────────────────────────────────────────
liblib_access_key: str = ""

View File

@@ -74,7 +74,7 @@ class MemoirService:
self._object_storage = object_storage
async def get_evidence(self, user_id: str, query: str, *, top_k: int = 10) -> dict:
"""通过 MemoryService → HybridRetriever 获取证据(向量与 Celery 的 FTS-only 路径不同)。"""
"""通过 MemoryService → HybridRetriever 获取证据(向量 chunks与 Celery 叙事路径一致)。"""
if self._memory is None:
return {
"relevant_chunks": [],

View File

@@ -33,6 +33,7 @@ from app.agents.memoir.story_route_agent import (
)
from app.agents.state_schema import MemoirStateSchema
from app.core.config import settings
from app.core.dependencies import get_embedding_provider
from app.core.logging import get_logger
from app.features.memoir.cover_eligibility import chapter_needs_cover_enqueue
from app.features.memoir.memoir_images.settings import MemoirImageSettings
@@ -714,8 +715,16 @@ def run_story_pipeline_for_category_batch(
top_k = int(settings.evidence_top_k_default)
if n_units > int(settings.evidence_large_batch_threshold):
top_k = int(settings.evidence_top_k_large_batch)
emb = get_embedding_provider()
embedding_available = emb.is_available()
try:
evidence = retrieve_evidence_sync(session, user_id, combined_text, top_k=top_k)
evidence = retrieve_evidence_sync(
session,
user_id,
combined_text,
top_k=top_k,
embedding_provider=emb,
)
except Exception as e:
logger.warning("Evidence 检索跳过: {}", e)
evidence = {
@@ -726,6 +735,16 @@ def run_story_pipeline_for_category_batch(
"relevant_stories": [],
}
logger.info(
"memoir_evidence_retrieved user_id={} chunks={} facts={} summaries={} stories={} vector_ok={}",
user_id,
len(evidence.get("relevant_chunks") or []),
len(evidence.get("relevant_facts") or []),
len(evidence.get("relevant_summaries") or []),
len(evidence.get("relevant_stories") or []),
embedding_available,
)
evidence_text = format_evidence_chunks_for_prompt(evidence)
oral_for_memoir = normalize_oral_for_memoir(combined_text, llm=llm)
ct_raw = (combined_text or "").strip()

View File

@@ -36,7 +36,12 @@ from app.features.memory.summarizer import (
generate_session_summary_async,
generate_session_summary_sync,
)
from app.features.memory.enrichment_pipeline import dedupe_key, normalize_object_json
from app.features.memory.enrichment_pipeline import (
dedupe_key,
normalize_object_json,
normalize_subject,
)
from app.features.user.models import User
from app.features.memory.timeline import (
build_timeline_events_from_facts_async,
build_timeline_events_from_facts_sync,
@@ -69,6 +74,10 @@ def enrich_memory_after_ingest_sync(
llm = _resolve_llm_sync()
if not llm:
return
narrator_name: str | None = None
u_row = session.get(User, user_id)
if u_row and (u_row.nickname or "").strip():
narrator_name = (u_row.nickname or "").strip()
chunks = list_chunks_for_source_sync(session, source_id)
if not chunks:
return
@@ -111,11 +120,13 @@ def enrich_memory_after_ingest_sync(
source_chunk_ids=chunk_ids,
)
raw_facts = extract_facts_from_transcript_sync(llm, numbered)
raw_facts = extract_facts_from_transcript_sync(
llm, numbered, narrator_name=narrator_name
)
seen: set[tuple] = set()
inserted: list[dict] = []
for f in raw_facts:
key = dedupe_key(f)
key = dedupe_key(f, narrator_name=narrator_name)
if key in seen:
continue
seen.add(key)
@@ -126,7 +137,7 @@ def enrich_memory_after_ingest_sync(
session,
user_id=user_id,
fact_type=f.get("fact_type") or "event",
subject=f.get("subject"),
subject=normalize_subject(f.get("subject"), narrator_name),
predicate=f.get("predicate"),
object_json=normalize_object_json(f.get("object_json")),
confidence=float(f.get("confidence") or 0.75),
@@ -175,6 +186,10 @@ async def enrich_memory_after_ingest_async(
llm = _resolve_llm_sync()
if not llm:
return
narrator_name: str | None = None
u_row = await db.get(User, user_id)
if u_row and (u_row.nickname or "").strip():
narrator_name = (u_row.nickname or "").strip()
stmt = (
select(MemoryChunk)
.where(MemoryChunk.source_id == source_id)
@@ -227,11 +242,13 @@ async def enrich_memory_after_ingest_async(
source_chunk_ids=chunk_ids,
)
raw_facts = await extract_facts_from_transcript_async(llm, numbered)
raw_facts = await extract_facts_from_transcript_async(
llm, numbered, narrator_name=narrator_name
)
seen: set[tuple] = set()
inserted: list[dict] = []
for f in raw_facts:
key = dedupe_key(f)
key = dedupe_key(f, narrator_name=narrator_name)
if key in seen:
continue
seen.add(key)
@@ -242,7 +259,7 @@ async def enrich_memory_after_ingest_async(
db,
user_id=user_id,
fact_type=f.get("fact_type") or "event",
subject=f.get("subject"),
subject=normalize_subject(f.get("subject"), narrator_name),
predicate=f.get("predicate"),
object_json=normalize_object_json(f.get("object_json")),
confidence=float(f.get("confidence") or 0.75),

View File

@@ -5,10 +5,34 @@ from __future__ import annotations
import json
from typing import Any
# 叙述者常见别名 — 归一化到 narrator_name 或「叙述者」
_NARRATOR_ALIASES: frozenset[str] = frozenset(
{
"",
"本人",
"人物",
"叙述者",
"讲述者",
"老人",
"自己",
"咱们",
}
)
def dedupe_key(f: dict) -> tuple:
s = f.get("subject") or ""
p = f.get("predicate") or ""
def normalize_subject(subject: str | None, narrator_name: str | None = None) -> str:
"""将代词/泛称映射为统一 subject便于去重与检索。"""
s = (subject or "").strip()
if not s:
return narrator_name or "叙述者"
if s in _NARRATOR_ALIASES:
return narrator_name or "叙述者"
return s
def dedupe_key(f: dict, *, narrator_name: str | None = None) -> tuple:
s = normalize_subject(f.get("subject"), narrator_name)
p = (f.get("predicate") or "").strip()
o = f.get("object_json")
try:
oj = json.dumps(o, sort_keys=True, ensure_ascii=False) if o is not None else ""

View File

@@ -4,22 +4,24 @@
权威层级(可靠性 hardening
- **Chunk 原文**(未 excluded为首要证据rolling 摘要/故事摘录为便利视图,不得压过冲突的 chunk。
- **MemoryFact**`confirmed` 为检索默认集;`candidate` 可被上游提升;`stale` 由 compaction 等标出,检索时应排除。
- 事实 FTS 无命中时是否退回「最近事实」由 `memory_fact_search_use_recent_fallback` 控制(默认可避免串台)。
- 事实 ILIKE 无命中时是否退回「最近事实」由 `memory_fact_search_use_recent_fallback` 控制(默认可避免串台)。
Celery 使用 sync`HybridRetriever` 使用 async + RRF chunk 合并
Celery 使用 sync + 向量 chunks`HybridRetriever` 使用 async + 向量 chunks
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.logging import get_logger
from app.features.memory.repo import (
list_summaries_for_evidence_async,
list_summaries_for_evidence_sync,
search_chunks_fts,
search_chunks_fts_sync,
search_chunks_vector_sync,
search_facts_for_user_async,
search_facts_for_user_sync,
search_timeline_events_for_user_async,
@@ -30,6 +32,11 @@ from app.features.story.repo import (
list_recent_stories_for_evidence_sync,
)
if TYPE_CHECKING:
from app.ports.embedding import EmbeddingProvider
logger = get_logger(__name__)
EMPTY_EVIDENCE_BUNDLE: dict = {
"relevant_chunks": [],
"relevant_summaries": [],
@@ -119,7 +126,7 @@ async def fetch_evidence_metadata_async(
def _empty_query_bundle_sync(session: Session, user_id: str, top_k: int) -> dict:
"""无 FTS query 时的「浏览」降级rolling 摘要 + 事实/时间线 fallback。"""
""" query 时的「浏览」降级rolling 摘要 + 事实/时间线 fallback。"""
from app.features.memory.models import MemorySummary
from sqlalchemy import select
@@ -204,19 +211,50 @@ async def _empty_query_bundle_async(db: AsyncSession, user_id: str, top_k: int)
def retrieve_evidence_bundle_sync(
session: Session, user_id: str, query: str, *, top_k: int = 10
session: Session,
user_id: str,
query: str,
*,
top_k: int = 10,
embedding_provider: "EmbeddingProvider | None" = None,
) -> dict:
"""Celery / 叙事流水线:FTS-only chunks + 元数据。"""
"""Celery / 叙事流水线:向量 chunks + 元数据(需 embedding_provider"""
if not query or not query.strip():
if settings.memory_evidence_empty_query_include_rolling:
return _empty_query_bundle_sync(session, user_id, top_k)
return dict(EMPTY_EVIDENCE_BUNDLE)
q = query.strip()
chunk_rows = search_chunks_fts_sync(session, user_id, q, top_k)
relevant_chunks = [
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
for r in chunk_rows
]
relevant_chunks: list[dict] = []
if embedding_provider is not None:
try:
q_emb = embedding_provider.embed_text_sync(q)
except Exception as exc:
logger.warning(
"retrieve_evidence_bundle_sync embed failed user_id={} err={}",
user_id,
exc,
)
q_emb = []
if q_emb:
chunk_rows = search_chunks_vector_sync(session, user_id, q_emb, top_k)
relevant_chunks = [
{
"id": r["id"],
"content": r["content"],
"chunk_index": r["chunk_index"],
}
for r in chunk_rows
]
else:
logger.warning(
"retrieve_evidence_bundle_sync empty_query_embedding user_id={}",
user_id,
)
else:
logger.warning(
"retrieve_evidence_bundle_sync no_embedding_provider user_id={}",
user_id,
)
meta = fetch_evidence_metadata_sync(session, user_id, q, top_k)
return {
"relevant_chunks": relevant_chunks,
@@ -233,7 +271,7 @@ async def retrieve_evidence_bundle_async(
merged_chunk_dicts: list[dict],
) -> dict:
"""
异步路径chunk 已由调用方 RRF 合并;此处只拼元数据。
异步路径chunk 已由调用方(如 HybridRetriever向量检索填入;此处只拼元数据。
merged_chunk_dicts: [{"id","content","chunk_index"}, ...]
"""

View File

@@ -21,18 +21,39 @@ def _max_transcript_chars() -> int:
return settings.memory_enrichment_max_chars
def extract_facts_from_transcript_sync(llm: Any, numbered_blocks: str) -> list[dict]:
def _facts_extraction_instructions(narrator_label: str) -> str:
return (
"你是回忆录事实抽取助手。用户正在口述人生回忆,所有内容默认是**过去发生的事**"
"而非当前或未来计划(除非原文明确说「现在」「打算」「准备将要」等)。\n\n"
"## 抽取规则\n"
"1. subject 必须用明确的人名或固定称谓:\n"
f" - 叙述者本人统一用「{narrator_label}\n"
" - 其他人用全名或稳定专名(如「王伟」),禁止用「他」「她」「我」「我们大伙」等代词作 subject"
"若代词在上下文中可唯一解析为某人,则 subject 写该人姓名/专名\n"
"2. 事件、职务变动、地点迁移等一律按**过去回忆**理解travel/调动/命令类表述勿写成「即将要做」"
"除非原文明确为未来时态\n"
"3. 若可推断大约年代或人生阶段,将 approximate_era 写入 object_json与 value 等字段并存),"
'例如 "1990年代""2001年""退休后""30岁前后"\n'
"4. fact_type: person|event|relation|place|milestone\n"
"5. predicate简短中文谓语如「出生地」「担任职务」「调往」\n"
"6. object_json字符串或对象可含 value、approximate_era 等\n"
"7. confidence 0..1source_chunk_id 必须等于某段 [chunk_id=...] 中的 id\n\n"
'只输出 JSON{"facts":[...]},无事实则 {"facts":[]}。\n\n'
)
def extract_facts_from_transcript_sync(
llm: Any,
numbered_blocks: str,
*,
narrator_name: str | None = None,
) -> list[dict]:
"""同步:带 chunk_id 标记的文本 → 事实列表。"""
if not llm or not (numbered_blocks or "").strip():
return []
text = numbered_blocks.strip()[: _max_transcript_chars()]
prompt = (
"你是回忆录记忆抽取助手。阅读下列带 [chunk_id=...] 的文本块,抽取可核查的事实。\n"
"每个事实含 fact_type: person|event|relation|place|milestonesubjectpredicate"
"object_json可为字符串或对象confidence 0..1source_chunk_id 必须等于某段的 chunk id。\n"
'只输出 JSON{"facts":[...]},无事实则 {"facts":[]}。\n\n'
f"{text}"
)
narrator_label = (narrator_name or "").strip() or "叙述者"
prompt = _facts_extraction_instructions(narrator_label) + text
try:
raw = invoke_json_object(
llm,
@@ -50,19 +71,17 @@ def extract_facts_from_transcript_sync(llm: Any, numbered_blocks: str) -> list[d
async def extract_facts_from_transcript_async(
llm: Any, numbered_blocks: str
llm: Any,
numbered_blocks: str,
*,
narrator_name: str | None = None,
) -> list[dict]:
"""异步版。"""
if not llm or not (numbered_blocks or "").strip():
return []
text = numbered_blocks.strip()[: _max_transcript_chars()]
prompt = (
"你是回忆录记忆抽取助手。阅读下列带 [chunk_id=...] 的文本块,抽取可核查的事实。\n"
"每个事实含 fact_type: person|event|relation|place|milestonesubjectpredicate"
"object_jsonconfidence 0..1source_chunk_id 必须等于某段的 chunk id。\n"
'只输出 JSON{"facts":[...]},无事实则 {"facts":[]}。\n\n'
f"{text}"
)
narrator_label = (narrator_name or "").strip() or "叙述者"
prompt = _facts_extraction_instructions(narrator_label) + text
try:
raw = await ainvoke_json_object(
llm,
@@ -81,11 +100,24 @@ async def extract_facts_from_transcript_async(
async def extract_facts(chunk_text: str, *, user_id: str) -> list[dict]:
"""兼容旧接口:单块文本(无 chunk id 时传空 source_chunk_id"""
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_llm_provider_fast
from app.features.user.models import User
llm = get_llm_provider_fast().langchain_llm
narrator_name: str | None = None
try:
async with AsyncSessionLocal() as db:
u = await db.get(User, user_id)
if u and (u.nickname or "").strip():
narrator_name = (u.nickname or "").strip()
except Exception:
pass
blocks = f"[chunk_id=null]\n{chunk_text}"
facts = await extract_facts_from_transcript_async(llm, blocks)
facts = await extract_facts_from_transcript_async(
llm, blocks, narrator_name=narrator_name
)
for f in facts:
if f.get("source_chunk_id") in (None, "null", ""):
f["source_chunk_id"] = None

View File

@@ -10,7 +10,6 @@ from sqlalchemy import (
String,
Text,
)
from sqlalchemy.dialects.postgresql import TSVECTOR as TSVector
from sqlalchemy.orm import relationship
from app.core.db import Base, utc_now
@@ -46,8 +45,6 @@ class MemoryChunk(Base):
content = Column(Text, nullable=False)
# pgvector embedding — Alembic migration 负责 CREATE EXTENSION vector 及列类型
embedding = Column(pgvector_type, nullable=True)
# PostgreSQL FTS — Alembic migration 负责 generated tsvector 列 + GIN index
content_tsv = Column(TSVector, nullable=True)
chunk_index = Column(Integer, nullable=False)
speaker = Column(String, nullable=True)
event_year = Column(Integer, nullable=True)

View File

@@ -1,7 +1,8 @@
"""Memory repository — MemorySource, MemoryChunk, MemoryFact, TimelineEvent data access."""
import uuid
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING
from sqlalchemy import delete, literal, or_, select, text, tuple_, update
from sqlalchemy.ext.asyncio import AsyncSession
@@ -16,6 +17,9 @@ from app.features.memory.models import (
TimelineEvent,
)
if TYPE_CHECKING:
from app.ports.embedding import EmbeddingProvider
def _new_id() -> str:
return str(uuid.uuid4())
@@ -105,16 +109,6 @@ async def create_chunk(
return chunk
def update_chunk_fts_sync(session: Session, chunk_id: str) -> None:
"""Populate content_tsv for FTS (sync). Caller must commit."""
session.execute(
text(
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
),
{"id": chunk_id},
)
def update_chunk_embedding_sync(
session: Session, chunk_id: str, embedding: list[float]
) -> None:
@@ -133,39 +127,6 @@ async def update_chunk_embedding(
chunk.embedding = embedding
async def update_chunk_fts(db: AsyncSession, chunk_id: str) -> None:
"""Populate content_tsv for FTS. Caller must commit."""
await db.execute(
text(
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
),
{"id": chunk_id},
)
async def search_chunks_fts(
db: AsyncSession, user_id: str, query: str, limit: int = 20
) -> list[dict]:
"""FTS search on memory_chunks. Returns list of {id, content, chunk_index}."""
if not query or not query.strip():
return []
q = query.strip()
stmt = text("""
SELECT id, content, chunk_index
FROM memory_chunks
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
LIMIT :lim
""")
result = await db.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit})
rows = result.mappings().all()
return [
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
for r in rows
]
async def get_chunks_by_ids(
db: AsyncSession, chunk_ids: list[str]
) -> list[MemoryChunk]:
@@ -219,29 +180,6 @@ def get_timeline_events_for_user_sync(
return list(session.execute(stmt).unique().scalars().all())
def search_chunks_fts_sync(
session: Session, user_id: str, query: str, limit: int = 20
) -> list[dict]:
"""FTS on memory_chunkssyncCelery"""
if not query or not query.strip():
return []
q = query.strip()
stmt = text("""
SELECT id, content, chunk_index
FROM memory_chunks
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
LIMIT :lim
""")
result = session.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit})
rows = result.mappings().all()
return [
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
for r in rows
]
def search_facts_for_user_sync(
session: Session, user_id: str, query: str, limit: int = 20
) -> list[MemoryFact]:
@@ -401,6 +339,49 @@ async def search_chunks_vector(
]
def search_chunks_vector_sync(
session: Session, user_id: str, query_embedding: list[float], limit: int = 20
) -> list[dict]:
"""pgvector 余弦距离检索syncCelery。返回 {id, content, chunk_index, distance}。"""
if not query_embedding:
return []
stmt = text("""
SELECT id, content, chunk_index,
(embedding <=> CAST(:emb AS vector)) AS distance
FROM memory_chunks
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
AND embedding IS NOT NULL
ORDER BY embedding <=> CAST(:emb2 AS vector)
LIMIT :lim
""")
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
result = session.execute(
stmt,
{"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit},
)
rows = result.mappings().all()
return [
{
"id": r["id"],
"content": r["content"],
"chunk_index": r["chunk_index"],
"distance": float(r["distance"]),
}
for r in rows
]
def list_users_with_recent_chunks_sync(session: Session, *, hours: int) -> list[str]:
"""最近 N 小时内有新 chunk 的用户 idBeat compaction 扫描)。"""
if hours < 1:
hours = 1
cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
stmt = (
select(MemoryChunk.user_id).where(MemoryChunk.created_at >= cutoff).distinct()
)
return list(session.execute(stmt).scalars().all())
def list_summaries_for_evidence_sync(
session: Session, *, user_id: str, q: str, limit: int
) -> list[dict]:
@@ -453,17 +434,27 @@ def list_summaries_for_evidence_sync(
def retrieve_evidence_sync(
session: Session, user_id: str, query: str, *, top_k: int = 10
session: Session,
user_id: str,
query: str,
*,
top_k: int = 10,
embedding_provider: "EmbeddingProvider | None" = None,
) -> dict:
"""
Sync evidence retrieval for Celery tasks.
能力:**仅 FTS** 检索 chunks与 `HybridRetriever` 的 FTS+向量 RRF 不同,见
`api/docs/memory-retrieval.md`facts/timeline 按 query ILIKEfallback 见 repo。
chunks**向量**pgvector与异步 `HybridRetriever` 对齐facts/timeline 按 query ILIKE。
"""
from app.features.memory.evidence import retrieve_evidence_bundle_sync
return retrieve_evidence_bundle_sync(session, user_id, query, top_k=top_k)
return retrieve_evidence_bundle_sync(
session,
user_id,
query,
top_k=top_k,
embedding_provider=embedding_provider,
)
async def get_timeline_events_for_user(

View File

@@ -1,31 +1,17 @@
"""Hybrid retriever — metadata filter + FTS + vector retrieval + score fusion."""
"""Hybrid retriever — 向量检索 + 元数据证据包。"""
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.features.memory.evidence import retrieve_evidence_bundle_async
from app.features.memory.repo import search_chunks_fts, search_chunks_vector
from app.features.memory.repo import search_chunks_vector
from app.ports.embedding import EmbeddingProvider
def _rrf_merge(
fts_items: list[dict], vector_items: list[dict], k: int = 60
) -> list[dict]:
"""Reciprocal Rank Fusion. Merge FTS and vector results by id."""
scores: dict[str, float] = {}
for rank, item in enumerate(fts_items):
cid = item["id"]
scores[cid] = scores.get(cid, 0) + 1 / (k + rank + 1)
for rank, item in enumerate(vector_items):
cid = item["id"]
scores[cid] = scores.get(cid, 0) + 1 / (k + rank + 1)
all_items = {x["id"]: x for x in fts_items + vector_items}
sorted_ids = sorted(scores.keys(), key=lambda i: scores[i], reverse=True)
return [all_items[i] for i in sorted_ids]
logger = get_logger(__name__)
class HybridRetriever:
"""Combine FTS, vector, and metadata filter into evidence bundle."""
"""向量 chunk 检索 + facts/timeline/summaries/stories。"""
def __init__(
self,
@@ -51,27 +37,27 @@ class HybridRetriever:
)
q = query.strip()
fts_chunks = await search_chunks_fts(
self._db, user_id=user_id, query=query, limit=top_k * 2
)
vector_chunks: list[dict] = []
if self._embedding and q:
merged_chunk_dicts: list[dict] = []
if self._embedding:
q_emb = await self._embedding.embed_text(q)
if q_emb:
vector_chunks = await search_chunks_vector(
self._db, user_id=user_id, query_embedding=q_emb, limit=top_k * 2
vector_rows = await search_chunks_vector(
self._db, user_id, q_emb, limit=top_k
)
merged = _rrf_merge(fts_chunks, vector_chunks)[:top_k]
merged_chunk_dicts = [
{
"id": c["id"],
"content": c["content"],
"chunk_index": c.get("chunk_index", 0),
}
for c in merged
]
merged_chunk_dicts = [
{
"id": c["id"],
"content": c["content"],
"chunk_index": c.get("chunk_index", 0),
}
for c in vector_rows
]
else:
logger.warning(
"HybridRetriever empty_query_embedding user_id={}", user_id
)
else:
logger.warning("HybridRetriever no_embedding_provider user_id={}", user_id)
return await retrieve_evidence_bundle_async(
self._db,

View File

@@ -1,16 +1,14 @@
"""
MemoryService — conversation / memoir 的统一门面。
- ingest_transcript: transcript -> memory_sources, chunks, embedding, FTS
- ingest_transcript: transcript -> memory_sources, chunks, embedding
- ingest 后可选LLM 富化session/rolling 摘要、事实、时间线)
- retrieve: 委托 HybridRetriever 返回 evidence bundleFTS + 可选向量 RRF
- retrieve: 委托 HybridRetriever 返回 evidence bundle向量 chunks
Celery 侧使用 `ingest_transcript_sync` + `retrieve_evidence_sync`,与异步路径差异
Celery 侧使用 `ingest_transcript_sync` + `retrieve_evidence_sync`,与异步路径对齐
`api/docs/memory-retrieval.md`。
"""
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
@@ -23,7 +21,6 @@ from app.features.memory.repo import (
set_chunk_excluded,
set_memory_fact_status,
update_chunk_embedding,
update_chunk_fts,
)
from app.ports.embedding import EmbeddingProvider
@@ -45,7 +42,7 @@ class MemoryService:
) -> str:
"""
Ingest conversation transcript into memory.
Creates MemorySource, chunks, populates embedding + FTS.
Creates MemorySource, chunks, populates embedding.
Returns source_id.
"""
if not transcript or not transcript.strip():
@@ -73,10 +70,6 @@ class MemoryService:
await self._db.flush()
# FTS: populate content_tsv
for chunk_id, _ in chunk_records:
await update_chunk_fts(self._db, chunk_id)
# Embedding: 若有 provider 则写入
if self._embedding and chunk_records:
texts = [c for _, c in chunk_records]
@@ -186,7 +179,7 @@ def ingest_transcript_sync(
) -> str:
"""
Sync transcript ingest for Celery tasks.
Creates source + chunks + FTS, and best-effort populates embeddings.
Creates source + chunks, and best-effort populates embeddings.
Returns source_id.
"""
from app.core.dependencies import get_embedding_provider
@@ -195,7 +188,6 @@ def ingest_transcript_sync(
create_chunk_sync,
create_source_sync,
update_chunk_embedding_sync,
update_chunk_fts_sync,
)
if not transcript or not transcript.strip():
@@ -222,13 +214,12 @@ def ingest_transcript_sync(
)
session.flush()
chunk_records.append((chunk.id, content))
update_chunk_fts_sync(session, chunk.id)
try:
embedding_provider = get_embedding_provider()
if chunk_records and embedding_provider is not None:
texts = [content for _, content in chunk_records]
embeddings = asyncio.run(embedding_provider.embed_texts(texts))
embeddings = embedding_provider.embed_texts_sync(texts)
for (chunk_id, _), emb in zip(chunk_records, embeddings):
if emb:
update_chunk_embedding_sync(session, chunk_id, emb)

View File

@@ -5,6 +5,10 @@ from typing import Protocol, runtime_checkable
@runtime_checkable
class EmbeddingProvider(Protocol):
def is_available(self) -> bool:
"""进程内 embedding 已配置且可发起调用(无 key / 未初始化 client 时为 False"""
...
async def embed_text(self, text: str) -> list[float]:
"""Embed a single text into a vector."""
...
@@ -12,3 +16,11 @@ class EmbeddingProvider(Protocol):
async def embed_texts(self, texts: list[str]) -> list[list[float]]:
"""Embed multiple texts into vectors."""
...
def embed_text_sync(self, text: str) -> list[float]:
"""同步嵌入单条文本Celery / sync DB 会话)。"""
...
def embed_texts_sync(self, texts: list[str]) -> list[list[float]]:
"""同步嵌入多条文本Celery / sync DB 会话)。"""
...

View File

@@ -63,11 +63,9 @@ celery_app.conf.update(
# 不设置自定义队列路由,使用 Celery 默认队列
)
# 定时任务配置(如果需要)
celery_app.conf.beat_schedule = {
# 示例:每小时清理过期会话
# "cleanup-expired-sessions": {
# "task": "app.tasks.cleanup.cleanup_sessions",
# "schedule": 3600.0,
# },
"memory-compaction-sweep": {
"task": "app.tasks.memory_compaction_tasks.memory_compaction_sweep",
"schedule": 6 * 3600.0,
},
}

View File

@@ -15,14 +15,33 @@ from app.core.memory_compaction_schedule import (
finalize_memory_compaction_run,
read_debounce_deadline_ts,
release_scheduler_gate,
schedule_memory_compaction_run,
set_incremental_cursor_pair,
)
from app.core.redis_lock import acquire_redis_lock, release_redis_lock
from app.features.memory.compaction_service import run_memory_compaction_sync
from app.features.memory.repo import list_users_with_recent_chunks_sync
logger = get_logger(__name__)
@shared_task
def memory_compaction_sweep() -> dict[str, Any]:
"""Beat为近期有记忆写入的用户调度 compactiondebounce 仍由 schedule 合并)。"""
if not settings.memory_compaction_enabled:
return {"skipped": True, "reason": "disabled"}
hours = int(settings.memory_compaction_sweep_recent_hours)
with get_sync_db() as session:
user_ids = list_users_with_recent_chunks_sync(session, hours=hours)
ctx_base: dict[str, Any] = {"trigger_source": "beat", "sweep_hours": hours}
for uid in user_ids:
schedule_memory_compaction_run(uid, dict(ctx_base))
logger.info(
"memory_compaction_sweep hours={} scheduled_users={}", hours, len(user_ids)
)
return {"scheduled": len(user_ids), "user_ids": user_ids}
@shared_task(bind=True, max_retries=12, default_retry_delay=20)
def memory_compaction_run(
self, user_id: str, context: dict[str, Any] | None = None