配置 SSOT(TOML + .env) 统一错误契约 Auth 与事务边界 Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client 可观测性(OpenTelemetry + LGTM)
502 lines
16 KiB
Python
502 lines
16 KiB
Python
"""
|
||
Memory compaction:增量 chunk 近重复检测,软排除(is_excluded + MemoryCurationAction)。
|
||
|
||
仅依赖 repo / settings;供 async MemoryService 调用。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import re
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from typing import Any
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.config import settings
|
||
from app.core.logging import get_logger
|
||
from app.features.memory.models import MemoryChunk, MemorySource
|
||
from app.features.memory.constants import memory
|
||
from app.features.memory.repo import (
|
||
create_curation_action,
|
||
get_first_chunk_after_cursor,
|
||
get_memory_chunk_for_user,
|
||
list_incremental_chunks_for_compaction,
|
||
mark_facts_stale_for_excluded_chunk,
|
||
search_nearest_chunks_for_compaction,
|
||
set_chunk_excluded,
|
||
)
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
_WS_RE = re.compile(r"\s+")
|
||
_PUNCT_RE = re.compile(r"[^\w\s\u4e00-\u9fff]", re.UNICODE)
|
||
|
||
|
||
def _normalize_text(s: str) -> str:
|
||
t = (s or "").strip().lower()
|
||
t = _PUNCT_RE.sub(" ", t)
|
||
return _WS_RE.sub(" ", t).strip()
|
||
|
||
|
||
def _word_jaccard(a: str, b: str) -> float:
|
||
wa = set(_normalize_text(a).split())
|
||
wb = set(_normalize_text(b).split())
|
||
if not wa or not wb:
|
||
return 0.0
|
||
return len(wa & wb) / len(wa | wb)
|
||
|
||
|
||
def text_layer_match(a: str, b: str, *, jaccard_min: float) -> bool:
|
||
if _word_jaccard(a, b) >= jaccard_min:
|
||
return True
|
||
na, nb = _normalize_text(a), _normalize_text(b)
|
||
if len(na) >= 12 and len(nb) >= 12 and (na in nb or nb in na):
|
||
return True
|
||
return False
|
||
|
||
|
||
def metadata_layer_match(
|
||
c: MemoryChunk,
|
||
n: dict[str, Any],
|
||
*,
|
||
event_year_window: int,
|
||
) -> bool:
|
||
if c.source_id == n.get("source_id"):
|
||
return True
|
||
cy, ny = c.event_year, n.get("event_year")
|
||
if cy is not None and ny is not None:
|
||
if abs(int(cy) - int(ny)) <= event_year_window:
|
||
return True
|
||
return False
|
||
|
||
|
||
def embedding_layer_match(distance: float, *, similarity_threshold: float) -> bool:
|
||
sim = 1.0 - distance
|
||
return sim >= similarity_threshold
|
||
|
||
|
||
def canonical_score(
|
||
*,
|
||
content: str,
|
||
metadata_json: Any,
|
||
source_type: str | None,
|
||
) -> float:
|
||
score = float(len(content or ""))
|
||
bonus = 0.0
|
||
if source_type in ("draft", "story"):
|
||
bonus += 8.0
|
||
if isinstance(metadata_json, dict):
|
||
bonus += min(len(metadata_json), 24) * 0.15
|
||
return score + bonus
|
||
|
||
|
||
async def _source_type_for_chunk(db: AsyncSession, chunk: MemoryChunk) -> str | None:
|
||
src = await db.get(MemorySource, chunk.source_id)
|
||
return src.source_type if src else None
|
||
|
||
|
||
def _build_curation_details(
|
||
*,
|
||
ctx: dict[str, Any],
|
||
winner_id: str,
|
||
layers: int,
|
||
distance: float,
|
||
) -> dict[str, Any]:
|
||
details: dict[str, Any] = {
|
||
"reason": "near_duplicate",
|
||
"similar_to_chunk_id": winner_id,
|
||
"layers": layers,
|
||
"cosine_distance": distance,
|
||
}
|
||
for key in (
|
||
"trigger_source",
|
||
"trigger_time",
|
||
"pipeline_run_id",
|
||
"memoir_correlation_id",
|
||
"request_id",
|
||
"story_ids",
|
||
"story_dispatch_ids",
|
||
"chapter_ids",
|
||
"recomposed_chapter_ids",
|
||
"chapters_to_enqueue",
|
||
"candidate_chunk_ids",
|
||
"candidate_source_ids",
|
||
):
|
||
value = ctx.get(key)
|
||
if value is not None:
|
||
details[key] = value
|
||
return details
|
||
|
||
|
||
def count_duplicate_layers(
|
||
*,
|
||
chunk: MemoryChunk,
|
||
neighbor: dict[str, Any],
|
||
distance: float,
|
||
similarity_threshold: float,
|
||
jaccard_min: float,
|
||
event_year_window: int,
|
||
) -> int:
|
||
n = 0
|
||
if embedding_layer_match(distance, similarity_threshold=similarity_threshold):
|
||
n += 1
|
||
if text_layer_match(
|
||
chunk.content or "", neighbor.get("content") or "", jaccard_min=jaccard_min
|
||
):
|
||
n += 1
|
||
if metadata_layer_match(chunk, neighbor, event_year_window=event_year_window):
|
||
n += 1
|
||
return n
|
||
|
||
|
||
async def _advance_cursor_past_excluded_only(
|
||
db: AsyncSession,
|
||
user_id: str,
|
||
cursor_ts: datetime,
|
||
cursor_id: str,
|
||
*,
|
||
max_steps: int,
|
||
) -> tuple[datetime, str] | None:
|
||
"""
|
||
当「待处理增量」为空但游标之后仍有 chunk 时,通常是因为连续 excluded。
|
||
将游标推进到这些 excluded 的最后一条,使后续运行能扫到未 excluded 的 chunk。
|
||
若游标之后没有任何 chunk,返回 None。
|
||
若游标之后第一条即非 excluded,返回 None(不应与空增量并存;由调用方打日志)。
|
||
"""
|
||
advanced = False
|
||
ts, cid = cursor_ts, cursor_id
|
||
for _ in range(max_steps):
|
||
nxt = await get_first_chunk_after_cursor(
|
||
db,
|
||
user_id=user_id,
|
||
after_cursor_ts=ts,
|
||
after_chunk_id=cid,
|
||
)
|
||
if nxt is None:
|
||
return (ts, cid) if advanced else None
|
||
ex = bool(nxt.is_excluded)
|
||
if not ex:
|
||
if advanced:
|
||
return (ts, cid)
|
||
logger.warning(
|
||
"memory_compaction_cursor_anomaly user_id={} cursor_ts={} cursor_id={} "
|
||
"next_chunk_id={} next_is_excluded=false",
|
||
user_id,
|
||
cursor_ts,
|
||
cursor_id,
|
||
nxt.id,
|
||
)
|
||
return None
|
||
nxt_ts = nxt.created_at or datetime.min.replace(tzinfo=timezone.utc)
|
||
if nxt_ts.tzinfo is None:
|
||
nxt_ts = nxt_ts.replace(tzinfo=timezone.utc)
|
||
ts, cid = nxt_ts, nxt.id
|
||
advanced = True
|
||
logger.warning(
|
||
"memory_compaction_cursor_skip_cap user_id={} steps={}",
|
||
user_id,
|
||
max_steps,
|
||
)
|
||
return (ts, cid) if advanced else None
|
||
|
||
|
||
async def run_memory_compaction(
|
||
db: AsyncSession, user_id: str, context: dict[str, Any] | None
|
||
) -> dict[str, Any]:
|
||
"""
|
||
对增量 chunk 做近重复软排除;调用方负责 commit。
|
||
|
||
Returns:
|
||
结构化结果,供日志与任务返回值使用。
|
||
"""
|
||
ctx = dict(context or {})
|
||
t0 = time.perf_counter()
|
||
pair = ctx.get("_cursor_pair_override")
|
||
if pair is None:
|
||
from app.core.memory_compaction_schedule import get_incremental_cursor_pair
|
||
|
||
cursor_ts, cursor_id = get_incremental_cursor_pair(user_id)
|
||
else:
|
||
cursor_ts, cursor_id = pair # type: ignore[misc]
|
||
|
||
cand_chunks = ctx.get("candidate_chunk_ids")
|
||
if cand_chunks is not None:
|
||
cand_chunks = [str(x) for x in cand_chunks]
|
||
cand_sources = ctx.get("candidate_source_ids")
|
||
if cand_sources is not None:
|
||
cand_sources = [str(x) for x in cand_sources]
|
||
has_candidate_filter = "candidate_chunk_ids" in ctx or "candidate_source_ids" in ctx
|
||
|
||
max_chunks = memory.compaction_max_chunks_per_run
|
||
incremental = await list_incremental_chunks_for_compaction(
|
||
db,
|
||
user_id=user_id,
|
||
after_cursor_ts=cursor_ts,
|
||
after_chunk_id=cursor_id,
|
||
limit=max_chunks,
|
||
candidate_chunk_ids=cand_chunks,
|
||
candidate_source_ids=cand_sources,
|
||
)
|
||
|
||
if not incremental:
|
||
ms = (time.perf_counter() - t0) * 1000
|
||
# 候选 id/source 收窄时,空集可能仅表示「交集为空」,不能盲目前进全局游标
|
||
if not has_candidate_filter:
|
||
tail = await _advance_cursor_past_excluded_only(
|
||
db,
|
||
user_id,
|
||
cursor_ts,
|
||
cursor_id,
|
||
max_steps=max(
|
||
memory.compaction_max_chunks_per_run * 2,
|
||
500,
|
||
),
|
||
)
|
||
if tail is not None:
|
||
new_ts, new_id = tail
|
||
logger.info(
|
||
"memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 "
|
||
"duration_ms={:.1f} skipped_reason=empty_incremental_cursor_advanced "
|
||
"trigger_source={}",
|
||
user_id,
|
||
ms,
|
||
ctx.get("trigger_source", ""),
|
||
)
|
||
return {
|
||
"chunks_scanned": 0,
|
||
"chunks_excluded": 0,
|
||
"candidates_considered": 0,
|
||
"new_cursor_ts": new_ts.isoformat(),
|
||
"new_cursor_id": new_id,
|
||
"duration_ms": round(ms, 1),
|
||
"skipped_reason": "empty_incremental_cursor_advanced",
|
||
}
|
||
logger.info(
|
||
"memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 "
|
||
"duration_ms={:.1f} skipped_reason=empty_incremental trigger_source={}",
|
||
user_id,
|
||
ms,
|
||
ctx.get("trigger_source", ""),
|
||
)
|
||
return {
|
||
"chunks_scanned": 0,
|
||
"chunks_excluded": 0,
|
||
"candidates_considered": 0,
|
||
"new_cursor_ts": None,
|
||
"new_cursor_id": None,
|
||
"duration_ms": round(ms, 1),
|
||
"skipped_reason": "empty_incremental",
|
||
}
|
||
|
||
sim_th = memory.compaction_chunk_similarity_threshold
|
||
min_layers = memory.compaction_min_layers_for_exclude
|
||
jaccard_min = memory.compaction_text_jaccard_min
|
||
year_w = memory.compaction_metadata_event_year_window
|
||
max_neighbors = memory.compaction_max_neighbors_per_chunk
|
||
max_excludes = memory.compaction_max_excludes_per_run
|
||
|
||
local_excluded: set[str] = set()
|
||
excludes_done = 0
|
||
candidates_considered = 0
|
||
|
||
# 稳定顺序:先处理较早写入的 chunk
|
||
incremental_sorted = sorted(
|
||
incremental,
|
||
key=lambda c: (c.created_at or datetime.min.replace(tzinfo=timezone.utc), c.id),
|
||
)
|
||
last_cursor_chunk: MemoryChunk | None = None
|
||
chunks_scanned_this_run = 0
|
||
|
||
for chunk in incremental_sorted:
|
||
if excludes_done >= max_excludes:
|
||
break
|
||
if chunk.id in local_excluded:
|
||
last_cursor_chunk = chunk
|
||
continue
|
||
row = await get_memory_chunk_for_user(db, chunk.id, user_id)
|
||
if row is None or row.is_excluded:
|
||
last_cursor_chunk = chunk
|
||
continue
|
||
emb = row.embedding
|
||
if emb is None:
|
||
ms = (time.perf_counter() - t0) * 1000
|
||
logger.info(
|
||
"memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} "
|
||
"candidates={} duration_ms={:.1f} skipped_reason=awaiting_embeddings "
|
||
"pending_chunk_id={} trigger_source={}",
|
||
user_id,
|
||
chunks_scanned_this_run,
|
||
excludes_done,
|
||
candidates_considered,
|
||
ms,
|
||
row.id,
|
||
ctx.get("trigger_source", ""),
|
||
)
|
||
return {
|
||
"chunks_scanned": chunks_scanned_this_run,
|
||
"chunks_excluded": excludes_done,
|
||
"candidates_considered": candidates_considered,
|
||
"new_cursor_ts": (
|
||
last_cursor_chunk.created_at.isoformat()
|
||
if last_cursor_chunk and last_cursor_chunk.created_at
|
||
else None
|
||
),
|
||
"new_cursor_id": last_cursor_chunk.id if last_cursor_chunk else None,
|
||
"duration_ms": round(ms, 1),
|
||
"skipped_reason": "awaiting_embeddings",
|
||
"pending_chunk_id": row.id,
|
||
}
|
||
|
||
chunks_scanned_this_run += 1
|
||
st_c = await _source_type_for_chunk(db, row)
|
||
neighbors = await search_nearest_chunks_for_compaction(
|
||
db,
|
||
user_id=user_id,
|
||
chunk_id=row.id,
|
||
query_embedding=list(emb),
|
||
limit=max_neighbors,
|
||
)
|
||
|
||
for nb in neighbors:
|
||
if excludes_done >= max_excludes:
|
||
break
|
||
nid = nb["id"]
|
||
if nid == row.id or nid in local_excluded:
|
||
continue
|
||
other = await get_memory_chunk_for_user(db, nid, user_id)
|
||
if other is None or other.is_excluded:
|
||
continue
|
||
|
||
dist = float(nb["distance"])
|
||
layers = count_duplicate_layers(
|
||
chunk=row,
|
||
neighbor=nb,
|
||
distance=dist,
|
||
similarity_threshold=sim_th,
|
||
jaccard_min=jaccard_min,
|
||
event_year_window=year_w,
|
||
)
|
||
if layers < min_layers:
|
||
continue
|
||
|
||
candidates_considered += 1
|
||
sc = canonical_score(
|
||
content=row.content or "",
|
||
metadata_json=row.metadata_json,
|
||
source_type=st_c,
|
||
)
|
||
sn = canonical_score(
|
||
content=nb.get("content") or "",
|
||
metadata_json=nb.get("metadata_json"),
|
||
source_type=nb.get("source_type"),
|
||
)
|
||
|
||
t_row = row.created_at or datetime.min.replace(tzinfo=timezone.utc)
|
||
t_nb = nb.get("created_at") or datetime.min.replace(tzinfo=timezone.utc)
|
||
if t_row.tzinfo is None:
|
||
t_row = t_row.replace(tzinfo=timezone.utc)
|
||
if t_nb.tzinfo is None:
|
||
t_nb = t_nb.replace(tzinfo=timezone.utc)
|
||
|
||
if sc > sn:
|
||
loser_id, winner_id = nid, row.id
|
||
elif sn > sc:
|
||
loser_id, winner_id = row.id, nid
|
||
else:
|
||
# 同分:保留更早写入
|
||
if t_row <= t_nb:
|
||
loser_id, winner_id = nid, row.id
|
||
else:
|
||
loser_id, winner_id = row.id, nid
|
||
|
||
loser = await get_memory_chunk_for_user(db, loser_id, user_id)
|
||
if loser is None or loser.is_excluded:
|
||
continue
|
||
|
||
ok = await set_chunk_excluded(db, loser_id, user_id, True)
|
||
if not ok:
|
||
continue
|
||
stale_n = await mark_facts_stale_for_excluded_chunk(
|
||
db, user_id=user_id, chunk_id=loser_id
|
||
)
|
||
if stale_n:
|
||
logger.info(
|
||
"memory_compaction_facts_staled user_id={} chunk_id={} count={}",
|
||
user_id,
|
||
loser_id,
|
||
stale_n,
|
||
)
|
||
await create_curation_action(
|
||
db,
|
||
user_id=user_id,
|
||
action_type="exclude",
|
||
target_type="chunk",
|
||
target_id=loser_id,
|
||
details=_build_curation_details(
|
||
ctx=ctx,
|
||
winner_id=winner_id,
|
||
layers=layers,
|
||
distance=dist,
|
||
),
|
||
)
|
||
excludes_done += 1
|
||
local_excluded.add(loser_id)
|
||
logger.info(
|
||
"memory_compaction_exclude user_id={} excluded_chunk_id={} "
|
||
"kept_chunk_id={} layers={} distance={:.4f}",
|
||
user_id,
|
||
loser_id,
|
||
winner_id,
|
||
layers,
|
||
dist,
|
||
)
|
||
if loser_id == row.id:
|
||
break
|
||
last_cursor_chunk = chunk
|
||
|
||
if last_cursor_chunk is None:
|
||
ms = (time.perf_counter() - t0) * 1000
|
||
logger.warning(
|
||
"memory_compaction_no_cursor_chunk user_id={} incremental_n={}",
|
||
user_id,
|
||
len(incremental_sorted),
|
||
)
|
||
return {
|
||
"chunks_scanned": 0,
|
||
"chunks_excluded": excludes_done,
|
||
"candidates_considered": candidates_considered,
|
||
"new_cursor_ts": None,
|
||
"new_cursor_id": None,
|
||
"duration_ms": round(ms, 1),
|
||
"skipped_reason": "no_cursor_chunk",
|
||
}
|
||
|
||
new_cursor_ts = last_cursor_chunk.created_at or datetime.min.replace(
|
||
tzinfo=timezone.utc
|
||
)
|
||
if new_cursor_ts.tzinfo is None:
|
||
new_cursor_ts = new_cursor_ts.replace(tzinfo=timezone.utc)
|
||
new_cursor_id = last_cursor_chunk.id
|
||
|
||
ms = (time.perf_counter() - t0) * 1000
|
||
logger.info(
|
||
"memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} "
|
||
"candidates={} duration_ms={:.1f} trigger_source={}",
|
||
user_id,
|
||
chunks_scanned_this_run,
|
||
excludes_done,
|
||
candidates_considered,
|
||
ms,
|
||
ctx.get("trigger_source", ""),
|
||
)
|
||
return {
|
||
"chunks_scanned": chunks_scanned_this_run,
|
||
"chunks_excluded": excludes_done,
|
||
"candidates_considered": candidates_considered,
|
||
"new_cursor_ts": new_cursor_ts.isoformat(),
|
||
"new_cursor_id": new_cursor_id,
|
||
"duration_ms": round(ms, 1),
|
||
"skipped_reason": None,
|
||
}
|