feat(api): Memory compaction 管线与调度修复,同步环境变量示例

Memory compaction(近重复 chunk 软排除)
- 新增 compaction 调度:Redis debounce、scheduler gate、增量游标;任务结束时 finalize,避免 gate 长期占用并处理运行期新 trigger。
- Celery memory_compaction_run:debounce 未到点则 retry;用户级 Redis 锁;成功路径更新游标并 finalize;异常时释放 scheduler gate 并 self.retry,避免静默卡死调度与瞬时失败不重试。
- compaction_service:多层判定 + canonical 打分;无 embedding 时停止前移游标(awaiting_embeddings);curation details 补全 trigger 等上下文。
- ingest_transcript_sync:同步路径尽力写入 embedding,与异步 ingest 行为对齐,避免 compaction 永远扫不到无向量 chunk。
- repo:新增 update_chunk_embedding_sync。
测试
- 扩展 test_memory_compaction:调度合并、finalize、ingest embedding、无向量游标、异常路径 gate+retry 等回归用
This commit is contained in:
Kevin
2026-03-30 10:46:35 +08:00
parent 0d999cb769
commit e884409410
15 changed files with 1699 additions and 7 deletions

View File

@@ -156,6 +156,20 @@ class Settings(BaseSettings):
memory_enrichment_enabled: bool = True
memory_enrichment_max_chars: int = Field(default=12000, ge=1000, le=100_000)
# ── Memory compaction近重复 chunk 软排除;事件触发 + Redis 防抖 + 用户锁)──
memory_compaction_enabled: bool = False
memory_compaction_debounce_seconds: int = Field(default=105, ge=10, le=3600)
memory_compaction_lock_ttl_seconds: int = Field(default=600, ge=60, le=7200)
memory_compaction_chunk_similarity_threshold: float = Field(
default=0.92, ge=0.5, le=0.999
)
memory_compaction_min_layers_for_exclude: int = Field(default=2, ge=1, le=3)
memory_compaction_max_chunks_per_run: int = Field(default=200, ge=1, le=10_000)
memory_compaction_max_excludes_per_run: int = Field(default=50, ge=1, le=1000)
memory_compaction_max_neighbors_per_chunk: int = Field(default=25, ge=5, le=100)
memory_compaction_text_jaccard_min: float = Field(default=0.55, ge=0.0, le=1.0)
memory_compaction_metadata_event_year_window: int = Field(default=1, ge=0, le=50)
# ── Liblib ───────────────────────────────────────────────
liblib_access_key: str = ""
liblib_secret_key: str = ""

View File

