Files
life-echo/api/app/features/memory/enrichment.py

278 lines
8.6 KiB
Python
Raw Normal View History

"""
Transcript ingest 之后的记忆富化摘要事实时间线
Celerysync MemoryService.ingestasync调用失败仅打日志不阻断主流程
"""
from __future__ import annotations
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.core.logging import get_logger
from app.features.memory.extractor import (
extract_facts_from_transcript_async,
extract_facts_from_transcript_sync,
)
from app.features.memory.models import MemoryChunk, MemorySummary
from app.features.memory.repo import (
create_memory_fact,
create_memory_fact_sync,
create_memory_summary,
create_memory_summary_sync,
create_timeline_event,
create_timeline_event_sync,
delete_timeline_events_by_memory_source,
delete_timeline_events_by_memory_source_sync,
list_chunks_for_source_sync,
upsert_rolling_summary_sync,
)
from app.features.memory.summarizer import (
generate_rolling_summary_async,
generate_rolling_summary_sync,
generate_session_summary_async,
generate_session_summary_sync,
)
from app.features.memory.enrichment_pipeline import dedupe_key, normalize_object_json
from app.features.memory.timeline import (
build_timeline_events_from_facts_async,
build_timeline_events_from_facts_sync,
)
logger = get_logger(__name__)
def _resolve_llm_sync() -> Any | None:
try:
from app.core.dependencies import get_llm_provider
return get_llm_provider().langchain_llm
except Exception as e:
logger.warning("memory enrichment 无法获取 LLM: {}", e)
return None
def enrich_memory_after_ingest_sync(
session: Session,
user_id: str,
source_id: str,
llm: Any | None = None,
) -> None:
from app.core.config import settings
if not settings.memory_enrichment_enabled:
return
if llm is None:
llm = _resolve_llm_sync()
if not llm:
return
chunks = list_chunks_for_source_sync(session, source_id)
if not chunks:
return
chunk_texts = [c.content for c in chunks]
chunk_ids = [c.id for c in chunks]
numbered = "\n\n".join(
f"[chunk_id={cid}]\n{txt}" for cid, txt in zip(chunk_ids, chunk_texts)
)
session_summary_text = generate_session_summary_sync(llm, chunk_texts)
if session_summary_text:
create_memory_summary_sync(
session,
user_id=user_id,
summary_type="session",
content=session_summary_text,
source_chunk_ids=chunk_ids,
)
existing_rolling = (
session.execute(
select(MemorySummary)
.where(
MemorySummary.user_id == user_id,
MemorySummary.summary_type == "rolling",
)
.order_by(MemorySummary.updated_at.desc())
.limit(1)
)
.unique()
.scalar_one_or_none()
)
existing_text = existing_rolling.content if existing_rolling else None
rolling_text = generate_rolling_summary_sync(llm, existing_text, chunk_texts)
if rolling_text:
upsert_rolling_summary_sync(
session,
user_id=user_id,
content=rolling_text,
source_chunk_ids=chunk_ids,
)
raw_facts = extract_facts_from_transcript_sync(llm, numbered)
seen: set[tuple] = set()
inserted: list[dict] = []
for f in raw_facts:
key = dedupe_key(f)
if key in seen:
continue
seen.add(key)
scid = f.get("source_chunk_id")
if scid and scid not in chunk_ids:
scid = chunk_ids[0] if chunk_ids else None
row = create_memory_fact_sync(
session,
user_id=user_id,
fact_type=f.get("fact_type") or "event",
subject=f.get("subject"),
predicate=f.get("predicate"),
object_json=normalize_object_json(f.get("object_json")),
confidence=float(f.get("confidence") or 0.75),
source_chunk_id=scid,
status="confirmed",
)
inserted.append(
{
"id": row.id,
"fact_type": row.fact_type,
"subject": row.subject,
"predicate": row.predicate,
"object_json": row.object_json,
}
)
if inserted:
delete_timeline_events_by_memory_source_sync(
session, user_id=user_id, memory_source_id=source_id
)
events = build_timeline_events_from_facts_sync(llm, inserted)
for ev in events:
create_timeline_event_sync(
session,
user_id=user_id,
event_year=ev.get("event_year"),
event_date=ev.get("event_date"),
title=ev["title"],
description=ev.get("description"),
source_fact_ids=ev.get("source_fact_ids") or None,
memory_source_id=source_id,
)
async def enrich_memory_after_ingest_async(
db: AsyncSession,
user_id: str,
source_id: str,
llm: Any | None = None,
) -> None:
from app.core.config import settings
if not settings.memory_enrichment_enabled:
return
if llm is None:
llm = _resolve_llm_sync()
if not llm:
return
stmt = (
select(MemoryChunk)
.where(MemoryChunk.source_id == source_id)
.order_by(MemoryChunk.chunk_index.asc())
)
result = await db.execute(stmt)
chunks = list(result.unique().scalars().all())
if not chunks:
return
chunk_texts = [c.content for c in chunks]
chunk_ids = [c.id for c in chunks]
numbered = "\n\n".join(
f"[chunk_id={cid}]\n{txt}" for cid, txt in zip(chunk_ids, chunk_texts)
)
session_summary_text = await generate_session_summary_async(llm, chunk_texts)
if session_summary_text:
await create_memory_summary(
db,
user_id=user_id,
summary_type="session",
content=session_summary_text,
source_chunk_ids=chunk_ids,
)
roll_stmt = (
select(MemorySummary)
.where(
MemorySummary.user_id == user_id,
MemorySummary.summary_type == "rolling",
)
.order_by(MemorySummary.updated_at.desc())
.limit(1)
)
r_result = await db.execute(roll_stmt)
existing_row = r_result.unique().scalar_one_or_none()
existing_text = existing_row.content if existing_row else None
rolling_text = await generate_rolling_summary_async(llm, existing_text, chunk_texts)
if rolling_text:
if existing_row:
existing_row.content = rolling_text
existing_row.source_chunk_ids = chunk_ids
else:
await create_memory_summary(
db,
user_id=user_id,
summary_type="rolling",
content=rolling_text,
source_chunk_ids=chunk_ids,
)
raw_facts = await extract_facts_from_transcript_async(llm, numbered)
seen: set[tuple] = set()
inserted: list[dict] = []
for f in raw_facts:
key = dedupe_key(f)
if key in seen:
continue
seen.add(key)
scid = f.get("source_chunk_id")
if scid and scid not in chunk_ids:
scid = chunk_ids[0] if chunk_ids else None
row = await create_memory_fact(
db,
user_id=user_id,
fact_type=f.get("fact_type") or "event",
subject=f.get("subject"),
predicate=f.get("predicate"),
object_json=normalize_object_json(f.get("object_json")),
confidence=float(f.get("confidence") or 0.75),
source_chunk_id=scid,
status="confirmed",
)
inserted.append(
{
"id": row.id,
"fact_type": row.fact_type,
"subject": row.subject,
"predicate": row.predicate,
"object_json": row.object_json,
}
)
if inserted:
await delete_timeline_events_by_memory_source(
db, user_id=user_id, memory_source_id=source_id
)
events = await build_timeline_events_from_facts_async(llm, inserted)
for ev in events:
await create_timeline_event(
db,
user_id=user_id,
event_year=ev.get("event_year"),
event_date=ev.get("event_date"),
title=ev["title"],
description=ev.get("description"),
source_fact_ids=ev.get("source_fact_ids") or None,
memory_source_id=source_id,
)