Files
life-echo/api/app/features/memory/compaction_service.py
Kevin e884409410 feat(api): Memory compaction 管线与调度修复,同步环境变量示例
Memory compaction(近重复 chunk 软排除)
- 新增 compaction 调度:Redis debounce、scheduler gate、增量游标;任务结束时 finalize,避免 gate 长期占用并处理运行期新 trigger。
- Celery memory_compaction_run:debounce 未到点则 retry;用户级 Redis 锁;成功路径更新游标并 finalize;异常时释放 scheduler gate 并 self.retry,避免静默卡死调度与瞬时失败不重试。
- compaction_service:多层判定 + canonical 打分;无 embedding 时停止前移游标(awaiting_embeddings);curation details 补全 trigger 等上下文。
- ingest_transcript_sync:同步路径尽力写入 embedding,与异步 ingest 行为对齐,避免 compaction 永远扫不到无向量 chunk。
- repo:新增 update_chunk_embedding_sync。
测试
- 扩展 test_memory_compaction:调度合并、finalize、ingest embedding、无向量游标、异常路径 gate+retry 等回归用
2026-03-30 10:46:35 +08:00

489 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Memory compaction增量 chunk 近重复检测软排除is_excluded + MemoryCurationAction
仅依赖 repo / settings供 Celery 同步任务调用。
"""
from __future__ import annotations
import re
import time
from datetime import datetime, timezone
from typing import Any
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.logging import get_logger
from app.features.memory.models import MemoryChunk, MemorySource
from app.features.memory.repo import (
create_curation_action_sync,
get_first_chunk_after_cursor_sync,
get_memory_chunk_sync,
list_incremental_chunks_for_compaction_sync,
search_nearest_chunks_for_compaction_sync,
set_chunk_excluded_sync,
)
logger = get_logger(__name__)
_WS_RE = re.compile(r"\s+")
_PUNCT_RE = re.compile(r"[^\w\s\u4e00-\u9fff]", re.UNICODE)
def _normalize_text(s: str) -> str:
t = (s or "").strip().lower()
t = _PUNCT_RE.sub(" ", t)
return _WS_RE.sub(" ", t).strip()
def _word_jaccard(a: str, b: str) -> float:
wa = set(_normalize_text(a).split())
wb = set(_normalize_text(b).split())
if not wa or not wb:
return 0.0
return len(wa & wb) / len(wa | wb)
def text_layer_match(a: str, b: str, *, jaccard_min: float) -> bool:
if _word_jaccard(a, b) >= jaccard_min:
return True
na, nb = _normalize_text(a), _normalize_text(b)
if len(na) >= 12 and len(nb) >= 12 and (na in nb or nb in na):
return True
return False
def metadata_layer_match(
c: MemoryChunk,
n: dict[str, Any],
*,
event_year_window: int,
) -> bool:
if c.source_id == n.get("source_id"):
return True
cy, ny = c.event_year, n.get("event_year")
if cy is not None and ny is not None:
if abs(int(cy) - int(ny)) <= event_year_window:
return True
return False
def embedding_layer_match(distance: float, *, similarity_threshold: float) -> bool:
sim = 1.0 - distance
return sim >= similarity_threshold
def canonical_score(
*,
content: str,
metadata_json: Any,
source_type: str | None,
) -> float:
score = float(len(content or ""))
bonus = 0.0
if source_type in ("draft", "story"):
bonus += 8.0
if isinstance(metadata_json, dict):
bonus += min(len(metadata_json), 24) * 0.15
return score + bonus
def _source_type_for_chunk(session: Session, chunk: MemoryChunk) -> str | None:
src = session.get(MemorySource, chunk.source_id)
return src.source_type if src else None
def _build_curation_details(
*,
ctx: dict[str, Any],
winner_id: str,
layers: int,
distance: float,
) -> dict[str, Any]:
details: dict[str, Any] = {
"reason": "near_duplicate",
"similar_to_chunk_id": winner_id,
"layers": layers,
"cosine_distance": distance,
}
for key in (
"trigger_source",
"trigger_time",
"pipeline_run_id",
"request_id",
"story_ids",
"story_dispatch_ids",
"chapter_ids",
"recomposed_chapter_ids",
"chapters_to_enqueue",
"candidate_chunk_ids",
"candidate_source_ids",
):
value = ctx.get(key)
if value is not None:
details[key] = value
return details
def count_duplicate_layers(
*,
chunk: MemoryChunk,
neighbor: dict[str, Any],
distance: float,
similarity_threshold: float,
jaccard_min: float,
event_year_window: int,
) -> int:
n = 0
if embedding_layer_match(distance, similarity_threshold=similarity_threshold):
n += 1
if text_layer_match(
chunk.content or "", neighbor.get("content") or "", jaccard_min=jaccard_min
):
n += 1
if metadata_layer_match(chunk, neighbor, event_year_window=event_year_window):
n += 1
return n
def _advance_cursor_past_excluded_only_sync(
session: Session,
user_id: str,
cursor_ts: datetime,
cursor_id: str,
*,
max_steps: int,
) -> tuple[datetime, str] | None:
"""
当「待处理增量」为空但游标之后仍有 chunk 时,通常是因为连续 excluded。
将游标推进到这些 excluded 的最后一条,使后续运行能扫到未 excluded 的 chunk。
若游标之后没有任何 chunk返回 None。
若游标之后第一条即非 excluded返回 None不应与空增量并存由调用方打日志
"""
advanced = False
ts, cid = cursor_ts, cursor_id
for _ in range(max_steps):
nxt = get_first_chunk_after_cursor_sync(
session,
user_id=user_id,
after_cursor_ts=ts,
after_chunk_id=cid,
)
if nxt is None:
return (ts, cid) if advanced else None
ex = bool(nxt.is_excluded)
if not ex:
if advanced:
return (ts, cid)
logger.warning(
"memory_compaction_cursor_anomaly user_id={} cursor_ts={} cursor_id={} "
"next_chunk_id={} next_is_excluded=false",
user_id,
cursor_ts,
cursor_id,
nxt.id,
)
return None
nxt_ts = nxt.created_at or datetime.min.replace(tzinfo=timezone.utc)
if nxt_ts.tzinfo is None:
nxt_ts = nxt_ts.replace(tzinfo=timezone.utc)
ts, cid = nxt_ts, nxt.id
advanced = True
logger.warning(
"memory_compaction_cursor_skip_cap user_id={} steps={}",
user_id,
max_steps,
)
return (ts, cid) if advanced else None
def run_memory_compaction_sync(
session: Session, user_id: str, context: dict[str, Any] | None
) -> dict[str, Any]:
"""
对增量 chunk 做近重复软排除;调用方负责 commit。
Returns:
结构化结果,供日志与任务返回值使用。
"""
ctx = dict(context or {})
t0 = time.perf_counter()
pair = ctx.get("_cursor_pair_override")
if pair is None:
from app.core.memory_compaction_schedule import get_incremental_cursor_pair
cursor_ts, cursor_id = get_incremental_cursor_pair(user_id)
else:
cursor_ts, cursor_id = pair # type: ignore[misc]
cand_chunks = ctx.get("candidate_chunk_ids")
if cand_chunks is not None:
cand_chunks = [str(x) for x in cand_chunks]
cand_sources = ctx.get("candidate_source_ids")
if cand_sources is not None:
cand_sources = [str(x) for x in cand_sources]
has_candidate_filter = "candidate_chunk_ids" in ctx or "candidate_source_ids" in ctx
max_chunks = settings.memory_compaction_max_chunks_per_run
incremental = list_incremental_chunks_for_compaction_sync(
session,
user_id=user_id,
after_cursor_ts=cursor_ts,
after_chunk_id=cursor_id,
limit=max_chunks,
candidate_chunk_ids=cand_chunks,
candidate_source_ids=cand_sources,
)
if not incremental:
ms = (time.perf_counter() - t0) * 1000
# 候选 id/source 收窄时,空集可能仅表示「交集为空」,不能盲目前进全局游标
if not has_candidate_filter:
tail = _advance_cursor_past_excluded_only_sync(
session,
user_id,
cursor_ts,
cursor_id,
max_steps=max(
settings.memory_compaction_max_chunks_per_run * 2,
500,
),
)
if tail is not None:
new_ts, new_id = tail
logger.info(
"memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 "
"duration_ms={:.1f} skipped_reason=empty_incremental_cursor_advanced "
"trigger_source={}",
user_id,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": 0,
"chunks_excluded": 0,
"candidates_considered": 0,
"new_cursor_ts": new_ts.isoformat(),
"new_cursor_id": new_id,
"duration_ms": round(ms, 1),
"skipped_reason": "empty_incremental_cursor_advanced",
}
logger.info(
"memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 "
"duration_ms={:.1f} skipped_reason=empty_incremental trigger_source={}",
user_id,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": 0,
"chunks_excluded": 0,
"candidates_considered": 0,
"new_cursor_ts": None,
"new_cursor_id": None,
"duration_ms": round(ms, 1),
"skipped_reason": "empty_incremental",
}
sim_th = settings.memory_compaction_chunk_similarity_threshold
min_layers = settings.memory_compaction_min_layers_for_exclude
jaccard_min = settings.memory_compaction_text_jaccard_min
year_w = settings.memory_compaction_metadata_event_year_window
max_neighbors = settings.memory_compaction_max_neighbors_per_chunk
max_excludes = settings.memory_compaction_max_excludes_per_run
local_excluded: set[str] = set()
excludes_done = 0
candidates_considered = 0
# 稳定顺序:先处理较早写入的 chunk
incremental_sorted = sorted(
incremental,
key=lambda c: (c.created_at or datetime.min.replace(tzinfo=timezone.utc), c.id),
)
last_cursor_chunk: MemoryChunk | None = None
chunks_scanned_this_run = 0
for chunk in incremental_sorted:
if excludes_done >= max_excludes:
break
if chunk.id in local_excluded:
last_cursor_chunk = chunk
continue
row = get_memory_chunk_sync(session, chunk.id, user_id)
if row is None or row.is_excluded:
last_cursor_chunk = chunk
continue
emb = row.embedding
if emb is None:
ms = (time.perf_counter() - t0) * 1000
logger.info(
"memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} "
"candidates={} duration_ms={:.1f} skipped_reason=awaiting_embeddings "
"pending_chunk_id={} trigger_source={}",
user_id,
chunks_scanned_this_run,
excludes_done,
candidates_considered,
ms,
row.id,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": chunks_scanned_this_run,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": (
last_cursor_chunk.created_at.isoformat()
if last_cursor_chunk and last_cursor_chunk.created_at
else None
),
"new_cursor_id": last_cursor_chunk.id if last_cursor_chunk else None,
"duration_ms": round(ms, 1),
"skipped_reason": "awaiting_embeddings",
"pending_chunk_id": row.id,
}
chunks_scanned_this_run += 1
st_c = _source_type_for_chunk(session, row)
neighbors = search_nearest_chunks_for_compaction_sync(
session,
user_id=user_id,
chunk_id=row.id,
query_embedding=list(emb),
limit=max_neighbors,
)
for nb in neighbors:
if excludes_done >= max_excludes:
break
nid = nb["id"]
if nid == row.id or nid in local_excluded:
continue
other = get_memory_chunk_sync(session, nid, user_id)
if other is None or other.is_excluded:
continue
dist = float(nb["distance"])
layers = count_duplicate_layers(
chunk=row,
neighbor=nb,
distance=dist,
similarity_threshold=sim_th,
jaccard_min=jaccard_min,
event_year_window=year_w,
)
if layers < min_layers:
continue
candidates_considered += 1
sc = canonical_score(
content=row.content or "",
metadata_json=row.metadata_json,
source_type=st_c,
)
sn = canonical_score(
content=nb.get("content") or "",
metadata_json=nb.get("metadata_json"),
source_type=nb.get("source_type"),
)
t_row = row.created_at or datetime.min.replace(tzinfo=timezone.utc)
t_nb = nb.get("created_at") or datetime.min.replace(tzinfo=timezone.utc)
if t_row.tzinfo is None:
t_row = t_row.replace(tzinfo=timezone.utc)
if t_nb.tzinfo is None:
t_nb = t_nb.replace(tzinfo=timezone.utc)
if sc > sn:
loser_id, winner_id = nid, row.id
elif sn > sc:
loser_id, winner_id = row.id, nid
else:
# 同分:保留更早写入
if t_row <= t_nb:
loser_id, winner_id = nid, row.id
else:
loser_id, winner_id = row.id, nid
loser = get_memory_chunk_sync(session, loser_id, user_id)
if loser is None or loser.is_excluded:
continue
ok = set_chunk_excluded_sync(session, loser_id, user_id, True)
if not ok:
continue
create_curation_action_sync(
session,
user_id=user_id,
action_type="exclude",
target_type="chunk",
target_id=loser_id,
details=_build_curation_details(
ctx=ctx,
winner_id=winner_id,
layers=layers,
distance=dist,
),
)
excludes_done += 1
local_excluded.add(loser_id)
logger.info(
"memory_compaction_exclude user_id={} excluded_chunk_id={} "
"kept_chunk_id={} layers={} distance={:.4f}",
user_id,
loser_id,
winner_id,
layers,
dist,
)
if loser_id == row.id:
break
last_cursor_chunk = chunk
if last_cursor_chunk is None:
ms = (time.perf_counter() - t0) * 1000
logger.warning(
"memory_compaction_no_cursor_chunk user_id={} incremental_n={}",
user_id,
len(incremental_sorted),
)
return {
"chunks_scanned": 0,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": None,
"new_cursor_id": None,
"duration_ms": round(ms, 1),
"skipped_reason": "no_cursor_chunk",
}
new_cursor_ts = last_cursor_chunk.created_at or datetime.min.replace(
tzinfo=timezone.utc
)
if new_cursor_ts.tzinfo is None:
new_cursor_ts = new_cursor_ts.replace(tzinfo=timezone.utc)
new_cursor_id = last_cursor_chunk.id
ms = (time.perf_counter() - t0) * 1000
logger.info(
"memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} "
"candidates={} duration_ms={:.1f} trigger_source={}",
user_id,
chunks_scanned_this_run,
excludes_done,
candidates_considered,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": chunks_scanned_this_run,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": new_cursor_ts.isoformat(),
"new_cursor_id": new_cursor_id,
"duration_ms": round(ms, 1),
"skipped_reason": None,
}