聊天和回忆录证据检索都走 pgvector,去掉 Postgres FTS/content_tsv,新迁移删掉 content_tsv 列(部署要先 alembic upgrade)。
Embedding 端口增加 is_available(),聊天和回忆录日志用统一方式表示向量是否真能调用。 记忆整理(compaction)支持 Beat 定期扫用户; 事实抽取提示与 subject 归一化,减少同一人多种称呼;
This commit is contained in:
@@ -7,6 +7,9 @@ import asyncio
|
||||
from zai import ZhipuAiClient
|
||||
|
||||
from app.core.embedding import MEMORY_EMBEDDING_DIMENSION
|
||||
from app.core.logging import get_logger
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
# 单次请求最多 64 条文本(智谱 Embedding-3 文档)
|
||||
_EMBED_BATCH_SIZE = 64
|
||||
@@ -22,6 +25,9 @@ class ZhipuEmbeddingProvider:
|
||||
) -> None:
|
||||
self._model = model
|
||||
if not api_key:
|
||||
_logger.warning(
|
||||
"ZhipuEmbeddingProvider: api_key 为空,embedding 将不可用(记忆检索与 ingest 向量写入会降级)"
|
||||
)
|
||||
self._client = None
|
||||
elif base_url:
|
||||
self._client = ZhipuAiClient(
|
||||
@@ -31,6 +37,9 @@ class ZhipuEmbeddingProvider:
|
||||
else:
|
||||
self._client = ZhipuAiClient(api_key=api_key)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._client is not None
|
||||
|
||||
def _create_vectors_sync(self, texts: list[str]) -> list[list[float]]:
|
||||
assert self._client is not None
|
||||
resp = self._client.embeddings.create(
|
||||
@@ -54,3 +63,16 @@ class ZhipuEmbeddingProvider:
|
||||
part = await asyncio.to_thread(self._create_vectors_sync, batch)
|
||||
out.extend(part)
|
||||
return out
|
||||
|
||||
def embed_text_sync(self, text: str) -> list[float]:
|
||||
vecs = self.embed_texts_sync([text])
|
||||
return vecs[0] if vecs else []
|
||||
|
||||
def embed_texts_sync(self, texts: list[str]) -> list[list[float]]:
|
||||
if not self._client or not texts:
|
||||
return []
|
||||
out: list[list[float]] = []
|
||||
for i in range(0, len(texts), _EMBED_BATCH_SIZE):
|
||||
batch = texts[i : i + _EMBED_BATCH_SIZE]
|
||||
out.extend(self._create_vectors_sync(batch))
|
||||
return out
|
||||
|
||||
@@ -68,9 +68,22 @@ async def _fetch_interview_memory_evidence(
|
||||
):
|
||||
return ""
|
||||
try:
|
||||
ms = MemoryService(db, embedding_provider=get_embedding_provider())
|
||||
emb = get_embedding_provider()
|
||||
ms = MemoryService(db, embedding_provider=emb)
|
||||
bundle = await ms.retrieve(user_id, msg, top_k=settings.chat_memory_top_k)
|
||||
text = format_evidence_chunks_for_prompt(bundle.model_dump())
|
||||
bd = bundle.model_dump()
|
||||
vector_ok = emb.is_available()
|
||||
logger.info(
|
||||
"memory_evidence_retrieved user_id={} chunks={} facts={} summaries={} timeline={} stories={} vector_ok={}",
|
||||
user_id,
|
||||
len(bd.get("relevant_chunks") or []),
|
||||
len(bd.get("relevant_facts") or []),
|
||||
len(bd.get("relevant_summaries") or []),
|
||||
len(bd.get("timeline_hints") or []),
|
||||
len(bd.get("relevant_stories") or []),
|
||||
vector_ok,
|
||||
)
|
||||
text = format_evidence_chunks_for_prompt(bd)
|
||||
t = (text or "").strip()
|
||||
if not t:
|
||||
return ""
|
||||
|
||||
@@ -267,16 +267,16 @@ class Settings(BaseSettings):
|
||||
memoir_phase2_singleflight_immediate: bool = True
|
||||
|
||||
# ── Memory 检索与富化 ─────────────────────────────────────
|
||||
# True:query 为空时仍返回 rolling 摘要 + 最近事实/时间线(无 chunk FTS)
|
||||
# True:query 为空时仍返回 rolling 摘要 + 最近事实/时间线(无 chunk 向量检索)
|
||||
memory_evidence_empty_query_include_rolling: bool = False
|
||||
# False:跳过 ingest 后 LLM 富化(摘要/事实/时间线)
|
||||
memory_enrichment_enabled: bool = True
|
||||
memory_enrichment_max_chars: int = Field(default=12000, ge=1000, le=100_000)
|
||||
# True:事实 FTS 未命中时退回「最近 confirmed 事实」(易引入无关/矛盾事实;默认关)
|
||||
# True:事实 ILIKE 未命中时退回「最近 confirmed 事实」(易引入无关/矛盾事实;默认关)
|
||||
memory_fact_search_use_recent_fallback: bool = False
|
||||
|
||||
# ── Memory compaction(近重复 chunk 软排除;事件触发 + Redis 防抖 + 用户锁)──
|
||||
memory_compaction_enabled: bool = False
|
||||
# ── Memory compaction(近重复 chunk 软排除;事件触发 + Redis 防抖 + 用户锁;需 worker + Beat 跑 sweep)──
|
||||
memory_compaction_enabled: bool = True
|
||||
memory_compaction_debounce_seconds: int = Field(default=105, ge=10, le=3600)
|
||||
memory_compaction_lock_ttl_seconds: int = Field(default=600, ge=60, le=7200)
|
||||
memory_compaction_chunk_similarity_threshold: float = Field(
|
||||
@@ -288,6 +288,8 @@ class Settings(BaseSettings):
|
||||
memory_compaction_max_neighbors_per_chunk: int = Field(default=25, ge=5, le=100)
|
||||
memory_compaction_text_jaccard_min: float = Field(default=0.55, ge=0.0, le=1.0)
|
||||
memory_compaction_metadata_event_year_window: int = Field(default=1, ge=0, le=50)
|
||||
# Beat sweep:扫描最近 N 小时内有新 chunk 的用户并调度 compaction
|
||||
memory_compaction_sweep_recent_hours: int = Field(default=24, ge=1, le=168)
|
||||
|
||||
# ── Liblib ───────────────────────────────────────────────
|
||||
liblib_access_key: str = ""
|
||||
|
||||
@@ -74,7 +74,7 @@ class MemoirService:
|
||||
self._object_storage = object_storage
|
||||
|
||||
async def get_evidence(self, user_id: str, query: str, *, top_k: int = 10) -> dict:
|
||||
"""通过 MemoryService → HybridRetriever 获取证据(含向量时与 Celery 的 FTS-only 路径不同)。"""
|
||||
"""通过 MemoryService → HybridRetriever 获取证据(向量 chunks,与 Celery 叙事路径一致)。"""
|
||||
if self._memory is None:
|
||||
return {
|
||||
"relevant_chunks": [],
|
||||
|
||||
@@ -33,6 +33,7 @@ from app.agents.memoir.story_route_agent import (
|
||||
)
|
||||
from app.agents.state_schema import MemoirStateSchema
|
||||
from app.core.config import settings
|
||||
from app.core.dependencies import get_embedding_provider
|
||||
from app.core.logging import get_logger
|
||||
from app.features.memoir.cover_eligibility import chapter_needs_cover_enqueue
|
||||
from app.features.memoir.memoir_images.settings import MemoirImageSettings
|
||||
@@ -714,8 +715,16 @@ def run_story_pipeline_for_category_batch(
|
||||
top_k = int(settings.evidence_top_k_default)
|
||||
if n_units > int(settings.evidence_large_batch_threshold):
|
||||
top_k = int(settings.evidence_top_k_large_batch)
|
||||
emb = get_embedding_provider()
|
||||
embedding_available = emb.is_available()
|
||||
try:
|
||||
evidence = retrieve_evidence_sync(session, user_id, combined_text, top_k=top_k)
|
||||
evidence = retrieve_evidence_sync(
|
||||
session,
|
||||
user_id,
|
||||
combined_text,
|
||||
top_k=top_k,
|
||||
embedding_provider=emb,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Evidence 检索跳过: {}", e)
|
||||
evidence = {
|
||||
@@ -726,6 +735,16 @@ def run_story_pipeline_for_category_batch(
|
||||
"relevant_stories": [],
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"memoir_evidence_retrieved user_id={} chunks={} facts={} summaries={} stories={} vector_ok={}",
|
||||
user_id,
|
||||
len(evidence.get("relevant_chunks") or []),
|
||||
len(evidence.get("relevant_facts") or []),
|
||||
len(evidence.get("relevant_summaries") or []),
|
||||
len(evidence.get("relevant_stories") or []),
|
||||
embedding_available,
|
||||
)
|
||||
|
||||
evidence_text = format_evidence_chunks_for_prompt(evidence)
|
||||
oral_for_memoir = normalize_oral_for_memoir(combined_text, llm=llm)
|
||||
ct_raw = (combined_text or "").strip()
|
||||
|
||||
@@ -36,7 +36,12 @@ from app.features.memory.summarizer import (
|
||||
generate_session_summary_async,
|
||||
generate_session_summary_sync,
|
||||
)
|
||||
from app.features.memory.enrichment_pipeline import dedupe_key, normalize_object_json
|
||||
from app.features.memory.enrichment_pipeline import (
|
||||
dedupe_key,
|
||||
normalize_object_json,
|
||||
normalize_subject,
|
||||
)
|
||||
from app.features.user.models import User
|
||||
from app.features.memory.timeline import (
|
||||
build_timeline_events_from_facts_async,
|
||||
build_timeline_events_from_facts_sync,
|
||||
@@ -69,6 +74,10 @@ def enrich_memory_after_ingest_sync(
|
||||
llm = _resolve_llm_sync()
|
||||
if not llm:
|
||||
return
|
||||
narrator_name: str | None = None
|
||||
u_row = session.get(User, user_id)
|
||||
if u_row and (u_row.nickname or "").strip():
|
||||
narrator_name = (u_row.nickname or "").strip()
|
||||
chunks = list_chunks_for_source_sync(session, source_id)
|
||||
if not chunks:
|
||||
return
|
||||
@@ -111,11 +120,13 @@ def enrich_memory_after_ingest_sync(
|
||||
source_chunk_ids=chunk_ids,
|
||||
)
|
||||
|
||||
raw_facts = extract_facts_from_transcript_sync(llm, numbered)
|
||||
raw_facts = extract_facts_from_transcript_sync(
|
||||
llm, numbered, narrator_name=narrator_name
|
||||
)
|
||||
seen: set[tuple] = set()
|
||||
inserted: list[dict] = []
|
||||
for f in raw_facts:
|
||||
key = dedupe_key(f)
|
||||
key = dedupe_key(f, narrator_name=narrator_name)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
@@ -126,7 +137,7 @@ def enrich_memory_after_ingest_sync(
|
||||
session,
|
||||
user_id=user_id,
|
||||
fact_type=f.get("fact_type") or "event",
|
||||
subject=f.get("subject"),
|
||||
subject=normalize_subject(f.get("subject"), narrator_name),
|
||||
predicate=f.get("predicate"),
|
||||
object_json=normalize_object_json(f.get("object_json")),
|
||||
confidence=float(f.get("confidence") or 0.75),
|
||||
@@ -175,6 +186,10 @@ async def enrich_memory_after_ingest_async(
|
||||
llm = _resolve_llm_sync()
|
||||
if not llm:
|
||||
return
|
||||
narrator_name: str | None = None
|
||||
u_row = await db.get(User, user_id)
|
||||
if u_row and (u_row.nickname or "").strip():
|
||||
narrator_name = (u_row.nickname or "").strip()
|
||||
stmt = (
|
||||
select(MemoryChunk)
|
||||
.where(MemoryChunk.source_id == source_id)
|
||||
@@ -227,11 +242,13 @@ async def enrich_memory_after_ingest_async(
|
||||
source_chunk_ids=chunk_ids,
|
||||
)
|
||||
|
||||
raw_facts = await extract_facts_from_transcript_async(llm, numbered)
|
||||
raw_facts = await extract_facts_from_transcript_async(
|
||||
llm, numbered, narrator_name=narrator_name
|
||||
)
|
||||
seen: set[tuple] = set()
|
||||
inserted: list[dict] = []
|
||||
for f in raw_facts:
|
||||
key = dedupe_key(f)
|
||||
key = dedupe_key(f, narrator_name=narrator_name)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
@@ -242,7 +259,7 @@ async def enrich_memory_after_ingest_async(
|
||||
db,
|
||||
user_id=user_id,
|
||||
fact_type=f.get("fact_type") or "event",
|
||||
subject=f.get("subject"),
|
||||
subject=normalize_subject(f.get("subject"), narrator_name),
|
||||
predicate=f.get("predicate"),
|
||||
object_json=normalize_object_json(f.get("object_json")),
|
||||
confidence=float(f.get("confidence") or 0.75),
|
||||
|
||||
@@ -5,10 +5,34 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
# 叙述者常见别名 — 归一化到 narrator_name 或「叙述者」
|
||||
_NARRATOR_ALIASES: frozenset[str] = frozenset(
|
||||
{
|
||||
"我",
|
||||
"本人",
|
||||
"人物",
|
||||
"叙述者",
|
||||
"讲述者",
|
||||
"老人",
|
||||
"自己",
|
||||
"咱们",
|
||||
}
|
||||
)
|
||||
|
||||
def dedupe_key(f: dict) -> tuple:
|
||||
s = f.get("subject") or ""
|
||||
p = f.get("predicate") or ""
|
||||
|
||||
def normalize_subject(subject: str | None, narrator_name: str | None = None) -> str:
|
||||
"""将代词/泛称映射为统一 subject,便于去重与检索。"""
|
||||
s = (subject or "").strip()
|
||||
if not s:
|
||||
return narrator_name or "叙述者"
|
||||
if s in _NARRATOR_ALIASES:
|
||||
return narrator_name or "叙述者"
|
||||
return s
|
||||
|
||||
|
||||
def dedupe_key(f: dict, *, narrator_name: str | None = None) -> tuple:
|
||||
s = normalize_subject(f.get("subject"), narrator_name)
|
||||
p = (f.get("predicate") or "").strip()
|
||||
o = f.get("object_json")
|
||||
try:
|
||||
oj = json.dumps(o, sort_keys=True, ensure_ascii=False) if o is not None else ""
|
||||
|
||||
@@ -4,22 +4,24 @@
|
||||
权威层级(可靠性 hardening):
|
||||
- **Chunk 原文**(未 excluded)为首要证据;rolling 摘要/故事摘录为便利视图,不得压过冲突的 chunk。
|
||||
- **MemoryFact**:`confirmed` 为检索默认集;`candidate` 可被上游提升;`stale` 由 compaction 等标出,检索时应排除。
|
||||
- 事实 FTS 无命中时是否退回「最近事实」由 `memory_fact_search_use_recent_fallback` 控制(默认可避免串台)。
|
||||
- 事实 ILIKE 无命中时是否退回「最近事实」由 `memory_fact_search_use_recent_fallback` 控制(默认可避免串台)。
|
||||
|
||||
Celery 使用 sync;`HybridRetriever` 使用 async + RRF chunk 合并。
|
||||
Celery 使用 sync + 向量 chunks;`HybridRetriever` 使用 async + 向量 chunks。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.features.memory.repo import (
|
||||
list_summaries_for_evidence_async,
|
||||
list_summaries_for_evidence_sync,
|
||||
search_chunks_fts,
|
||||
search_chunks_fts_sync,
|
||||
search_chunks_vector_sync,
|
||||
search_facts_for_user_async,
|
||||
search_facts_for_user_sync,
|
||||
search_timeline_events_for_user_async,
|
||||
@@ -30,6 +32,11 @@ from app.features.story.repo import (
|
||||
list_recent_stories_for_evidence_sync,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.ports.embedding import EmbeddingProvider
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
EMPTY_EVIDENCE_BUNDLE: dict = {
|
||||
"relevant_chunks": [],
|
||||
"relevant_summaries": [],
|
||||
@@ -119,7 +126,7 @@ async def fetch_evidence_metadata_async(
|
||||
|
||||
|
||||
def _empty_query_bundle_sync(session: Session, user_id: str, top_k: int) -> dict:
|
||||
"""无 FTS query 时的「浏览」降级:rolling 摘要 + 事实/时间线 fallback。"""
|
||||
"""空 query 时的「浏览」降级:rolling 摘要 + 事实/时间线 fallback。"""
|
||||
from app.features.memory.models import MemorySummary
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -204,19 +211,50 @@ async def _empty_query_bundle_async(db: AsyncSession, user_id: str, top_k: int)
|
||||
|
||||
|
||||
def retrieve_evidence_bundle_sync(
|
||||
session: Session, user_id: str, query: str, *, top_k: int = 10
|
||||
session: Session,
|
||||
user_id: str,
|
||||
query: str,
|
||||
*,
|
||||
top_k: int = 10,
|
||||
embedding_provider: "EmbeddingProvider | None" = None,
|
||||
) -> dict:
|
||||
"""Celery / 叙事流水线:FTS-only chunks + 元数据。"""
|
||||
"""Celery / 叙事流水线:向量 chunks + 元数据(需 embedding_provider)。"""
|
||||
if not query or not query.strip():
|
||||
if settings.memory_evidence_empty_query_include_rolling:
|
||||
return _empty_query_bundle_sync(session, user_id, top_k)
|
||||
return dict(EMPTY_EVIDENCE_BUNDLE)
|
||||
q = query.strip()
|
||||
chunk_rows = search_chunks_fts_sync(session, user_id, q, top_k)
|
||||
relevant_chunks = [
|
||||
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
|
||||
for r in chunk_rows
|
||||
]
|
||||
relevant_chunks: list[dict] = []
|
||||
if embedding_provider is not None:
|
||||
try:
|
||||
q_emb = embedding_provider.embed_text_sync(q)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"retrieve_evidence_bundle_sync embed failed user_id={} err={}",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
q_emb = []
|
||||
if q_emb:
|
||||
chunk_rows = search_chunks_vector_sync(session, user_id, q_emb, top_k)
|
||||
relevant_chunks = [
|
||||
{
|
||||
"id": r["id"],
|
||||
"content": r["content"],
|
||||
"chunk_index": r["chunk_index"],
|
||||
}
|
||||
for r in chunk_rows
|
||||
]
|
||||
else:
|
||||
logger.warning(
|
||||
"retrieve_evidence_bundle_sync empty_query_embedding user_id={}",
|
||||
user_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"retrieve_evidence_bundle_sync no_embedding_provider user_id={}",
|
||||
user_id,
|
||||
)
|
||||
meta = fetch_evidence_metadata_sync(session, user_id, q, top_k)
|
||||
return {
|
||||
"relevant_chunks": relevant_chunks,
|
||||
@@ -233,7 +271,7 @@ async def retrieve_evidence_bundle_async(
|
||||
merged_chunk_dicts: list[dict],
|
||||
) -> dict:
|
||||
"""
|
||||
异步路径:chunk 已由调用方 RRF 合并;此处只拼元数据。
|
||||
异步路径:chunk 已由调用方(如 HybridRetriever)向量检索填入;此处只拼元数据。
|
||||
|
||||
merged_chunk_dicts: [{"id","content","chunk_index"}, ...]
|
||||
"""
|
||||
|
||||
@@ -21,18 +21,39 @@ def _max_transcript_chars() -> int:
|
||||
return settings.memory_enrichment_max_chars
|
||||
|
||||
|
||||
def extract_facts_from_transcript_sync(llm: Any, numbered_blocks: str) -> list[dict]:
|
||||
def _facts_extraction_instructions(narrator_label: str) -> str:
|
||||
return (
|
||||
"你是回忆录事实抽取助手。用户正在口述人生回忆,所有内容默认是**过去发生的事**,"
|
||||
"而非当前或未来计划(除非原文明确说「现在」「打算」「准备将要」等)。\n\n"
|
||||
"## 抽取规则\n"
|
||||
"1. subject 必须用明确的人名或固定称谓:\n"
|
||||
f" - 叙述者本人统一用「{narrator_label}」\n"
|
||||
" - 其他人用全名或稳定专名(如「王伟」),禁止用「他」「她」「我」「我们大伙」等代词作 subject;"
|
||||
"若代词在上下文中可唯一解析为某人,则 subject 写该人姓名/专名\n"
|
||||
"2. 事件、职务变动、地点迁移等一律按**过去回忆**理解;travel/调动/命令类表述勿写成「即将要做」"
|
||||
"除非原文明确为未来时态\n"
|
||||
"3. 若可推断大约年代或人生阶段,将 approximate_era 写入 object_json(与 value 等字段并存),"
|
||||
'例如 "1990年代"、"2001年"、"退休后"、"30岁前后"\n'
|
||||
"4. fact_type: person|event|relation|place|milestone\n"
|
||||
"5. predicate:简短中文谓语(如「出生地」「担任职务」「调往」)\n"
|
||||
"6. object_json:字符串或对象;可含 value、approximate_era 等\n"
|
||||
"7. confidence 0..1;source_chunk_id 必须等于某段 [chunk_id=...] 中的 id\n\n"
|
||||
'只输出 JSON:{"facts":[...]},无事实则 {"facts":[]}。\n\n'
|
||||
)
|
||||
|
||||
|
||||
def extract_facts_from_transcript_sync(
|
||||
llm: Any,
|
||||
numbered_blocks: str,
|
||||
*,
|
||||
narrator_name: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""同步:带 chunk_id 标记的文本 → 事实列表。"""
|
||||
if not llm or not (numbered_blocks or "").strip():
|
||||
return []
|
||||
text = numbered_blocks.strip()[: _max_transcript_chars()]
|
||||
prompt = (
|
||||
"你是回忆录记忆抽取助手。阅读下列带 [chunk_id=...] 的文本块,抽取可核查的事实。\n"
|
||||
"每个事实含 fact_type: person|event|relation|place|milestone;subject;predicate;"
|
||||
"object_json(可为字符串或对象);confidence 0..1;source_chunk_id 必须等于某段的 chunk id。\n"
|
||||
'只输出 JSON:{"facts":[...]},无事实则 {"facts":[]}。\n\n'
|
||||
f"{text}"
|
||||
)
|
||||
narrator_label = (narrator_name or "").strip() or "叙述者"
|
||||
prompt = _facts_extraction_instructions(narrator_label) + text
|
||||
try:
|
||||
raw = invoke_json_object(
|
||||
llm,
|
||||
@@ -50,19 +71,17 @@ def extract_facts_from_transcript_sync(llm: Any, numbered_blocks: str) -> list[d
|
||||
|
||||
|
||||
async def extract_facts_from_transcript_async(
|
||||
llm: Any, numbered_blocks: str
|
||||
llm: Any,
|
||||
numbered_blocks: str,
|
||||
*,
|
||||
narrator_name: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""异步版。"""
|
||||
if not llm or not (numbered_blocks or "").strip():
|
||||
return []
|
||||
text = numbered_blocks.strip()[: _max_transcript_chars()]
|
||||
prompt = (
|
||||
"你是回忆录记忆抽取助手。阅读下列带 [chunk_id=...] 的文本块,抽取可核查的事实。\n"
|
||||
"每个事实含 fact_type: person|event|relation|place|milestone;subject;predicate;"
|
||||
"object_json;confidence 0..1;source_chunk_id 必须等于某段的 chunk id。\n"
|
||||
'只输出 JSON:{"facts":[...]},无事实则 {"facts":[]}。\n\n'
|
||||
f"{text}"
|
||||
)
|
||||
narrator_label = (narrator_name or "").strip() or "叙述者"
|
||||
prompt = _facts_extraction_instructions(narrator_label) + text
|
||||
try:
|
||||
raw = await ainvoke_json_object(
|
||||
llm,
|
||||
@@ -81,11 +100,24 @@ async def extract_facts_from_transcript_async(
|
||||
|
||||
async def extract_facts(chunk_text: str, *, user_id: str) -> list[dict]:
|
||||
"""兼容旧接口:单块文本(无 chunk id 时传空 source_chunk_id)。"""
|
||||
from app.core.db import AsyncSessionLocal
|
||||
from app.core.dependencies import get_llm_provider_fast
|
||||
from app.features.user.models import User
|
||||
|
||||
llm = get_llm_provider_fast().langchain_llm
|
||||
narrator_name: str | None = None
|
||||
try:
|
||||
async with AsyncSessionLocal() as db:
|
||||
u = await db.get(User, user_id)
|
||||
if u and (u.nickname or "").strip():
|
||||
narrator_name = (u.nickname or "").strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
blocks = f"[chunk_id=null]\n{chunk_text}"
|
||||
facts = await extract_facts_from_transcript_async(llm, blocks)
|
||||
facts = await extract_facts_from_transcript_async(
|
||||
llm, blocks, narrator_name=narrator_name
|
||||
)
|
||||
for f in facts:
|
||||
if f.get("source_chunk_id") in (None, "null", ""):
|
||||
f["source_chunk_id"] = None
|
||||
|
||||
@@ -10,7 +10,6 @@ from sqlalchemy import (
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR as TSVector
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.core.db import Base, utc_now
|
||||
@@ -46,8 +45,6 @@ class MemoryChunk(Base):
|
||||
content = Column(Text, nullable=False)
|
||||
# pgvector embedding — Alembic migration 负责 CREATE EXTENSION vector 及列类型
|
||||
embedding = Column(pgvector_type, nullable=True)
|
||||
# PostgreSQL FTS — Alembic migration 负责 generated tsvector 列 + GIN index
|
||||
content_tsv = Column(TSVector, nullable=True)
|
||||
chunk_index = Column(Integer, nullable=False)
|
||||
speaker = Column(String, nullable=True)
|
||||
event_year = Column(Integer, nullable=True)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Memory repository — MemorySource, MemoryChunk, MemoryFact, TimelineEvent data access."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import delete, literal, or_, select, text, tuple_, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -16,6 +17,9 @@ from app.features.memory.models import (
|
||||
TimelineEvent,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.ports.embedding import EmbeddingProvider
|
||||
|
||||
|
||||
def _new_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
@@ -105,16 +109,6 @@ async def create_chunk(
|
||||
return chunk
|
||||
|
||||
|
||||
def update_chunk_fts_sync(session: Session, chunk_id: str) -> None:
|
||||
"""Populate content_tsv for FTS (sync). Caller must commit."""
|
||||
session.execute(
|
||||
text(
|
||||
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
|
||||
),
|
||||
{"id": chunk_id},
|
||||
)
|
||||
|
||||
|
||||
def update_chunk_embedding_sync(
|
||||
session: Session, chunk_id: str, embedding: list[float]
|
||||
) -> None:
|
||||
@@ -133,39 +127,6 @@ async def update_chunk_embedding(
|
||||
chunk.embedding = embedding
|
||||
|
||||
|
||||
async def update_chunk_fts(db: AsyncSession, chunk_id: str) -> None:
|
||||
"""Populate content_tsv for FTS. Caller must commit."""
|
||||
await db.execute(
|
||||
text(
|
||||
"UPDATE memory_chunks SET content_tsv = to_tsvector('simple', content) WHERE id = :id"
|
||||
),
|
||||
{"id": chunk_id},
|
||||
)
|
||||
|
||||
|
||||
async def search_chunks_fts(
|
||||
db: AsyncSession, user_id: str, query: str, limit: int = 20
|
||||
) -> list[dict]:
|
||||
"""FTS search on memory_chunks. Returns list of {id, content, chunk_index}."""
|
||||
if not query or not query.strip():
|
||||
return []
|
||||
q = query.strip()
|
||||
stmt = text("""
|
||||
SELECT id, content, chunk_index
|
||||
FROM memory_chunks
|
||||
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
|
||||
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
|
||||
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
|
||||
LIMIT :lim
|
||||
""")
|
||||
result = await db.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit})
|
||||
rows = result.mappings().all()
|
||||
return [
|
||||
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
async def get_chunks_by_ids(
|
||||
db: AsyncSession, chunk_ids: list[str]
|
||||
) -> list[MemoryChunk]:
|
||||
@@ -219,29 +180,6 @@ def get_timeline_events_for_user_sync(
|
||||
return list(session.execute(stmt).unique().scalars().all())
|
||||
|
||||
|
||||
def search_chunks_fts_sync(
|
||||
session: Session, user_id: str, query: str, limit: int = 20
|
||||
) -> list[dict]:
|
||||
"""FTS on memory_chunks(sync,Celery)。"""
|
||||
if not query or not query.strip():
|
||||
return []
|
||||
q = query.strip()
|
||||
stmt = text("""
|
||||
SELECT id, content, chunk_index
|
||||
FROM memory_chunks
|
||||
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
|
||||
AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q)
|
||||
ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC
|
||||
LIMIT :lim
|
||||
""")
|
||||
result = session.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit})
|
||||
rows = result.mappings().all()
|
||||
return [
|
||||
{"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def search_facts_for_user_sync(
|
||||
session: Session, user_id: str, query: str, limit: int = 20
|
||||
) -> list[MemoryFact]:
|
||||
@@ -401,6 +339,49 @@ async def search_chunks_vector(
|
||||
]
|
||||
|
||||
|
||||
def search_chunks_vector_sync(
|
||||
session: Session, user_id: str, query_embedding: list[float], limit: int = 20
|
||||
) -> list[dict]:
|
||||
"""pgvector 余弦距离检索(sync,Celery)。返回 {id, content, chunk_index, distance}。"""
|
||||
if not query_embedding:
|
||||
return []
|
||||
stmt = text("""
|
||||
SELECT id, content, chunk_index,
|
||||
(embedding <=> CAST(:emb AS vector)) AS distance
|
||||
FROM memory_chunks
|
||||
WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false)
|
||||
AND embedding IS NOT NULL
|
||||
ORDER BY embedding <=> CAST(:emb2 AS vector)
|
||||
LIMIT :lim
|
||||
""")
|
||||
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
|
||||
result = session.execute(
|
||||
stmt,
|
||||
{"user_id": user_id, "emb": emb_str, "emb2": emb_str, "lim": limit},
|
||||
)
|
||||
rows = result.mappings().all()
|
||||
return [
|
||||
{
|
||||
"id": r["id"],
|
||||
"content": r["content"],
|
||||
"chunk_index": r["chunk_index"],
|
||||
"distance": float(r["distance"]),
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def list_users_with_recent_chunks_sync(session: Session, *, hours: int) -> list[str]:
|
||||
"""最近 N 小时内有新 chunk 的用户 id(Beat compaction 扫描)。"""
|
||||
if hours < 1:
|
||||
hours = 1
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
stmt = (
|
||||
select(MemoryChunk.user_id).where(MemoryChunk.created_at >= cutoff).distinct()
|
||||
)
|
||||
return list(session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
def list_summaries_for_evidence_sync(
|
||||
session: Session, *, user_id: str, q: str, limit: int
|
||||
) -> list[dict]:
|
||||
@@ -453,17 +434,27 @@ def list_summaries_for_evidence_sync(
|
||||
|
||||
|
||||
def retrieve_evidence_sync(
|
||||
session: Session, user_id: str, query: str, *, top_k: int = 10
|
||||
session: Session,
|
||||
user_id: str,
|
||||
query: str,
|
||||
*,
|
||||
top_k: int = 10,
|
||||
embedding_provider: "EmbeddingProvider | None" = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Sync evidence retrieval for Celery tasks.
|
||||
|
||||
能力:**仅 FTS** 检索 chunks(与 `HybridRetriever` 的 FTS+向量 RRF 不同,见
|
||||
`api/docs/memory-retrieval.md`);facts/timeline 按 query ILIKE;fallback 见 repo。
|
||||
chunks:**向量**(pgvector)与异步 `HybridRetriever` 对齐;facts/timeline 按 query ILIKE。
|
||||
"""
|
||||
from app.features.memory.evidence import retrieve_evidence_bundle_sync
|
||||
|
||||
return retrieve_evidence_bundle_sync(session, user_id, query, top_k=top_k)
|
||||
return retrieve_evidence_bundle_sync(
|
||||
session,
|
||||
user_id,
|
||||
query,
|
||||
top_k=top_k,
|
||||
embedding_provider=embedding_provider,
|
||||
)
|
||||
|
||||
|
||||
async def get_timeline_events_for_user(
|
||||
|
||||
@@ -1,31 +1,17 @@
|
||||
"""Hybrid retriever — metadata filter + FTS + vector retrieval + score fusion."""
|
||||
"""Hybrid retriever — 向量检索 + 元数据证据包。"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.features.memory.evidence import retrieve_evidence_bundle_async
|
||||
from app.features.memory.repo import search_chunks_fts, search_chunks_vector
|
||||
from app.features.memory.repo import search_chunks_vector
|
||||
from app.ports.embedding import EmbeddingProvider
|
||||
|
||||
|
||||
def _rrf_merge(
|
||||
fts_items: list[dict], vector_items: list[dict], k: int = 60
|
||||
) -> list[dict]:
|
||||
"""Reciprocal Rank Fusion. Merge FTS and vector results by id."""
|
||||
scores: dict[str, float] = {}
|
||||
for rank, item in enumerate(fts_items):
|
||||
cid = item["id"]
|
||||
scores[cid] = scores.get(cid, 0) + 1 / (k + rank + 1)
|
||||
for rank, item in enumerate(vector_items):
|
||||
cid = item["id"]
|
||||
scores[cid] = scores.get(cid, 0) + 1 / (k + rank + 1)
|
||||
|
||||
all_items = {x["id"]: x for x in fts_items + vector_items}
|
||||
sorted_ids = sorted(scores.keys(), key=lambda i: scores[i], reverse=True)
|
||||
return [all_items[i] for i in sorted_ids]
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HybridRetriever:
|
||||
"""Combine FTS, vector, and metadata filter into evidence bundle."""
|
||||
"""向量 chunk 检索 + facts/timeline/summaries/stories。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -51,27 +37,27 @@ class HybridRetriever:
|
||||
)
|
||||
|
||||
q = query.strip()
|
||||
fts_chunks = await search_chunks_fts(
|
||||
self._db, user_id=user_id, query=query, limit=top_k * 2
|
||||
)
|
||||
|
||||
vector_chunks: list[dict] = []
|
||||
if self._embedding and q:
|
||||
merged_chunk_dicts: list[dict] = []
|
||||
if self._embedding:
|
||||
q_emb = await self._embedding.embed_text(q)
|
||||
if q_emb:
|
||||
vector_chunks = await search_chunks_vector(
|
||||
self._db, user_id=user_id, query_embedding=q_emb, limit=top_k * 2
|
||||
vector_rows = await search_chunks_vector(
|
||||
self._db, user_id, q_emb, limit=top_k
|
||||
)
|
||||
|
||||
merged = _rrf_merge(fts_chunks, vector_chunks)[:top_k]
|
||||
merged_chunk_dicts = [
|
||||
{
|
||||
"id": c["id"],
|
||||
"content": c["content"],
|
||||
"chunk_index": c.get("chunk_index", 0),
|
||||
}
|
||||
for c in merged
|
||||
]
|
||||
merged_chunk_dicts = [
|
||||
{
|
||||
"id": c["id"],
|
||||
"content": c["content"],
|
||||
"chunk_index": c.get("chunk_index", 0),
|
||||
}
|
||||
for c in vector_rows
|
||||
]
|
||||
else:
|
||||
logger.warning(
|
||||
"HybridRetriever empty_query_embedding user_id={}", user_id
|
||||
)
|
||||
else:
|
||||
logger.warning("HybridRetriever no_embedding_provider user_id={}", user_id)
|
||||
|
||||
return await retrieve_evidence_bundle_async(
|
||||
self._db,
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
"""
|
||||
MemoryService — conversation / memoir 的统一门面。
|
||||
|
||||
- ingest_transcript: transcript -> memory_sources, chunks, embedding, FTS
|
||||
- ingest_transcript: transcript -> memory_sources, chunks, embedding
|
||||
- ingest 后可选:LLM 富化(session/rolling 摘要、事实、时间线)
|
||||
- retrieve: 委托 HybridRetriever 返回 evidence bundle(FTS + 可选向量 RRF)
|
||||
- retrieve: 委托 HybridRetriever 返回 evidence bundle(向量 chunks)
|
||||
|
||||
Celery 侧使用 `ingest_transcript_sync` + `retrieve_evidence_sync`,与异步路径差异见
|
||||
Celery 侧使用 `ingest_transcript_sync` + `retrieve_evidence_sync`,与异步路径对齐见
|
||||
`api/docs/memory-retrieval.md`。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
@@ -23,7 +21,6 @@ from app.features.memory.repo import (
|
||||
set_chunk_excluded,
|
||||
set_memory_fact_status,
|
||||
update_chunk_embedding,
|
||||
update_chunk_fts,
|
||||
)
|
||||
from app.ports.embedding import EmbeddingProvider
|
||||
|
||||
@@ -45,7 +42,7 @@ class MemoryService:
|
||||
) -> str:
|
||||
"""
|
||||
Ingest conversation transcript into memory.
|
||||
Creates MemorySource, chunks, populates embedding + FTS.
|
||||
Creates MemorySource, chunks, populates embedding.
|
||||
Returns source_id.
|
||||
"""
|
||||
if not transcript or not transcript.strip():
|
||||
@@ -73,10 +70,6 @@ class MemoryService:
|
||||
|
||||
await self._db.flush()
|
||||
|
||||
# FTS: populate content_tsv
|
||||
for chunk_id, _ in chunk_records:
|
||||
await update_chunk_fts(self._db, chunk_id)
|
||||
|
||||
# Embedding: 若有 provider 则写入
|
||||
if self._embedding and chunk_records:
|
||||
texts = [c for _, c in chunk_records]
|
||||
@@ -186,7 +179,7 @@ def ingest_transcript_sync(
|
||||
) -> str:
|
||||
"""
|
||||
Sync transcript ingest for Celery tasks.
|
||||
Creates source + chunks + FTS, and best-effort populates embeddings.
|
||||
Creates source + chunks, and best-effort populates embeddings.
|
||||
Returns source_id.
|
||||
"""
|
||||
from app.core.dependencies import get_embedding_provider
|
||||
@@ -195,7 +188,6 @@ def ingest_transcript_sync(
|
||||
create_chunk_sync,
|
||||
create_source_sync,
|
||||
update_chunk_embedding_sync,
|
||||
update_chunk_fts_sync,
|
||||
)
|
||||
|
||||
if not transcript or not transcript.strip():
|
||||
@@ -222,13 +214,12 @@ def ingest_transcript_sync(
|
||||
)
|
||||
session.flush()
|
||||
chunk_records.append((chunk.id, content))
|
||||
update_chunk_fts_sync(session, chunk.id)
|
||||
|
||||
try:
|
||||
embedding_provider = get_embedding_provider()
|
||||
if chunk_records and embedding_provider is not None:
|
||||
texts = [content for _, content in chunk_records]
|
||||
embeddings = asyncio.run(embedding_provider.embed_texts(texts))
|
||||
embeddings = embedding_provider.embed_texts_sync(texts)
|
||||
for (chunk_id, _), emb in zip(chunk_records, embeddings):
|
||||
if emb:
|
||||
update_chunk_embedding_sync(session, chunk_id, emb)
|
||||
|
||||
@@ -5,6 +5,10 @@ from typing import Protocol, runtime_checkable
|
||||
|
||||
@runtime_checkable
|
||||
class EmbeddingProvider(Protocol):
|
||||
def is_available(self) -> bool:
|
||||
"""进程内 embedding 已配置且可发起调用(无 key / 未初始化 client 时为 False)。"""
|
||||
...
|
||||
|
||||
async def embed_text(self, text: str) -> list[float]:
|
||||
"""Embed a single text into a vector."""
|
||||
...
|
||||
@@ -12,3 +16,11 @@ class EmbeddingProvider(Protocol):
|
||||
async def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed multiple texts into vectors."""
|
||||
...
|
||||
|
||||
def embed_text_sync(self, text: str) -> list[float]:
|
||||
"""同步嵌入单条文本(Celery / sync DB 会话)。"""
|
||||
...
|
||||
|
||||
def embed_texts_sync(self, texts: list[str]) -> list[list[float]]:
|
||||
"""同步嵌入多条文本(Celery / sync DB 会话)。"""
|
||||
...
|
||||
|
||||
@@ -63,11 +63,9 @@ celery_app.conf.update(
|
||||
# 不设置自定义队列路由,使用 Celery 默认队列
|
||||
)
|
||||
|
||||
# 定时任务配置(如果需要)
|
||||
celery_app.conf.beat_schedule = {
|
||||
# 示例:每小时清理过期会话
|
||||
# "cleanup-expired-sessions": {
|
||||
# "task": "app.tasks.cleanup.cleanup_sessions",
|
||||
# "schedule": 3600.0,
|
||||
# },
|
||||
"memory-compaction-sweep": {
|
||||
"task": "app.tasks.memory_compaction_tasks.memory_compaction_sweep",
|
||||
"schedule": 6 * 3600.0,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -15,14 +15,33 @@ from app.core.memory_compaction_schedule import (
|
||||
finalize_memory_compaction_run,
|
||||
read_debounce_deadline_ts,
|
||||
release_scheduler_gate,
|
||||
schedule_memory_compaction_run,
|
||||
set_incremental_cursor_pair,
|
||||
)
|
||||
from app.core.redis_lock import acquire_redis_lock, release_redis_lock
|
||||
from app.features.memory.compaction_service import run_memory_compaction_sync
|
||||
from app.features.memory.repo import list_users_with_recent_chunks_sync
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@shared_task
|
||||
def memory_compaction_sweep() -> dict[str, Any]:
|
||||
"""Beat:为近期有记忆写入的用户调度 compaction(debounce 仍由 schedule 合并)。"""
|
||||
if not settings.memory_compaction_enabled:
|
||||
return {"skipped": True, "reason": "disabled"}
|
||||
hours = int(settings.memory_compaction_sweep_recent_hours)
|
||||
with get_sync_db() as session:
|
||||
user_ids = list_users_with_recent_chunks_sync(session, hours=hours)
|
||||
ctx_base: dict[str, Any] = {"trigger_source": "beat", "sweep_hours": hours}
|
||||
for uid in user_ids:
|
||||
schedule_memory_compaction_run(uid, dict(ctx_base))
|
||||
logger.info(
|
||||
"memory_compaction_sweep hours={} scheduled_users={}", hours, len(user_ids)
|
||||
)
|
||||
return {"scheduled": len(user_ids), "user_ids": user_ids}
|
||||
|
||||
|
||||
@shared_task(bind=True, max_retries=12, default_retry_delay=20)
|
||||
def memory_compaction_run(
|
||||
self, user_id: str, context: dict[str, Any] | None = None
|
||||
|
||||
Reference in New Issue
Block a user