""" Memory compaction:增量 chunk 近重复检测,软排除(is_excluded + MemoryCurationAction)。 仅依赖 repo / settings;供 async MemoryService 调用。 """ from __future__ import annotations import re import time from datetime import datetime, timezone from typing import Any from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.logging import get_logger from app.features.memory.models import MemoryChunk, MemorySource from app.features.memory.repo import ( create_curation_action, get_first_chunk_after_cursor, get_memory_chunk_for_user, list_incremental_chunks_for_compaction, mark_facts_stale_for_excluded_chunk, search_nearest_chunks_for_compaction, set_chunk_excluded, ) logger = get_logger(__name__) _WS_RE = re.compile(r"\s+") _PUNCT_RE = re.compile(r"[^\w\s\u4e00-\u9fff]", re.UNICODE) def _normalize_text(s: str) -> str: t = (s or "").strip().lower() t = _PUNCT_RE.sub(" ", t) return _WS_RE.sub(" ", t).strip() def _word_jaccard(a: str, b: str) -> float: wa = set(_normalize_text(a).split()) wb = set(_normalize_text(b).split()) if not wa or not wb: return 0.0 return len(wa & wb) / len(wa | wb) def text_layer_match(a: str, b: str, *, jaccard_min: float) -> bool: if _word_jaccard(a, b) >= jaccard_min: return True na, nb = _normalize_text(a), _normalize_text(b) if len(na) >= 12 and len(nb) >= 12 and (na in nb or nb in na): return True return False def metadata_layer_match( c: MemoryChunk, n: dict[str, Any], *, event_year_window: int, ) -> bool: if c.source_id == n.get("source_id"): return True cy, ny = c.event_year, n.get("event_year") if cy is not None and ny is not None: if abs(int(cy) - int(ny)) <= event_year_window: return True return False def embedding_layer_match(distance: float, *, similarity_threshold: float) -> bool: sim = 1.0 - distance return sim >= similarity_threshold def canonical_score( *, content: str, metadata_json: Any, source_type: str | None, ) -> float: score = float(len(content or "")) bonus = 0.0 if source_type in ("draft", "story"): bonus += 8.0 if isinstance(metadata_json, dict): bonus += min(len(metadata_json), 24) * 0.15 return score + bonus async def _source_type_for_chunk(db: AsyncSession, chunk: MemoryChunk) -> str | None: src = await db.get(MemorySource, chunk.source_id) return src.source_type if src else None def _build_curation_details( *, ctx: dict[str, Any], winner_id: str, layers: int, distance: float, ) -> dict[str, Any]: details: dict[str, Any] = { "reason": "near_duplicate", "similar_to_chunk_id": winner_id, "layers": layers, "cosine_distance": distance, } for key in ( "trigger_source", "trigger_time", "pipeline_run_id", "memoir_correlation_id", "request_id", "story_ids", "story_dispatch_ids", "chapter_ids", "recomposed_chapter_ids", "chapters_to_enqueue", "candidate_chunk_ids", "candidate_source_ids", ): value = ctx.get(key) if value is not None: details[key] = value return details def count_duplicate_layers( *, chunk: MemoryChunk, neighbor: dict[str, Any], distance: float, similarity_threshold: float, jaccard_min: float, event_year_window: int, ) -> int: n = 0 if embedding_layer_match(distance, similarity_threshold=similarity_threshold): n += 1 if text_layer_match( chunk.content or "", neighbor.get("content") or "", jaccard_min=jaccard_min ): n += 1 if metadata_layer_match(chunk, neighbor, event_year_window=event_year_window): n += 1 return n async def _advance_cursor_past_excluded_only( db: AsyncSession, user_id: str, cursor_ts: datetime, cursor_id: str, *, max_steps: int, ) -> tuple[datetime, str] | None: """ 当「待处理增量」为空但游标之后仍有 chunk 时,通常是因为连续 excluded。 将游标推进到这些 excluded 的最后一条,使后续运行能扫到未 excluded 的 chunk。 若游标之后没有任何 chunk,返回 None。 若游标之后第一条即非 excluded,返回 None(不应与空增量并存;由调用方打日志)。 """ advanced = False ts, cid = cursor_ts, cursor_id for _ in range(max_steps): nxt = await get_first_chunk_after_cursor( db, user_id=user_id, after_cursor_ts=ts, after_chunk_id=cid, ) if nxt is None: return (ts, cid) if advanced else None ex = bool(nxt.is_excluded) if not ex: if advanced: return (ts, cid) logger.warning( "memory_compaction_cursor_anomaly user_id={} cursor_ts={} cursor_id={} " "next_chunk_id={} next_is_excluded=false", user_id, cursor_ts, cursor_id, nxt.id, ) return None nxt_ts = nxt.created_at or datetime.min.replace(tzinfo=timezone.utc) if nxt_ts.tzinfo is None: nxt_ts = nxt_ts.replace(tzinfo=timezone.utc) ts, cid = nxt_ts, nxt.id advanced = True logger.warning( "memory_compaction_cursor_skip_cap user_id={} steps={}", user_id, max_steps, ) return (ts, cid) if advanced else None async def run_memory_compaction( db: AsyncSession, user_id: str, context: dict[str, Any] | None ) -> dict[str, Any]: """ 对增量 chunk 做近重复软排除;调用方负责 commit。 Returns: 结构化结果,供日志与任务返回值使用。 """ ctx = dict(context or {}) t0 = time.perf_counter() pair = ctx.get("_cursor_pair_override") if pair is None: from app.core.memory_compaction_schedule import get_incremental_cursor_pair cursor_ts, cursor_id = get_incremental_cursor_pair(user_id) else: cursor_ts, cursor_id = pair # type: ignore[misc] cand_chunks = ctx.get("candidate_chunk_ids") if cand_chunks is not None: cand_chunks = [str(x) for x in cand_chunks] cand_sources = ctx.get("candidate_source_ids") if cand_sources is not None: cand_sources = [str(x) for x in cand_sources] has_candidate_filter = "candidate_chunk_ids" in ctx or "candidate_source_ids" in ctx max_chunks = settings.memory_compaction_max_chunks_per_run incremental = await list_incremental_chunks_for_compaction( db, user_id=user_id, after_cursor_ts=cursor_ts, after_chunk_id=cursor_id, limit=max_chunks, candidate_chunk_ids=cand_chunks, candidate_source_ids=cand_sources, ) if not incremental: ms = (time.perf_counter() - t0) * 1000 # 候选 id/source 收窄时,空集可能仅表示「交集为空」,不能盲目前进全局游标 if not has_candidate_filter: tail = await _advance_cursor_past_excluded_only( db, user_id, cursor_ts, cursor_id, max_steps=max( settings.memory_compaction_max_chunks_per_run * 2, 500, ), ) if tail is not None: new_ts, new_id = tail logger.info( "memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 " "duration_ms={:.1f} skipped_reason=empty_incremental_cursor_advanced " "trigger_source={}", user_id, ms, ctx.get("trigger_source", ""), ) return { "chunks_scanned": 0, "chunks_excluded": 0, "candidates_considered": 0, "new_cursor_ts": new_ts.isoformat(), "new_cursor_id": new_id, "duration_ms": round(ms, 1), "skipped_reason": "empty_incremental_cursor_advanced", } logger.info( "memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 " "duration_ms={:.1f} skipped_reason=empty_incremental trigger_source={}", user_id, ms, ctx.get("trigger_source", ""), ) return { "chunks_scanned": 0, "chunks_excluded": 0, "candidates_considered": 0, "new_cursor_ts": None, "new_cursor_id": None, "duration_ms": round(ms, 1), "skipped_reason": "empty_incremental", } sim_th = settings.memory_compaction_chunk_similarity_threshold min_layers = settings.memory_compaction_min_layers_for_exclude jaccard_min = settings.memory_compaction_text_jaccard_min year_w = settings.memory_compaction_metadata_event_year_window max_neighbors = settings.memory_compaction_max_neighbors_per_chunk max_excludes = settings.memory_compaction_max_excludes_per_run local_excluded: set[str] = set() excludes_done = 0 candidates_considered = 0 # 稳定顺序:先处理较早写入的 chunk incremental_sorted = sorted( incremental, key=lambda c: (c.created_at or datetime.min.replace(tzinfo=timezone.utc), c.id), ) last_cursor_chunk: MemoryChunk | None = None chunks_scanned_this_run = 0 for chunk in incremental_sorted: if excludes_done >= max_excludes: break if chunk.id in local_excluded: last_cursor_chunk = chunk continue row = await get_memory_chunk_for_user(db, chunk.id, user_id) if row is None or row.is_excluded: last_cursor_chunk = chunk continue emb = row.embedding if emb is None: ms = (time.perf_counter() - t0) * 1000 logger.info( "memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} " "candidates={} duration_ms={:.1f} skipped_reason=awaiting_embeddings " "pending_chunk_id={} trigger_source={}", user_id, chunks_scanned_this_run, excludes_done, candidates_considered, ms, row.id, ctx.get("trigger_source", ""), ) return { "chunks_scanned": chunks_scanned_this_run, "chunks_excluded": excludes_done, "candidates_considered": candidates_considered, "new_cursor_ts": ( last_cursor_chunk.created_at.isoformat() if last_cursor_chunk and last_cursor_chunk.created_at else None ), "new_cursor_id": last_cursor_chunk.id if last_cursor_chunk else None, "duration_ms": round(ms, 1), "skipped_reason": "awaiting_embeddings", "pending_chunk_id": row.id, } chunks_scanned_this_run += 1 st_c = await _source_type_for_chunk(db, row) neighbors = await search_nearest_chunks_for_compaction( db, user_id=user_id, chunk_id=row.id, query_embedding=list(emb), limit=max_neighbors, ) for nb in neighbors: if excludes_done >= max_excludes: break nid = nb["id"] if nid == row.id or nid in local_excluded: continue other = await get_memory_chunk_for_user(db, nid, user_id) if other is None or other.is_excluded: continue dist = float(nb["distance"]) layers = count_duplicate_layers( chunk=row, neighbor=nb, distance=dist, similarity_threshold=sim_th, jaccard_min=jaccard_min, event_year_window=year_w, ) if layers < min_layers: continue candidates_considered += 1 sc = canonical_score( content=row.content or "", metadata_json=row.metadata_json, source_type=st_c, ) sn = canonical_score( content=nb.get("content") or "", metadata_json=nb.get("metadata_json"), source_type=nb.get("source_type"), ) t_row = row.created_at or datetime.min.replace(tzinfo=timezone.utc) t_nb = nb.get("created_at") or datetime.min.replace(tzinfo=timezone.utc) if t_row.tzinfo is None: t_row = t_row.replace(tzinfo=timezone.utc) if t_nb.tzinfo is None: t_nb = t_nb.replace(tzinfo=timezone.utc) if sc > sn: loser_id, winner_id = nid, row.id elif sn > sc: loser_id, winner_id = row.id, nid else: # 同分:保留更早写入 if t_row <= t_nb: loser_id, winner_id = nid, row.id else: loser_id, winner_id = row.id, nid loser = await get_memory_chunk_for_user(db, loser_id, user_id) if loser is None or loser.is_excluded: continue ok = await set_chunk_excluded(db, loser_id, user_id, True) if not ok: continue stale_n = await mark_facts_stale_for_excluded_chunk( db, user_id=user_id, chunk_id=loser_id ) if stale_n: logger.info( "memory_compaction_facts_staled user_id={} chunk_id={} count={}", user_id, loser_id, stale_n, ) await create_curation_action( db, user_id=user_id, action_type="exclude", target_type="chunk", target_id=loser_id, details=_build_curation_details( ctx=ctx, winner_id=winner_id, layers=layers, distance=dist, ), ) excludes_done += 1 local_excluded.add(loser_id) logger.info( "memory_compaction_exclude user_id={} excluded_chunk_id={} " "kept_chunk_id={} layers={} distance={:.4f}", user_id, loser_id, winner_id, layers, dist, ) if loser_id == row.id: break last_cursor_chunk = chunk if last_cursor_chunk is None: ms = (time.perf_counter() - t0) * 1000 logger.warning( "memory_compaction_no_cursor_chunk user_id={} incremental_n={}", user_id, len(incremental_sorted), ) return { "chunks_scanned": 0, "chunks_excluded": excludes_done, "candidates_considered": candidates_considered, "new_cursor_ts": None, "new_cursor_id": None, "duration_ms": round(ms, 1), "skipped_reason": "no_cursor_chunk", } new_cursor_ts = last_cursor_chunk.created_at or datetime.min.replace( tzinfo=timezone.utc ) if new_cursor_ts.tzinfo is None: new_cursor_ts = new_cursor_ts.replace(tzinfo=timezone.utc) new_cursor_id = last_cursor_chunk.id ms = (time.perf_counter() - t0) * 1000 logger.info( "memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} " "candidates={} duration_ms={:.1f} trigger_source={}", user_id, chunks_scanned_this_run, excludes_done, candidates_considered, ms, ctx.get("trigger_source", ""), ) return { "chunks_scanned": chunks_scanned_this_run, "chunks_excluded": excludes_done, "candidates_considered": candidates_considered, "new_cursor_ts": new_cursor_ts.isoformat(), "new_cursor_id": new_cursor_id, "duration_ms": round(ms, 1), "skipped_reason": None, }