@@ -0,0 +1,230 @@
"""
Memory compaction 调度Redis debounce_until 主路径 + 尽量少发 Celery 任务。
每次触发将「最早可执行时间」推后;仅首个触发在窗口内注册一次 delayed task。
任务执行时若尚未到 debounce 截止时间则 self.retry 延后重试。
"""
from __future__ import annotations
import json
import math
import threading
import time
from datetime import datetime, timezone
from typing import Any
import redis
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
_DEBOUNCE_KEY = "memory_compaction:debounce_until:{user_id}"
_SCHEDULER_KEY = "memory_compaction:scheduler_gate:{user_id}"
_CURSOR_KEY = "memory_compaction:chunk_cursor:{user_id}"
# 与 memory_chunks.id 字典序比较用(首跑起点)
_CHUNK_CURSOR_ID_ZERO = "00000000-0000-0000-0000-000000000000"
_redis_client: redis.Redis | None = None
_redis_lock = threading.Lock()
def _get_redis() -> redis.Redis:
"""进程内复用单个 Redis 客户端(内置连接池),避免每次调用新建连接。"""
global _redis_client
if _redis_client is None:
with _redis_lock:
if _redis_client is None:
_redis_client = redis.from_url(
settings.redis_url, decode_responses=True
)
return _redis_client
def _debounce_key_ttl_seconds() -> int:
"""debounce 键与 scheduler_gate 共用 TTL避免 gate 先过期导致重复 apply_async。"""
return int(settings.memory_compaction_debounce_seconds) + 900
def debounce_key(user_id: str) -> str:
return _DEBOUNCE_KEY.format(user_id=user_id)
def scheduler_key(user_id: str) -> str:
return _SCHEDULER_KEY.format(user_id=user_id)
def chunk_cursor_key(user_id: str) -> str:
return _CURSOR_KEY.format(user_id=user_id)
def read_debounce_deadline_ts(user_id: str) -> float | None:
"""若未到截止时间,任务应 retry 延后执行。"""
r = _get_redis()
raw = r.get(debounce_key(user_id))
if raw is None:
return None
try:
return float(raw)
except ValueError:
return None
def clear_debounce_deadline(user_id: str) -> None:
try:
_get_redis().delete(debounce_key(user_id))
except Exception:
pass
def release_scheduler_gate(user_id: str) -> None:
try:
_get_redis().delete(scheduler_key(user_id))
except Exception:
pass
def _enqueue_memory_compaction_task(
user_id: str, context: dict[str, Any] | None, *, countdown: int
) -> None:
from app.tasks.memory_compaction_tasks import memory_compaction_run
memory_compaction_run.apply_async(
args=[user_id, context or {}],
countdown=max(0, countdown),
)
def _schedule_for_existing_deadline(
user_id: str, context: dict[str, Any] | None, *, deadline_ts: float
) -> bool:
r = _get_redis()
ttl = _debounce_key_ttl_seconds()
r.set(debounce_key(user_id), str(deadline_ts), ex=ttl)
if not r.set(scheduler_key(user_id), "1", nx=True, ex=ttl):
return False
countdown = max(0, math.ceil(deadline_ts - time.time()))
try:
_enqueue_memory_compaction_task(user_id, context, countdown=countdown)
logger.info(
"memory_compaction scheduled: user_id={} countdown_s={} deadline_ts={}",
user_id,
countdown,
deadline_ts,
)
return True
except Exception as exc:
logger.warning(
"memory_compaction schedule failed user_id={} err={}", user_id, exc
)
release_scheduler_gate(user_id)
raise
def get_incremental_cursor_pair(user_id: str) -> tuple[datetime, str]:
"""
增量游标:(created_at, id) 字典序严格大于游标的 chunk 待处理。
兼容旧值:仅 ISO 时间字符串(视为 id 为零 UUID
"""
r = _get_redis()
raw = r.get(chunk_cursor_key(user_id))
epoch = datetime(1970, 1, 1, tzinfo=timezone.utc)
if not raw:
return epoch, _CHUNK_CURSOR_ID_ZERO
try:
if raw.strip().startswith("{"):
data = json.loads(raw)
ts = data.get("ts") or data.get("cursor_ts")
cid = data.get("id") or data.get("chunk_id") or _CHUNK_CURSOR_ID_ZERO
if not isinstance(ts, str):
return epoch, _CHUNK_CURSOR_ID_ZERO
dt = datetime.fromisoformat(ts)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt, str(cid)
dt = datetime.fromisoformat(raw)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt, _CHUNK_CURSOR_ID_ZERO
except (ValueError, json.JSONDecodeError, TypeError):
return epoch, _CHUNK_CURSOR_ID_ZERO
def set_incremental_cursor_pair(user_id: str, dt: datetime, chunk_id: str) -> None:
r = _get_redis()
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
payload = json.dumps(
{"ts": dt.isoformat(), "id": chunk_id},
ensure_ascii=False,
)
r.set(chunk_cursor_key(user_id), payload, ex=86400 * 400)
def finalize_memory_compaction_run(
user_id: str,
*,
observed_deadline_ts: float | None,
context: dict[str, Any] | None,
) -> None:
"""
任务结束后释放 scheduler gate。
若运行期间出现更新的 triggerdeadline 向后推),则基于现存 deadline 再注册一次 task
否则清理已消费的 debounce deadline。
"""
latest_deadline = read_debounce_deadline_ts(user_id)
release_scheduler_gate(user_id)
if latest_deadline is None:
return
if (
observed_deadline_ts is not None
and latest_deadline > observed_deadline_ts + 1e-6
):
_schedule_for_existing_deadline(
user_id,
context,
deadline_ts=latest_deadline,
)
return
clear_debounce_deadline(user_id)
def schedule_memory_compaction_run(
user_id: str, context: dict[str, Any] | None
) -> None:
"""在 memoir / 章节重组等成功后调用:推后 debounce 截止时间并尽量只派发一次延迟任务。"""
if not settings.memory_compaction_enabled:
return
r = _get_redis()
now = time.time()
quiet = float(settings.memory_compaction_debounce_seconds)
new_deadline = now + quiet
raw = r.get(debounce_key(user_id))
if raw is not None:
try:
prev = float(raw)
new_deadline = max(new_deadline, prev)
except ValueError:
pass
ttl = _debounce_key_ttl_seconds()
r.set(debounce_key(user_id), str(new_deadline), ex=ttl)
if r.get(scheduler_key(user_id)) is not None:
logger.debug(
"memory_compaction schedule merged: user_id={} deadline_ts={}",
user_id,
new_deadline,
)
return
_schedule_for_existing_deadline(user_id, context, deadline_ts=new_deadline)

