Files
life-echo/api/app/features/memory/compaction_service.py

502 lines
16 KiB
Python
Raw Normal View History

"""
Memory compaction增量 chunk 近重复检测软排除is_excluded + MemoryCurationAction
仅依赖 repo / settings async MemoryService 调用
"""
from __future__ import annotations
import re
import time
from datetime import datetime, timezone
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.features.memory.models import MemoryChunk, MemorySource
from app.features.memory.constants import memory
from app.features.memory.repo import (
create_curation_action,
get_first_chunk_after_cursor,
get_memory_chunk_for_user,
list_incremental_chunks_for_compaction,
mark_facts_stale_for_excluded_chunk,
search_nearest_chunks_for_compaction,
set_chunk_excluded,
)
logger = get_logger(__name__)
_WS_RE = re.compile(r"\s+")
_PUNCT_RE = re.compile(r"[^\w\s\u4e00-\u9fff]", re.UNICODE)
def _normalize_text(s: str) -> str:
t = (s or "").strip().lower()
t = _PUNCT_RE.sub(" ", t)
return _WS_RE.sub(" ", t).strip()
def _word_jaccard(a: str, b: str) -> float:
wa = set(_normalize_text(a).split())
wb = set(_normalize_text(b).split())
if not wa or not wb:
return 0.0
return len(wa & wb) / len(wa | wb)
def text_layer_match(a: str, b: str, *, jaccard_min: float) -> bool:
if _word_jaccard(a, b) >= jaccard_min:
return True
na, nb = _normalize_text(a), _normalize_text(b)
if len(na) >= 12 and len(nb) >= 12 and (na in nb or nb in na):
return True
return False
def metadata_layer_match(
c: MemoryChunk,
n: dict[str, Any],
*,
event_year_window: int,
) -> bool:
if c.source_id == n.get("source_id"):
return True
cy, ny = c.event_year, n.get("event_year")
if cy is not None and ny is not None:
if abs(int(cy) - int(ny)) <= event_year_window:
return True
return False
def embedding_layer_match(distance: float, *, similarity_threshold: float) -> bool:
sim = 1.0 - distance
return sim >= similarity_threshold
def canonical_score(
*,
content: str,
metadata_json: Any,
source_type: str | None,
) -> float:
score = float(len(content or ""))
bonus = 0.0
if source_type in ("draft", "story"):
bonus += 8.0
if isinstance(metadata_json, dict):
bonus += min(len(metadata_json), 24) * 0.15
return score + bonus
async def _source_type_for_chunk(db: AsyncSession, chunk: MemoryChunk) -> str | None:
src = await db.get(MemorySource, chunk.source_id)
return src.source_type if src else None
def _build_curation_details(
*,
ctx: dict[str, Any],
winner_id: str,
layers: int,
distance: float,
) -> dict[str, Any]:
details: dict[str, Any] = {
"reason": "near_duplicate",
"similar_to_chunk_id": winner_id,
"layers": layers,
"cosine_distance": distance,
}
for key in (
"trigger_source",
"trigger_time",
"pipeline_run_id",
"memoir_correlation_id",
"request_id",
"story_ids",
"story_dispatch_ids",
"chapter_ids",
"recomposed_chapter_ids",
"chapters_to_enqueue",
"candidate_chunk_ids",
"candidate_source_ids",
):
value = ctx.get(key)
if value is not None:
details[key] = value
return details
def count_duplicate_layers(
*,
chunk: MemoryChunk,
neighbor: dict[str, Any],
distance: float,
similarity_threshold: float,
jaccard_min: float,
event_year_window: int,
) -> int:
n = 0
if embedding_layer_match(distance, similarity_threshold=similarity_threshold):
n += 1
if text_layer_match(
chunk.content or "", neighbor.get("content") or "", jaccard_min=jaccard_min
):
n += 1
if metadata_layer_match(chunk, neighbor, event_year_window=event_year_window):
n += 1
return n
async def _advance_cursor_past_excluded_only(
db: AsyncSession,
user_id: str,
cursor_ts: datetime,
cursor_id: str,
*,
max_steps: int,
) -> tuple[datetime, str] | None:
"""
待处理增量为空但游标之后仍有 chunk 通常是因为连续 excluded
将游标推进到这些 excluded 的最后一条使后续运行能扫到未 excluded chunk
若游标之后没有任何 chunk返回 None
若游标之后第一条即非 excluded返回 None不应与空增量并存由调用方打日志
"""
advanced = False
ts, cid = cursor_ts, cursor_id
for _ in range(max_steps):
nxt = await get_first_chunk_after_cursor(
db,
user_id=user_id,
after_cursor_ts=ts,
after_chunk_id=cid,
)
if nxt is None:
return (ts, cid) if advanced else None
ex = bool(nxt.is_excluded)
if not ex:
if advanced:
return (ts, cid)
logger.warning(
"memory_compaction_cursor_anomaly user_id={} cursor_ts={} cursor_id={} "
"next_chunk_id={} next_is_excluded=false",
user_id,
cursor_ts,
cursor_id,
nxt.id,
)
return None
nxt_ts = nxt.created_at or datetime.min.replace(tzinfo=timezone.utc)
if nxt_ts.tzinfo is None:
nxt_ts = nxt_ts.replace(tzinfo=timezone.utc)
ts, cid = nxt_ts, nxt.id
advanced = True
logger.warning(
"memory_compaction_cursor_skip_cap user_id={} steps={}",
user_id,
max_steps,
)
return (ts, cid) if advanced else None
async def run_memory_compaction(
db: AsyncSession, user_id: str, context: dict[str, Any] | None
) -> dict[str, Any]:
"""
对增量 chunk 做近重复软排除调用方负责 commit
Returns:
结构化结果供日志与任务返回值使用
"""
ctx = dict(context or {})
t0 = time.perf_counter()
pair = ctx.get("_cursor_pair_override")
if pair is None:
from app.core.memory_compaction_schedule import get_incremental_cursor_pair
cursor_ts, cursor_id = get_incremental_cursor_pair(user_id)
else:
cursor_ts, cursor_id = pair # type: ignore[misc]
cand_chunks = ctx.get("candidate_chunk_ids")
if cand_chunks is not None:
cand_chunks = [str(x) for x in cand_chunks]
cand_sources = ctx.get("candidate_source_ids")
if cand_sources is not None:
cand_sources = [str(x) for x in cand_sources]
has_candidate_filter = "candidate_chunk_ids" in ctx or "candidate_source_ids" in ctx
max_chunks = memory.compaction_max_chunks_per_run
incremental = await list_incremental_chunks_for_compaction(
db,
user_id=user_id,
after_cursor_ts=cursor_ts,
after_chunk_id=cursor_id,
limit=max_chunks,
candidate_chunk_ids=cand_chunks,
candidate_source_ids=cand_sources,
)
if not incremental:
ms = (time.perf_counter() - t0) * 1000
# 候选 id/source 收窄时,空集可能仅表示「交集为空」,不能盲目前进全局游标
if not has_candidate_filter:
tail = await _advance_cursor_past_excluded_only(
db,
user_id,
cursor_ts,
cursor_id,
max_steps=max(
memory.compaction_max_chunks_per_run * 2,
500,
),
)
if tail is not None:
new_ts, new_id = tail
logger.info(
"memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 "
"duration_ms={:.1f} skipped_reason=empty_incremental_cursor_advanced "
"trigger_source={}",
user_id,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": 0,
"chunks_excluded": 0,
"candidates_considered": 0,
"new_cursor_ts": new_ts.isoformat(),
"new_cursor_id": new_id,
"duration_ms": round(ms, 1),
"skipped_reason": "empty_incremental_cursor_advanced",
}
logger.info(
"memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 "
"duration_ms={:.1f} skipped_reason=empty_incremental trigger_source={}",
user_id,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": 0,
"chunks_excluded": 0,
"candidates_considered": 0,
"new_cursor_ts": None,
"new_cursor_id": None,
"duration_ms": round(ms, 1),
"skipped_reason": "empty_incremental",
}
sim_th = memory.compaction_chunk_similarity_threshold
min_layers = memory.compaction_min_layers_for_exclude
jaccard_min = memory.compaction_text_jaccard_min
year_w = memory.compaction_metadata_event_year_window
max_neighbors = memory.compaction_max_neighbors_per_chunk
max_excludes = memory.compaction_max_excludes_per_run
local_excluded: set[str] = set()
excludes_done = 0
candidates_considered = 0
# 稳定顺序:先处理较早写入的 chunk
incremental_sorted = sorted(
incremental,
key=lambda c: (c.created_at or datetime.min.replace(tzinfo=timezone.utc), c.id),
)
last_cursor_chunk: MemoryChunk | None = None
chunks_scanned_this_run = 0
for chunk in incremental_sorted:
if excludes_done >= max_excludes:
break
if chunk.id in local_excluded:
last_cursor_chunk = chunk
continue
row = await get_memory_chunk_for_user(db, chunk.id, user_id)
if row is None or row.is_excluded:
last_cursor_chunk = chunk
continue
emb = row.embedding
if emb is None:
ms = (time.perf_counter() - t0) * 1000
logger.info(
"memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} "
"candidates={} duration_ms={:.1f} skipped_reason=awaiting_embeddings "
"pending_chunk_id={} trigger_source={}",
user_id,
chunks_scanned_this_run,
excludes_done,
candidates_considered,
ms,
row.id,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": chunks_scanned_this_run,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": (
last_cursor_chunk.created_at.isoformat()
if last_cursor_chunk and last_cursor_chunk.created_at
else None
),
"new_cursor_id": last_cursor_chunk.id if last_cursor_chunk else None,
"duration_ms": round(ms, 1),
"skipped_reason": "awaiting_embeddings",
"pending_chunk_id": row.id,
}
chunks_scanned_this_run += 1
st_c = await _source_type_for_chunk(db, row)
neighbors = await search_nearest_chunks_for_compaction(
db,
user_id=user_id,
chunk_id=row.id,
query_embedding=list(emb),
limit=max_neighbors,
)
for nb in neighbors:
if excludes_done >= max_excludes:
break
nid = nb["id"]
if nid == row.id or nid in local_excluded:
continue
other = await get_memory_chunk_for_user(db, nid, user_id)
if other is None or other.is_excluded:
continue
dist = float(nb["distance"])
layers = count_duplicate_layers(
chunk=row,
neighbor=nb,
distance=dist,
similarity_threshold=sim_th,
jaccard_min=jaccard_min,
event_year_window=year_w,
)
if layers < min_layers:
continue
candidates_considered += 1
sc = canonical_score(
content=row.content or "",
metadata_json=row.metadata_json,
source_type=st_c,
)
sn = canonical_score(
content=nb.get("content") or "",
metadata_json=nb.get("metadata_json"),
source_type=nb.get("source_type"),
)
t_row = row.created_at or datetime.min.replace(tzinfo=timezone.utc)
t_nb = nb.get("created_at") or datetime.min.replace(tzinfo=timezone.utc)
if t_row.tzinfo is None:
t_row = t_row.replace(tzinfo=timezone.utc)
if t_nb.tzinfo is None:
t_nb = t_nb.replace(tzinfo=timezone.utc)
if sc > sn:
loser_id, winner_id = nid, row.id
elif sn > sc:
loser_id, winner_id = row.id, nid
else:
# 同分:保留更早写入
if t_row <= t_nb:
loser_id, winner_id = nid, row.id
else:
loser_id, winner_id = row.id, nid
loser = await get_memory_chunk_for_user(db, loser_id, user_id)
if loser is None or loser.is_excluded:
continue
ok = await set_chunk_excluded(db, loser_id, user_id, True)
if not ok:
continue
stale_n = await mark_facts_stale_for_excluded_chunk(
db, user_id=user_id, chunk_id=loser_id
)
if stale_n:
logger.info(
"memory_compaction_facts_staled user_id={} chunk_id={} count={}",
user_id,
loser_id,
stale_n,
)
await create_curation_action(
db,
user_id=user_id,
action_type="exclude",
target_type="chunk",
target_id=loser_id,
details=_build_curation_details(
ctx=ctx,
winner_id=winner_id,
layers=layers,
distance=dist,
),
)
excludes_done += 1
local_excluded.add(loser_id)
logger.info(
"memory_compaction_exclude user_id={} excluded_chunk_id={} "
"kept_chunk_id={} layers={} distance={:.4f}",
user_id,
loser_id,
winner_id,
layers,
dist,
)
if loser_id == row.id:
break
last_cursor_chunk = chunk
if last_cursor_chunk is None:
ms = (time.perf_counter() - t0) * 1000
logger.warning(
"memory_compaction_no_cursor_chunk user_id={} incremental_n={}",
user_id,
len(incremental_sorted),
)
return {
"chunks_scanned": 0,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": None,
"new_cursor_id": None,
"duration_ms": round(ms, 1),
"skipped_reason": "no_cursor_chunk",
}
new_cursor_ts = last_cursor_chunk.created_at or datetime.min.replace(
tzinfo=timezone.utc
)
if new_cursor_ts.tzinfo is None:
new_cursor_ts = new_cursor_ts.replace(tzinfo=timezone.utc)
new_cursor_id = last_cursor_chunk.id
ms = (time.perf_counter() - t0) * 1000
logger.info(
"memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} "
"candidates={} duration_ms={:.1f} trigger_source={}",
user_id,
chunks_scanned_this_run,
excludes_done,
candidates_considered,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": chunks_scanned_this_run,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": new_cursor_ts.isoformat(),
"new_cursor_id": new_cursor_id,
"duration_ms": round(ms, 1),
"skipped_reason": None,
}