From e4bf0710c73ba644ac654e0e4f5d59a99936b75c Mon Sep 17 00:00:00 2001 From: Kevin Date: Fri, 27 Mar 2026 16:01:28 +0800 Subject: [PATCH] =?UTF-8?q?feat(memory,conversation):=20=E8=AE=B0=E5=BF=86?= =?UTF-8?q?=E5=AF=8C=E5=8C=96/=E8=AF=81=E6=8D=AE=E5=8C=85=E3=80=81?= =?UTF-8?q?=E6=97=B6=E9=97=B4=E7=BA=BF=E5=B9=82=E7=AD=89=E5=AD=97=E6=AE=B5?= =?UTF-8?q?=E4=B8=8E=E5=AF=B9=E8=AF=9D=E5=88=86=E6=AE=B5=E5=85=A8=E9=93=BE?= =?UTF-8?q?=E8=B7=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 数据库 - 新增迁移 0003:timeline_events.memory_source_id 外键 → memory_sources,便于按 ingest 源做时间线幂等 后端 - 记忆 - 新增 ingest 后 LLM 富化(摘要/事实/时间线),可配置开关与最大字符数 - 新增证据包组装:合并 chunk、摘要、事实、时间线、故事等检索结果;支持空 query 时是否仍带 rolling 等开关 - repo/retriever/service/router/schemas/summarizer/timeline/extractor 等扩展;文档 memory-retrieval.md 更新 后端 - 对话 WS - 增加 PING/PONG;分段 ASR 日志与空音频处理;转写失败与「无助手回复」错误提示更明确 - 助手多段回复持久化使用统一分隔符,与分段逻辑一致 后端 - Agent - reply_limits:按 [SPLIT] 与段落拆段,并保证非空 fallback,供 WS 与 TTS 多段下发 后端 - 回忆录任务 - transcript ingest 记录 source_id;任务成功结? --- .github/workflows/legacy-data-migrate.yml | 5 +- api/.env.example | 2 +- .../0003_timeline_events_memory_source_id.py | 45 ++ api/app/adapters/asr/tencent_asr.py | 24 +- api/app/adapters/asr/whisper_local.py | 142 +++- api/app/agents/chat/interview_agent.py | 30 +- api/app/agents/chat/profile_agent.py | 32 +- api/app/agents/chat/prompts_conversation.py | 11 +- api/app/agents/chat/reply_limits.py | 43 ++ api/app/agents/memoir/classification_agent.py | 45 +- api/app/agents/memoir/fidelity_check_agent.py | 30 +- api/app/agents/memoir/narrative_agent.py | 2 +- api/app/agents/memoir/orchestrator.py | 8 +- api/app/agents/memoir/prompts.py | 100 ++- api/app/core/config.py | 9 +- .../features/conversation/history_store.py | 12 +- .../features/conversation/ws/message_types.py | 2 + api/app/features/conversation/ws/pipeline.py | 55 +- api/app/features/conversation/ws/router.py | 145 ++-- .../features/memoir/memoir_images/settings.py | 2 +- api/app/features/memoir/service.py | 4 +- .../features/memoir/story_pipeline_sync.py | 135 ++-- api/app/features/memory/curation.py | 18 +- api/app/features/memory/enrichment.py | 277 ++++++++ .../features/memory/enrichment_pipeline.py | 25 + api/app/features/memory/evidence.py | 244 +++++++ api/app/features/memory/extractor.py | 92 ++- api/app/features/memory/llm_schemas.py | 103 +++ api/app/features/memory/models.py | 6 + api/app/features/memory/repo.py | 620 ++++++++++++++++-- api/app/features/memory/retriever.py | 65 +- api/app/features/memory/router.py | 72 +- api/app/features/memory/schemas.py | 5 +- api/app/features/memory/service.py | 105 ++- api/app/features/memory/summarizer.py | 134 +++- api/app/features/memory/timeline.py | 76 ++- api/app/features/story/repo.py | 39 +- api/app/features/user/repo.py | 31 +- api/app/features/user/router.py | 3 +- api/app/features/user/schemas.py | 2 +- api/app/features/user/service.py | 40 +- api/app/main.py | 2 + api/app/tasks/memoir_tasks.py | 44 +- api/docs/memory-retrieval.md | 21 +- api/scripts/migrate_legacy_to_current.py | 137 +++- api/tests/test_classification_fragment.py | 8 +- api/tests/test_memory_evidence.py | 30 + api/tests/test_reply_segments.py | 25 + api/tests/test_whisper_local.py | 61 ++ app-expo/src/app/(main)/conversation/[id].tsx | 103 ++- app-expo/src/app/(main)/delete-data.tsx | 187 +++--- app-expo/src/core/audio/audio-focus.ts | 11 +- app-expo/src/core/ws/client.ts | 1 + app-expo/src/core/ws/types.ts | 2 + app-expo/src/features/conversation/hooks.ts | 47 +- .../features/conversation/message-split.ts | 70 +- .../features/conversation/realtime-session.ts | 94 ++- .../src/features/voice/hooks/use-player.ts | 14 +- .../src/features/voice/hooks/use-recorder.ts | 17 +- app-expo/src/features/voice/recorder.ts | 66 +- app-expo/src/features/voice/types.ts | 12 + app-expo/src/i18n/generated/resources.ts | 3 +- .../src/i18n/locales/en/conversation.json | 1 + app-expo/src/i18n/locales/en/profile.json | 2 +- .../src/i18n/locales/zh/conversation.json | 1 + app-expo/src/i18n/locales/zh/profile.json | 2 +- app-expo/tests/core/audio/audio-focus.test.ts | 27 + .../conversation/message-split.test.ts | 32 + .../tests/features/voice/recorder.test.ts | 51 ++ .../tests/features/voice/use-player.test.tsx | 50 ++ 70 files changed, 3404 insertions(+), 557 deletions(-) create mode 100644 api/alembic/versions/0003_timeline_events_memory_source_id.py create mode 100644 api/app/features/memory/enrichment.py create mode 100644 api/app/features/memory/enrichment_pipeline.py create mode 100644 api/app/features/memory/evidence.py create mode 100644 api/app/features/memory/llm_schemas.py create mode 100644 api/tests/test_memory_evidence.py create mode 100644 api/tests/test_reply_segments.py create mode 100644 api/tests/test_whisper_local.py create mode 100644 app-expo/tests/features/voice/recorder.test.ts create mode 100644 app-expo/tests/features/voice/use-player.test.tsx diff --git a/.github/workflows/legacy-data-migrate.yml b/.github/workflows/legacy-data-migrate.yml index ca0b229..8d86724 100644 --- a/.github/workflows/legacy-data-migrate.yml +++ b/.github/workflows/legacy-data-migrate.yml @@ -1,5 +1,7 @@ # 一次性:将旧 pg_dump 数据迁入当前 Alembic schema(api/scripts/migrate_legacy_to_current.py) # +# 目标库须已是 alembic upgrade head(与线上一致);占号用户清理逻辑依赖当前全部迁移后的表结构。 +# # 不会在 push / 部署时自动运行,仅手动 workflow_dispatch,避免每次构建误迁库。 # 远端需已用 docker compose 部署(目录约定与 docker-build-deploy 一致:STAGING_DEPLOY_PATH / PROD_DEPLOY_PATH)。 # @@ -156,6 +158,7 @@ jobs: echo "执行 Python 迁移(api 容器内)..." docker compose exec -T api uv run python scripts/migrate_legacy_to_current.py \ --legacy-url "$LEGACY_URL" \ - --target-url "$DB_URL" + --target-url "$DB_URL" \ + --phone-conflict replace_target echo "完成。" REMOTE diff --git a/api/.env.example b/api/.env.example index 9580b44..5b9c4d1 100644 --- a/api/.env.example +++ b/api/.env.example @@ -159,7 +159,7 @@ MEMOIR_IMAGE_PROVIDER=liblib MEMOIR_IMAGE_STYLE_DEFAULT=watercolor MEMOIR_IMAGE_SIZE_DEFAULT=1280x720 # Story 正文至少多少字才生成主图 intent / 调图(0=不限制) -STORY_IMAGE_MIN_BODY_CHARS=800 +STORY_IMAGE_MIN_BODY_CHARS=400 # 叙事模型输出相对口述过短则回退为口述原文 MEMOIR_NARRATIVE_FALLBACK_BODY_RATIO=0.5 MEMOIR_NARRATIVE_FALLBACK_MIN_CHARS=20 diff --git a/api/alembic/versions/0003_timeline_events_memory_source_id.py b/api/alembic/versions/0003_timeline_events_memory_source_id.py new file mode 100644 index 0000000..f237ada --- /dev/null +++ b/api/alembic/versions/0003_timeline_events_memory_source_id.py @@ -0,0 +1,45 @@ +"""timeline_events: memory_source_id for enrichment idempotency per ingest source. + +Revision ID: 0003_timeline_memory_source +Revises: 0002_segments_user_input_text +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +revision: str = "0003_timeline_memory_source" +down_revision: Union[str, None] = "0002_segments_user_input_text" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "timeline_events", + sa.Column("memory_source_id", sa.String(), nullable=True), + ) + op.create_index( + "ix_timeline_events_memory_source_id", + "timeline_events", + ["memory_source_id"], + ) + op.create_foreign_key( + "fk_timeline_events_memory_source_id_memory_sources", + "timeline_events", + "memory_sources", + ["memory_source_id"], + ["id"], + ondelete="SET NULL", + ) + + +def downgrade() -> None: + op.drop_constraint( + "fk_timeline_events_memory_source_id_memory_sources", + "timeline_events", + type_="foreignkey", + ) + op.drop_index("ix_timeline_events_memory_source_id", table_name="timeline_events") + op.drop_column("timeline_events", "memory_source_id") diff --git a/api/app/adapters/asr/tencent_asr.py b/api/app/adapters/asr/tencent_asr.py index 0adf755..d7f6690 100644 --- a/api/app/adapters/asr/tencent_asr.py +++ b/api/app/adapters/asr/tencent_asr.py @@ -38,7 +38,7 @@ class TencentASRProvider: async def transcribe(self, audio: bytes, format: str = "m4a") -> str: client = self._get_client() if not client: - return "" + return "转写失败: 腾讯云 ASR 客户端未初始化(请检查密钥与依赖)" try: from tencentcloud.asr.v20190614 import models @@ -46,12 +46,26 @@ class TencentASRProvider: req = models.SentenceRecognitionRequest() req.EngSerViceType = "16k_zh" req.SourceType = 1 - req.VoiceFormat = format + # 小写;与文档一致。iOS 常见为 m4a(AAC) 容器,与 16k 引擎匹配 + req.VoiceFormat = (format or "m4a").lower() req.Data = audio_base64 req.DataLen = len(audio) resp = client.SentenceRecognition(req) - return (resp.Result or "").strip() + text = (resp.Result or "").strip() + if text: + return text + err = getattr(resp, "Error", None) or getattr(resp, "Message", None) + logger.warning( + "Tencent ASR empty Result, audio_len={} format={} err={}", + len(audio), + req.VoiceFormat, + err, + ) + return ( + "转写失败: 腾讯云返回空文本(常见原因:采样率与 16k_zh 不匹配、" + "格式不受支持或音频无效;请确认客户端为 16kHz 单声道 m4a)" + ) except Exception as e: - logger.error("Tencent ASR transcribe failed: {}", e) - return "" + logger.error("Tencent ASR transcribe failed: {}", e, exc_info=True) + return f"转写失败: {e}"[:500] diff --git a/api/app/adapters/asr/whisper_local.py b/api/app/adapters/asr/whisper_local.py index ee62a4c..846666b 100644 --- a/api/app/adapters/asr/whisper_local.py +++ b/api/app/adapters/asr/whisper_local.py @@ -1,11 +1,42 @@ """Local faster-whisper ASR adapter — implements ASRProvider port.""" -from app.core.logging import get_logger +from __future__ import annotations + +import asyncio import os +import re import tempfile +from typing import Any, Iterable + +from app.core.logging import get_logger logger = get_logger(__name__) +_SUBTITLE_WATERMARK_RE = re.compile( + r"(字幕|听译|压制|字幕组).{0,20}(by|BY|By)|字幕\s*by", + re.UNICODE, +) + + +def _looks_like_subtitle_hallucination(text: str) -> bool: + """静音时第二遍易吐出视频字幕水印;仅丢弃此类短句。""" + t = (text or "").strip() + if len(t) > 48: + return False + if _SUBTITLE_WATERMARK_RE.search(t): + return True + if len(t) <= 12 and "字幕" in t and not re.search(r"[??!!。,、]", t): + return True + return False + + +def _join_segment_text(segments: Iterable[Any]) -> tuple[str, int]: + segs = list(segments) + return "".join(str(getattr(seg, "text", "") or "") for seg in segs).strip(), len( + segs + ) + + _DEFAULT_CACHE_DIR = os.path.normpath( os.path.join( os.path.dirname(os.path.abspath(__file__)), @@ -70,30 +101,95 @@ class WhisperASRProvider: return self._load_model() async def transcribe(self, audio: bytes, format: str = "m4a") -> str: + # 与 v1.1.0 相同的单次 transcribe;推理放线程池,避免阻塞 asyncio(tag 上为同步调用)。 self._load_model() if not self._model: return "" - tmp_path = None - try: - with tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False) as tmp: - tmp.write(audio) - tmp_path = tmp.name + model = self._model - segments, _info = self._model.transcribe( - tmp_path, - language="zh", - beam_size=5, - vad_filter=True, - vad_parameters={"min_silence_duration_ms": 500}, - ) - return "".join(seg.text for seg in segments).strip() - except Exception as e: - logger.error("Whisper transcribe failed: {}", e) - return "" - finally: - if tmp_path and os.path.exists(tmp_path): - try: - os.remove(tmp_path) - except OSError: - pass + def _sync_transcribe() -> str: + tmp_path = None + try: + with tempfile.NamedTemporaryFile( + suffix=f".{format}", delete=False + ) as tmp: + tmp.write(audio) + tmp_path = tmp.name + + segments, _info = model.transcribe( + tmp_path, + language="zh", + beam_size=5, + vad_filter=True, + vad_parameters={ + "min_silence_duration_ms": 500, + "threshold": 0.35, + "min_speech_duration_ms": 200, + }, + ) + text, pass1_seg_count = _join_segment_text(segments) + used_second_pass = False + pass2_seg_count = 0 + pass3_seg_count = 0 + + if not text: + logger.info( + "Whisper VAD pass 无文本,关闭 VAD 再试一次(短录音易被 VAD 判为静音)" + ) + segments2, _info2 = model.transcribe( + tmp_path, + language="zh", + beam_size=5, + vad_filter=False, + condition_on_previous_text=False, + # 略抬高:减少边界片段被标成 no_speech 而整段为空 + no_speech_threshold=0.85, + ) + raw2, pass2_seg_count = _join_segment_text(segments2) + used_second_pass = True + if raw2 and _looks_like_subtitle_hallucination(raw2): + logger.info( + "Whisper 丢弃疑似字幕水印幻听: {!r}", + raw2[:120], + ) + text = "" + else: + text = raw2 + + if not text and used_second_pass: + try: + from faster_whisper import decode_audio + + audio_np = decode_audio(tmp_path, sampling_rate=16000) + segments3, _info3 = model.transcribe( + audio_np, + language="zh", + beam_size=5, + vad_filter=False, + condition_on_previous_text=False, + no_speech_threshold=0.85, + ) + raw3, pass3_seg_count = _join_segment_text(segments3) + if raw3 and _looks_like_subtitle_hallucination(raw3): + logger.info( + "Whisper decode_audio 回退仍是疑似字幕水印幻听: {!r}", + raw3[:120], + ) + elif raw3: + text = raw3 + except Exception as ex: + logger.warning("Whisper decode_audio 回退失败: {}", ex) + + return text + except Exception as e: + logger.error("Whisper transcribe failed: {}", e) + return "" + finally: + if tmp_path and os.path.exists(tmp_path): + try: + os.remove(tmp_path) + except OSError: + pass + + return await asyncio.to_thread(_sync_transcribe) diff --git a/api/app/agents/chat/interview_agent.py b/api/app/agents/chat/interview_agent.py index c58b419..92311b6 100644 --- a/api/app/agents/chat/interview_agent.py +++ b/api/app/agents/chat/interview_agent.py @@ -17,7 +17,11 @@ from app.agents.chat.prompts_conversation import ( get_opening_prompt, ) from app.agents.state_schema import MemoirStateSchema -from app.agents.chat.reply_limits import truncate_chat_segments +from app.agents.chat.reply_limits import ( + nonempty_segments_or_fallback, + segments_from_llm_response, + truncate_chat_segments, +) from app.core.agent_logging import ( agent_span, log_agent_payload, @@ -135,10 +139,12 @@ class InterviewAgent: log_agent_payload( logger, "InterviewAgent.generate_response.raw_response", response_text ) - messages = [ - msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip() - ] - raw_list = messages if messages else [response_text.strip()] + raw_list = segments_from_llm_response( + response_text, + max_segments=settings.chat_interview_max_segments, + ) + if not raw_list: + raw_list = [response_text.strip()] out = truncate_chat_segments( raw_list, max_segments=settings.chat_interview_max_segments, @@ -150,6 +156,7 @@ class InterviewAgent: : settings.chat_interview_max_chars_per_segment ] ] + out = nonempty_segments_or_fallback(out, fallback=_FALLBACK_REPLY) log_agent_summary( logger, "InterviewAgent.generate_response segments={} conversation_id={}", @@ -193,10 +200,9 @@ class InterviewAgent: log_agent_payload( logger, "InterviewAgent.opening.raw_response", response_text ) - messages = [ - msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip() - ] - raw_list = messages if messages else [response_text.strip()] + raw_list = segments_from_llm_response(response_text, max_segments=2) + if not raw_list: + raw_list = [response_text.strip()] out = truncate_chat_segments( raw_list, max_segments=2, @@ -208,7 +214,7 @@ class InterviewAgent: len(out), conversation_id, ) - return ( + segments = ( out if out else [ @@ -217,6 +223,10 @@ class InterviewAgent: ] ] ) + return nonempty_segments_or_fallback( + segments, + fallback="你好呀~ 又见面了,最近有没有什么事想跟我说说?", + ) except Exception as e: logger.error("生成开场白失败: {}", e, exc_info=True) return ["你好呀~ 又见面了,最近有没有什么事想跟我说说?"] diff --git a/api/app/agents/chat/profile_agent.py b/api/app/agents/chat/profile_agent.py index 2dd2345..2f424ae 100644 --- a/api/app/agents/chat/profile_agent.py +++ b/api/app/agents/chat/profile_agent.py @@ -19,7 +19,11 @@ from app.core.langchain_llm import ainvoke_json_object from app.core.agent_logging import agent_span, log_agent_payload, log_agent_summary from app.core.config import settings from app.core.logging import get_logger -from app.agents.chat.reply_limits import truncate_chat_segments +from app.agents.chat.reply_limits import ( + nonempty_segments_or_fallback, + segments_from_llm_response, + truncate_chat_segments, +) from app.features.memoir.memoir_images.json_payload import extract_json_payload logger = get_logger(__name__) @@ -135,10 +139,9 @@ class ProfileAgent: log_agent_payload( logger, "ProfileAgent.followup.raw_response", response_text ) - messages = [ - msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip() - ] - raw_list = messages if messages else [response_text.strip()] + raw_list = segments_from_llm_response(response_text, max_segments=3) + if not raw_list: + raw_list = [response_text.strip()] out = truncate_chat_segments( raw_list, max_segments=3, @@ -150,7 +153,7 @@ class ProfileAgent: len(out), conversation_id, ) - return ( + segments = ( out if out else [ @@ -159,6 +162,10 @@ class ProfileAgent: ] ] ) + return nonempty_segments_or_fallback( + segments, + fallback="谢谢分享!能再告诉我一些吗?", + ) except Exception as e: logger.error("生成资料跟进回复失败: {}", e) return ["谢谢分享!能再告诉我一些吗?"] @@ -193,10 +200,9 @@ class ProfileAgent: log_agent_payload( logger, "ProfileAgent.greeting.raw_response", response_text ) - messages = [ - msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip() - ] - raw_list = messages if messages else [response_text.strip()] + raw_list = segments_from_llm_response(response_text, max_segments=2) + if not raw_list: + raw_list = [response_text.strip()] out = truncate_chat_segments( raw_list, max_segments=2, @@ -208,7 +214,7 @@ class ProfileAgent: len(out), conversation_id, ) - return ( + segments = ( out if out else [ @@ -217,6 +223,10 @@ class ProfileAgent: ] ] ) + return nonempty_segments_or_fallback( + segments, + fallback="你好!在开始之前,能告诉我你是哪一年出生的吗?", + ) except Exception as e: logger.error("生成资料收集开场白失败: {}", e) return [ diff --git a/api/app/agents/chat/prompts_conversation.py b/api/app/agents/chat/prompts_conversation.py index 3a41b0c..793e308 100644 --- a/api/app/agents/chat/prompts_conversation.py +++ b/api/app/agents/chat/prompts_conversation.py @@ -77,7 +77,7 @@ def get_system_prompt( prompt = f"""你是「岁月知己」,像老朋友一样陪用户聊人生。**回复要短**,像微信聊天,不要长篇、不要文学腔。 -规则:先简短接住对方一句,**最多再问一个具体问题**;禁止括号与思考过程;禁止采访腔(如「我注意到」「我想了解」)。 +规则:先简短接住对方一句,**最多再问一个具体问题**;禁止括号与思考过程;禁止采访腔(如「我注意到」「我想了解」);**不要重复确认**对方刚说过或上文已能推断的信息。 当前阶段:{stage_name_map.get(current_stage, current_stage.value)} 已聊话题:{covered_topics_str} @@ -364,7 +364,7 @@ def get_guided_conversation_prompt( style = RESPONSE_STYLES[conversation_turn % len(RESPONSE_STYLES)] style_guidance = { "empathy": "共情一两句即可", - "curious": "表现好奇,追问一个具体点", + "curious": "若还有未展开的细节可好奇问一个点;若上文已说清或可自然推断,只承接或换角度,**勿为凑问题而追问**", "reflection": "可一句简短感慨,勿讲大道理", "lighthearted": "轻松一点,别讲段子太长", "connection": "可提「我也有过类似感受」一句,勿编造具体经历细节", @@ -431,12 +431,13 @@ def get_guided_conversation_prompt( ## 任务(短) 1. 先简短回应一句,不要总结成长文。 2. 用户若跳到别的人生阶段,跟着他聊,别硬拉回。 -3. 需要追问时**只问一个**具体小问题;不必每轮都问。 -4. 可用 [SPLIT] 分成**最多 2 条**消息,每条都很短。 +3. 需要追问时**只问一个**具体小问题;**不必每轮都问**;若用户已说明或语境已能推出(如谁买的、和谁),**别再为同一件事做 yes/no 确认**。 +4. 用户只回简短肯定/否定(如「是的」「对」)时,**结合上文**理解,承接即可或问**新**角度,勿重复上一句已问过的事。 +5. 可用 [SPLIT] 分成**最多 2 条**消息,每条都很短。 {dynamic_guidance}{uncovered_hint} ## 禁止 -括号/思考过程;采访腔;重复确认用户档案里已有信息;别编用户没说的细节。 +括号/思考过程;采访腔;**重复确认**用户档案、**上文已说**或**强暗示下已可知**的事实(包括无信息量的「是不是他/她…」式追问);别编用户没说的细节。 直接输出([SPLIT] 可选,最多 2 段):""" diff --git a/api/app/agents/chat/reply_limits.py b/api/app/agents/chat/reply_limits.py index c44838f..ec949a2 100644 --- a/api/app/agents/chat/reply_limits.py +++ b/api/app/agents/chat/reply_limits.py @@ -2,6 +2,49 @@ from __future__ import annotations +import re + + +def segments_from_llm_response( + response_text: str, + *, + max_segments: int = 3, + min_paragraph_chars: int = 12, +) -> list[str]: + """ + 优先按字面 [SPLIT] 拆段;若模型只输出一段、但用空行写了多段,再按段落拆。 + 解决「两段话 + 换行」却未写 [SPLIT] 时仍要拆气泡 / 多段 TTS 的情况。 + """ + text = (response_text or "").strip() + if not text: + return [] + primary = [p.strip() for p in text.split("[SPLIT]") if p.strip()] + if len(primary) > 1: + return primary[:max_segments] + blob = primary[0] if primary else text + if "\n" not in blob: + return [blob] + paras = [p.strip() for p in re.split(r"\n\s*\n+", blob) if p.strip()] + if len(paras) < 2: + return [blob] + paras = [p for p in paras if len(p) >= min_paragraph_chars] + if len(paras) < 2: + return [blob] + return paras[:max_segments] + + +def nonempty_segments_or_fallback( + segments: list[str], + *, + fallback: str, +) -> list[str]: + """去掉空段;若全部为空白/空串则返回单条 fallback,避免 WS 下发空 text。""" + cleaned = [s for s in segments if (s or "").strip()] + if cleaned: + return cleaned + fb = (fallback or "").strip() + return [fb] if fb else ["…"] + def truncate_chat_segments( segments: list[str], diff --git a/api/app/agents/memoir/classification_agent.py b/api/app/agents/memoir/classification_agent.py index bfb21ce..b17e9dd 100644 --- a/api/app/agents/memoir/classification_agent.py +++ b/api/app/agents/memoir/classification_agent.py @@ -1,16 +1,16 @@ """ -ClassificationAgent:将内容分类到 8 个章节类别,或判定无价值返回 None。 -对应现有逻辑:_classify_chapter_category +ClassificationAgent:将内容分类到 8 个章节类别之一。 -返回 None 表示本段不进入回忆录 Story/章节流水线;与 StoryRoute 中「可独立讲述的一段人生经历」 -(见 prompts.get_story_route_prompt)在标准上对齐:零散档案点不进 Story,记忆与 slot 抽取仍由上游完成。 +原「LLM 返回 none / 零散档案启发式」不再跳过 Story:统一映射为 ``summary`` 章节, +仍走叙事流水线落库;与 StoryRoute 仍兼容(批次内 new/append 规划不变)。 +Memory ingest 由 Celery 任务在批次级先行完成,与分类结果独立。 """ from __future__ import annotations import json import re -from typing import Any, Optional +from typing import Any from app.agents.memoir.prompts import ( CHAPTER_CATEGORIES, @@ -22,6 +22,9 @@ from app.features.memoir.memoir_images.json_payload import extract_json_payload logger = get_logger(__name__) +# 模型判定 none 或启发式命中零散档案时,仍写入回忆录正文所用的兜底章节 +_SUMMARY_FALLBACK_CATEGORY = "summary" + # 与「仅档案句式」组合使用;过短但明显为叙事句的仍交 LLM 判断 _FRAGMENT_SHORT_MAX_LEN = 48 @@ -67,8 +70,8 @@ def _detect_stage(text: str, fallback_stage: str) -> str: def _looks_like_fragment_only(text: str) -> bool: """ - 保守启发式:明显为档案点/标签句,不足以作为 Story 叙事单元。 - 与 get_chapter_classification_prompt 中「应返回 none」的情形一致;误判风险通过窄正则控制。 + 保守启发式:明显为档案点/标签句。 + 命中时仍进回忆录正文,章节映射为 ``summary``(与 LLM 返回 none 一致)。 """ s = (text or "").strip() if not s: @@ -107,26 +110,30 @@ def _parse_category_from_llm_response(raw: str) -> str: class ClassificationAgent: - """将内容分类到 8 个章节类别之一,或判定无价值返回 None""" + """将内容分类到 8 个章节类别之一;none/零散档案映射为 ``summary`` 仍进 Story。""" def classify( self, text: str, fallback_stage: str, llm: Any, - ) -> Optional[str]: + *, + segment_id: str | None = None, + ) -> str: """ 分类到 8 个章节类别之一。 - 若 LLM 判定内容不足以独立成篇(none)或启发式判定为零散档案点,返回 None。 + LLM 返回 none 或启发式为零散档案时,返回 ``summary``(仍走回忆录流水线)。 llm 需支持 .invoke(prompt) 同步调用。 """ if _looks_like_fragment_only(text): - logger.debug( - "零散档案/极短标签句,跳过回忆录 Story: text_len={} text={}", + logger.info( + "event=chapter_classification_summary_fallback reason=fragment_heuristic " + "segment_id={} text_len={} category={}", + segment_id or "", len(text or ""), - text or "", + _SUMMARY_FALLBACK_CATEGORY, ) - return None + return _SUMMARY_FALLBACK_CATEGORY if llm: try: @@ -139,12 +146,14 @@ class ClassificationAgent: ) category = _parse_category_from_llm_response(raw) if category == "none": - logger.debug( - "LLM 判定内容不足以成篇,跳过: text_len={} text={}", + logger.info( + "event=chapter_classification_summary_fallback reason=llm_none " + "segment_id={} text_len={} category={}", + segment_id or "", len(text or ""), - text or "", + _SUMMARY_FALLBACK_CATEGORY, ) - return None + return _SUMMARY_FALLBACK_CATEGORY if category in CHAPTER_CATEGORIES: return category except Exception as e: diff --git a/api/app/agents/memoir/fidelity_check_agent.py b/api/app/agents/memoir/fidelity_check_agent.py index ee008c5..886c5b7 100644 --- a/api/app/agents/memoir/fidelity_check_agent.py +++ b/api/app/agents/memoir/fidelity_check_agent.py @@ -1,6 +1,7 @@ """ FidelityCheckAgent:比较「用户口述」与叙事 JSON 输出,判定是否存在明显编造或越界。 -失败时由流水线回退为口述正文(见 story_pipeline_sync)。 +续写合并(append)时传入 `existing_canonical_markdown`,将已有故事正文一并视为允许来源。 +失败时由流水线回退(见 story_pipeline_sync):续写为「已有 + 口述」,新建为口述原文。 """ from __future__ import annotations @@ -43,6 +44,7 @@ class FidelityCheckAgent: oral_text: str, narrative_json: str, llm: Any, + existing_canonical_markdown: str | None = None, ) -> bool: if not llm or not settings.memoir_fidelity_check_enabled: return True @@ -50,8 +52,32 @@ class FidelityCheckAgent: gen = (narrative_json or "").strip() if not oral or not gen: return True + existing = (existing_canonical_markdown or "").strip() _log_suspicious_years_not_in_oral(oral, gen) - prompt = f"""你是事实核对员。比较下面两段文字。 + if existing: + prompt = f"""你是事实核对员。当前为**续写合并**:模型需要把「已有故事正文」与「本轮口述」合成一篇,生成稿**允许且应当**保留已有正文中的事实(可改写语序、合并段落),并融入本轮口述中的新事实。 + +【用户本轮口述】(本段亲口补充) +{oral[:8000]} + +【已有故事正文】(已落库、允许在生成稿中出现或改写;出现于此处的内容**不算**本轮编造) +{existing[:12000]} + +【模型生成的 JSON 叙事】 +{gen[:16000]} + +判断:生成稿是否出现**既明显不在本轮口述、也明显不在已有故事正文**的具体人名、地名、时间、数字、事件经过、对话,或把摘录/档案里才有的信息写成了用户亲口经历? +若内容可归因于「已有故事」或「本轮口述」的合理整理,pass=true。 +若存在无法归因的明显编造或越界,pass=false。 + +**JSON 输出**:只输出一个合法 JSON 对象。 +{{"pass": true, "reason": null}} +或 +{{"pass": false, "reason": "一句话说明"}} + +只输出 JSON,不要其它文字。""" + else: + prompt = f"""你是事实核对员。比较下面两段文字。 【用户口述】(亲历内容) {oral[:8000]} diff --git a/api/app/agents/memoir/narrative_agent.py b/api/app/agents/memoir/narrative_agent.py index bfbbd58..dfe852a 100644 --- a/api/app/agents/memoir/narrative_agent.py +++ b/api/app/agents/memoir/narrative_agent.py @@ -1,6 +1,6 @@ """ NarrativeAgent:生成创意标题和叙事改写。 -对应现有逻辑:get_creative_title_json_prompt、get_narrative_json_prompt +叙事正文走 `get_narrative_json_prompt` / `get_narrative_merge_json_prompt`(传记作家式书面语 + 事实边界)。 """ from __future__ import annotations diff --git a/api/app/agents/memoir/orchestrator.py b/api/app/agents/memoir/orchestrator.py index bfabeef..51b05a3 100644 --- a/api/app/agents/memoir/orchestrator.py +++ b/api/app/agents/memoir/orchestrator.py @@ -91,6 +91,7 @@ class MemoirOrchestrator: text=text, fallback_stage=detected_stage, llm=llm, + segment_id=segment.id, ) if agent_summary_enabled(): logger.info( @@ -108,13 +109,6 @@ class MemoirOrchestrator: segment.id, list((result.slots or {}).keys()), ) - if chapter_category is None: - logger.debug( - "段落无回忆录价值,跳过: segment_id={} transcript={}", - segment.id, - getattr(segment, "user_input_text", None) or "", - ) - continue category_to_segments.setdefault(chapter_category, []).append(segment) return PreparedMemoirBatches( diff --git a/api/app/agents/memoir/prompts.py b/api/app/agents/memoir/prompts.py index 8e5bed8..fb772b7 100644 --- a/api/app/agents/memoir/prompts.py +++ b/api/app/agents/memoir/prompts.py @@ -130,29 +130,67 @@ def get_memoir_editor_system_prompt() -> str: """ -def get_memoir_fidelity_system_prompt() -> str: - """叙事/标题生成专用:准确性优先,禁止编造事实(与 get_memoir_editor_system_prompt 分离)。""" - return """你是回忆录编辑助手,任务是把用户口述整理为第一人称书面叙述。 - -## 事实边界(必须遵守,优先于文采) +def _memoir_fidelity_core_rules() -> str: + """事实边界 1–4 条(与文体第 5 条拆分,供 story 叙事与标题等复用)。""" + return """## 事实边界(必须遵守,优先于文采) 1. **正文只能展开「本段用户口述」区块中的内容**。若输入中有「相关记忆摘录」等参考区,其中信息**不得**写成本人本轮亲口经历的细节;最多用一两句作主题衔接,且不得引入摘录里才有的具体人名、地点、时间、对话、数字。 2. **禁止编造**:不得新增用户未提及的具体人物姓名、对话原文、地点、时间、事件经过、因果、数字;不得推断性心理描写或「典型年代场景」填充。 3. **禁止为凑字数扩写**:材料短则输出短;段落数量与长度随材料而定。 -4. 允许:去除口语赘词与寒暄、调整语序、合并重复指代、把口语改为书面语;**不得**用虚构细节「让文章更好看」。 -5. **叙述风格平实**:少用抒情、比喻与文学铺陈;像清楚记事,不要写成散文。 +4. 允许:去除口语赘词与寒暄、调整语序、合并重复指代、把口语改为书面语;**不得**用虚构细节「让文章更好看」。""" -## 用户档案与阶段信息 + +def _memoir_fidelity_user_profile_rules() -> str: + return """## 用户档案与阶段信息 - 「用户基本信息」「时间参考」仅可使用其中**已写明**的条目;不得把档案中的出生地等写进正文,除非用户在本段口述里已提及或明确关联。""" -def get_narrative_editor_system_prompt() -> str: - """叙事改写:准确性系统提示 + 可执行文体约束(不用 get_memoir_editor_system_prompt 中的「过渡句/生动细节」泛化指令)。""" - return f"""{get_memoir_fidelity_system_prompt()} +def get_memoir_fidelity_system_prompt() -> str: + """叙事/标题生成专用:准确性优先,禁止编造事实(与 get_memoir_editor_system_prompt 分离)。""" + return f"""你是回忆录编辑助手,任务是把用户口述整理为第一人称书面叙述。 -## 文体(在严守事实的前提下) -- 使用第一人称、**平实书面语**(少修辞、少感叹);不要直接引用对话原话。 -- 不使用 Markdown 标题(#、##)、不使用表格。 -- 如有「衔接上下文」,仅保持语气与时间线连贯,不重复已有段落全文。""" +{_memoir_fidelity_core_rules()} +5. **叙述风格平实**:少用抒情、比喻与文学铺陈;像清楚记事,不要写成散文。 + +{_memoir_fidelity_user_profile_rules()}""" + + +def get_memoir_fidelity_facts_only_prompt() -> str: + """与 `get_memoir_fidelity_system_prompt` 相同的事实 1–4 条,第 5 条改为允许传记作家式文采(仍禁止编造)。""" + return f"""你是回忆录编辑助手,任务是把用户口述整理为第一人称书面叙述。 + +{_memoir_fidelity_core_rules()} +5. **文体**:在遵守第 1–4 条的前提下,可将口语改写为**优雅、连贯的回忆录书面语**(适当过渡句,保留并书面化用户已提及的细节与情感);文采服务于真实内容,**不得**用虚构描写替代或填补事实。 + +{_memoir_fidelity_user_profile_rules()}""" + + +def _memoir_editor_narrative_style_block() -> str: + """与 `get_memoir_editor_system_prompt` 对齐的传记作家改写要点(用于写入 chapter 的 story 正文)。""" + return """## 传记作家文体(须同时遵守上文「事实边界」) +你是一位专业的传记作家和文字编辑,擅长将口语化的对话内容整理成优雅的书面语回忆录章节。 + +### 提炼与筛选 +对话中往往夹杂噪音,须严格筛选:保留具体事件、人物关系、时地、情感与信念、用户已提及的细节;过滤语气词、寒暄、与 AI 的交互、无关闲聊、重复冗余。 + +### 改写原则 +- 保持用户的真实情感 +- 使用优雅但不失亲切的书面语,不要直接引用对话原话 +- 适当添加过渡句,使段落连贯 +- 保留生动的细节,但将口语表达改写为书面叙述 +- 去除口语中的填充词和无意义重复 +- 保持时间顺序和逻辑清晰 + +### 输出格式约束 +- 使用第一人称 +- 不使用 Markdown 标题(#、##)、不使用表格 +- 如有「衔接上下文」,仅保持语气与时间线连贯,不重复已有段落全文""" + + +def get_narrative_editor_system_prompt() -> str: + """故事/章节叙事:传记作家式书面语 + 事实边界(chapter 直接展示 story 时使用)。""" + return f"""{get_memoir_fidelity_facts_only_prompt()} + +{_memoir_editor_narrative_style_block()}""" def _short_classification_edit_prefix() -> str: @@ -209,7 +247,9 @@ childhood, education, career_early, career_achievement, career_challenge, family **JSON 输出**:`response_format=json_object`,只输出: {{"category": "childhood|education|career_early|career_achievement|career_challenge|family|beliefs|summary|none"}} -不要其它文字。""" +不要其它文字。 + +若你返回 **none**,服务端会将本段映射到 **summary** 章节并仍写入回忆录正文(不落库丢弃)。""" def get_state_extraction_prompt( @@ -378,7 +418,7 @@ def get_narrative_prompt( ## 步骤 1. 从「本段用户口述」提炼可写事实;丢弃语气词、寒暄、与 AI 的交互。 -2. 改写为第一人称书面叙述:可调整语序与用词,**不得**新增事实。 +2. 改写为第一人称书面叙述(优雅、连贯,可适当过渡;可调整语序与用词),**不得**新增事实。 3. 若材料中无值得记录的人生经历内容,输出空字符串。 ## 格式 @@ -428,7 +468,7 @@ def get_narrative_json_prompt( 1. **只展开「本段用户口述」**;若有参考摘录区,不得把摘录中的具体事实写成本轮亲历经历(见系统说明)。 2. 过滤语气词、寒暄、与 AI 的交互;不重复已有故事全文;本批只写同一主题/事件链。 3. 段落数量与每段长度**随材料而定**,禁止为凑字数编造。 -4. 使用第一人称、**平实书面语**,少修辞;不要直接引用原话;不要用 `#`、`##`、表格。 +4. 使用第一人称、**优雅书面语**(可适当过渡与铺陈,须基于口述事实);不要直接引用原话;不要用 `#`、`##`、表格。 ## 输出格式(严格 JSON) {{ @@ -504,7 +544,7 @@ def get_narrative_merge_json_prompt( 1. 输出为**完整故事正文**(不是仅写本段):`paragraphs` 须包含重组后的**全文**。 2. **禁止编造**:不得新增用户未在「已有」或「本段」中出现的人名、地点、时间、对话、数字。 3. 若本段与旧文完全重复或无新信息,可仅输出与旧文等价重组后的正文(不得无故缩短到明显少于旧文)。 -4. 使用第一人称、平实书面语;不要用 `#`、`##`、表格。 +4. 使用第一人称、**优雅书面语**(与系统说明中的传记作家文体一致);不要用 `#`、`##`、表格。 ## 输出格式(严格 JSON) {{ @@ -527,8 +567,8 @@ def get_story_route_prompt( ) -> str: """Celery 批次:判断写入新 story 还是追加已有 story。输出严格 JSON。 - 「故事」= 可独立讲述的一段人生经历;进入本步的批次已满足 get_chapter_classification_prompt - 中章节级分类(非 none),二者语义一致。 + 「故事」= 可独立讲述的一段人生经历;进入本步的批次已归入具体 chapter category + (含模型返回 none 或零散档案启发式时映射的 summary)。 """ return f"""你是回忆录编辑助手。根据本批用户口述与候选故事列表,决定: - append_story:内容明显延续、补充某一已有故事的主题与时间线,且能对应到具体 candidate id @@ -636,12 +676,13 @@ def format_narrative_user_content(oral_text: str, evidence_text: str = "") -> st def format_evidence_chunks_for_prompt(evidence: dict) -> str: """将 retrieve_evidence / retrieve_evidence_sync 结果格式化为简短文本,供叙事 prompt 使用。 - 仅包含实际返回的 chunks、confirmed facts、timeline;不包含 relevant_summaries / relevant_stories - (当前管线多为空列表,避免模型误以为有摘要或故事全文可用)。 + 包含 chunks、摘要(若有)、confirmed facts、timeline、故事摘要(若有)。 """ chunks = evidence.get("relevant_chunks") or [] + summaries = evidence.get("relevant_summaries") or [] facts = evidence.get("relevant_facts") or [] timeline = evidence.get("timeline_hints") or [] + stories = evidence.get("relevant_stories") or [] parts: list[str] = [] for c in chunks[:10]: content = ( @@ -649,6 +690,13 @@ def format_evidence_chunks_for_prompt(evidence: dict) -> str: ) if content: parts.append(content.strip()) + for s in summaries[:3]: + if isinstance(s, dict): + st = (s.get("content") or "").strip() + stype = (s.get("summary_type") or "").strip() + if st: + label = f"[摘要:{stype}]" if stype else "[摘要]" + parts.append(f"{label} {st}") for f in facts[:5]: if isinstance(f, dict): subj = f.get("subject", "") @@ -668,6 +716,12 @@ def format_evidence_chunks_for_prompt(evidence: dict) -> str: ) if line: parts.append(line) + for st in stories[:3]: + if isinstance(st, dict): + title = (st.get("title") or "").strip() + summ = (st.get("summary") or "").strip() + if title or summ: + parts.append(" ".join(x for x in (title, summ) if x)) return "\n\n".join(parts) if parts else "" diff --git a/api/app/core/config.py b/api/app/core/config.py index 929907f..4767bc0 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -144,11 +144,18 @@ class Settings(BaseSettings): memoir_image_size_default: str = "1280x720" memoir_image_download_hosts: str = "" # Story 正文至少多少字才创建主图 intent / 调图(0 表示不限制) - story_image_min_body_chars: int = 800 + story_image_min_body_chars: int = 400 # 叙事输出相对口述过短则回退为口述原文(比例与下限) memoir_narrative_fallback_body_ratio: float = 0.5 memoir_narrative_fallback_min_chars: int = 20 + # ── Memory 检索与富化 ───────────────────────────────────── + # True:query 为空时仍返回 rolling 摘要 + 最近事实/时间线(无 chunk FTS) + memory_evidence_empty_query_include_rolling: bool = False + # False:跳过 ingest 后 LLM 富化(摘要/事实/时间线) + memory_enrichment_enabled: bool = True + memory_enrichment_max_chars: int = Field(default=12000, ge=1000, le=100_000) + # ── Liblib ─────────────────────────────────────────────── liblib_access_key: str = "" liblib_secret_key: str = "" diff --git a/api/app/features/conversation/history_store.py b/api/app/features/conversation/history_store.py index 1a9683e..99a344a 100644 --- a/api/app/features/conversation/history_store.py +++ b/api/app/features/conversation/history_store.py @@ -18,6 +18,9 @@ from app.features.conversation.session_history import ( logger = get_logger(__name__) +# 与 LLM / 客户端约定:多段助手消息用 [SPLIT] 拼接,便于拆成多条气泡与多段 TTS +AI_RESPONSE_SEGMENT_JOIN = "[SPLIT]" + def _utc_now() -> datetime: return datetime.now(timezone.utc) @@ -55,10 +58,10 @@ class ConversationHistoryStore: async def record_ai_only_turn( self, conversation_id: str, responses: list[str] - ) -> None: + ) -> str | None: if not responses: - return - combined = "\n\n".join(responses) + return None + combined = AI_RESPONSE_SEGMENT_JOIN.join(responses) created_at = _utc_now() msg = ConversationMessage( id=str(uuid.uuid4()), @@ -72,6 +75,7 @@ class ConversationHistoryStore: await self._touch_conversation(conversation_id, occurred_at=created_at) await self._db.commit() await self._sync_redis_best_effort(conversation_id) + return msg.id async def record_human_ai_turn( self, @@ -106,7 +110,7 @@ class ConversationHistoryStore: segment_id=segment_id, created_at=human_ts, ) - combined = "\n\n".join(responses) + combined = AI_RESPONSE_SEGMENT_JOIN.join(responses) ai = ConversationMessage( id=str(uuid.uuid4()), conversation_id=conversation_id, diff --git a/api/app/features/conversation/ws/message_types.py b/api/app/features/conversation/ws/message_types.py index ee8fe34..c69a13d 100644 --- a/api/app/features/conversation/ws/message_types.py +++ b/api/app/features/conversation/ws/message_types.py @@ -17,6 +17,8 @@ class MessageType(str, Enum): AGENT_RESPONSE = "agent_response" TTS_AUDIO = "tts_audio" TTS_CANCEL = "tts_cancel" + PING = "ping" + PONG = "pong" END_CONVERSATION = "end_conversation" MEMOIR_UPDATE = "memoir_update" ERROR = "error" diff --git a/api/app/features/conversation/ws/pipeline.py b/api/app/features/conversation/ws/pipeline.py index 069ce88..dbc624b 100644 --- a/api/app/features/conversation/ws/pipeline.py +++ b/api/app/features/conversation/ws/pipeline.py @@ -22,7 +22,10 @@ from app.core.config import settings from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider -from app.features.conversation.history_store import ConversationHistoryStore +from app.features.conversation.history_store import ( + AI_RESPONSE_SEGMENT_JOIN, + ConversationHistoryStore, +) from app.features.conversation.models import Conversation, Segment from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType @@ -369,6 +372,16 @@ async def process_audio_segment( ) -> None: """分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。""" state = get_or_create_segment_state(conversation_id, voice_session_id) + logger.info( + "process_audio_segment 开始: conversation_id={} voice_session_id={} " + "segment_index={} is_last={} duration_s={} audio_b64_len={}", + conversation_id, + voice_session_id, + segment_index, + is_last, + audio_duration, + len(audio_base64 or ""), + ) try: async with AsyncSessionLocal() as db: @@ -420,6 +433,12 @@ async def process_audio_segment( audio_bytes = base64.b64decode(audio_base64) except Exception: audio_bytes = b"" + if not audio_bytes: + logger.warning( + "process_audio_segment: 解码后音频为空 conversation_id={} segment_index={}", + conversation_id, + segment_index, + ) transcript_text = await get_asr_provider().transcribe( audio_bytes, format="m4a" ) @@ -440,12 +459,19 @@ async def process_audio_segment( ) if _is_transcribe_failure(transcript_text): + detail = (transcript_text or "").strip() + if detail.startswith("转写失败"): + user_msg = f"分段 {segment_index} {detail}" + elif not detail: + user_msg = f"分段 {segment_index} 转写失败:未识别到内容(请检查后端 ASR 配置)" + else: + user_msg = f"分段 {segment_index} 转写失败:{detail[:400]}" await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { - "message": f"分段 {segment_index} 转写失败,请重试该片段", + "message": user_msg, "segment_index": segment_index, }, "timestamp": datetime.now(timezone.utc).isoformat(), @@ -553,6 +579,12 @@ async def process_user_message( store = ConversationHistoryStore(db) tts_urls: list[str] = [] try: + logger.info( + "process_user_message 开始: conversation_id={} segment_id={} user_chars={}", + conversation_id, + segment.id, + len(user_message or ""), + ) is_from_voice = bool(segment.audio_url) voice_session_id = _voice_session_id_from_audio_url(segment.audio_url) audio_dur = getattr(segment, "audio_duration_seconds", None) @@ -586,7 +618,7 @@ async def process_user_message( responses = turn.messages skip_tts = turn.skip_tts - segment.agent_response = "\n\n".join(responses) + segment.agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses) _mark_conversation_active(conversation) ai_msg_id = await store.record_human_ai_turn( conversation_id=conversation_id, @@ -600,6 +632,22 @@ async def process_user_message( segment_id=segment.id, ) if not ai_msg_id: + logger.warning( + "process_user_message: 无有效助手段落(responses 为空),conversation_id={} segment_id={}", + conversation_id, + segment.id, + ) + if conversation_id in manager.active_connections: + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": { + "message": "未生成回复,请重试或稍后再试", + }, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) return tts_epoch_start = _tts_epoch_value(conversation_id) @@ -614,6 +662,7 @@ async def process_user_message( "text": response_text, "index": i, "total": n, + "assistant_message_id": ai_msg_id, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, diff --git a/api/app/features/conversation/ws/router.py b/api/app/features/conversation/ws/router.py index b151412..77c4354 100644 --- a/api/app/features/conversation/ws/router.py +++ b/api/app/features/conversation/ws/router.py @@ -85,6 +85,11 @@ async def websocket_endpoint( return await manager.connect(websocket, conversation_id) + logger.info( + "WebSocket 已连接 conversation_id={} user_id={}", + conversation_id, + user_id, + ) quota_service = QuotaService(db=db) conversation_service = ConversationService(db=db, quota_service=quota_service) @@ -156,25 +161,30 @@ async def websocket_endpoint( missing_fields=missing_profile, nickname=user.nickname or "", ) - await ConversationHistoryStore(db).record_ai_only_turn( - conversation_id, greetings - ) - for i, text in enumerate(greetings): - await manager.send_message( - conversation_id, - { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": { - "text": text, - "index": i, - "total": len(greetings), + ai_msg_id = await ConversationHistoryStore( + db + ).record_ai_only_turn(conversation_id, greetings) + if ai_msg_id: + ng = len(greetings) + for i, text in enumerate(greetings): + await manager.send_message( + conversation_id, + { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": { + "text": text, + "index": i, + "total": ng, + "assistant_message_id": ai_msg_id, + }, + "timestamp": datetime.now( + timezone.utc + ).isoformat(), }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }, - ) - if i < len(greetings) - 1: - await asyncio.sleep(0.5) + ) + if i < ng - 1: + await asyncio.sleep(0.5) except Exception as e: logger.error(f"发送资料收集开场白失败: {e}", exc_info=True) else: @@ -193,25 +203,30 @@ async def websocket_endpoint( user_profile_context=user_profile_context, ) ) - await ConversationHistoryStore(db).record_ai_only_turn( - conversation_id, opening_messages - ) - for i, text in enumerate(opening_messages): - await manager.send_message( - conversation_id, - { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": { - "text": text, - "index": i, - "total": len(opening_messages), + ai_msg_id = await ConversationHistoryStore( + db + ).record_ai_only_turn(conversation_id, opening_messages) + if ai_msg_id: + no = len(opening_messages) + for i, text in enumerate(opening_messages): + await manager.send_message( + conversation_id, + { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": { + "text": text, + "index": i, + "total": no, + "assistant_message_id": ai_msg_id, + }, + "timestamp": datetime.now( + timezone.utc + ).isoformat(), }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }, - ) - if i < len(opening_messages) - 1: - await asyncio.sleep(0.5) + ) + if i < no - 1: + await asyncio.sleep(0.5) except Exception as e: logger.error(f"发送空对话开场白失败: {e}", exc_info=True) @@ -225,6 +240,29 @@ async def websocket_endpoint( break message = await websocket.receive_json() msg_type = message.get("type") + if msg_type == MessageType.AUDIO_SEGMENT: + _d = message.get("data") or {} + logger.info( + "WebSocket 收到消息 type={} conversation_id={} " + "segment_index={} is_last={} duration_s={} audio_b64_len={}", + msg_type, + conversation_id, + _d.get("segment_index"), + bool(_d.get("is_last")), + int(_d.get("duration") or 0), + len(_d.get("audio_base64") or ""), + ) + elif msg_type is not None: + logger.info( + "WebSocket 收到消息 type={} conversation_id={}", + msg_type, + conversation_id, + ) + else: + logger.warning( + "WebSocket 收到缺少 type 的 JSON conversation_id={}", + conversation_id, + ) if msg_type == MessageType.TEXT: text_message = message.get("data", {}).get("text", "") @@ -628,6 +666,25 @@ async def websocket_endpoint( ) break + elif msg_type == MessageType.PING: + await manager.send_message( + conversation_id, + { + "type": MessageType.PONG, + "conversation_id": conversation_id, + "data": {}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) + + else: + if msg_type is not None: + logger.warning( + "WebSocket 未识别的消息 type={} conversation_id={}", + msg_type, + conversation_id, + ) + except RuntimeError as e: error_msg = str(e) if ( @@ -659,9 +716,11 @@ async def websocket_endpoint( except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") break - except WebSocketDisconnect: - logger.debug( - "WebSocket 断开连接: conversation_id={}", conversation_id + except WebSocketDisconnect as disc: + logger.info( + "WebSocket 断开连接(收消息循环): conversation_id={} code={}", + conversation_id, + getattr(disc, "code", None), ) break except Exception as e: @@ -680,8 +739,12 @@ async def websocket_endpoint( logger.warning(f"发送错误消息失败: {send_error}") break - except WebSocketDisconnect: - logger.debug("WebSocket 断开连接: conversation_id={}", conversation_id) + except WebSocketDisconnect as disc: + logger.info( + "WebSocket 断开连接: conversation_id={} code={}", + conversation_id, + getattr(disc, "code", None), + ) await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) except Exception as e: diff --git a/api/app/features/memoir/memoir_images/settings.py b/api/app/features/memoir/memoir_images/settings.py index 1b84b0e..29094f7 100644 --- a/api/app/features/memoir/memoir_images/settings.py +++ b/api/app/features/memoir/memoir_images/settings.py @@ -21,7 +21,7 @@ class MemoirImageSettings: poll_interval_seconds: int = DEFAULT_POLL_INTERVAL_SECONDS max_attempts: int = DEFAULT_MAX_ATTEMPTS liblib_template_uuid: str = DEFAULT_LIBLIB_TEMPLATE_UUID - story_image_min_body_chars: int = 800 + story_image_min_body_chars: int = 400 @classmethod def from_settings(cls, settings: "Settings") -> "MemoirImageSettings": diff --git a/api/app/features/memoir/service.py b/api/app/features/memoir/service.py index 34bf163..cdbf108 100644 --- a/api/app/features/memoir/service.py +++ b/api/app/features/memoir/service.py @@ -62,8 +62,10 @@ class MemoirService: "relevant_summaries": [], "relevant_facts": [], "timeline_hints": [], + "relevant_stories": [], } - return await self._memory.retrieve(user_id, query, top_k=top_k) + bundle = await self._memory.retrieve(user_id, query, top_k=top_k) + return bundle.model_dump() async def _cleanup_unavailable_images(self, ch: Chapter) -> None: cleaned = False diff --git a/api/app/features/memoir/story_pipeline_sync.py b/api/app/features/memoir/story_pipeline_sync.py index c025987..ff3be8b 100644 --- a/api/app/features/memoir/story_pipeline_sync.py +++ b/api/app/features/memoir/story_pipeline_sync.py @@ -26,7 +26,6 @@ from app.agents.memoir.story_route_agent import ( from app.agents.state_schema import MemoirStateSchema from app.core.logging import get_logger from app.features.memoir.cover_eligibility import chapter_needs_cover_enqueue -from app.features.memoir.helpers import _chapter_markdown from app.features.memoir.memoir_images.settings import MemoirImageSettings from app.features.memoir.models import Chapter from app.features.memoir.narrative_to_markdown import narrative_to_markdown @@ -46,26 +45,56 @@ from app.features.story.sync_write import ( logger = get_logger(__name__) -def _gate_narrative_fidelity(oral_text: str, narrative_raw: str, llm: Any) -> str: - """叙事 JSON 忠实度检查;不通过则回退为单段口述正文。""" +def _fidelity_fallback_json(oral: str, existing_canonical: str | None) -> str: + """忠实度未通过时的安全回退:续写场景保留旧文 + 本段口述,避免只剩一句。""" + o = (oral or "").strip()[:15000] + ex = (existing_canonical or "").strip()[:15000] + if ex and o: + return json.dumps( + {"paragraphs": [{"content": ex}, {"content": o}]}, + ensure_ascii=False, + ) + if ex: + return json.dumps( + {"paragraphs": [{"content": ex}]}, + ensure_ascii=False, + ) + return json.dumps( + {"paragraphs": [{"content": o}]}, + ensure_ascii=False, + ) + + +def _gate_narrative_fidelity( + oral_text: str, + narrative_raw: str, + llm: Any, + *, + existing_canonical: str | None = None, +) -> str: + """叙事 JSON 忠实度检查;不通过则回退为口述正文(续写时保留已有故事 + 口述)。""" from app.agents.memoir.fidelity_check_agent import FidelityCheckAgent if not settings.memoir_fidelity_check_enabled or not llm: return narrative_raw agent = FidelityCheckAgent() - if agent.passes(oral_text=oral_text, narrative_json=narrative_raw, llm=llm): + ex = (existing_canonical or "").strip() or None + if agent.passes( + oral_text=oral_text, + narrative_json=narrative_raw, + llm=llm, + existing_canonical_markdown=ex, + ): return narrative_raw logger.warning( - "event=fidelity_gate_fallback oral_len={}", + "event=fidelity_gate_fallback oral_len={} merge={}", len((oral_text or "").strip()), + bool(ex), ) o = (oral_text or "").strip() - if not o: + if not o and not ex: return narrative_raw - return json.dumps( - {"paragraphs": [{"content": o[:15000]}]}, - ensure_ascii=False, - ) + return _fidelity_fallback_json(o, ex) def _should_fallback_to_transcript(md: str, oral: str) -> bool: @@ -84,6 +113,28 @@ def _should_fallback_to_transcript(md: str, oral: str) -> bool: return len(m) < threshold +def _coalesce_story_markdown( + md: str, + oral: str, + existing_for_narrative: str, +) -> str: + """落库前对齐正文:空输出或过短回退时,续写场景保留「已有故事 + 本段口述」。""" + o = (oral or "").strip() + ex = (existing_for_narrative or "").strip() + m = (md or "").strip() + if not m: + if ex and o: + return f"{ex}\n\n{o}" + if o: + return o + return ex + if o and _should_fallback_to_transcript(m, o): + if ex: + return f"{ex}\n\n{o}" + return o + return m + + def _is_json_narrative(text: str) -> bool: if not text or not text.strip(): return False @@ -102,7 +153,6 @@ def _apply_narrative_fallbacks( narrative_raw: str, combined_unit_text: str, existing_for_narrative: str, - existing_chapter_md: str, *, chapter_category: str, ) -> str: @@ -130,22 +180,22 @@ def _apply_narrative_fallbacks( ) return f"{existing_for_narrative}\n\n{combined_unit_text}" - if ( - not existing_for_narrative - and existing_chapter_md - and not _is_json_narrative(narrative_raw) - and len(narrative_raw) < len(existing_chapter_md) * 0.8 - ): - logger.warning( - "event=narrative_fallback reason=chapter_length_anomaly action=append_transcript " - "chapter_category={}", - chapter_category, - ) - return f"{existing_chapter_md}\n\n{combined_unit_text}" + # 禁止把「章节级 canonical」(多故事拼接)写进单条 Story:会把全章正文塞进一个故事, + # 且该 story 若挂多章会导致各章阅读视图串台。新建故事时宁可短,也不拼接 existing_chapter_md。 md_check = narrative_to_markdown(narrative_raw).strip() oral = (combined_unit_text or "").strip() + ex_fb = (existing_for_narrative or "").strip() if oral and _should_fallback_to_transcript(md_check, oral): + if ex_fb: + logger.warning( + "event=narrative_fallback reason=body_too_short_vs_oral_merge " + "chapter_category={} oral_len={} md_len={}", + chapter_category, + len(oral), + len(md_check), + ) + return f"{ex_fb}\n\n{oral}" logger.warning( "event=narrative_fallback reason=body_too_short_vs_oral " "chapter_category={} oral_len={} md_len={}", @@ -210,7 +260,6 @@ def _run_batch_plan_writes( chapter: Chapter, chapter_category: str, evidence_text: str, - existing_chapter_md: str, slot_snippets: dict[str, str], user_id: str, user_profile: str, @@ -240,20 +289,24 @@ def _run_batch_plan_writes( birth_year=user_birth_year, llm=llm, ) - narrative_raw = _gate_narrative_fidelity(unit_text, narrative_raw, llm) + narrative_raw = _gate_narrative_fidelity( + unit_text, + narrative_raw, + llm, + existing_canonical=existing_for_narrative or None, + ) narrative_raw = _apply_narrative_fallbacks( narrative_raw, unit_text, existing_for_narrative, - existing_chapter_md, chapter_category=chapter_category, ) - md = narrative_to_markdown(narrative_raw).strip() - if not md: - md = unit_text.strip() - elif _should_fallback_to_transcript(md, unit_text.strip()): - md = unit_text.strip() + md = _coalesce_story_markdown( + narrative_to_markdown(narrative_raw).strip(), + unit_text.strip(), + existing_for_narrative or "", + ) if target_story_id: append_story_version_sync(session, target_story_id, md) @@ -347,7 +400,6 @@ def run_story_pipeline_for_category_batch( slot_snippets[key] = snip title = chapter.title if chapter else f"{chapter_category} 回忆" - existing_chapter_md = _chapter_markdown(chapter) if chapter else "" if not chapter: title = narrative_agent.generate_title( @@ -404,7 +456,6 @@ def run_story_pipeline_for_category_batch( chapter=chapter, chapter_category=chapter_category, evidence_text=evidence_text, - existing_chapter_md=existing_chapter_md, slot_snippets=slot_snippets, user_id=user_id, user_profile=user_profile, @@ -439,21 +490,25 @@ def run_story_pipeline_for_category_batch( birth_year=user_birth_year, llm=llm, ) - narrative_raw = _gate_narrative_fidelity(combined_text, narrative_raw, llm) + narrative_raw = _gate_narrative_fidelity( + combined_text, + narrative_raw, + llm, + existing_canonical=existing_for_narrative or None, + ) narrative_raw = _apply_narrative_fallbacks( narrative_raw, combined_text, existing_for_narrative, - existing_chapter_md, chapter_category=chapter_category, ) - md = narrative_to_markdown(narrative_raw).strip() - if not md: - md = combined_text.strip() - elif _should_fallback_to_transcript(md, combined_text.strip()): - md = combined_text.strip() + md = _coalesce_story_markdown( + narrative_to_markdown(narrative_raw).strip(), + combined_text.strip(), + existing_for_narrative or "", + ) do_append = target_story_id is not None diff --git a/api/app/features/memory/curation.py b/api/app/features/memory/curation.py index 3b3eefa..2792f40 100644 --- a/api/app/features/memory/curation.py +++ b/api/app/features/memory/curation.py @@ -1,17 +1 @@ -"""Memory curation actions — exclude / restore / correct / reject / confirm (skeleton).""" - - -async def exclude_chunk(chunk_id: str, *, user_id: str, reason: str = "") -> None: - raise NotImplementedError - - -async def restore_chunk(chunk_id: str, *, user_id: str) -> None: - raise NotImplementedError - - -async def confirm_fact(fact_id: str, *, user_id: str) -> None: - raise NotImplementedError - - -async def reject_fact(fact_id: str, *, user_id: str) -> None: - raise NotImplementedError +"""Memory curation — 业务入口为 `MemoryService` 的 exclude_chunk / restore_chunk / confirm_fact / reject_fact。""" diff --git a/api/app/features/memory/enrichment.py b/api/app/features/memory/enrichment.py new file mode 100644 index 0000000..f24c55b --- /dev/null +++ b/api/app/features/memory/enrichment.py @@ -0,0 +1,277 @@ +""" +Transcript ingest 之后的记忆富化:摘要、事实、时间线。 + +由 Celery(sync)与 MemoryService.ingest(async)调用;失败仅打日志,不阻断主流程。 +""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session + +from app.core.logging import get_logger +from app.features.memory.extractor import ( + extract_facts_from_transcript_async, + extract_facts_from_transcript_sync, +) +from app.features.memory.models import MemoryChunk, MemorySummary +from app.features.memory.repo import ( + create_memory_fact, + create_memory_fact_sync, + create_memory_summary, + create_memory_summary_sync, + create_timeline_event, + create_timeline_event_sync, + delete_timeline_events_by_memory_source, + delete_timeline_events_by_memory_source_sync, + list_chunks_for_source_sync, + upsert_rolling_summary_sync, +) +from app.features.memory.summarizer import ( + generate_rolling_summary_async, + generate_rolling_summary_sync, + generate_session_summary_async, + generate_session_summary_sync, +) +from app.features.memory.enrichment_pipeline import dedupe_key, normalize_object_json +from app.features.memory.timeline import ( + build_timeline_events_from_facts_async, + build_timeline_events_from_facts_sync, +) + +logger = get_logger(__name__) + + +def _resolve_llm_sync() -> Any | None: + try: + from app.core.dependencies import get_llm_provider + + return get_llm_provider().langchain_llm + except Exception as e: + logger.warning("memory enrichment 无法获取 LLM: {}", e) + return None + + +def enrich_memory_after_ingest_sync( + session: Session, + user_id: str, + source_id: str, + llm: Any | None = None, +) -> None: + from app.core.config import settings + + if not settings.memory_enrichment_enabled: + return + if llm is None: + llm = _resolve_llm_sync() + if not llm: + return + chunks = list_chunks_for_source_sync(session, source_id) + if not chunks: + return + chunk_texts = [c.content for c in chunks] + chunk_ids = [c.id for c in chunks] + numbered = "\n\n".join( + f"[chunk_id={cid}]\n{txt}" for cid, txt in zip(chunk_ids, chunk_texts) + ) + + session_summary_text = generate_session_summary_sync(llm, chunk_texts) + if session_summary_text: + create_memory_summary_sync( + session, + user_id=user_id, + summary_type="session", + content=session_summary_text, + source_chunk_ids=chunk_ids, + ) + + existing_rolling = ( + session.execute( + select(MemorySummary) + .where( + MemorySummary.user_id == user_id, + MemorySummary.summary_type == "rolling", + ) + .order_by(MemorySummary.updated_at.desc()) + .limit(1) + ) + .unique() + .scalar_one_or_none() + ) + existing_text = existing_rolling.content if existing_rolling else None + rolling_text = generate_rolling_summary_sync(llm, existing_text, chunk_texts) + if rolling_text: + upsert_rolling_summary_sync( + session, + user_id=user_id, + content=rolling_text, + source_chunk_ids=chunk_ids, + ) + + raw_facts = extract_facts_from_transcript_sync(llm, numbered) + seen: set[tuple] = set() + inserted: list[dict] = [] + for f in raw_facts: + key = dedupe_key(f) + if key in seen: + continue + seen.add(key) + scid = f.get("source_chunk_id") + if scid and scid not in chunk_ids: + scid = chunk_ids[0] if chunk_ids else None + row = create_memory_fact_sync( + session, + user_id=user_id, + fact_type=f.get("fact_type") or "event", + subject=f.get("subject"), + predicate=f.get("predicate"), + object_json=normalize_object_json(f.get("object_json")), + confidence=float(f.get("confidence") or 0.75), + source_chunk_id=scid, + status="confirmed", + ) + inserted.append( + { + "id": row.id, + "fact_type": row.fact_type, + "subject": row.subject, + "predicate": row.predicate, + "object_json": row.object_json, + } + ) + + if inserted: + delete_timeline_events_by_memory_source_sync( + session, user_id=user_id, memory_source_id=source_id + ) + events = build_timeline_events_from_facts_sync(llm, inserted) + for ev in events: + create_timeline_event_sync( + session, + user_id=user_id, + event_year=ev.get("event_year"), + event_date=ev.get("event_date"), + title=ev["title"], + description=ev.get("description"), + source_fact_ids=ev.get("source_fact_ids") or None, + memory_source_id=source_id, + ) + + +async def enrich_memory_after_ingest_async( + db: AsyncSession, + user_id: str, + source_id: str, + llm: Any | None = None, +) -> None: + from app.core.config import settings + + if not settings.memory_enrichment_enabled: + return + if llm is None: + llm = _resolve_llm_sync() + if not llm: + return + stmt = ( + select(MemoryChunk) + .where(MemoryChunk.source_id == source_id) + .order_by(MemoryChunk.chunk_index.asc()) + ) + result = await db.execute(stmt) + chunks = list(result.unique().scalars().all()) + if not chunks: + return + chunk_texts = [c.content for c in chunks] + chunk_ids = [c.id for c in chunks] + numbered = "\n\n".join( + f"[chunk_id={cid}]\n{txt}" for cid, txt in zip(chunk_ids, chunk_texts) + ) + + session_summary_text = await generate_session_summary_async(llm, chunk_texts) + if session_summary_text: + await create_memory_summary( + db, + user_id=user_id, + summary_type="session", + content=session_summary_text, + source_chunk_ids=chunk_ids, + ) + + roll_stmt = ( + select(MemorySummary) + .where( + MemorySummary.user_id == user_id, + MemorySummary.summary_type == "rolling", + ) + .order_by(MemorySummary.updated_at.desc()) + .limit(1) + ) + r_result = await db.execute(roll_stmt) + existing_row = r_result.unique().scalar_one_or_none() + existing_text = existing_row.content if existing_row else None + + rolling_text = await generate_rolling_summary_async(llm, existing_text, chunk_texts) + if rolling_text: + if existing_row: + existing_row.content = rolling_text + existing_row.source_chunk_ids = chunk_ids + else: + await create_memory_summary( + db, + user_id=user_id, + summary_type="rolling", + content=rolling_text, + source_chunk_ids=chunk_ids, + ) + + raw_facts = await extract_facts_from_transcript_async(llm, numbered) + seen: set[tuple] = set() + inserted: list[dict] = [] + for f in raw_facts: + key = dedupe_key(f) + if key in seen: + continue + seen.add(key) + scid = f.get("source_chunk_id") + if scid and scid not in chunk_ids: + scid = chunk_ids[0] if chunk_ids else None + row = await create_memory_fact( + db, + user_id=user_id, + fact_type=f.get("fact_type") or "event", + subject=f.get("subject"), + predicate=f.get("predicate"), + object_json=normalize_object_json(f.get("object_json")), + confidence=float(f.get("confidence") or 0.75), + source_chunk_id=scid, + status="confirmed", + ) + inserted.append( + { + "id": row.id, + "fact_type": row.fact_type, + "subject": row.subject, + "predicate": row.predicate, + "object_json": row.object_json, + } + ) + + if inserted: + await delete_timeline_events_by_memory_source( + db, user_id=user_id, memory_source_id=source_id + ) + events = await build_timeline_events_from_facts_async(llm, inserted) + for ev in events: + await create_timeline_event( + db, + user_id=user_id, + event_year=ev.get("event_year"), + event_date=ev.get("event_date"), + title=ev["title"], + description=ev.get("description"), + source_fact_ids=ev.get("source_fact_ids") or None, + memory_source_id=source_id, + ) diff --git a/api/app/features/memory/enrichment_pipeline.py b/api/app/features/memory/enrichment_pipeline.py new file mode 100644 index 0000000..76a4999 --- /dev/null +++ b/api/app/features/memory/enrichment_pipeline.py @@ -0,0 +1,25 @@ +"""Enrichment 共享:去重键与 object_json 规范化(sync/async 共用)。""" + +from __future__ import annotations + +import json +from typing import Any + + +def dedupe_key(f: dict) -> tuple: + s = f.get("subject") or "" + p = f.get("predicate") or "" + o = f.get("object_json") + try: + oj = json.dumps(o, sort_keys=True, ensure_ascii=False) if o is not None else "" + except (TypeError, ValueError): + oj = str(o) + return (str(s), str(p), oj) + + +def normalize_object_json(obj: Any) -> dict | list | None: + if obj is None: + return None + if isinstance(obj, (dict, list)): + return obj + return {"value": obj} diff --git a/api/app/features/memory/evidence.py b/api/app/features/memory/evidence.py new file mode 100644 index 0000000..7a73bf6 --- /dev/null +++ b/api/app/features/memory/evidence.py @@ -0,0 +1,244 @@ +""" +证据包组装:跨 memory + story 的检索结果合并(业务层,非纯 repo)。 + +Celery 使用 sync;`HybridRetriever` 使用 async + RRF chunk 合并。 +""" + +from __future__ import annotations + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session + +from app.core.config import settings +from app.features.memory.repo import ( + list_summaries_for_evidence_async, + list_summaries_for_evidence_sync, + search_chunks_fts, + search_chunks_fts_sync, + search_facts_for_user_async, + search_facts_for_user_sync, + search_timeline_events_for_user_async, + search_timeline_events_for_user_sync, +) +from app.features.story.repo import ( + list_recent_stories_for_evidence, + list_recent_stories_for_evidence_sync, +) + +EMPTY_EVIDENCE_BUNDLE: dict = { + "relevant_chunks": [], + "relevant_summaries": [], + "relevant_facts": [], + "timeline_hints": [], + "relevant_stories": [], +} + + +def _facts_to_dicts(facts) -> list[dict]: + return [ + { + "id": f.id, + "fact_type": f.fact_type, + "subject": f.subject, + "predicate": f.predicate, + "object_json": f.object_json, + } + for f in facts + ] + + +def _timeline_to_dicts(events) -> list[dict]: + return [ + { + "id": e.id, + "event_year": e.event_year, + "event_date": e.event_date, + "title": e.title, + "description": e.description, + } + for e in events + ] + + +def _stories_to_dicts(story_rows) -> list[dict]: + return [ + { + "id": s.id, + "title": s.title, + "summary": s.summary, + "stage": s.stage, + "story_type": s.story_type, + } + for s in story_rows + ] + + +def fetch_evidence_metadata_sync( + session: Session, user_id: str, q: str, top_k: int +) -> dict: + """非 chunk 证据:摘要、事实、时间线、故事(sync)。""" + facts = search_facts_for_user_sync(session, user_id, q, top_k) + events = search_timeline_events_for_user_sync(session, user_id, q, top_k) + relevant_summaries = list_summaries_for_evidence_sync( + session, user_id=user_id, q=q, limit=top_k + ) + story_rows = list_recent_stories_for_evidence_sync( + session, user_id, query=q, limit=top_k + ) + return { + "relevant_facts": _facts_to_dicts(facts), + "timeline_hints": _timeline_to_dicts(events), + "relevant_summaries": relevant_summaries, + "relevant_stories": _stories_to_dicts(story_rows), + } + + +async def fetch_evidence_metadata_async( + db: AsyncSession, user_id: str, q: str, top_k: int +) -> dict: + """非 chunk 证据(async)。""" + facts = await search_facts_for_user_async(db, user_id, q, top_k) + events = await search_timeline_events_for_user_async(db, user_id, q, top_k) + relevant_summaries = await list_summaries_for_evidence_async( + db, user_id=user_id, q=q, limit=top_k + ) + story_rows = await list_recent_stories_for_evidence( + db, user_id=user_id, query=q, limit=top_k + ) + return { + "relevant_facts": _facts_to_dicts(facts), + "timeline_hints": _timeline_to_dicts(events), + "relevant_summaries": relevant_summaries, + "relevant_stories": _stories_to_dicts(story_rows), + } + + +def _empty_query_bundle_sync(session: Session, user_id: str, top_k: int) -> dict: + """无 FTS query 时的「浏览」降级:rolling 摘要 + 事实/时间线 fallback。""" + from app.features.memory.models import MemorySummary + from sqlalchemy import select + + from app.features.memory.repo import ( + get_facts_for_user_sync, + get_timeline_events_for_user_sync, + ) + + rolling = ( + session.execute( + select(MemorySummary) + .where( + MemorySummary.user_id == user_id, + MemorySummary.summary_type == "rolling", + ) + .order_by(MemorySummary.updated_at.desc()) + .limit(1) + ) + .unique() + .scalar_one_or_none() + ) + summaries = [] + if rolling: + summaries = [ + { + "id": rolling.id, + "summary_type": rolling.summary_type, + "content": rolling.content, + "source_chunk_ids": rolling.source_chunk_ids, + } + ] + facts = get_facts_for_user_sync(session, user_id, top_k) + events = get_timeline_events_for_user_sync(session, user_id, top_k) + return { + "relevant_chunks": [], + "relevant_summaries": summaries, + "relevant_facts": _facts_to_dicts(facts), + "timeline_hints": _timeline_to_dicts(events), + "relevant_stories": [], + } + + +async def _empty_query_bundle_async(db: AsyncSession, user_id: str, top_k: int) -> dict: + from sqlalchemy import select + + from app.features.memory.models import MemorySummary + from app.features.memory.repo import ( + get_facts_for_user, + get_timeline_events_for_user, + ) + + roll_stmt = ( + select(MemorySummary) + .where( + MemorySummary.user_id == user_id, + MemorySummary.summary_type == "rolling", + ) + .order_by(MemorySummary.updated_at.desc()) + .limit(1) + ) + r_result = await db.execute(roll_stmt) + rolling = r_result.unique().scalar_one_or_none() + summaries = [] + if rolling: + summaries = [ + { + "id": rolling.id, + "summary_type": rolling.summary_type, + "content": rolling.content, + "source_chunk_ids": rolling.source_chunk_ids, + } + ] + facts = await get_facts_for_user(db, user_id=user_id, limit=top_k) + events = await get_timeline_events_for_user(db, user_id=user_id, limit=top_k) + return { + "relevant_chunks": [], + "relevant_summaries": summaries, + "relevant_facts": _facts_to_dicts(facts), + "timeline_hints": _timeline_to_dicts(events), + "relevant_stories": [], + } + + +def retrieve_evidence_bundle_sync( + session: Session, user_id: str, query: str, *, top_k: int = 10 +) -> dict: + """Celery / 叙事流水线:FTS-only chunks + 元数据。""" + if not query or not query.strip(): + if settings.memory_evidence_empty_query_include_rolling: + return _empty_query_bundle_sync(session, user_id, top_k) + return dict(EMPTY_EVIDENCE_BUNDLE) + q = query.strip() + chunk_rows = search_chunks_fts_sync(session, user_id, q, top_k) + relevant_chunks = [ + {"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]} + for r in chunk_rows + ] + meta = fetch_evidence_metadata_sync(session, user_id, q, top_k) + return { + "relevant_chunks": relevant_chunks, + **meta, + } + + +async def retrieve_evidence_bundle_async( + db: AsyncSession, + user_id: str, + query: str, + *, + top_k: int = 10, + merged_chunk_dicts: list[dict], +) -> dict: + """ + 异步路径:chunk 已由调用方 RRF 合并;此处只拼元数据。 + + merged_chunk_dicts: [{"id","content","chunk_index"}, ...] + """ + if not query or not query.strip(): + if settings.memory_evidence_empty_query_include_rolling: + return await _empty_query_bundle_async(db, user_id, top_k) + return dict(EMPTY_EVIDENCE_BUNDLE) + q = query.strip() + meta = await fetch_evidence_metadata_async(db, user_id, q, top_k) + return { + "relevant_chunks": merged_chunk_dicts, + **meta, + } diff --git a/api/app/features/memory/extractor.py b/api/app/features/memory/extractor.py index 7d77b69..24856c6 100644 --- a/api/app/features/memory/extractor.py +++ b/api/app/features/memory/extractor.py @@ -1,6 +1,92 @@ -"""Entity / event / fact extraction from chunks (skeleton).""" +"""从 transcript 块中抽取结构化事实(LLM + JSON)。""" + +from __future__ import annotations + +from typing import Any + +from app.core.langchain_llm import ainvoke_json_object, invoke_json_object +from app.core.logging import get_logger +from app.features.memory.llm_schemas import ( + FactsExtractionPayload, + facts_payload_to_dicts, + parse_json_payload, +) + +logger = get_logger(__name__) + + +def _max_transcript_chars() -> int: + from app.core.config import settings + + return settings.memory_enrichment_max_chars + + +def extract_facts_from_transcript_sync(llm: Any, numbered_blocks: str) -> list[dict]: + """同步:带 chunk_id 标记的文本 → 事实列表。""" + if not llm or not (numbered_blocks or "").strip(): + return [] + text = numbered_blocks.strip()[: _max_transcript_chars()] + prompt = ( + "你是回忆录记忆抽取助手。阅读下列带 [chunk_id=...] 的文本块,抽取可核查的事实。\n" + "每个事实含 fact_type: person|event|relation|place|milestone;subject;predicate;" + "object_json(可为字符串或对象);confidence 0..1;source_chunk_id 必须等于某段的 chunk id。\n" + '只输出 JSON:{"facts":[...]},无事实则 {"facts":[]}。\n\n' + f"{text}" + ) + try: + raw = invoke_json_object( + llm, + prompt, + max_tokens=4096, + agent="memory.extract_facts_sync", + ) + parsed = parse_json_payload(raw, FactsExtractionPayload) + if parsed is None: + return [] + return facts_payload_to_dicts(parsed) + except (TypeError, ValueError) as e: + logger.warning("extract_facts_from_transcript_sync 解析失败: {}", e) + return [] + + +async def extract_facts_from_transcript_async( + llm: Any, numbered_blocks: str +) -> list[dict]: + """异步版。""" + if not llm or not (numbered_blocks or "").strip(): + return [] + text = numbered_blocks.strip()[: _max_transcript_chars()] + prompt = ( + "你是回忆录记忆抽取助手。阅读下列带 [chunk_id=...] 的文本块,抽取可核查的事实。\n" + "每个事实含 fact_type: person|event|relation|place|milestone;subject;predicate;" + "object_json;confidence 0..1;source_chunk_id 必须等于某段的 chunk id。\n" + '只输出 JSON:{"facts":[...]},无事实则 {"facts":[]}。\n\n' + f"{text}" + ) + try: + raw = await ainvoke_json_object( + llm, + prompt, + max_tokens=4096, + agent="memory.extract_facts_async", + ) + parsed = parse_json_payload(raw, FactsExtractionPayload) + if parsed is None: + return [] + return facts_payload_to_dicts(parsed) + except (TypeError, ValueError) as e: + logger.warning("extract_facts_from_transcript_async 解析失败: {}", e) + return [] async def extract_facts(chunk_text: str, *, user_id: str) -> list[dict]: - """Extract structured facts from a text chunk using LLM.""" - raise NotImplementedError + """兼容旧接口:单块文本(无 chunk id 时传空 source_chunk_id)。""" + from app.core.dependencies import get_llm_provider + + llm = get_llm_provider().langchain_llm + blocks = f"[chunk_id=null]\n{chunk_text}" + facts = await extract_facts_from_transcript_async(llm, blocks) + for f in facts: + if f.get("source_chunk_id") in (None, "null", ""): + f["source_chunk_id"] = None + return facts diff --git a/api/app/features/memory/llm_schemas.py b/api/app/features/memory/llm_schemas.py new file mode 100644 index 0000000..c425731 --- /dev/null +++ b/api/app/features/memory/llm_schemas.py @@ -0,0 +1,103 @@ +"""LLM JSON 输出校验(memory 富化)。""" + +from __future__ import annotations + +import json +from typing import Any, TypeVar + +from pydantic import BaseModel, Field, field_validator + +TModel = TypeVar("TModel", bound=BaseModel) + + +class ExtractedFactItem(BaseModel): + fact_type: str = "event" + subject: str | None = None + predicate: str | None = None + object_json: Any = None + confidence: float = Field(default=0.75, ge=0.0, le=1.0) + source_chunk_id: str | None = None + + @field_validator("fact_type", mode="before") + @classmethod + def _coerce_fact_type(cls, v: object) -> str: + ft = str(v or "event").strip() or "event" + if ft not in ("person", "event", "relation", "place", "milestone"): + return "event" + return ft + + +class FactsExtractionPayload(BaseModel): + facts: list[ExtractedFactItem] = Field(default_factory=list) + + +class SessionSummaryPayload(BaseModel): + summary: str = "" + + +class RollingSummaryPayload(BaseModel): + rolling_summary: str = "" + + +class TimelineEventItem(BaseModel): + event_year: int | None = None + event_date: str | None = None + title: str = "" + description: str | None = None + source_fact_ids: list[str] = Field(default_factory=list) + + @field_validator("source_fact_ids", mode="before") + @classmethod + def _coerce_sf(cls, v: object) -> list[str]: + if v is None: + return [] + if isinstance(v, str): + return [v] if v else [] + if isinstance(v, list): + return [str(x) for x in v if x] + return [] + + +class TimelineEventsPayload(BaseModel): + events: list[TimelineEventItem] = Field(default_factory=list) + + +def parse_json_payload(raw: str, model: type[TModel]) -> TModel | None: + """解析 invoke_json_object 返回的 JSON 字符串。""" + from app.features.memoir.memoir_images.json_payload import extract_json_payload + + try: + cleaned = extract_json_payload(raw) + data = json.loads(cleaned) + return model.model_validate(data) + except (json.JSONDecodeError, ValueError, TypeError): + return None + + +def facts_payload_to_dicts(payload: FactsExtractionPayload) -> list[dict]: + out: list[dict] = [] + for item in payload.facts: + d = item.model_dump() + scid = d.get("source_chunk_id") + if scid is not None and not isinstance(scid, str): + d["source_chunk_id"] = str(scid) + out.append(d) + return out + + +def timeline_payload_to_dicts(payload: TimelineEventsPayload) -> list[dict]: + out: list[dict] = [] + for ev in payload.events: + title = (ev.title or "").strip() + if not title: + continue + out.append( + { + "event_year": ev.event_year, + "event_date": ev.event_date, + "title": title, + "description": ev.description, + "source_fact_ids": ev.source_fact_ids or [], + } + ) + return out[:20] diff --git a/api/app/features/memory/models.py b/api/app/features/memory/models.py index 6593c78..7149a61 100644 --- a/api/app/features/memory/models.py +++ b/api/app/features/memory/models.py @@ -87,6 +87,12 @@ class TimelineEvent(Base): __tablename__ = "timeline_events" id = Column(String, primary_key=True) user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True) + memory_source_id = Column( + String, + ForeignKey("memory_sources.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) event_year = Column(Integer, nullable=True) event_date = Column(String, nullable=True) title = Column(String, nullable=False) diff --git a/api/app/features/memory/repo.py b/api/app/features/memory/repo.py index 228aded..4920468 100644 --- a/api/app/features/memory/repo.py +++ b/api/app/features/memory/repo.py @@ -3,14 +3,16 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import select, text +from sqlalchemy import delete, or_, select, text from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.features.memory.models import ( MemoryChunk, + MemoryCurationAction, MemoryFact, MemorySource, + MemorySummary, TimelineEvent, ) @@ -182,6 +184,153 @@ async def get_facts_for_user( return list(result.unique().scalars().all()) +def get_facts_for_user_sync( + session: Session, user_id: str, limit: int = 20 +) -> list[MemoryFact]: + stmt = ( + select(MemoryFact) + .where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed") + .order_by(MemoryFact.created_at.desc()) + .limit(limit) + ) + return list(session.execute(stmt).unique().scalars().all()) + + +def get_timeline_events_for_user_sync( + session: Session, user_id: str, limit: int = 20 +) -> list[TimelineEvent]: + stmt = ( + select(TimelineEvent) + .where(TimelineEvent.user_id == user_id) + .order_by( + TimelineEvent.event_year.desc().nullslast(), TimelineEvent.created_at.desc() + ) + .limit(limit) + ) + return list(session.execute(stmt).unique().scalars().all()) + + +def search_chunks_fts_sync( + session: Session, user_id: str, query: str, limit: int = 20 +) -> list[dict]: + """FTS on memory_chunks(sync,Celery)。""" + if not query or not query.strip(): + return [] + q = query.strip() + stmt = text(""" + SELECT id, content, chunk_index + FROM memory_chunks + WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false) + AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q) + ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC + LIMIT :lim + """) + result = session.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": limit}) + rows = result.mappings().all() + return [ + {"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]} + for r in rows + ] + + +def search_facts_for_user_sync( + session: Session, user_id: str, query: str, limit: int = 20 +) -> list[MemoryFact]: + q = (query or "").strip() + if not q: + return get_facts_for_user_sync(session, user_id, limit) + pat = f"%{q}%" + stmt = ( + select(MemoryFact) + .where( + MemoryFact.user_id == user_id, + MemoryFact.status == "confirmed", + or_(MemoryFact.subject.ilike(pat), MemoryFact.predicate.ilike(pat)), + ) + .order_by(MemoryFact.created_at.desc()) + .limit(limit) + ) + rows = list(session.execute(stmt).unique().scalars().all()) + if rows: + return rows + return get_facts_for_user_sync(session, user_id, limit) + + +async def search_facts_for_user_async( + db: AsyncSession, user_id: str, query: str, limit: int = 20 +) -> list[MemoryFact]: + q = (query or "").strip() + if not q: + return await get_facts_for_user(db, user_id=user_id, limit=limit) + pat = f"%{q}%" + stmt = ( + select(MemoryFact) + .where( + MemoryFact.user_id == user_id, + MemoryFact.status == "confirmed", + or_(MemoryFact.subject.ilike(pat), MemoryFact.predicate.ilike(pat)), + ) + .order_by(MemoryFact.created_at.desc()) + .limit(limit) + ) + result = await db.execute(stmt) + rows = list(result.unique().scalars().all()) + if rows: + return rows + return await get_facts_for_user(db, user_id=user_id, limit=limit) + + +def search_timeline_events_for_user_sync( + session: Session, user_id: str, query: str, limit: int = 20 +) -> list[TimelineEvent]: + q = (query or "").strip() + if not q: + return get_timeline_events_for_user_sync(session, user_id, limit) + pat = f"%{q}%" + stmt = ( + select(TimelineEvent) + .where( + TimelineEvent.user_id == user_id, + or_( + TimelineEvent.title.ilike(pat), + TimelineEvent.description.ilike(pat), + ), + ) + .order_by(TimelineEvent.event_year.desc().nullslast()) + .limit(limit) + ) + rows = list(session.execute(stmt).unique().scalars().all()) + if rows: + return rows + return get_timeline_events_for_user_sync(session, user_id, limit) + + +async def search_timeline_events_for_user_async( + db: AsyncSession, user_id: str, query: str, limit: int = 20 +) -> list[TimelineEvent]: + q = (query or "").strip() + if not q: + return await get_timeline_events_for_user(db, user_id=user_id, limit=limit) + pat = f"%{q}%" + stmt = ( + select(TimelineEvent) + .where( + TimelineEvent.user_id == user_id, + or_( + TimelineEvent.title.ilike(pat), + TimelineEvent.description.ilike(pat), + ), + ) + .order_by(TimelineEvent.event_year.desc().nullslast()) + .limit(limit) + ) + result = await db.execute(stmt) + rows = list(result.unique().scalars().all()) + if rows: + return rows + return await get_timeline_events_for_user(db, user_id=user_id, limit=limit) + + async def search_chunks_vector( db: AsyncSession, user_id: str, query_embedding: list[float], limit: int = 20 ) -> list[dict]: @@ -215,6 +364,57 @@ async def search_chunks_vector( ] +def list_summaries_for_evidence_sync( + session: Session, *, user_id: str, q: str, limit: int +) -> list[dict]: + """最新 rolling + 内容匹配 query 的摘要(ILIKE)。""" + pat = f"%{q}%" + rolling = ( + session.execute( + select(MemorySummary) + .where( + MemorySummary.user_id == user_id, + MemorySummary.summary_type == "rolling", + ) + .order_by(MemorySummary.updated_at.desc()) + .limit(1) + ) + .unique() + .scalar_one_or_none() + ) + rows: list[MemorySummary] = [] + seen: set[str] = set() + if rolling: + rows.append(rolling) + seen.add(rolling.id) + rest = limit - len(rows) + if rest > 0: + stmt = ( + select(MemorySummary) + .where( + MemorySummary.user_id == user_id, + MemorySummary.content.ilike(pat), + ) + .order_by(MemorySummary.updated_at.desc()) + .limit(rest + len(seen)) + ) + for s in session.execute(stmt).unique().scalars().all(): + if s.id not in seen: + rows.append(s) + seen.add(s.id) + if len(rows) >= limit: + break + return [ + { + "id": s.id, + "summary_type": s.summary_type, + "content": s.content, + "source_chunk_ids": s.source_chunk_ids, + } + for s in rows[:limit] + ] + + def retrieve_evidence_sync( session: Session, user_id: str, query: str, *, top_k: int = 10 ) -> dict: @@ -222,76 +422,11 @@ def retrieve_evidence_sync( Sync evidence retrieval for Celery tasks. 能力:**仅 FTS** 检索 chunks(与 `HybridRetriever` 的 FTS+向量 RRF 不同,见 - `api/docs/memory-retrieval.md`);confirmed facts;timeline。 - ingest 在 Celery 任务内先于 story 流水线执行并已 commit,故本检索可见刚写入的 chunk。 + `api/docs/memory-retrieval.md`);facts/timeline 按 query ILIKE;fallback 见 repo。 """ - if not query or not query.strip(): - return { - "relevant_chunks": [], - "relevant_summaries": [], - "relevant_facts": [], - "timeline_hints": [], - "relevant_stories": [], - } - q = query.strip() - # FTS chunks - stmt = text(""" - SELECT id, content, chunk_index - FROM memory_chunks - WHERE user_id = :user_id AND (is_excluded IS NOT TRUE OR is_excluded = false) - AND content_tsv IS NOT NULL AND content_tsv @@ plainto_tsquery('simple', :q) - ORDER BY ts_rank_cd(content_tsv, plainto_tsquery('simple', :q2)) DESC - LIMIT :lim - """) - result = session.execute(stmt, {"user_id": user_id, "q": q, "q2": q, "lim": top_k}) - rows = result.mappings().all() - relevant_chunks = [ - {"id": r["id"], "content": r["content"], "chunk_index": r["chunk_index"]} - for r in rows - ] - # Facts - facts_stmt = ( - select(MemoryFact) - .where(MemoryFact.user_id == user_id, MemoryFact.status == "confirmed") - .order_by(MemoryFact.created_at.desc()) - .limit(top_k) - ) - facts = list(session.execute(facts_stmt).unique().scalars().all()) - relevant_facts = [ - { - "id": f.id, - "fact_type": f.fact_type, - "subject": f.subject, - "predicate": f.predicate, - "object_json": f.object_json, - } - for f in facts - ] - # Timeline - events_stmt = ( - select(TimelineEvent) - .where(TimelineEvent.user_id == user_id) - .order_by(TimelineEvent.event_year.desc().nullslast()) - .limit(top_k) - ) - events = list(session.execute(events_stmt).unique().scalars().all()) - timeline_hints = [ - { - "id": e.id, - "event_year": e.event_year, - "event_date": e.event_date, - "title": e.title, - "description": e.description, - } - for e in events - ] - return { - "relevant_chunks": relevant_chunks, - "relevant_summaries": [], - "relevant_facts": relevant_facts, - "timeline_hints": timeline_hints, - "relevant_stories": [], - } + from app.features.memory.evidence import retrieve_evidence_bundle_sync + + return retrieve_evidence_bundle_sync(session, user_id, query, top_k=top_k) async def get_timeline_events_for_user( @@ -320,3 +455,348 @@ async def list_storage_keys_for_conversation( ) result = await db.execute(stmt) return sorted({r for r in result.scalars().all() if r}) + + +def list_chunks_for_source_sync(session: Session, source_id: str) -> list[MemoryChunk]: + stmt = ( + select(MemoryChunk) + .where(MemoryChunk.source_id == source_id) + .order_by(MemoryChunk.chunk_index.asc()) + ) + return list(session.execute(stmt).unique().scalars().all()) + + +def create_memory_summary_sync( + session: Session, + *, + user_id: str, + summary_type: str, + content: str, + source_chunk_ids: list[str] | None = None, +) -> MemorySummary: + row = MemorySummary( + id=_new_id(), + user_id=user_id, + summary_type=summary_type, + content=content, + source_chunk_ids=source_chunk_ids, + ) + session.add(row) + return row + + +async def create_memory_summary( + db: AsyncSession, + *, + user_id: str, + summary_type: str, + content: str, + source_chunk_ids: list[str] | None = None, +) -> MemorySummary: + row = MemorySummary( + id=_new_id(), + user_id=user_id, + summary_type=summary_type, + content=content, + source_chunk_ids=source_chunk_ids, + ) + db.add(row) + return row + + +def get_latest_rolling_summary_sync( + session: Session, user_id: str +) -> MemorySummary | None: + stmt = ( + select(MemorySummary) + .where( + MemorySummary.user_id == user_id, + MemorySummary.summary_type == "rolling", + ) + .order_by(MemorySummary.updated_at.desc()) + .limit(1) + ) + return session.execute(stmt).unique().scalar_one_or_none() + + +def upsert_rolling_summary_sync( + session: Session, + *, + user_id: str, + content: str, + source_chunk_ids: list[str] | None = None, +) -> MemorySummary: + existing = get_latest_rolling_summary_sync(session, user_id) + if existing: + existing.content = content + if source_chunk_ids is not None: + existing.source_chunk_ids = source_chunk_ids + return existing + return create_memory_summary_sync( + session, + user_id=user_id, + summary_type="rolling", + content=content, + source_chunk_ids=source_chunk_ids, + ) + + +def create_memory_fact_sync( + session: Session, + *, + user_id: str, + fact_type: str, + subject: str | None, + predicate: str | None, + object_json: dict | None, + confidence: float, + source_chunk_id: str | None, + status: str = "confirmed", +) -> MemoryFact: + row = MemoryFact( + id=_new_id(), + user_id=user_id, + fact_type=fact_type, + subject=subject, + predicate=predicate, + object_json=object_json, + confidence=confidence, + source_chunk_id=source_chunk_id, + status=status, + ) + session.add(row) + return row + + +async def create_memory_fact( + db: AsyncSession, + *, + user_id: str, + fact_type: str, + subject: str | None, + predicate: str | None, + object_json: dict | None, + confidence: float, + source_chunk_id: str | None, + status: str = "confirmed", +) -> MemoryFact: + row = MemoryFact( + id=_new_id(), + user_id=user_id, + fact_type=fact_type, + subject=subject, + predicate=predicate, + object_json=object_json, + confidence=confidence, + source_chunk_id=source_chunk_id, + status=status, + ) + db.add(row) + return row + + +async def get_memory_fact_for_user( + db: AsyncSession, fact_id: str, user_id: str +) -> MemoryFact | None: + row = await db.get(MemoryFact, fact_id) + if row is None or row.user_id != user_id: + return None + return row + + +async def set_memory_fact_status( + db: AsyncSession, fact_id: str, user_id: str, status: str +) -> bool: + row = await get_memory_fact_for_user(db, fact_id, user_id) + if row is None: + return False + row.status = status + return True + + +def delete_timeline_events_by_memory_source_sync( + session: Session, *, user_id: str, memory_source_id: str +) -> int: + stmt = delete(TimelineEvent).where( + TimelineEvent.user_id == user_id, + TimelineEvent.memory_source_id == memory_source_id, + ) + result = session.execute(stmt) + return result.rowcount or 0 + + +async def delete_timeline_events_by_memory_source( + db: AsyncSession, *, user_id: str, memory_source_id: str +) -> int: + stmt = delete(TimelineEvent).where( + TimelineEvent.user_id == user_id, + TimelineEvent.memory_source_id == memory_source_id, + ) + result = await db.execute(stmt) + return result.rowcount or 0 + + +def create_timeline_event_sync( + session: Session, + *, + user_id: str, + event_year: int | None, + event_date: str | None, + title: str, + description: str | None, + person_refs: list | None = None, + source_fact_ids: list[str] | None = None, + memory_source_id: str | None = None, +) -> TimelineEvent: + row = TimelineEvent( + id=_new_id(), + user_id=user_id, + memory_source_id=memory_source_id, + event_year=event_year, + event_date=event_date, + title=title, + description=description, + person_refs=person_refs, + source_fact_ids=source_fact_ids, + ) + session.add(row) + return row + + +async def create_timeline_event( + db: AsyncSession, + *, + user_id: str, + event_year: int | None, + event_date: str | None, + title: str, + description: str | None, + person_refs: list | None = None, + source_fact_ids: list[str] | None = None, + memory_source_id: str | None = None, +) -> TimelineEvent: + row = TimelineEvent( + id=_new_id(), + user_id=user_id, + memory_source_id=memory_source_id, + event_year=event_year, + event_date=event_date, + title=title, + description=description, + person_refs=person_refs, + source_fact_ids=source_fact_ids, + ) + db.add(row) + return row + + +def create_curation_action_sync( + session: Session, + *, + user_id: str, + action_type: str, + target_type: str, + target_id: str, + details: dict | None = None, +) -> MemoryCurationAction: + row = MemoryCurationAction( + id=_new_id(), + user_id=user_id, + action_type=action_type, + target_type=target_type, + target_id=target_id, + details=details, + ) + session.add(row) + return row + + +async def create_curation_action( + db: AsyncSession, + *, + user_id: str, + action_type: str, + target_type: str, + target_id: str, + details: dict | None = None, +) -> MemoryCurationAction: + row = MemoryCurationAction( + id=_new_id(), + user_id=user_id, + action_type=action_type, + target_type=target_type, + target_id=target_id, + details=details, + ) + db.add(row) + return row + + +async def get_memory_chunk_for_user( + db: AsyncSession, chunk_id: str, user_id: str +) -> MemoryChunk | None: + row = await db.get(MemoryChunk, chunk_id) + if row is None or row.user_id != user_id: + return None + return row + + +async def set_chunk_excluded( + db: AsyncSession, chunk_id: str, user_id: str, excluded: bool +) -> bool: + row = await get_memory_chunk_for_user(db, chunk_id, user_id) + if row is None: + return False + row.is_excluded = excluded + return True + + +async def list_summaries_for_evidence_async( + db: AsyncSession, *, user_id: str, q: str, limit: int +) -> list[dict]: + if not (q or "").strip(): + return [] + pat = f"%{q.strip()}%" + rolling_stmt = ( + select(MemorySummary) + .where( + MemorySummary.user_id == user_id, + MemorySummary.summary_type == "rolling", + ) + .order_by(MemorySummary.updated_at.desc()) + .limit(1) + ) + r_result = await db.execute(rolling_stmt) + rolling = r_result.unique().scalar_one_or_none() + rows: list[MemorySummary] = [] + seen: set[str] = set() + if rolling: + rows.append(rolling) + seen.add(rolling.id) + rest = limit - len(rows) + if rest > 0: + stmt = ( + select(MemorySummary) + .where( + MemorySummary.user_id == user_id, + MemorySummary.content.ilike(pat), + ) + .order_by(MemorySummary.updated_at.desc()) + .limit(rest + len(seen)) + ) + o_result = await db.execute(stmt) + for s in o_result.unique().scalars().all(): + if s.id not in seen: + rows.append(s) + seen.add(s.id) + if len(rows) >= limit: + break + return [ + { + "id": s.id, + "summary_type": s.summary_type, + "content": s.content, + "source_chunk_ids": s.source_chunk_ids, + } + for s in rows[:limit] + ] diff --git a/api/app/features/memory/retriever.py b/api/app/features/memory/retriever.py index eeeb249..904fe2d 100644 --- a/api/app/features/memory/retriever.py +++ b/api/app/features/memory/retriever.py @@ -2,12 +2,8 @@ from sqlalchemy.ext.asyncio import AsyncSession -from app.features.memory.repo import ( - get_facts_for_user, - get_timeline_events_for_user, - search_chunks_fts, - search_chunks_vector, -) +from app.features.memory.evidence import retrieve_evidence_bundle_async +from app.features.memory.repo import search_chunks_fts, search_chunks_vector from app.ports.embedding import EmbeddingProvider @@ -44,24 +40,31 @@ class HybridRetriever: """ Return evidence bundle: {relevant_chunks, relevant_summaries, relevant_facts, timeline_hints, relevant_stories} - - `relevant_summaries` / `relevant_stories` 当前多为占位空列表;叙事 prompt 仅应依赖 - 已实现填充的字段(见 `format_evidence_chunks_for_prompt`)。 """ + if not query.strip(): + return await retrieve_evidence_bundle_async( + self._db, + user_id, + query, + top_k=top_k, + merged_chunk_dicts=[], + ) + + q = query.strip() fts_chunks = await search_chunks_fts( self._db, user_id=user_id, query=query, limit=top_k * 2 ) vector_chunks: list[dict] = [] - if self._embedding and query.strip(): - q_emb = await self._embedding.embed_text(query.strip()) + if self._embedding and q: + q_emb = await self._embedding.embed_text(q) if q_emb: vector_chunks = await search_chunks_vector( self._db, user_id=user_id, query_embedding=q_emb, limit=top_k * 2 ) merged = _rrf_merge(fts_chunks, vector_chunks)[:top_k] - relevant_chunks = [ + merged_chunk_dicts = [ { "id": c["id"], "content": c["content"], @@ -70,36 +73,10 @@ class HybridRetriever: for c in merged ] - facts = await get_facts_for_user(self._db, user_id=user_id, limit=top_k) - relevant_facts = [ - { - "id": f.id, - "fact_type": f.fact_type, - "subject": f.subject, - "predicate": f.predicate, - "object_json": f.object_json, - } - for f in facts - ] - - events = await get_timeline_events_for_user( - self._db, user_id=user_id, limit=top_k + return await retrieve_evidence_bundle_async( + self._db, + user_id, + query, + top_k=top_k, + merged_chunk_dicts=merged_chunk_dicts, ) - timeline_hints = [ - { - "id": e.id, - "event_year": e.event_year, - "event_date": e.event_date, - "title": e.title, - "description": e.description, - } - for e in events - ] - - return { - "relevant_chunks": relevant_chunks, - "relevant_summaries": [], - "relevant_facts": relevant_facts, - "timeline_hints": timeline_hints, - "relevant_stories": [], - } diff --git a/api/app/features/memory/router.py b/api/app/features/memory/router.py index 64af6f6..1152bef 100644 --- a/api/app/features/memory/router.py +++ b/api/app/features/memory/router.py @@ -1,5 +1,71 @@ -"""Memory management API — 二期扩展位,一期以内部 MemoryService 为主。""" +"""Memory 策展与内部扩展 API。""" -from fastapi import APIRouter +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel, Field -router = APIRouter(prefix="/api/memory", tags=["memory"], include_in_schema=False) +from app.core.dependencies import get_current_user +from app.features.memory.deps import get_memory_service +from app.features.memory.service import MemoryService +from app.features.user.models import User + +router = APIRouter(prefix="/api/memory", tags=["memory"]) + + +class ExcludeBody(BaseModel): + reason: str = Field(default="", max_length=2000) + + +class RejectFactBody(BaseModel): + reason: str = Field(default="", max_length=2000) + + +@router.post("/chunks/{chunk_id}/exclude", status_code=status.HTTP_204_NO_CONTENT) +async def exclude_chunk( + chunk_id: str, + body: ExcludeBody | None = None, + current_user: User = Depends(get_current_user), + memory: MemoryService = Depends(get_memory_service), +): + reason = (body.reason if body else "") or "" + ok = await memory.exclude_chunk(current_user.id, chunk_id, reason=reason) + if not ok: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="chunk 不存在" + ) + + +@router.post("/chunks/{chunk_id}/restore", status_code=status.HTTP_204_NO_CONTENT) +async def restore_chunk( + chunk_id: str, + current_user: User = Depends(get_current_user), + memory: MemoryService = Depends(get_memory_service), +): + ok = await memory.restore_chunk(current_user.id, chunk_id) + if not ok: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="chunk 不存在" + ) + + +@router.post("/facts/{fact_id}/confirm", status_code=status.HTTP_204_NO_CONTENT) +async def confirm_fact( + fact_id: str, + current_user: User = Depends(get_current_user), + memory: MemoryService = Depends(get_memory_service), +): + ok = await memory.confirm_fact(current_user.id, fact_id) + if not ok: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="fact 不存在") + + +@router.post("/facts/{fact_id}/reject", status_code=status.HTTP_204_NO_CONTENT) +async def reject_fact( + fact_id: str, + body: RejectFactBody | None = None, + current_user: User = Depends(get_current_user), + memory: MemoryService = Depends(get_memory_service), +): + reason = (body.reason if body else "") or "" + ok = await memory.reject_fact(current_user.id, fact_id, reason=reason) + if not ok: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="fact 不存在") diff --git a/api/app/features/memory/schemas.py b/api/app/features/memory/schemas.py index 5e6e0c4..0ec894f 100644 --- a/api/app/features/memory/schemas.py +++ b/api/app/features/memory/schemas.py @@ -1,10 +1,13 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class EvidenceBundle(BaseModel): """MemoryService 产出的检索结果,供 conversation/memoir 消费。""" + model_config = ConfigDict(extra="ignore") + relevant_chunks: list[dict] = [] relevant_summaries: list[dict] = [] relevant_facts: list[dict] = [] timeline_hints: list[dict] = [] + relevant_stories: list[dict] = [] diff --git a/api/app/features/memory/service.py b/api/app/features/memory/service.py index 8eccf5d..dc78f9b 100644 --- a/api/app/features/memory/service.py +++ b/api/app/features/memory/service.py @@ -2,6 +2,7 @@ MemoryService — conversation / memoir 的统一门面。 - ingest_transcript: transcript -> memory_sources, chunks, embedding, FTS +- ingest 后可选:LLM 富化(session/rolling 摘要、事实、时间线) - retrieve: 委托 HybridRetriever 返回 evidence bundle(FTS + 可选向量 RRF) Celery 侧使用 `ingest_transcript_sync` + `retrieve_evidence_sync`,与异步路径差异见 @@ -10,15 +11,22 @@ Celery 侧使用 `ingest_transcript_sync` + `retrieve_evidence_sync`,与异步 from sqlalchemy.ext.asyncio import AsyncSession +from app.core.logging import get_logger from app.features.memory.chunker import chunk_transcript +from app.features.memory.schemas import EvidenceBundle from app.features.memory.repo import ( create_chunk, + create_curation_action, create_source, + set_chunk_excluded, + set_memory_fact_status, update_chunk_embedding, update_chunk_fts, ) from app.ports.embedding import EmbeddingProvider +logger = get_logger(__name__) + class MemoryService: def __init__( @@ -75,15 +83,97 @@ class MemoryService: if emb: await update_chunk_embedding(self._db, chunk_id, emb) + try: + from app.core.config import settings + from app.core.dependencies import get_llm_provider + from app.features.memory.enrichment import enrich_memory_after_ingest_async + + if settings.memory_enrichment_enabled: + llm = get_llm_provider().langchain_llm + await enrich_memory_after_ingest_async( + self._db, user_id, source.id, llm + ) + except Exception as e: + logger.warning( + "memory enrichment 跳过: {} exc_type={}", e, type(e).__name__ + ) + await self._db.commit() return source.id - async def retrieve(self, user_id: str, query: str, *, top_k: int = 10) -> dict: + async def retrieve( + self, user_id: str, query: str, *, top_k: int = 10 + ) -> EvidenceBundle: """Retrieve relevant evidence. 委托 HybridRetriever。""" from app.features.memory.retriever import HybridRetriever retriever = HybridRetriever(self._db, embedding_provider=self._embedding) - return await retriever.retrieve(user_id=user_id, query=query, top_k=top_k) + raw = await retriever.retrieve(user_id=user_id, query=query, top_k=top_k) + return EvidenceBundle.model_validate(raw) + + async def exclude_chunk( + self, user_id: str, chunk_id: str, *, reason: str = "" + ) -> bool: + ok = await set_chunk_excluded(self._db, chunk_id, user_id, True) + if not ok: + return False + await create_curation_action( + self._db, + user_id=user_id, + action_type="exclude", + target_type="chunk", + target_id=chunk_id, + details={"reason": reason} if reason else None, + ) + await self._db.commit() + return True + + async def restore_chunk(self, user_id: str, chunk_id: str) -> bool: + ok = await set_chunk_excluded(self._db, chunk_id, user_id, False) + if not ok: + return False + await create_curation_action( + self._db, + user_id=user_id, + action_type="restore", + target_type="chunk", + target_id=chunk_id, + details=None, + ) + await self._db.commit() + return True + + async def confirm_fact(self, user_id: str, fact_id: str) -> bool: + ok = await set_memory_fact_status(self._db, fact_id, user_id, "confirmed") + if not ok: + return False + await create_curation_action( + self._db, + user_id=user_id, + action_type="confirm", + target_type="fact", + target_id=fact_id, + details=None, + ) + await self._db.commit() + return True + + async def reject_fact( + self, user_id: str, fact_id: str, *, reason: str = "" + ) -> bool: + ok = await set_memory_fact_status(self._db, fact_id, user_id, "rejected") + if not ok: + return False + await create_curation_action( + self._db, + user_id=user_id, + action_type="reject", + target_type="fact", + target_id=fact_id, + details={"reason": reason} if reason else None, + ) + await self._db.commit() + return ok def ingest_transcript_sync( @@ -128,5 +218,16 @@ def ingest_transcript_sync( session.flush() update_chunk_fts_sync(session, chunk.id) + try: + from app.core.config import settings + from app.features.memory.enrichment import enrich_memory_after_ingest_sync + + if settings.memory_enrichment_enabled: + enrich_memory_after_ingest_sync(session, user_id, source.id, llm=None) + except Exception as e: + logger.warning( + "memory enrichment 跳过(sync): {} exc_type={}", e, type(e).__name__ + ) + session.commit() return source.id diff --git a/api/app/features/memory/summarizer.py b/api/app/features/memory/summarizer.py index 849deb8..72223fd 100644 --- a/api/app/features/memory/summarizer.py +++ b/api/app/features/memory/summarizer.py @@ -1,11 +1,131 @@ -"""Session and rolling summary generation (skeleton).""" +"""会话摘要与滚动摘要(LLM + JSON)。""" + +from __future__ import annotations + +from typing import Any + +from app.core.langchain_llm import ainvoke_json_object, invoke_json_object +from app.core.logging import get_logger +from app.features.memory.llm_schemas import ( + RollingSummaryPayload, + SessionSummaryPayload, + parse_json_payload, +) + +logger = get_logger(__name__) -async def generate_session_summary(chunks: list[str]) -> str: - """Generate a summary for a conversation session.""" - raise NotImplementedError +def _max_input_chars() -> int: + from app.core.config import settings + + return settings.memory_enrichment_max_chars -async def generate_rolling_summary(existing_summary: str, new_chunks: list[str]) -> str: - """Update a rolling summary with new chunks.""" - raise NotImplementedError +def generate_session_summary_sync(llm: Any, chunk_texts: list[str]) -> str: + """为本批块生成 session 级短摘要。""" + if not llm: + return "" + lim = _max_input_chars() + combined = "\n\n".join(t for t in chunk_texts if t).strip()[:lim] + if not combined: + return "" + prompt = ( + "用 2~8 句中文概括下列口述/对话要点,不编造、不评价。只输出 JSON:" + '{"summary":"..."}\n\n文本:\n' + f"{combined}" + ) + try: + raw = invoke_json_object( + llm, prompt, max_tokens=2048, agent="memory.session_summary_sync" + ) + parsed = parse_json_payload(raw, SessionSummaryPayload) + if parsed is None: + return "" + return str(parsed.summary or "").strip() + except (TypeError, ValueError) as e: + logger.warning("generate_session_summary_sync 失败: {}", e) + return "" + + +async def generate_session_summary_async(llm: Any, chunk_texts: list[str]) -> str: + if not llm: + return "" + lim = _max_input_chars() + combined = "\n\n".join(t for t in chunk_texts if t).strip()[:lim] + if not combined: + return "" + prompt = ( + "用 2~8 句中文概括下列口述/对话要点,不编造、不评价。只输出 JSON:" + '{"summary":"..."}\n\n文本:\n' + f"{combined}" + ) + try: + raw = await ainvoke_json_object( + llm, prompt, max_tokens=2048, agent="memory.session_summary_async" + ) + parsed = parse_json_payload(raw, SessionSummaryPayload) + if parsed is None: + return "" + return str(parsed.summary or "").strip() + except (TypeError, ValueError) as e: + logger.warning("generate_session_summary_async 失败: {}", e) + return "" + + +def generate_rolling_summary_sync( + llm: Any, existing_summary: str | None, new_chunk_texts: list[str] +) -> str: + """合并已有滚动摘要与新材料。""" + if not llm: + return (existing_summary or "").strip() + lim = _max_input_chars() + new_t = "\n\n".join(t for t in new_chunk_texts if t).strip()[:lim] + if not new_t and not (existing_summary or "").strip(): + return "" + ex = (existing_summary or "").strip()[:lim] + prompt = ( + "将「已有滚动摘要」与「新材料」合并为更新后的滚动摘要(中文,段落)。" + "保留人物与时间线索;不编造;可省略无关细节。\n" + '只输出 JSON:{"rolling_summary":"..."}\n\n' + f"【已有摘要】\n{ex}\n\n【新材料】\n{new_t}" + ) + try: + raw = invoke_json_object( + llm, prompt, max_tokens=3072, agent="memory.rolling_summary_sync" + ) + parsed = parse_json_payload(raw, RollingSummaryPayload) + if parsed is None: + return (existing_summary or "").strip() + return str(parsed.rolling_summary or "").strip() + except (TypeError, ValueError) as e: + logger.warning("generate_rolling_summary_sync 失败: {}", e) + return (existing_summary or "").strip() + + +async def generate_rolling_summary_async( + llm: Any, existing_summary: str | None, new_chunk_texts: list[str] +) -> str: + if not llm: + return (existing_summary or "").strip() + lim = _max_input_chars() + new_t = "\n\n".join(t for t in new_chunk_texts if t).strip()[:lim] + if not new_t and not (existing_summary or "").strip(): + return "" + ex = (existing_summary or "").strip()[:lim] + prompt = ( + "将「已有滚动摘要」与「新材料」合并为更新后的滚动摘要(中文,段落)。" + "保留人物与时间线索;不编造。\n" + '只输出 JSON:{"rolling_summary":"..."}\n\n' + f"【已有摘要】\n{ex}\n\n【新材料】\n{new_t}" + ) + try: + raw = await ainvoke_json_object( + llm, prompt, max_tokens=3072, agent="memory.rolling_summary_async" + ) + parsed = parse_json_payload(raw, RollingSummaryPayload) + if parsed is None: + return (existing_summary or "").strip() + return str(parsed.rolling_summary or "").strip() + except (TypeError, ValueError) as e: + logger.warning("generate_rolling_summary_async 失败: {}", e) + return (existing_summary or "").strip() diff --git a/api/app/features/memory/timeline.py b/api/app/features/memory/timeline.py index 9106607..237e376 100644 --- a/api/app/features/memory/timeline.py +++ b/api/app/features/memory/timeline.py @@ -1,6 +1,76 @@ -"""Chronology organization — build and update timeline events (skeleton).""" +"""由已抽取事实生成时间线事件(LLM + JSON)。""" + +from __future__ import annotations + +import json +from typing import Any + +from app.core.langchain_llm import ainvoke_json_object, invoke_json_object +from app.core.logging import get_logger +from app.features.memory.llm_schemas import ( + TimelineEventsPayload, + parse_json_payload, + timeline_payload_to_dicts, +) + +logger = get_logger(__name__) + +MAX_FACTS_JSON = 20000 + + +def build_timeline_events_from_facts_sync(llm: Any, facts: list[dict]) -> list[dict]: + """facts 须含 id 字段(已落库)。""" + if not llm or not facts: + return [] + payload = json.dumps(facts, ensure_ascii=False)[:MAX_FACTS_JSON] + prompt = ( + "根据下列事实(含 id)生成时间线事件,用于回忆录展示。\n" + "每条含 event_year(整数或 null)、event_date(可选)、title、description、" + "source_fact_ids(必须来自输入中的 id 列表)。\n" + '只输出 JSON:{"events":[...]},无事件则 {"events":[]}。最多 15 条。\n\n' + f"{payload}" + ) + try: + raw = invoke_json_object( + llm, prompt, max_tokens=4096, agent="memory.timeline_events_sync" + ) + parsed = parse_json_payload(raw, TimelineEventsPayload) + if parsed is None: + return [] + return timeline_payload_to_dicts(parsed) + except (TypeError, ValueError) as e: + logger.warning("build_timeline_events_from_facts_sync 失败: {}", e) + return [] + + +async def build_timeline_events_from_facts_async( + llm: Any, facts: list[dict] +) -> list[dict]: + if not llm or not facts: + return [] + payload = json.dumps(facts, ensure_ascii=False)[:MAX_FACTS_JSON] + prompt = ( + "根据下列事实(含 id)生成时间线事件。\n" + "每条含 event_year、event_date、title、description、source_fact_ids(来自输入 id)。\n" + '只输出 JSON:{"events":[...]}。\n\n' + f"{payload}" + ) + try: + raw = await ainvoke_json_object( + llm, prompt, max_tokens=4096, agent="memory.timeline_events_async" + ) + parsed = parse_json_payload(raw, TimelineEventsPayload) + if parsed is None: + return [] + return timeline_payload_to_dicts(parsed) + except (TypeError, ValueError) as e: + logger.warning("build_timeline_events_from_facts_async 失败: {}", e) + return [] async def build_timeline_events(facts: list[dict]) -> list[dict]: - """Organize facts into chronological timeline events.""" - raise NotImplementedError + """兼容旧接口。""" + from app.core.dependencies import get_llm_provider + + llm = get_llm_provider().langchain_llm + return await build_timeline_events_from_facts_async(llm, facts) diff --git a/api/app/features/story/repo.py b/api/app/features/story/repo.py index f47c4e5..ad5261a 100644 --- a/api/app/features/story/repo.py +++ b/api/app/features/story/repo.py @@ -3,8 +3,9 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import delete, select +from sqlalchemy import delete, or_, select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from app.features.story.models import ( Story, @@ -183,3 +184,39 @@ async def get_stories_by_ids(db: AsyncSession, story_ids: list[str]) -> list[Sto stories = list(result.unique().scalars().all()) order = {sid: i for i, sid in enumerate(story_ids)} return sorted(stories, key=lambda s: order.get(s.id, 999)) + + +async def list_recent_stories_for_evidence( + db: AsyncSession, + user_id: str, + *, + query: str | None = None, + limit: int = 5, +) -> list[Story]: + """供 memory 检索:活跃故事,可选标题/摘要模糊匹配。""" + stmt = select(Story).where(Story.user_id == user_id).where(Story.status == "active") + q = (query or "").strip() + if q: + pat = f"%{q}%" + stmt = stmt.where(or_(Story.title.ilike(pat), Story.summary.ilike(pat))) + stmt = stmt.order_by(Story.updated_at.desc()).limit(limit) + result = await db.execute(stmt) + return list(result.unique().scalars().all()) + + +def list_recent_stories_for_evidence_sync( + session: Session, + user_id: str, + *, + query: str | None = None, + limit: int = 5, +) -> list[Story]: + """同步会话版 `list_recent_stories_for_evidence`(Celery / retrieve_evidence_sync)。""" + stmt = select(Story).where(Story.user_id == user_id).where(Story.status == "active") + q = (query or "").strip() + if q: + pat = f"%{q}%" + stmt = stmt.where(or_(Story.title.ilike(pat), Story.summary.ilike(pat))) + stmt = stmt.order_by(Story.updated_at.desc()).limit(limit) + result = session.execute(stmt) + return list(result.unique().scalars().all()) diff --git a/api/app/features/user/repo.py b/api/app/features/user/repo.py index 67732a4..362b174 100644 --- a/api/app/features/user/repo.py +++ b/api/app/features/user/repo.py @@ -1,6 +1,6 @@ """User 数据访问:查询与「清空用户业务数据」批量删除。""" -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession from app.core.cos_url_keys import ( @@ -9,7 +9,7 @@ from app.core.cos_url_keys import ( ) from app.features.asset.models import Asset from app.features.auth.models import RefreshToken -from app.features.conversation.models import Conversation, Segment +from app.features.conversation.models import Conversation, ConversationMessage, Segment from app.features.memoir.models import ( Book, Chapter, @@ -34,6 +34,23 @@ async def get_user_by_id(user_id: str, db: AsyncSession) -> User | None: return await db.get(User, user_id) +async def clear_user_demographics(db: AsyncSession, user_id: str) -> None: + """ + 清空 users 表上由访谈收集的档案字段(不删账号行;手机号、密码等登录字段保留)。 + 聊天助手会从这些字段注入「出生地」等上下文,清空后才会真正「忘记」。 + """ + await db.execute( + update(User) + .where(User.id == user_id) + .values( + birth_year=None, + birth_place=None, + grew_up_place=None, + occupation=None, + ) + ) + + async def collect_purge_context( db: AsyncSession, user_id: str ) -> tuple[list[str], list[str], list[str]]: @@ -119,7 +136,8 @@ async def collect_object_storage_keys_before_purge( async def purge_user_related_rows(db: AsyncSession, user_id: str) -> None: """ 物理删除当前用户除账号(users 行)外的业务数据。 - 顺序按外键依赖:memory → 资源意图关联的 assets → story/chapter/book → 对话 → 订单与 refresh token 等。 + 顺序按外键依赖:memory → 资源意图关联的 assets → story/chapter/book → + conversation_messages(引用 segments)→ segments → conversations → … """ await db.execute(delete(MemoryFact).where(MemoryFact.user_id == user_id)) await db.execute(delete(MemoryChunk).where(MemoryChunk.user_id == user_id)) @@ -138,6 +156,13 @@ async def purge_user_related_rows(db: AsyncSession, user_id: str) -> None: await db.execute(delete(Chapter).where(Chapter.user_id == user_id)) await db.execute(delete(Book).where(Book.user_id == user_id)) + await db.execute( + delete(ConversationMessage).where( + ConversationMessage.conversation_id.in_( + select(Conversation.id).where(Conversation.user_id == user_id) + ) + ) + ) await db.execute( delete(Segment).where( Segment.conversation_id.in_( diff --git a/api/app/features/user/router.py b/api/app/features/user/router.py index b83154c..2532c11 100644 --- a/api/app/features/user/router.py +++ b/api/app/features/user/router.py @@ -80,7 +80,8 @@ async def purge_user_data( 永久删除当前账号下的业务数据:对话与片段、记忆层、故事与插图意图、书籍与章节(含图片任务行)、 回忆录状态、订单记录、刷新令牌;并清理会话 Redis 历史、任务追踪与相关分布式锁 key; 对 memory_sources / memoir_images / 关联 Asset 中记录的 storage_key 尽力删除对象存储对象。 - 不删除 users 表中的账号(手机号、密码等);口令见请求体 schema 说明。 + 保留 users 表中的账号与登录字段(手机号、密码等),并清空出生年/出生地/成长地/职业等档案字段。 + 口令见请求体 schema 说明。 """ try: return await service.purge_all_user_data( diff --git a/api/app/features/user/schemas.py b/api/app/features/user/schemas.py index 0fa3312..09d719e 100644 --- a/api/app/features/user/schemas.py +++ b/api/app/features/user/schemas.py @@ -54,7 +54,7 @@ PURGE_USER_DATA_CONFIRMATION = "我确认永久删除我的全部回忆与对话 class PurgeUserDataRequest(BaseModel): - """清空账号下全部业务数据(保留登录账号与手机号等身份字段)。""" + """清空账号下全部业务数据(保留登录账号与手机号等;并清空出生年/出生地等档案字段)。""" confirmation: str = Field( ..., diff --git a/api/app/features/user/service.py b/api/app/features/user/service.py index e272bb0..a1c350f 100644 --- a/api/app/features/user/service.py +++ b/api/app/features/user/service.py @@ -90,7 +90,7 @@ class UserService: confirmation: str, object_storage: ObjectStorage | None = None, ) -> PurgeUserDataResponse: - """物理删除该用户业务数据(不含 users 账号行);提交后再清 Redis / 任务追踪 / 锁 key。""" + """物理删除该用户业务数据(保留 users 行与登录字段);并清空出生年/出生地等档案字段;提交后再清 Redis 等。""" if confirmation != PURGE_USER_DATA_CONFIRMATION: raise ValueError("确认文案不正确,请按提示完整输入口令") @@ -98,16 +98,32 @@ class UserService: if not user: raise ValueError("用户不存在") + logger.info("用户数据清空开始 user_id={}", user_id) + storage_keys = await repo.collect_object_storage_keys_before_purge( self._db, user_id ) conv_ids, chapter_ids, story_ids = await repo.collect_purge_context( self._db, user_id ) + logger.debug( + "清空前收集 user_id={} storage_keys={} conversations={} chapters={} stories={}", + user_id, + len(storage_keys), + len(conv_ids), + len(chapter_ids), + len(story_ids), + ) + await repo.purge_user_related_rows(self._db, user_id) + await repo.clear_user_demographics(self._db, user_id) await self._db.commit() + logger.info("用户数据 DB 行已删除、档案字段已清空并提交 user_id={}", user_id) if object_storage and storage_keys: + logger.debug( + "对象存储尝试删除 user_id={} key_count={}", user_id, len(storage_keys) + ) for key in storage_keys: try: object_storage.delete(key) @@ -115,6 +131,12 @@ class UserService: logger.warning( "对象存储删除失败 user_id={} key={} err={}", user_id, key, e ) + elif storage_keys and not object_storage: + logger.warning( + "用户数据清空:未注入 object_storage,跳过 {} 个对象存储 key user_id={}", + len(storage_keys), + user_id, + ) for cid in conv_ids: try: @@ -123,9 +145,16 @@ class UserService: logger.warning( "清空会话 Redis 历史失败 conversation_id={} err={}", cid, e ) + if conv_ids: + logger.debug( + "已请求清空 Redis 会话历史 user_id={} conversation_count={}", + user_id, + len(conv_ids), + ) try: await task_tracker.clear_user_tasks(user_id) + logger.debug("用户任务追踪已清空 user_id={}", user_id) except Exception as e: logger.warning("清空用户任务追踪失败 user_id={} err={}", user_id, e) @@ -141,13 +170,22 @@ class UserService: await redis_service.delete_keys_matching_pattern( f"lock:story-image:{sid}" ) + logger.debug( + "Redis 分布式锁 key 已清理 user_id={} chapter_count={} story_count={}", + user_id, + len(chapter_ids), + len(story_ids), + ) except Exception as e: logger.warning("清理 Redis 锁 key 失败 user_id={} err={}", user_id, e) + logger.info("用户数据清空完成 user_id={}", user_id) + return PurgeUserDataResponse( success=True, message=( "已清空该账号下的对话、记忆、故事、章节、订单等业务数据,并已尝试删除关联的对象存储文件;" + "个人档案中的出生年份、出生地、成长地、职业等已清空;" "所有登录会话已失效,请重新登录" ), ) diff --git a/api/app/main.py b/api/app/main.py index 5d4a750..670a998 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -21,6 +21,7 @@ from app.features.content.router import router as content_router from app.features.conversation.router import router as conversation_router from app.features.conversation.ws.router import websocket_endpoint from app.features.memoir.router import router as memoir_router +from app.features.memory.router import router as memory_router from app.features.payment.router import router as payment_router from app.features.plan.router import router as plan_router from app.features.quota.router import router as quota_router @@ -140,6 +141,7 @@ app.include_router(auth_router) app.websocket("/ws/conversation/{conversation_id}")(websocket_endpoint) app.include_router(conversation_router) app.include_router(memoir_router) +app.include_router(memory_router) app.include_router(user_router) app.include_router(user_feedback_router) app.include_router(plan_router) diff --git a/api/app/tasks/memoir_tasks.py b/api/app/tasks/memoir_tasks.py index b79c004..1c5ccdc 100644 --- a/api/app/tasks/memoir_tasks.py +++ b/api/app/tasks/memoir_tasks.py @@ -273,7 +273,6 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): try: with get_sync_db() as db: - chapters_to_enqueue: set[str] = set() # 获取段落 stmt = select(Segment).where(Segment.id.in_(segment_ids)) result = db.execute(stmt) @@ -291,9 +290,24 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): try: from app.features.memory.service import ingest_transcript_sync - ingest_transcript_sync(db, user_id, conv_id, transcript) + source_id = ingest_transcript_sync(db, user_id, conv_id, transcript) + logger.info( + "event=memory_transcript_ingested user_id={} task_id={} " + "source_id={} conversation_id={} transcript_chars={} " + "segment_count={}", + user_id, + task_id, + source_id, + conv_id, + len(transcript), + len(segments), + ) except Exception as e: - logger.warning("Memory ingest 跳过: {}", e) + logger.warning( + "Memory ingest 跳过: {} exc_type={}", + e, + type(e).__name__, + ) llm = _get_llm() image_settings = MemoirImageSettings.from_env() @@ -404,14 +418,32 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): if try_enqueue_generate_chapter_cover(chapter_id, source="pipeline"): logger.info(f"派发章节封面任务: chapter={chapter_id}") - logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}") + categories_processed = sorted(prepared.category_to_segments.keys()) + logger.info( + "回忆录处理完成: user_id={} task_id={} segment_count={} " + "categories_processed={}", + user_id, + task_id, + len(segments), + categories_processed, + ) # 更新任务状态为成功 _update_task_status_sync( - user_id, task_id, "success", {"processed": len(segments)} + user_id, + task_id, + "success", + { + "processed": len(segments), + "categories_processed": categories_processed, + }, ) - return {"status": "success", "processed": len(segments)} + return { + "status": "success", + "processed": len(segments), + "categories_processed": categories_processed, + } except Exception as e: logger.error(f"回忆录处理失败: {e}") diff --git a/api/docs/memory-retrieval.md b/api/docs/memory-retrieval.md index c3ba466..7401df4 100644 --- a/api/docs/memory-retrieval.md +++ b/api/docs/memory-retrieval.md @@ -4,8 +4,10 @@ | 路径 | 入口 | 检索能力 | |------|------|----------| -| **异步(HTTP / MemoirService) | `MemoryService.retrieve` → `HybridRetriever` | **FTS + 向量(pgvector)**,RRF 融合;facts;timeline;`relevant_summaries` / `relevant_stories` 当前多为占位空列表 | -| **同步(Celery) | `retrieve_evidence_sync`(`app/features/memory/repo.py`) | **仅 FTS** chunks;confirmed **facts**;**timeline**;无向量(worker 内 ingest 不写 embedding) | +| **异步(HTTP / MemoirService) | `MemoryService.retrieve` → `HybridRetriever` → `evidence.retrieve_evidence_bundle_async` | **FTS + 向量(pgvector)**,RRF 融合 chunks;facts / timeline 按 **query ILIKE**,无命中则 **fallback** 最近条;rolling + ILIKE **摘要**;**stories**(标题/摘要匹配) | +| **同步(Celery) | `retrieve_evidence_sync`(`repo` 薄封装 → `evidence.retrieve_evidence_bundle_sync`) | **仅 FTS** chunks;其余与上类似(无向量) | + +证据组装在 `app/features/memory/evidence.py`;`memory/repo` 仅提供原子查询(chunk FTS、facts/timeline 搜索、摘要列表等),story 合并在 evidence 层完成。 ## 为何 Celery 与 Hybrid 不完全一致 @@ -14,10 +16,23 @@ 业务上应假设:**线上章节生成任务**以 FTS 证据为主;**异步 API** 若配置了 embedding,检索语义更富。 +## 空 query + +- 默认:`relevant_*` 均为空(与历史行为一致)。 +- 若设置 `memory_evidence_empty_query_include_rolling=true`:返回**无 chunk FTS**,但含 **rolling 摘要**、最近 facts / timeline(用于「浏览」模式)。 + +## 富化(ingest 后 LLM) + +- `memory_enrichment_enabled`(默认 `true`):`ingest_transcript` / `ingest_transcript_sync` 后执行摘要、事实、时间线;`false` 时跳过。 +- `memory_enrichment_max_chars`:截断送入 LLM 的文本长度。 +- 同一 `memory_source_id` 的时间线在重跑富化前会先删后插入,避免重复事件。 + ## Celery 任务中的顺序 `process_memoir_segments`(`app/tasks/memoir_tasks.py`)在**同一任务**内先执行 `ingest_transcript_sync`(并 `commit`),再执行 `MemoirOrchestrator` 与 `run_story_pipeline_for_category_batch`。因此 `retrieve_evidence_sync` 能看到**本批刚写入**的 memory chunks(无竞态)。 +章节分类上,若模型返回 **none** 或命中零散档案启发式,Story 侧会统一落入 **`summary` 章节**并继续叙事落库,与「本批 transcript 已进 memory」一致,避免误以为内容被丢弃。 + ## Evidence 与叙事 Prompt -`format_evidence_chunks_for_prompt` 仅拼接**实际返回**的 chunks、facts、timeline;不包含未实现的 summaries/stories,避免模型误以为有额外材料。 +`format_evidence_chunks_for_prompt` 拼接 chunks、**摘要(若有)**、facts、timeline、**故事摘要(若有)**;模型应把摘录视为参考材料,非本段口述。 diff --git a/api/scripts/migrate_legacy_to_current.py b/api/scripts/migrate_legacy_to_current.py index 84735a1..d4700d5 100644 --- a/api/scripts/migrate_legacy_to_current.py +++ b/api/scripts/migrate_legacy_to_current.py @@ -12,7 +12,9 @@ refresh_tokens / segments / sms_verification_codes / users(见仓库内历史 createdb life_echo_legacy psql -d life_echo_legacy -f api/backups/life_echo_20260313_182756.sql -2) 目标库已执行 ``alembic upgrade head``(含 pgvector 与当前 ORM 表)。 +2) **目标库**须与线上一致:已跑完当前仓库全部 Alembic 迁移(``alembic upgrade head``), + 含 pgvector 与 ORM 所建全部表。线上/预发库均为该 schema;本脚本及 ``_purge_target_user`` 的 + DELETE 顺序按当前表结构编写,若将来迁移新增表或外键,需同步更新删除逻辑。 3) 运行(仓库内可用 ``uv run python scripts/...``):: @@ -34,8 +36,9 @@ refresh_tokens / segments / sms_verification_codes / users(见仓库内历史 冲突策略:默认对主键 id 做 UPSERT(旧数据覆盖目标同 id 行)。可用 --on-conflict skip 跳过已存在主键。 -若目标库已有用户且手机号与某条 legacy 用户冲突(同号不同 id),会自动跳过该 legacy 用户及其 books/chapters/ -conversations 等关联行,避免违反 ``users.phone`` 唯一约束。新生产库一般为空库,不会触发。 +若目标库已有用户且手机号与某条 legacy 用户冲突(同号不同 id),由 ``--phone-conflict`` 控制: + - ``replace_target``(默认):先按与 ``purge_user_related_rows`` 相同顺序删除目标侧占号用户及其业务数据,再迁入 legacy(前提:目标库已是 Alembic head 的完整 schema); + - ``skip``:跳过该 legacy 用户及其关联行(旧行为)。 **宿主机上跑脚本(数据库在 Docker Compose 里)**:`.env` 里常见主机名 `postgres`,在容器外无法解析。 可直接把 URL 写成 `...@127.0.0.1:5432/...`,或使用 `--db-host 127.0.0.1` 自动替换两个 URL 中的主机名(端口不变)。 @@ -58,6 +61,7 @@ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") logger = logging.getLogger(__name__) OnConflict = Literal["upsert", "skip"] +PhoneConflictMode = Literal["skip", "replace_target"] def _replace_url_host(url: str, new_host: str) -> str: @@ -78,26 +82,112 @@ def _open(url: str) -> Connection: return connect(url, autocommit=False) -def _legacy_user_ids_skipped_for_phone( - legacy: Connection, target: Connection +def _purge_target_user(conn: Connection, user_id: str) -> None: + """ + 删除目标库中某用户及其业务数据,顺序与 app.features.user.repo.purge_user_related_rows 一致, + 以便随后插入 legacy 用户行而不违反 users.phone 唯一约束。 + + 假定目标库 schema 与 ``alembic upgrade head`` 一致(线上库常态);新增迁移若引入指向 + ``users`` 的表或外键,须在此补充 DELETE。 + """ + uid = user_id + with conn.cursor() as cur: + cur.execute("DELETE FROM memory_facts WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM memory_chunks WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM memory_sources WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM memory_summaries WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM timeline_events WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM memory_curation_actions WHERE user_id = %s", (uid,)) + + cur.execute( + """ + SELECT sii.asset_id FROM story_image_intents sii + INNER JOIN stories s ON s.id = sii.story_id + WHERE s.user_id = %s AND sii.asset_id IS NOT NULL + """, + (uid,), + ) + asset_ids = [row[0] for row in cur.fetchall()] + cur.execute( + """ + SELECT cci.asset_id FROM chapter_cover_intents cci + INNER JOIN chapters c ON c.id = cci.chapter_id + WHERE c.user_id = %s AND cci.asset_id IS NOT NULL + """, + (uid,), + ) + asset_ids.extend(row[0] for row in cur.fetchall()) + seen_assets = {a for a in asset_ids if a} + if seen_assets: + ids = list(seen_assets) + ph = ",".join(["%s"] * len(ids)) + cur.execute(f"DELETE FROM assets WHERE id IN ({ph})", ids) + + cur.execute("DELETE FROM stories WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM chapters WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM books WHERE user_id = %s", (uid,)) + + cur.execute( + """ + DELETE FROM conversation_messages + WHERE conversation_id IN ( + SELECT id FROM conversations WHERE user_id = %s + ) + """, + (uid,), + ) + cur.execute( + """ + DELETE FROM segments + WHERE conversation_id IN ( + SELECT id FROM conversations WHERE user_id = %s + ) + """, + (uid,), + ) + cur.execute("DELETE FROM conversations WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM memoir_states WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM orders WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM refresh_tokens WHERE user_id = %s", (uid,)) + cur.execute("DELETE FROM users WHERE id = %s", (uid,)) + + +def _phone_conflict_legacy_skips( + legacy: Connection, + target: Connection, + mode: PhoneConflictMode, ) -> set[str]: - """目标库已占用某手机号且 id 与 legacy 不一致时,不能插入该 legacy 用户。""" + """同号不同 id:skip 则跳过 legacy 用户;replace_target 则先删目标占号用户再迁入。""" with target.cursor() as cur: cur.execute("SELECT phone, id FROM users") - phone_owner = {row[0]: row[1] for row in cur.fetchall()} + phone_owner: dict[str, str] = {row[0]: row[1] for row in cur.fetchall()} skipped: set[str] = set() with legacy.cursor(row_factory=dict_row) as cur: - cur.execute("SELECT id, phone FROM users") - for r in cur.fetchall(): - owner = phone_owner.get(r["phone"]) - if owner is not None and owner != r["id"]: - skipped.add(r["id"]) - logger.warning( - "skip legacy user %s phone=%s (target user id=%s)", - r["id"], - r["phone"], - owner, - ) + cur.execute("SELECT id, phone FROM users ORDER BY id") + rows = cur.fetchall() + for r in rows: + legacy_id = r["id"] + phone = r["phone"] + owner = phone_owner.get(phone) + if owner is None or owner == legacy_id: + continue + if mode == "skip": + skipped.add(legacy_id) + logger.warning( + "skip legacy user %s phone=%s (target user id=%s)", + legacy_id, + phone, + owner, + ) + continue + logger.warning( + "phone conflict: purging target user %s (phone=%s) then migrating legacy user %s", + owner, + phone, + legacy_id, + ) + _purge_target_user(target, owner) + phone_owner = {p: uid for p, uid in phone_owner.items() if uid != owner} return skipped @@ -690,6 +780,12 @@ def main() -> None: default="upsert", help="upsert: overwrite same id; skip: keep existing rows", ) + p.add_argument( + "--phone-conflict", + choices=("skip", "replace_target"), + default="replace_target", + help="同号不同 id:replace_target=先删目标占号用户再迁入;skip=跳过该 legacy 用户", + ) p.add_argument( "--dry-run", action="store_true", @@ -706,6 +802,7 @@ def main() -> None: ) args = p.parse_args() on_conflict: OnConflict = args.on_conflict # type: ignore[assignment] + phone_conflict: PhoneConflictMode = args.phone_conflict # type: ignore[assignment] legacy_url = args.legacy_url target_url = args.target_url @@ -736,10 +833,10 @@ def main() -> None: logger.info("dry-run done") return - skip_users = _legacy_user_ids_skipped_for_phone(legacy, target) + skip_users = _phone_conflict_legacy_skips(legacy, target, phone_conflict) if skip_users: logger.info( - "skip %d legacy users due to phone already owned in target", + "skip %d legacy users due to phone conflict mode=skip", len(skip_users), ) diff --git a/api/tests/test_classification_fragment.py b/api/tests/test_classification_fragment.py index 8e42599..1d7dc36 100644 --- a/api/tests/test_classification_fragment.py +++ b/api/tests/test_classification_fragment.py @@ -1,4 +1,4 @@ -"""ClassificationAgent:零散档案启发式与分类 none 语义(纯函数/无 LLM)。""" +"""ClassificationAgent:零散档案启发式与 none→summary 兜底(纯函数/无 LLM)。""" import pytest @@ -28,9 +28,11 @@ def test_looks_like_fragment_only(text: str, expected_fragment: bool) -> None: assert _looks_like_fragment_only(text) is expected_fragment -def test_classify_skips_story_for_birth_year_without_llm() -> None: +def test_classify_maps_birth_year_fragment_to_summary_without_llm() -> None: agent = ClassificationAgent() - assert agent.classify("1999年出生", fallback_stage="childhood", llm=None) is None + assert ( + agent.classify("1999年出生", fallback_stage="childhood", llm=None) == "summary" + ) @pytest.mark.parametrize( diff --git a/api/tests/test_memory_evidence.py b/api/tests/test_memory_evidence.py new file mode 100644 index 0000000..aa0c1a5 --- /dev/null +++ b/api/tests/test_memory_evidence.py @@ -0,0 +1,30 @@ +"""Memory evidence 组装与检索契约(纯函数 / 无 DB)。""" + +from app.features.memory.evidence import ( + EMPTY_EVIDENCE_BUNDLE, + _facts_to_dicts, + _stories_to_dicts, + _timeline_to_dicts, +) +from app.features.memory.schemas import EvidenceBundle + + +def test_empty_evidence_bundle_keys() -> None: + assert set(EMPTY_EVIDENCE_BUNDLE.keys()) == { + "relevant_chunks", + "relevant_summaries", + "relevant_facts", + "timeline_hints", + "relevant_stories", + } + + +def test_evidence_bundle_model_accepts_dict() -> None: + b = EvidenceBundle.model_validate(EMPTY_EVIDENCE_BUNDLE) + assert b.relevant_chunks == [] + + +def test_format_helpers_empty() -> None: + assert _facts_to_dicts([]) == [] + assert _timeline_to_dicts([]) == [] + assert _stories_to_dicts([]) == [] diff --git a/api/tests/test_reply_segments.py b/api/tests/test_reply_segments.py new file mode 100644 index 0000000..885ad5a --- /dev/null +++ b/api/tests/test_reply_segments.py @@ -0,0 +1,25 @@ +"""segments_from_llm_response:与客户端 split 规则对齐的单元校验。""" + +from app.agents.chat.reply_limits import ( + nonempty_segments_or_fallback, + segments_from_llm_response, +) + + +def test_split_marker(): + assert segments_from_llm_response("a[SPLIT]b", max_segments=3) == ["a", "b"] + + +def test_paragraph_fallback_when_no_marker(): + a = "太为你高兴了!在上海大剧院的舞台绽放,聚光灯下的你。" + b = "说到舞台,我忽然想起你黄浦江边的童年。从看着江水流淌,到在舞台上演绎别人的悲欢。" + assert segments_from_llm_response(f"{a}\n\n{b}", max_segments=3) == [a, b] + + +def test_short_paragraphs_not_split(): + t = "a\n\nb" + assert segments_from_llm_response(t, max_segments=3) == [t] + + +def test_nonempty_fallback_when_all_blank(): + assert nonempty_segments_or_fallback(["", " "], fallback="ok") == ["ok"] diff --git a/api/tests/test_whisper_local.py b/api/tests/test_whisper_local.py new file mode 100644 index 0000000..81016b7 --- /dev/null +++ b/api/tests/test_whisper_local.py @@ -0,0 +1,61 @@ +import asyncio +import sys +from types import SimpleNamespace + +import pytest + +from app.adapters.asr.whisper_local import ( + WhisperASRProvider, + _looks_like_subtitle_hallucination, +) + + +def test_subtitle_watermark_detection() -> None: + assert _looks_like_subtitle_hallucination("字幕by索兰娅") is True + assert _looks_like_subtitle_hallucination("今天想聊聊童年往事") is False + + +@pytest.mark.asyncio +async def test_transcribe_retries_decode_audio_after_discarded_pass2( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class DummyModel: + def __init__(self) -> None: + self.calls: list[object] = [] + + def transcribe(self, audio: object, **_: object): + self.calls.append(audio) + n = len(self.calls) + if n == 1: + return iter([]), SimpleNamespace() + if n == 2: + return iter([SimpleNamespace(text="字幕by索兰娅")]), SimpleNamespace() + if n == 3: + assert audio == "decoded-audio" + return ( + iter([SimpleNamespace(text="你好,今天想聊聊童年。")]), + SimpleNamespace(), + ) + raise AssertionError(f"unexpected transcribe call #{n}") + + async def fake_to_thread(fn): + return fn() + + def fake_decode_audio(_: str, sampling_rate: int = 16000): + assert sampling_rate == 16000 + return "decoded-audio" + + monkeypatch.setattr(asyncio, "to_thread", fake_to_thread) + monkeypatch.setitem( + sys.modules, + "faster_whisper", + SimpleNamespace(decode_audio=fake_decode_audio), + ) + + provider = WhisperASRProvider() + provider._model = DummyModel() + + text = await provider.transcribe(b"fake-audio", format="m4a") + + assert text == "你好,今天想聊聊童年。" + assert len(provider._model.calls) == 3 diff --git a/app-expo/src/app/(main)/conversation/[id].tsx b/app-expo/src/app/(main)/conversation/[id].tsx index ab79f18..af16b8c 100644 --- a/app-expo/src/app/(main)/conversation/[id].tsx +++ b/app-expo/src/app/(main)/conversation/[id].tsx @@ -13,6 +13,7 @@ import { import React, { useCallback, useEffect, + useLayoutEffect, useMemo, useRef, useState, @@ -36,7 +37,10 @@ import { TextInput, View, } from 'react-native'; -import { KeyboardAvoidingView as KeyboardControllerAvoidingView } from 'react-native-keyboard-controller'; +import { + KeyboardAvoidingView as KeyboardControllerAvoidingView, + KeyboardController, +} from 'react-native-keyboard-controller'; import { useSafeAreaInsets } from 'react-native-safe-area-context'; import { useTranslation } from 'react-i18next'; import { useQueryClient } from '@tanstack/react-query'; @@ -51,6 +55,7 @@ import { useMessages, useRealtimeSession } from '@/features/conversation/hooks'; import type { TtsSegmentPayload } from '@/features/conversation/realtime-session'; import { conversationKeys } from '@/features/conversation/query-keys'; import { + assistantSegmentMessageId, splitMessageParts, splitStreamingSegments, } from '@/features/conversation/message-split'; @@ -101,8 +106,15 @@ function flattenMessagesForList( if (msg.senderType === 'user') { result.push({ ...msg, listKey: msg.id }); } else { - const parts = splitMessageParts(msg.content); + const rawContent = String(msg.content ?? ''); + const parts = splitMessageParts(rawContent); if (parts.length === 0) { + // 空/仅空白/拆段后无可见字时仍展示一条,否则列表里会「没有 AI 气泡」 + result.push({ + ...msg, + listKey: msg.id, + content: rawContent.trim().length > 0 ? rawContent : '…', + }); continue; } if (parts.length > 1) { @@ -1075,7 +1087,14 @@ export default function ConversationScreen() { (old) => { if (!old?.length) return old; let idx = -1; - if (p.assistantMessageId) { + if (p.assistantMessageId != null && p.index != null) { + idx = old.findIndex( + (m) => + m.id === + assistantSegmentMessageId(p.assistantMessageId!, p.index!), + ); + } + if (idx < 0 && p.assistantMessageId) { idx = old.findIndex((m) => m.id === p.assistantMessageId); } if (idx < 0) { @@ -1091,9 +1110,11 @@ export default function ConversationScreen() { const target = old[idx]!; const prevUrls = target.ttsAudioUrls ?? []; if (prevUrls.includes(cosUrl)) return old; - const nextUrls = [...prevUrls, cosUrl]; + const segmentBind = p.assistantMessageId != null && p.index != null; + const nextUrls = segmentBind ? [cosUrl] : [...prevUrls, cosUrl]; const nextId = p.assistantMessageId && + p.index == null && (target.id.startsWith(`${convId}_agent_`) || target.id.startsWith('pending')) ? p.assistantMessageId @@ -1109,7 +1130,10 @@ export default function ConversationScreen() { ); } - const listKey = p.assistantMessageId ?? TTS_STREAMING_LIST_KEY; + const listKey = + p.assistantMessageId != null && p.index != null + ? assistantSegmentMessageId(p.assistantMessageId, p.index) + : (p.assistantMessageId ?? TTS_STREAMING_LIST_KEY); const shared = { kind: 'tts_auto' as const, label: 'TTS', @@ -1176,20 +1200,13 @@ export default function ConversationScreen() { void stop(); }, [sendTtsCancel, stop]); - const handleRecordingComplete = useCallback( - (uri: string, durationMs: number) => { - void sendVoiceMessage(uri, durationMs); - }, - [sendVoiceMessage], - ); - const { status: recorderStatus, durationMs: recordingDurationMs, start: startRecording, stop: stopRecording, cancel: cancelRecording, - } = useRecorder(handleRecordingComplete); + } = useRecorder(); const [input, setInput] = useState(''); const [inputResetKey, setInputResetKey] = useState(0); @@ -1214,15 +1231,28 @@ export default function ConversationScreen() { } }, []); - const flattenedData = flattenMessagesForList(messages ?? []); + const flattenedData = useMemo( + () => flattenMessagesForList(messages ?? []), + [messages], + ); const isRecording = recorderStatus === 'recording'; const recordingDuration = Math.floor(recordingDurationMs / 1000); const handleStartRecording = useCallback(async () => { - const ok = await startRecording(); - if (!ok) { - Alert.alert(t('recordingPermissionDenied')); + const result = await startRecording(); + if (!result.ok) { + const messageKey = + result.reason === 'permission_denied' + ? 'recordingPermissionDenied' + : 'recordingStartFailed'; + Alert.alert( + t(messageKey), + __DEV__ && result.errorMessage ? result.errorMessage : undefined, + ); + if (__DEV__ && result.errorMessage) { + console.warn('startRecording failed', result); + } return; } sendTtsCancel(); @@ -1230,6 +1260,12 @@ export default function ConversationScreen() { void stop(); }, [sendTtsCancel, startRecording, stop, t]); + const handleStopRecording = useCallback(async () => { + const result = await stopRecording(); + if (!result) return; + void sendVoiceMessage(result.uri, result.durationMs); + }, [stopRecording, sendVoiceMessage]); + const scrollListToEndAfterComposerLayout = useCallback(() => { InteractionManager.runAfterInteractions(() => { requestAnimationFrame(() => { @@ -1268,6 +1304,16 @@ export default function ConversationScreen() { return () => subs.forEach((s) => s.remove()); }, [scrollListToEndAfterComposerLayout]); + /** + * 切到语音时改用 KeyboardController.dismiss,与 keyboard-controller 的 Reanimated + * 键盘进度一致;仅用 RN Keyboard.dismiss 时,AvoidingView 可能仍保留键盘高度的 padding。 + */ + useLayoutEffect(() => { + if (inputMode !== 'voice') return; + setIsKeyboardVisible(false); + void KeyboardController.dismiss(); + }, [inputMode]); + const onComposerBlockLayout = useCallback( (e: LayoutChangeEvent) => { const h = e.nativeEvent.layout.height; @@ -1485,16 +1531,11 @@ export default function ConversationScreen() { textInputStyle={inputTextStyle} inputMode={inputMode} onInputModeToggle={() => { - setInputMode((m) => { - if (m === 'text') { - Keyboard.dismiss(); - } - return m === 'text' ? 'voice' : 'text'; - }); + setInputMode((m) => (m === 'text' ? 'voice' : 'text')); }} onAddPress={() => {}} onStartRecording={handleStartRecording} - onStopRecording={() => void stopRecording()} + onStopRecording={() => void handleStopRecording()} onCancelRecording={() => void cancelRecording()} isRecording={isRecording} recordingDuration={recordingDuration} @@ -1830,6 +1871,7 @@ const styles = StyleSheet.create({ voiceRecordPill: { flexDirection: 'row', alignItems: 'center', + justifyContent: 'center', gap: 6, backgroundColor: 'rgba(255, 255, 255, 0.8)', borderRadius: 999, @@ -1843,18 +1885,27 @@ const styles = StyleSheet.create({ backgroundColor: CHAT_COLORS.errorRed, }, voiceRecordDurationWrap: { - width: 32, + minWidth: 40, + height: 20, alignItems: 'center', justifyContent: 'center', }, voiceRecordDuration: { fontSize: 12, - lineHeight: 12, + /** 与容器等高,避免 Android/iOS 数字相对胶囊上下偏移 */ + height: 20, + lineHeight: 20, + paddingVertical: 0, + marginVertical: 0, color: 'rgba(27, 27, 31, 0.86)', textAlign: 'center', + ...(Platform.OS === 'ios' && { + fontVariant: ['tabular-nums'] as const, + }), }, voiceRecordDurationAndroid: { textAlignVertical: 'center', + includeFontPadding: false, } as const, textInput: { padding: 0, diff --git a/app-expo/src/app/(main)/delete-data.tsx b/app-expo/src/app/(main)/delete-data.tsx index da21aa5..a56ab0d 100644 --- a/app-expo/src/app/(main)/delete-data.tsx +++ b/app-expo/src/app/(main)/delete-data.tsx @@ -1,30 +1,24 @@ import React, { useState } from 'react'; -import { KeyboardAvoidingView, Platform, ScrollView, View } from 'react-native'; +import { ScrollView, View } from 'react-native'; import { useSafeAreaInsets } from 'react-native-safe-area-context'; import { useTranslation } from 'react-i18next'; +import { KeyboardAvoidingView as KeyboardControllerAvoidingView } from 'react-native-keyboard-controller'; import { AlertDialog, - AlertDialogAction, - AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, } from '@/components/ui/alert-dialog'; -import { Button, buttonVariants } from '@/components/ui/button'; -import { Input } from '@/components/ui/input'; +import { Button } from '@/components/ui/button'; import { Text } from '@/components/ui/text'; +import { Textarea } from '@/components/ui/textarea'; import { ScreenHeader } from '@/components/screen-header'; +import { ScreenGutter } from '@/constants/layout'; import { PURGE_USER_DATA_CONFIRMATION } from '@/features/profile/constants'; import { usePurgeUserData } from '@/features/profile/hooks'; -import { cn } from '@/lib/utils'; - -/** ScreenHeader 可视高度近似:安全区顶 + 上下 padding + minHeight 行 */ -function headerKeyboardOffset(topInset: number) { - return Platform.OS === 'ios' ? topInset + 72 : 0; -} export default function DeleteDataScreen() { const { t } = useTranslation('profile'); @@ -44,85 +38,86 @@ export default function DeleteDataScreen() { ); }; - const scrollBottomPad = Math.max(insets.bottom, 16) + 120; + const scrollBottomPad = Math.max(insets.bottom, 16) + 24; return ( - - - - - - - {t('dataPrivacy.purgeWarningTitle')} - - - {t('dataPrivacy.purgeWarningBody')} - - - - - - {t('dataPrivacy.purgePhraseHint')} - - - - {PURGE_USER_DATA_CONFIRMATION} + + + + + + + + {t('dataPrivacy.purgeWarningTitle')} + + + {t('dataPrivacy.purgeWarningBody')} + + + + {t('dataPrivacy.purgePhraseHint')} + + + + {PURGE_USER_DATA_CONFIRMATION} + + + + + + + {t('dataPrivacy.purgeInputLabel')} + +