View File

@@ -1,34 +1,49 @@
"""Small Redis lock helpers for background tasks."""
from dataclasses import dataclass
import threading
import uuid
from dataclasses import dataclass
import redis
from app.core.config import settings
_redis_lock_client: redis.Redis | None = None
_redis_lock_init_lock = threading.Lock()
def _get_redis_lock_client() -> redis.Redis:
"""进程内复用单个 Redis 客户端decode_responses=False与锁 token 字节一致)。"""
global _redis_lock_client
if _redis_lock_client is None:
with _redis_lock_init_lock:
if _redis_lock_client is None:
_redis_lock_client = redis.from_url(
settings.redis_url, decode_responses=False
)
return _redis_lock_client
@dataclass(frozen=True)
class RedisLockHandle:
client: redis.Redis
key: str
token: bytes
def acquire_redis_lock(key: str, *, ttl_seconds: int) -> RedisLockHandle | None:
"""Acquire a single-owner Redis lock or return None when unavailable."""
client = redis.from_url(settings.redis_url, decode_responses=False)
client = _get_redis_lock_client()
token = uuid.uuid4().hex.encode("utf-8")
if not client.set(key, token, nx=True, ex=ttl_seconds):
return None
return RedisLockHandle(client=client, key=key, token=token)
return RedisLockHandle(key=key, token=token)
def release_redis_lock(handle: RedisLockHandle | None) -> None:
"""Release the lock only if we still own it."""
if handle is None:
return
handle.client.eval(
_get_redis_lock_client().eval(
"""
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])

View File

@@ -0,0 +1,488 @@
"""
Memory compaction增量 chunk 近重复检测软排除is_excluded + MemoryCurationAction
仅依赖 repo / settings供 Celery 同步任务调用。
"""
from __future__ import annotations
import re
import time
from datetime import datetime, timezone
from typing import Any
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.logging import get_logger
from app.features.memory.models import MemoryChunk, MemorySource
from app.features.memory.repo import (
create_curation_action_sync,
get_first_chunk_after_cursor_sync,
get_memory_chunk_sync,
list_incremental_chunks_for_compaction_sync,
search_nearest_chunks_for_compaction_sync,
set_chunk_excluded_sync,
)
logger = get_logger(__name__)
_WS_RE = re.compile(r"\s+")
_PUNCT_RE = re.compile(r"[^\w\s\u4e00-\u9fff]", re.UNICODE)
def _normalize_text(s: str) -> str:
t = (s or "").strip().lower()
t = _PUNCT_RE.sub(" ", t)
return _WS_RE.sub(" ", t).strip()
def _word_jaccard(a: str, b: str) -> float:
wa = set(_normalize_text(a).split())
wb = set(_normalize_text(b).split())
if not wa or not wb:
return 0.0
return len(wa & wb) / len(wa | wb)
def text_layer_match(a: str, b: str, *, jaccard_min: float) -> bool:
if _word_jaccard(a, b) >= jaccard_min:
return True
na, nb = _normalize_text(a), _normalize_text(b)
if len(na) >= 12 and len(nb) >= 12 and (na in nb or nb in na):
return True
return False
def metadata_layer_match(
c: MemoryChunk,
n: dict[str, Any],
*,
event_year_window: int,
) -> bool:
if c.source_id == n.get("source_id"):
return True
cy, ny = c.event_year, n.get("event_year")
if cy is not None and ny is not None:
if abs(int(cy) - int(ny)) <= event_year_window:
return True
return False
def embedding_layer_match(distance: float, *, similarity_threshold: float) -> bool:
sim = 1.0 - distance
return sim >= similarity_threshold
def canonical_score(
*,
content: str,
metadata_json: Any,
source_type: str | None,
) -> float:
score = float(len(content or ""))
bonus = 0.0
if source_type in ("draft", "story"):
bonus += 8.0
if isinstance(metadata_json, dict):
bonus += min(len(metadata_json), 24) * 0.15
return score + bonus
def _source_type_for_chunk(session: Session, chunk: MemoryChunk) -> str | None:
src = session.get(MemorySource, chunk.source_id)
return src.source_type if src else None
def _build_curation_details(
*,
ctx: dict[str, Any],
winner_id: str,
layers: int,
distance: float,
) -> dict[str, Any]:
details: dict[str, Any] = {
"reason": "near_duplicate",
"similar_to_chunk_id": winner_id,
"layers": layers,
"cosine_distance": distance,
}
for key in (
"trigger_source",
"trigger_time",
"pipeline_run_id",
"request_id",
"story_ids",
"story_dispatch_ids",
"chapter_ids",
"recomposed_chapter_ids",
"chapters_to_enqueue",
"candidate_chunk_ids",
"candidate_source_ids",
):
value = ctx.get(key)
if value is not None:
details[key] = value
return details
def count_duplicate_layers(
*,
chunk: MemoryChunk,
neighbor: dict[str, Any],
distance: float,
similarity_threshold: float,
jaccard_min: float,
event_year_window: int,
) -> int:
n = 0
if embedding_layer_match(distance, similarity_threshold=similarity_threshold):
n += 1
if text_layer_match(
chunk.content or "", neighbor.get("content") or "", jaccard_min=jaccard_min
):
n += 1
if metadata_layer_match(chunk, neighbor, event_year_window=event_year_window):
n += 1
return n
def _advance_cursor_past_excluded_only_sync(
session: Session,
user_id: str,
cursor_ts: datetime,
cursor_id: str,
*,
max_steps: int,
) -> tuple[datetime, str] | None:
"""
当「待处理增量」为空但游标之后仍有 chunk 时,通常是因为连续 excluded。
将游标推进到这些 excluded 的最后一条,使后续运行能扫到未 excluded 的 chunk。
若游标之后没有任何 chunk返回 None。
若游标之后第一条即非 excluded返回 None不应与空增量并存由调用方打日志
"""
advanced = False
ts, cid = cursor_ts, cursor_id
for _ in range(max_steps):
nxt = get_first_chunk_after_cursor_sync(
session,
user_id=user_id,
after_cursor_ts=ts,
after_chunk_id=cid,
)
if nxt is None:
return (ts, cid) if advanced else None
ex = bool(nxt.is_excluded)
if not ex:
if advanced:
return (ts, cid)
logger.warning(
"memory_compaction_cursor_anomaly user_id={} cursor_ts={} cursor_id={} "
"next_chunk_id={} next_is_excluded=false",
user_id,
cursor_ts,
cursor_id,
nxt.id,
)
return None
nxt_ts = nxt.created_at or datetime.min.replace(tzinfo=timezone.utc)
if nxt_ts.tzinfo is None:
nxt_ts = nxt_ts.replace(tzinfo=timezone.utc)
ts, cid = nxt_ts, nxt.id
advanced = True
logger.warning(
"memory_compaction_cursor_skip_cap user_id={} steps={}",
user_id,
max_steps,
)
return (ts, cid) if advanced else None
def run_memory_compaction_sync(
session: Session, user_id: str, context: dict[str, Any] | None
) -> dict[str, Any]:
"""
对增量 chunk 做近重复软排除;调用方负责 commit。
Returns:
结构化结果,供日志与任务返回值使用。
"""
ctx = dict(context or {})
t0 = time.perf_counter()
pair = ctx.get("_cursor_pair_override")
if pair is None:
from app.core.memory_compaction_schedule import get_incremental_cursor_pair
cursor_ts, cursor_id = get_incremental_cursor_pair(user_id)
else:
cursor_ts, cursor_id = pair # type: ignore[misc]
cand_chunks = ctx.get("candidate_chunk_ids")
if cand_chunks is not None:
cand_chunks = [str(x) for x in cand_chunks]
cand_sources = ctx.get("candidate_source_ids")
if cand_sources is not None:
cand_sources = [str(x) for x in cand_sources]
has_candidate_filter = "candidate_chunk_ids" in ctx or "candidate_source_ids" in ctx
max_chunks = settings.memory_compaction_max_chunks_per_run
incremental = list_incremental_chunks_for_compaction_sync(
session,
user_id=user_id,
after_cursor_ts=cursor_ts,
after_chunk_id=cursor_id,
limit=max_chunks,
candidate_chunk_ids=cand_chunks,
candidate_source_ids=cand_sources,
)
if not incremental:
ms = (time.perf_counter() - t0) * 1000
# 候选 id/source 收窄时,空集可能仅表示「交集为空」,不能盲目前进全局游标
if not has_candidate_filter:
tail = _advance_cursor_past_excluded_only_sync(
session,
user_id,
cursor_ts,
cursor_id,
max_steps=max(
settings.memory_compaction_max_chunks_per_run * 2,
500,
),
)
if tail is not None:
new_ts, new_id = tail
logger.info(
"memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 "
"duration_ms={:.1f} skipped_reason=empty_incremental_cursor_advanced "
"trigger_source={}",
user_id,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": 0,
"chunks_excluded": 0,
"candidates_considered": 0,
"new_cursor_ts": new_ts.isoformat(),
"new_cursor_id": new_id,
"duration_ms": round(ms, 1),
"skipped_reason": "empty_incremental_cursor_advanced",
}
logger.info(
"memory_compaction_done user_id={} chunks_scanned=0 chunks_excluded=0 "
"duration_ms={:.1f} skipped_reason=empty_incremental trigger_source={}",
user_id,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": 0,
"chunks_excluded": 0,
"candidates_considered": 0,
"new_cursor_ts": None,
"new_cursor_id": None,
"duration_ms": round(ms, 1),
"skipped_reason": "empty_incremental",
}
sim_th = settings.memory_compaction_chunk_similarity_threshold
min_layers = settings.memory_compaction_min_layers_for_exclude
jaccard_min = settings.memory_compaction_text_jaccard_min
year_w = settings.memory_compaction_metadata_event_year_window
max_neighbors = settings.memory_compaction_max_neighbors_per_chunk
max_excludes = settings.memory_compaction_max_excludes_per_run
local_excluded: set[str] = set()
excludes_done = 0
candidates_considered = 0
# 稳定顺序:先处理较早写入的 chunk
incremental_sorted = sorted(
incremental,
key=lambda c: (c.created_at or datetime.min.replace(tzinfo=timezone.utc), c.id),
)
last_cursor_chunk: MemoryChunk | None = None
chunks_scanned_this_run = 0
for chunk in incremental_sorted:
if excludes_done >= max_excludes:
break
if chunk.id in local_excluded:
last_cursor_chunk = chunk
continue
row = get_memory_chunk_sync(session, chunk.id, user_id)
if row is None or row.is_excluded:
last_cursor_chunk = chunk
continue
emb = row.embedding
if emb is None:
ms = (time.perf_counter() - t0) * 1000
logger.info(
"memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} "
"candidates={} duration_ms={:.1f} skipped_reason=awaiting_embeddings "
"pending_chunk_id={} trigger_source={}",
user_id,
chunks_scanned_this_run,
excludes_done,
candidates_considered,
ms,
row.id,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": chunks_scanned_this_run,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": (
last_cursor_chunk.created_at.isoformat()
if last_cursor_chunk and last_cursor_chunk.created_at
else None
),
"new_cursor_id": last_cursor_chunk.id if last_cursor_chunk else None,
"duration_ms": round(ms, 1),
"skipped_reason": "awaiting_embeddings",
"pending_chunk_id": row.id,
}
chunks_scanned_this_run += 1
st_c = _source_type_for_chunk(session, row)
neighbors = search_nearest_chunks_for_compaction_sync(
session,
user_id=user_id,
chunk_id=row.id,
query_embedding=list(emb),
limit=max_neighbors,
)
for nb in neighbors:
if excludes_done >= max_excludes:
break
nid = nb["id"]
if nid == row.id or nid in local_excluded:
continue
other = get_memory_chunk_sync(session, nid, user_id)
if other is None or other.is_excluded:
continue
dist = float(nb["distance"])
layers = count_duplicate_layers(
chunk=row,
neighbor=nb,
distance=dist,
similarity_threshold=sim_th,
jaccard_min=jaccard_min,
event_year_window=year_w,
)
if layers < min_layers:
continue
candidates_considered += 1
sc = canonical_score(
content=row.content or "",
metadata_json=row.metadata_json,
source_type=st_c,
)
sn = canonical_score(
content=nb.get("content") or "",
metadata_json=nb.get("metadata_json"),
source_type=nb.get("source_type"),
)
t_row = row.created_at or datetime.min.replace(tzinfo=timezone.utc)
t_nb = nb.get("created_at") or datetime.min.replace(tzinfo=timezone.utc)
if t_row.tzinfo is None:
t_row = t_row.replace(tzinfo=timezone.utc)
if t_nb.tzinfo is None:
t_nb = t_nb.replace(tzinfo=timezone.utc)
if sc > sn:
loser_id, winner_id = nid, row.id
elif sn > sc:
loser_id, winner_id = row.id, nid
else:
# 同分:保留更早写入
if t_row <= t_nb:
loser_id, winner_id = nid, row.id
else:
loser_id, winner_id = row.id, nid
loser = get_memory_chunk_sync(session, loser_id, user_id)
if loser is None or loser.is_excluded:
continue
ok = set_chunk_excluded_sync(session, loser_id, user_id, True)
if not ok:
continue
create_curation_action_sync(
session,
user_id=user_id,
action_type="exclude",
target_type="chunk",
target_id=loser_id,
details=_build_curation_details(
ctx=ctx,
winner_id=winner_id,
layers=layers,
distance=dist,
),
)
excludes_done += 1
local_excluded.add(loser_id)
logger.info(
"memory_compaction_exclude user_id={} excluded_chunk_id={} "
"kept_chunk_id={} layers={} distance={:.4f}",
user_id,
loser_id,
winner_id,
layers,
dist,
)
if loser_id == row.id:
break
last_cursor_chunk = chunk
if last_cursor_chunk is None:
ms = (time.perf_counter() - t0) * 1000
logger.warning(
"memory_compaction_no_cursor_chunk user_id={} incremental_n={}",
user_id,
len(incremental_sorted),
)
return {
"chunks_scanned": 0,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": None,
"new_cursor_id": None,
"duration_ms": round(ms, 1),
"skipped_reason": "no_cursor_chunk",
}
new_cursor_ts = last_cursor_chunk.created_at or datetime.min.replace(
tzinfo=timezone.utc
)
if new_cursor_ts.tzinfo is None:
new_cursor_ts = new_cursor_ts.replace(tzinfo=timezone.utc)
new_cursor_id = last_cursor_chunk.id
ms = (time.perf_counter() - t0) * 1000
logger.info(
"memory_compaction_done user_id={} chunks_scanned={} chunks_excluded={} "
"candidates={} duration_ms={:.1f} trigger_source={}",
user_id,
chunks_scanned_this_run,
excludes_done,
candidates_considered,
ms,
ctx.get("trigger_source", ""),
)
return {
"chunks_scanned": chunks_scanned_this_run,
"chunks_excluded": excludes_done,
"candidates_considered": candidates_considered,
"new_cursor_ts": new_cursor_ts.isoformat(),
"new_cursor_id": new_cursor_id,
"duration_ms": round(ms, 1),
"skipped_reason": None,
}

View File

@@ -3,7 +3,7 @@
import uuid
from datetime import datetime, timezone
from sqlalchemy import delete, or_, select, text
from sqlalchemy import delete, literal, or_, select, text, tuple_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
@@ -115,6 +115,15 @@ def update_chunk_fts_sync(session: Session, chunk_id: str) -> None:
)
def update_chunk_embedding_sync(
session: Session, chunk_id: str, embedding: list[float]
) -> None:
"""Update chunk embedding (sync). Caller must commit."""
chunk = session.get(MemoryChunk, chunk_id)
if chunk:
chunk.embedding = embedding
async def update_chunk_embedding(
db: AsyncSession, chunk_id: str, embedding: list[float]
) -> None:
@@ -741,6 +750,130 @@ async def get_memory_chunk_for_user(
return row
def get_memory_chunk_sync(
session: Session, chunk_id: str, user_id: str
) -> MemoryChunk | None:
row = session.get(MemoryChunk, chunk_id)
if row is None or row.user_id != user_id:
return None
return row
def set_chunk_excluded_sync(
session: Session, chunk_id: str, user_id: str, excluded: bool
) -> bool:
row = get_memory_chunk_sync(session, chunk_id, user_id)
if row is None:
return False
row.is_excluded = excluded
return True
def list_incremental_chunks_for_compaction_sync(
session: Session,
*,
user_id: str,
after_cursor_ts: datetime,
after_chunk_id: str,
limit: int,
candidate_chunk_ids: list[str] | None = None,
candidate_source_ids: list[str] | None = None,
) -> list[MemoryChunk]:
"""增量 chunk(created_at, id) 字典序大于游标;可选与候选 id/source 求交。"""
stmt = (
select(MemoryChunk)
.where(
MemoryChunk.user_id == user_id,
tuple_(MemoryChunk.created_at, MemoryChunk.id)
> tuple_(literal(after_cursor_ts), literal(after_chunk_id)),
or_(MemoryChunk.is_excluded.is_(False), MemoryChunk.is_excluded.is_(None)),
)
.order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc())
.limit(limit)
)
if candidate_chunk_ids:
stmt = stmt.where(MemoryChunk.id.in_(candidate_chunk_ids))
if candidate_source_ids:
stmt = stmt.where(MemoryChunk.source_id.in_(candidate_source_ids))
rows = session.execute(stmt).unique().scalars().all()
return list(rows)
def get_first_chunk_after_cursor_sync(
session: Session,
*,
user_id: str,
after_cursor_ts: datetime,
after_chunk_id: str,
) -> MemoryChunk | None:
"""游标之后字典序第一条 chunk含 excluded用于空增量时推进游标。"""
stmt = (
select(MemoryChunk)
.where(
MemoryChunk.user_id == user_id,
tuple_(MemoryChunk.created_at, MemoryChunk.id)
> tuple_(literal(after_cursor_ts), literal(after_chunk_id)),
)
.order_by(MemoryChunk.created_at.asc(), MemoryChunk.id.asc())
.limit(1)
)
return session.execute(stmt).scalars().first()
def search_nearest_chunks_for_compaction_sync(
session: Session,
*,
user_id: str,
chunk_id: str,
query_embedding: list[float],
limit: int,
) -> list[dict]:
"""
按余弦距离取 Top-K 近邻不含自身。pgvector `<=>` 为 cosine distance。
返回 dict: id, content, source_id, event_year, metadata_json, source_type,
distance, created_at
"""
if not query_embedding:
return []
stmt = text("""
SELECT mc.id, mc.content, mc.source_id, mc.event_year, mc.metadata_json,
ms.source_type, mc.created_at,
(mc.embedding <=> :emb::vector) AS distance
FROM memory_chunks mc
JOIN memory_sources ms ON ms.id = mc.source_id
WHERE mc.user_id = :user_id
AND (mc.is_excluded IS NOT TRUE OR mc.is_excluded = false)
AND mc.embedding IS NOT NULL
AND mc.id != :chunk_id
ORDER BY mc.embedding <=> :emb2::vector
LIMIT :lim
""")
emb_str = "[" + ",".join(str(x) for x in query_embedding) + "]"
result = session.execute(
stmt,
{
"user_id": user_id,
"chunk_id": chunk_id,
"emb": emb_str,
"emb2": emb_str,
"lim": limit,
},
)
return [
{
"id": r["id"],
"content": r["content"],
"source_id": r["source_id"],
"event_year": r["event_year"],
"metadata_json": r["metadata_json"],
"source_type": r["source_type"],
"created_at": r["created_at"],
"distance": float(r["distance"]),
}
for r in result.mappings().all()
]
async def set_chunk_excluded(
db: AsyncSession, chunk_id: str, user_id: str, excluded: bool
) -> bool:

View File

@@ -9,6 +9,8 @@ Celery 侧使用 `ingest_transcript_sync` + `retrieve_evidence_sync`,与异步
`api/docs/memory-retrieval.md`。
"""
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
@@ -184,13 +186,15 @@ def ingest_transcript_sync(
) -> str:
"""
Sync transcript ingest for Celery tasks.
Creates source + chunks + FTS. Skips embedding (async).
Creates source + chunks + FTS, and best-effort populates embeddings.
Returns source_id.
"""
from app.core.dependencies import get_embedding_provider
from app.features.memory.chunker import chunk_transcript
from app.features.memory.repo import (
create_chunk_sync,
create_source_sync,
update_chunk_embedding_sync,
update_chunk_fts_sync,
)
@@ -207,6 +211,7 @@ def ingest_transcript_sync(
session.flush()
chunks_text = chunk_transcript(transcript.strip())
chunk_records: list[tuple[str, str]] = []
for i, content in enumerate(chunks_text):
chunk = create_chunk_sync(
session,
@@ -216,8 +221,22 @@ def ingest_transcript_sync(
chunk_index=i,
)
session.flush()
chunk_records.append((chunk.id, content))
update_chunk_fts_sync(session, chunk.id)
try:
embedding_provider = get_embedding_provider()
if chunk_records and embedding_provider is not None:
texts = [content for _, content in chunk_records]
embeddings = asyncio.run(embedding_provider.embed_texts(texts))
for (chunk_id, _), emb in zip(chunk_records, embeddings):
if emb:
update_chunk_embedding_sync(session, chunk_id, emb)
except Exception as e:
logger.warning(
"memory embedding 跳过(sync): {} exc_type={}", e, type(e).__name__
)
try:
from app.core.config import settings
from app.features.memory.enrichment import enrich_memory_after_ingest_sync

View File

@@ -5,6 +5,7 @@ Celery 任务模块
from .celery_app import celery_app
from .chapter_cover_tasks import generate_chapter_cover
from .memoir_tasks import process_memoir_segments
from .memory_compaction_tasks import memory_compaction_run
from .story_image_tasks import generate_story_image
__all__ = [
@@ -12,4 +13,5 @@ __all__ = [
"process_memoir_segments",
"generate_chapter_cover",
"generate_story_image",
"memory_compaction_run",
]

View File

@@ -34,6 +34,7 @@ celery_app = Celery(
"app.tasks.story_image_tasks",
"app.tasks.chapter_cover_tasks",
"app.tasks.chapter_compose_tasks",
"app.tasks.memory_compaction_tasks",
],
)

View File

@@ -1,20 +1,27 @@
"""Celerystory 变更后重组关联章节的 canonical_markdown物化视图"""
from datetime import datetime, timezone
from celery import shared_task
from sqlalchemy import select
from app.core.db import get_sync_db
from app.core.logging import get_logger
from app.core.memory_compaction_schedule import schedule_memory_compaction_run
from app.features.memoir import repo as memoir_repo
from app.features.memoir.models import Chapter, ChapterStoryLink
from app.features.story.models import Story
logger = get_logger(__name__)
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def recompose_chapters_for_story(self, story_id: str) -> dict:
user_id: str | None = None
try:
with get_sync_db() as session:
story = session.get(Story, story_id)
user_id = story.user_id if story else None
stmt = (
select(Chapter.id)
.join(
@@ -30,6 +37,17 @@ def recompose_chapters_for_story(self, story_id: str) -> dict:
for cid in ids:
memoir_repo.compose_chapter_from_story_links_sync(session, cid)
session.commit()
if user_id:
schedule_memory_compaction_run(
user_id,
{
"trigger_source": "chapter_recompose",
"trigger_time": datetime.now(timezone.utc).isoformat(),
"pipeline_run_id": str(self.request.id),
"story_ids": [story_id],
"recomposed_chapter_ids": ids,
},
)
logger.info(
"recompose_chapters_for_story: story={} recomposed_chapters={}",
story_id,

View File

@@ -18,6 +18,7 @@ from app.agents.state_schema import MemoirStateSchema, SlotData, default_state
from app.core.db import get_sync_db
from app.core.dependencies import get_llm_provider
from app.core.logging import get_logger
from app.core.memory_compaction_schedule import schedule_memory_compaction_run
from app.features.conversation.models import Segment
from app.features.memoir.cover_eligibility import (
chapter_needs_cover_enqueue,
@@ -418,6 +419,17 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
if try_enqueue_generate_chapter_cover(chapter_id, source="pipeline"):
logger.info(f"派发章节封面任务: chapter={chapter_id}")
schedule_memory_compaction_run(
user_id,
{
"trigger_source": "memoir_segments",
"trigger_time": datetime.now(timezone.utc).isoformat(),
"pipeline_run_id": str(task_id),
"story_dispatch_ids": sorted(story_dispatch_ids),
"chapters_to_enqueue": sorted(chapters_to_enqueue),
},
)
categories_processed = sorted(prepared.category_to_segments.keys())
logger.info(
"回忆录处理完成: user_id={} task_id={} segment_count={} "

View File

@@ -0,0 +1,79 @@
"""Celerymemory compaction近重复 chunk 软排除)。"""
from __future__ import annotations
import time
from datetime import datetime
from typing import Any
from celery import shared_task
from app.core.config import settings
from app.core.db import get_sync_db
from app.core.logging import get_logger
from app.core.memory_compaction_schedule import (
finalize_memory_compaction_run,
read_debounce_deadline_ts,
release_scheduler_gate,
set_incremental_cursor_pair,
)
from app.core.redis_lock import acquire_redis_lock, release_redis_lock
from app.features.memory.compaction_service import run_memory_compaction_sync
logger = get_logger(__name__)
@shared_task(bind=True, max_retries=12, default_retry_delay=20)
def memory_compaction_run(
self, user_id: str, context: dict[str, Any] | None = None
) -> dict[str, Any]:
if not settings.memory_compaction_enabled:
return {"skipped": True, "reason": "disabled"}
ctx = dict(context or {})
deadline = read_debounce_deadline_ts(user_id)
now = time.time()
if deadline is not None and now < deadline:
delay = max(1.0, deadline - now)
raise self.retry(countdown=int(delay))
lock = acquire_redis_lock(
f"lock:memory_compaction:{user_id}",
ttl_seconds=settings.memory_compaction_lock_ttl_seconds,
)
if lock is None:
logger.info(
"memory_compaction_skipped user_id={} skipped_reason=lock_not_acquired",
user_id,
)
out = {"skipped": True, "reason": "lock_not_acquired"}
finalize_memory_compaction_run(
user_id,
observed_deadline_ts=deadline,
context=ctx,
)
return out
try:
with get_sync_db() as session:
out = run_memory_compaction_sync(session, user_id, ctx)
session.commit()
if out.get("new_cursor_ts") and out.get("new_cursor_id") is not None:
set_incremental_cursor_pair(
user_id,
datetime.fromisoformat(out["new_cursor_ts"]),
str(out["new_cursor_id"]),
)
finalize_memory_compaction_run(
user_id,
observed_deadline_ts=deadline,
context=ctx,
)
return out
except Exception as exc:
logger.warning("memory_compaction_run failed user_id={} err={}", user_id, exc)
release_scheduler_gate(user_id)
raise self.retry(exc=exc)
finally:
release_redis_lock(lock)