Files
life-echo/api/app/features/memory/compaction_service.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

502 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Memory compaction增量 chunk 近重复检测软排除is_excluded + MemoryCurationAction
仅依赖 repo / settings供 async MemoryService 调用。
"""
from __future__ import annotations
import re
import time
from datetime import datetime, timezone
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.features.memory.models import MemoryChunk, MemorySource
from app.features.memory.constants import memory
from app.features.memory.repo import (
create_curation_action,
get_first_chunk_after_cursor,
get_memory_chunk_for_user,
list_incremental_chunks_for_compaction,
mark_facts_stale_for_excluded_chunk,
search_nearest_chunks_for_compaction,
set_chunk_excluded,
)
logger = get_logger(__name__)
_WS_RE = re.compile(r"\s+")
_PUNCT_RE = re.compile(r"[^\w\s\u4e00-\u9fff]", re.UNICODE)
def _normalize_text(s: str) -> str:
t = (s or "").strip().lower()
t = _PUNCT_RE.sub(" ", t)
return _WS_RE.sub(" ", t).strip()
def _word_jaccard(a: str, b: str) -> float:
wa = set(_normalize_text(a).split())
wb = set(_normalize_text(b).split())
if not wa or not wb:
return 0.0
return len(wa & wb) / len(wa | wb)
def text_layer_match(a: str, b: str, *, jaccard_min: float) -> bool:
if _word_jaccard(a, b) >= jaccard_min:
return True
na, nb = _normalize_text(a), _normalize_text(b)
if len(na) >= 12 and len(nb) >= 12 and (na in nb or nb in na):
return True
return False
def metadata_layer_match(
c: MemoryChunk,
n: dict[str, Any],
*,
event_year_window: int,
) -> bool:
if c.source_id == n.get("source_id"):
return True
cy, ny = c.event_year, n.get("event_year")
if cy is not None and ny is not None:
if abs(int(cy) - int(ny)) <= event_year_window:
return True
return False
def embedding_layer_match(distance: float, *, similarity_threshold: float) -> bool:
sim = 1.0 - distance
return sim >= similarity_threshold
def canonical_score(
*,
content: str,
metadata_json: Any,
source_type: str | None,
) -> float:
score = float(len(content or ""))
bonus = 0.0
if source_type in ("draft", "story"):
bonus += 8.0
if isinstance(metadata_json, dict):
bonus += min(len(metadata_json), 24) * 0.15
return score + bonus
async def _source_type_for_chunk(db: AsyncSession, chunk: MemoryChunk) -> str | None:
src = await db.get(MemorySource, chunk.source_id)
return src.source_type if src else None
def _build_curation_details(
*,
ctx: dict[str, Any],
winner_id: str,
layers: int,
distance: float,
) -> dict[str, Any]:
details: dict[str, Any] = {
"reason": "near_duplicate",
"similar_to_chunk_id": winner_id,
"layers": layers,
"cosine_distance": distance,
}
for key in (
"trigger_source",
"trigger_time",
"pipeline_run_id",
"memoir_correlation_id",
"request_id",
"story_ids",
"story_dispatch_ids",
"chapter_ids",
"recomposed_chapter_ids",
"chapters_to_enqueue",
"candidate_chunk_ids",
"candidate_source_ids",
):
value = ctx.get(key)
if value is not None:
details[key] = value
return details
def count_duplicate_layers(
*,
chunk: MemoryChunk,
neighbor: dict[str, Any],
distance: float,
similarity_threshold: float,
jaccard_min: float,
event_year_window: int,
) -> int:
n = 0
if embedding_layer_match(distance, similarity_threshold=similarity_threshold):
n += 1
if text_layer_match(
chunk.content or "", neighbor.get("content") or "", jaccard_min=jaccard_min
):
n += 1
if metadata_layer_match(chunk, neighbor, event_year_window=event_year_window):
n += 1
return n
async def _advance_cursor_past_excluded_only(
db: AsyncSession,
user_id: str,
cursor_ts: datetime,
cursor_id: str,
*,
max_steps: int,
) -> tuple[datetime, str] | None:
"""
当「待处理增量」为空但游标之后仍有 chunk 时,通常是因为连续 excluded。
将游标推进到这些 excluded 的最后一条,使后续运行能扫到未 excluded 的 chunk。
若游标之后没有任何 chunk返回 None。
若游标之后第一条即非 excluded返回 None不应与空增量并存由调用方打日志
"""
advanced = False
ts, cid = cursor_ts, cursor_id
for _ in range(max_steps):
nxt = await get_first_chunk_after_cursor(
db,
user_id=user_id,
after_cursor_ts=ts,
after_chunk_id=cid,
)
if nxt is None:
return (ts, cid) if advanced else None
ex = bool(nxt.is_excluded)
if not ex:
if advanced:
return (ts, cid)
logger.warning(
"memory_compaction_cursor_anomaly user_id={} cursor_ts={} cursor_id={} "
"next_chunk_id={} next_is_excluded=false",
user_id,
cursor_ts,
cursor_id,
nxt.id,
)
return None
nxt_ts = nxt.created_at or datetime.min.replace(tzinfo=timezone.utc)
if nxt_ts.tzinfo is None:
nxt_ts = nxt_ts.replace(tzinfo=timezone.utc)
ts, cid = nxt_ts, nxt.id
advanced = True
logger.warning(
"memory_compaction_cursor_skip_cap user_id={} steps={}",
user_id,
max_steps,
)
return (ts, cid) if advanced else None
async def run_memory_compaction(
db: AsyncSession, user_id: str, context: dict[str, Any] | None
) -> dict[str, Any]:
"""
对增量 chunk 做近重复软排除;调用方负责 commit。
Returns:
结构化结果,供日志与任务返回值使用。
"""
ctx = dict(context or {})
t0 = time.perf_counter()
pair = ctx.get("_cursor_pair_override")
if pair is None:
from app.core.memory_compaction_schedule import get_incremental_cursor_pair
cursor_ts, cursor_id = get_incremental_cursor_pair(user_id)
else:
cursor_ts, cursor_id = pair # type: ignore[misc]
cand_chunks = ctx.get("candidate_chunk_ids")
if cand_chunks is not None:
cand_chunks = [str(x) for x in cand_chunks]
cand_sources = ctx.get("candidate_source_ids")
if cand_sources is not None:
cand_sources = [str(x) for x in cand_sources]
has_candidate_filter = "candidate_chunk_ids" in ctx or "candidate_source_ids" in ctx
max_chunks = memory.compaction_max_chunks_per_run
incremental = await list_incremental_chunks_for_compaction(
db,
user_id=user_id,
after_cursor_ts=cursor_ts,
after_chunk_id=cursor_id,
limit=max_chunks,
candidate_chunk_ids=cand_chunks,
candidate_source_ids=cand_sources,
)
if not incremental:
ms = (time.perf_counter() - t0) * 1000
# 候选 id/source 收窄时,空集可能仅表示「交集为空」,不能盲目前进全局游标
if not has_candidate_filter:
tail = await _advance_cursor_past_excluded_only(
db,
user_id,
cursor_ts,
cursor_id,
max_steps=max(
memory.compaction_max_chunks_per_run * 2,
500,
),
)
if tail is not None:
new_ts, new_id = tail
logger.info(
"memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 "
"duration_ms={:.1f} skipped_reason=empty_incremental_cursor_advanced "
"trigger_source={}",
user_id,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": 0,
"chunks_excluded": 0,
"candidates_considered": 0,
"new_cursor_ts": new_ts.isoformat(),
"new_cursor_id": new_id,
"duration_ms": round(ms, 1),
"skipped_reason": "empty_incremental_cursor_advanced",
}
logger.info(
"memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 "
"duration_ms={:.1f} skipped_reason=empty_incremental trigger_source={}",
user_id,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": 0,
"chunks_excluded": 0,
"candidates_considered": 0,
"new_cursor_ts": None,
"new_cursor_id": None,
"duration_ms": round(ms, 1),
"skipped_reason": "empty_incremental",
}
sim_th = memory.compaction_chunk_similarity_threshold
min_layers = memory.compaction_min_layers_for_exclude
jaccard_min = memory.compaction_text_jaccard_min
year_w = memory.compaction_metadata_event_year_window
max_neighbors = memory.compaction_max_neighbors_per_chunk
max_excludes = memory.compaction_max_excludes_per_run
local_excluded: set[str] = set()
excludes_done = 0
candidates_considered = 0
# 稳定顺序:先处理较早写入的 chunk
incremental_sorted = sorted(
incremental,
key=lambda c: (c.created_at or datetime.min.replace(tzinfo=timezone.utc), c.id),
)
last_cursor_chunk: MemoryChunk | None = None
chunks_scanned_this_run = 0
for chunk in incremental_sorted:
if excludes_done >= max_excludes:
break
if chunk.id in local_excluded:
last_cursor_chunk = chunk
continue
row = await get_memory_chunk_for_user(db, chunk.id, user_id)
if row is None or row.is_excluded:
last_cursor_chunk = chunk
continue
emb = row.embedding
if emb is None:
ms = (time.perf_counter() - t0) * 1000
logger.info(
"memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} "
"candidates={} duration_ms={:.1f} skipped_reason=awaiting_embeddings "
"pending_chunk_id={} trigger_source={}",
user_id,
chunks_scanned_this_run,
excludes_done,
candidates_considered,
ms,
row.id,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": chunks_scanned_this_run,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": (
last_cursor_chunk.created_at.isoformat()
if last_cursor_chunk and last_cursor_chunk.created_at
else None
),
"new_cursor_id": last_cursor_chunk.id if last_cursor_chunk else None,
"duration_ms": round(ms, 1),
"skipped_reason": "awaiting_embeddings",
"pending_chunk_id": row.id,
}
chunks_scanned_this_run += 1
st_c = await _source_type_for_chunk(db, row)
neighbors = await search_nearest_chunks_for_compaction(
db,
user_id=user_id,
chunk_id=row.id,
query_embedding=list(emb),
limit=max_neighbors,
)
for nb in neighbors:
if excludes_done >= max_excludes:
break
nid = nb["id"]
if nid == row.id or nid in local_excluded:
continue
other = await get_memory_chunk_for_user(db, nid, user_id)
if other is None or other.is_excluded:
continue
dist = float(nb["distance"])
layers = count_duplicate_layers(
chunk=row,
neighbor=nb,
distance=dist,
similarity_threshold=sim_th,
jaccard_min=jaccard_min,
event_year_window=year_w,
)
if layers < min_layers:
continue
candidates_considered += 1
sc = canonical_score(
content=row.content or "",
metadata_json=row.metadata_json,
source_type=st_c,
)
sn = canonical_score(
content=nb.get("content") or "",
metadata_json=nb.get("metadata_json"),
source_type=nb.get("source_type"),
)
t_row = row.created_at or datetime.min.replace(tzinfo=timezone.utc)
t_nb = nb.get("created_at") or datetime.min.replace(tzinfo=timezone.utc)
if t_row.tzinfo is None:
t_row = t_row.replace(tzinfo=timezone.utc)
if t_nb.tzinfo is None:
t_nb = t_nb.replace(tzinfo=timezone.utc)
if sc > sn:
loser_id, winner_id = nid, row.id
elif sn > sc:
loser_id, winner_id = row.id, nid
else:
# 同分:保留更早写入
if t_row <= t_nb:
loser_id, winner_id = nid, row.id
else:
loser_id, winner_id = row.id, nid
loser = await get_memory_chunk_for_user(db, loser_id, user_id)
if loser is None or loser.is_excluded:
continue
ok = await set_chunk_excluded(db, loser_id, user_id, True)
if not ok:
continue
stale_n = await mark_facts_stale_for_excluded_chunk(
db, user_id=user_id, chunk_id=loser_id
)
if stale_n:
logger.info(
"memory_compaction_facts_staled user_id={} chunk_id={} count={}",
user_id,
loser_id,
stale_n,
)
await create_curation_action(
db,
user_id=user_id,
action_type="exclude",
target_type="chunk",
target_id=loser_id,
details=_build_curation_details(
ctx=ctx,
winner_id=winner_id,
layers=layers,
distance=dist,
),
)
excludes_done += 1
local_excluded.add(loser_id)
logger.info(
"memory_compaction_exclude user_id={} excluded_chunk_id={} "
"kept_chunk_id={} layers={} distance={:.4f}",
user_id,
loser_id,
winner_id,
layers,
dist,
)
if loser_id == row.id:
break
last_cursor_chunk = chunk
if last_cursor_chunk is None:
ms = (time.perf_counter() - t0) * 1000
logger.warning(
"memory_compaction_no_cursor_chunk user_id={} incremental_n={}",
user_id,
len(incremental_sorted),
)
return {
"chunks_scanned": 0,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": None,
"new_cursor_id": None,
"duration_ms": round(ms, 1),
"skipped_reason": "no_cursor_chunk",
}
new_cursor_ts = last_cursor_chunk.created_at or datetime.min.replace(
tzinfo=timezone.utc
)
if new_cursor_ts.tzinfo is None:
new_cursor_ts = new_cursor_ts.replace(tzinfo=timezone.utc)
new_cursor_id = last_cursor_chunk.id
ms = (time.perf_counter() - t0) * 1000
logger.info(
"memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} "
"candidates={} duration_ms={:.1f} trigger_source={}",
user_id,
chunks_scanned_this_run,
excludes_done,
candidates_considered,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": chunks_scanned_this_run,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": new_cursor_ts.isoformat(),
"new_cursor_id": new_cursor_id,
"duration_ms": round(ms, 1),
"skipped_reason": None,
}