From ac436b87a22e9a42337a504f7c31b90668528477 Mon Sep 17 00:00:00 2001 From: Kevin Date: Thu, 30 Apr 2026 09:17:01 +0800 Subject: [PATCH] =?UTF-8?q?feat(api):=20=E6=94=B6=E6=95=9B=E5=AF=B9?= =?UTF-8?q?=E8=AF=9D=E4=B8=8E=E8=AE=B0=E5=BF=86=E6=B5=81=E7=A8=8B=E8=BE=B9?= =?UTF-8?q?=E7=95=8C=EF=BC=8C=E5=BC=95=E5=85=A5=20LLM=20=E7=BD=91=E5=85=B3?= =?UTF-8?q?=E4=B8=8E=E4=B8=93=E7=94=A8=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - MemoryService 异步路径委托 MemoryIngestService / MemoryRetrievalService;富化派发经 MemoryEnrichmentScheduler - WebSocket pipeline 经 ChatTurnService 与显式 DTO 编排单轮对话;回忆录片段入队由 MemoirIngestScheduler 封装 - 新增 LlmGateway(LlmUseCase),各 agent、任务与适配器对齐 ports - 补充 memory 提示适配、runtime 类型、memory-retrieval 文档、ai-touchpoints 说明与扫描脚本及配套测试 Made-with: Cursor --- api/README.md | 4 +- api/app/adapters/llm/deepseek.py | 17 ++- api/app/agents/chat/interview_agent.py | 7 +- api/app/agents/chat/orchestrator.py | 44 ++++-- api/app/agents/chat/profile_agent.py | 40 ++++- api/app/agents/chat/prompt_layers.py | 8 +- api/app/agents/image_prompt/orchestrator.py | 4 +- api/app/agents/memoir/orchestrator.py | 13 +- api/app/core/llm_gateway.py | 102 +++++++++++++ api/app/features/conversation/chat_turn.py | 116 ++++++++++++++ api/app/features/conversation/ws/pipeline.py | 65 +++++--- api/app/features/conversation/ws/router.py | 6 +- .../evaluation/memoir_readiness_service.py | 8 +- api/app/features/evaluation/replay_service.py | 10 +- api/app/features/memoir/ingest_scheduler.py | 77 ++++++++++ api/app/features/memory/enrichment.py | 13 +- .../features/memory/enrichment_scheduler.py | 50 ++++++ api/app/features/memory/extractor.py | 6 +- api/app/features/memory/ingest_service.py | 110 ++++++++++++++ api/app/features/memory/prompt_adapter.py | 26 ++++ api/app/features/memory/retrieval_service.py | 55 +++++++ api/app/features/memory/runtime_types.py | 24 +++ api/app/features/memory/service.py | 123 +++------------ api/app/features/memory/timeline.py | 7 +- api/app/ports/llm.py | 9 +- api/app/tasks/memoir_quality_pass_tasks.py | 4 +- api/app/tasks/memoir_tasks.py | 8 +- api/app/tasks/story_title_tasks.py | 4 +- api/docs/ai-touchpoints.md | 79 ++++++++++ api/docs/memory-retrieval.md | 14 ++ api/scripts/ai_touchpoints_scan.py | 107 +++++++++++++ api/tests/test_chat_turn_service.py | 70 +++++++++ api/tests/test_interview_prompts.py | 31 +++- api/tests/test_llm_gateway.py | 62 ++++++++ api/tests/test_memoir_ingest_scheduler.py | 65 ++++++++ api/tests/test_memory_boundaries.py | 143 ++++++++++++++++++ api/tests/test_memory_evidence.py | 68 +++++++++ 37 files changed, 1400 insertions(+), 199 deletions(-) create mode 100644 api/app/core/llm_gateway.py create mode 100644 api/app/features/conversation/chat_turn.py create mode 100644 api/app/features/memoir/ingest_scheduler.py create mode 100644 api/app/features/memory/enrichment_scheduler.py create mode 100644 api/app/features/memory/ingest_service.py create mode 100644 api/app/features/memory/prompt_adapter.py create mode 100644 api/app/features/memory/retrieval_service.py create mode 100644 api/app/features/memory/runtime_types.py create mode 100644 api/docs/ai-touchpoints.md create mode 100644 api/scripts/ai_touchpoints_scan.py create mode 100644 api/tests/test_chat_turn_service.py create mode 100644 api/tests/test_llm_gateway.py create mode 100644 api/tests/test_memoir_ingest_scheduler.py create mode 100644 api/tests/test_memory_boundaries.py diff --git a/api/README.md b/api/README.md index c152d08..a5a2474 100644 --- a/api/README.md +++ b/api/README.md @@ -15,7 +15,9 @@ Life Echo API 是一个智能对话系统,通过 WebSocket 实时连接,使 ### LLM 与记忆(约定文档) - **JSON 模式**:结构化抽取/路由/叙事 JSON 使用 `app/core/langchain_llm.py` 的 `bind_json_object_mode`(与 [DeepSeek JSON Output](https://api-docs.deepseek.com/guides/json_mode) 一致);详见 [`docs/llm-json-mode.md`](docs/llm-json-mode.md)。适配器说明见 [`app/adapters/llm/deepseek.py`](app/adapters/llm/deepseek.py)。 -- **记忆检索**:异步与 Celery 均使用 **向量(pgvector)** chunks,见 [`docs/memory-retrieval.md`](docs/memory-retrieval.md)。 +- **记忆检索**:异步与 Celery 均使用 **向量(pgvector)** chunks,见 [`docs/memory-retrieval.md`](docs/memory-retrieval.md)(含 async/sync **行为矩阵**)。 +- **AI 相关代码扫描**:`uv run python scripts/ai_touchpoints_scan.py --markdown docs/ai-touchpoints.md`(在 `api/` 目录下执行)生成带标签的触点列表,见 [`docs/ai-touchpoints.md`](docs/ai-touchpoints.md)。 +- **与 AI 强相关的配置项(摘录)**:`CHAT_MEMORY_RETRIEVAL_ENABLED` / `MEMOIR_PHASE1_BATCH_LLM_ENABLED` / `MEMORY_ENRICHMENT_ENABLED` / `MEMORY_EVIDENCE_EMPTY_QUERY_INCLUDE_ROLLING` 等见 `app/core/config.py`;调参时建议对照 [`docs/memory-retrieval.md`](docs/memory-retrieval.md) 与 [`docs/ai-touchpoints.md`](docs/ai-touchpoints.md)。 - **Memory compaction**:`.env.example` / [`.env.development`](.env.development) / [`.env.staging`](.env.staging) / [`.env.production`](.env.production) 均默认 `MEMORY_COMPACTION_ENABLED=true`。须运行 **Celery worker** 与 **celery-beat**([`docker-compose.yml`](docker-compose.yml) 已包含 `celery-beat`,用于定期 `memory_compaction_sweep`)。 - **Memory LLM enrichment(单次 LLM:会话摘要 + 事实)**:任务路由到 **`memory_idle`** 队列(`CELERY_MEMORY_ENRICHMENT_QUEUE`,默认 `memory_idle`)。本地与 compose 内 worker 已使用 `-Q celery,memory_idle`;生产可单独起低并发 worker 只消费 `memory_idle`,与主队列隔离。 diff --git a/api/app/adapters/llm/deepseek.py b/api/app/adapters/llm/deepseek.py index 2245b97..244574c 100644 --- a/api/app/adapters/llm/deepseek.py +++ b/api/app/adapters/llm/deepseek.py @@ -52,8 +52,9 @@ class DeepSeekLLMProvider: *, temperature: float | None = None, model: str | None = None, + max_tokens: int | None = None, ) -> str: - llm = self._get_llm(temperature, model) + llm = self._get_llm(temperature, model, max_tokens) lc_messages = _to_langchain_messages(messages) result = await llm.ainvoke(lc_messages) return str(result.content) @@ -64,21 +65,29 @@ class DeepSeekLLMProvider: *, temperature: float | None = None, model: str | None = None, + max_tokens: int | None = None, ) -> AsyncIterator[str]: - llm = self._get_llm(temperature, model) + llm = self._get_llm(temperature, model, max_tokens) lc_messages = _to_langchain_messages(messages) async for chunk in llm.astream(lc_messages): if chunk.content: yield str(chunk.content) - def _get_llm(self, temperature: float | None, model: str | None): - if temperature is None and model is None: + def _get_llm( + self, + temperature: float | None, + model: str | None, + max_tokens: int | None = None, + ): + if temperature is None and model is None and max_tokens is None: return self._llm kwargs: dict = {} if temperature is not None: kwargs["temperature"] = temperature if model is not None: kwargs["model"] = model + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens return self._llm.bind(**kwargs) if kwargs else self._llm diff --git a/api/app/agents/chat/interview_agent.py b/api/app/agents/chat/interview_agent.py index a62d54b..d6287d5 100644 --- a/api/app/agents/chat/interview_agent.py +++ b/api/app/agents/chat/interview_agent.py @@ -18,7 +18,6 @@ from app.agents.chat.interview_state_hints import ( update_recent_questions, ) from app.agents.chat.interview_turn_plan import plan_interview_turn -from app.agents.chat.reply_planner import maybe_refine_turn_plan_with_llm from app.agents.chat.personas import normalize_interview_persona from app.agents.chat.prompt_context import ChatPromptContext from app.agents.chat.prompts_conversation import ( @@ -30,6 +29,7 @@ from app.agents.chat.reply_limits import ( segments_from_llm_response, truncate_chat_segments, ) +from app.agents.chat.reply_planner import maybe_refine_turn_plan_with_llm from app.agents.chat.stage_detection import keyword_fallback_primary_stage from app.agents.state_schema import MemoirStateSchema from app.core.agent_logging import ( @@ -38,7 +38,7 @@ from app.core.agent_logging import ( log_agent_summary, ) from app.core.config import settings -from app.core.dependencies import get_llm_provider +from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger from app.features.conversation.input_normalize import normalize_chat_input_for_agent @@ -89,8 +89,7 @@ def _finalize_chat_segments_after_llm( def _get_langchain_llm(): try: - provider = get_llm_provider() - return getattr(provider, "langchain_llm", None) + return LlmGateway().langchain_llm_for(LlmUseCase("chat.interview")) except Exception: return None diff --git a/api/app/agents/chat/orchestrator.py b/api/app/agents/chat/orchestrator.py index 1e73397..d7f43f4 100644 --- a/api/app/agents/chat/orchestrator.py +++ b/api/app/agents/chat/orchestrator.py @@ -4,6 +4,7 @@ ChatOrchestrator:AI 回复用户模块的编排层 """ import time +from collections.abc import Callable from datetime import datetime from typing import TYPE_CHECKING, List, Optional @@ -24,7 +25,8 @@ from app.agents.chat.stage_detection import ( from app.agents.state_schema import MemoirStateSchema from app.core.agent_logging import agent_summary_enabled, log_agent_detail from app.core.config import settings -from app.core.dependencies import get_llm_provider +from app.core.dependencies import get_embedding_provider +from app.core.llm_gateway import LlmGateway from app.core.logging import get_logger from app.features.conversation.input_normalize import normalize_chat_input_for_agent from app.features.memoir.state_service import ( @@ -32,18 +34,20 @@ from app.features.memoir.state_service import ( save_interview_state_meta, switch_stage, ) +from app.features.memory.prompt_adapter import MemoryPromptAdapter def _llm_for_chat_input_normalize(): try: - p = get_llm_provider() - return getattr(p, "langchain_llm", None) + return LlmGateway().langchain_llm_for() except Exception: return None if TYPE_CHECKING: from app.features.user.models import User + from app.ports.embedding import EmbeddingProvider + from app.ports.llm import LLMProvider logger = get_logger(__name__) @@ -56,9 +60,10 @@ async def _fetch_interview_memory_bundle( db: AsyncSession, user_id: str, user_message: str, + *, + get_embedding_provider_fn: Callable[[], "EmbeddingProvider"], ) -> tuple[dict | None, object | None]: - """检索记忆 bundle(原始结构);是否进主 prompt 由 `slice_interview_memory` 再筛。""" - from app.core.dependencies import get_embedding_provider + """检索记忆 bundle(原始结构);是否进主 prompt 由 adapter 再筛。""" from app.features.memory.retrieval_trace import ( chat_memory_retrieval_trace_from_bundle, ) @@ -76,7 +81,7 @@ async def _fetch_interview_memory_bundle( ) return None, None try: - emb = get_embedding_provider() + emb = get_embedding_provider_fn() ms = MemoryService(db, embedding_provider=emb) top_k = settings.chat_memory_top_k bundle = await ms.retrieve(user_id, msg, top_k=top_k) @@ -103,11 +108,22 @@ class ChatOrchestrator: """ 聊天编排器:根据用户资料完成度路由到 ProfileAgent 或 InterviewAgent。 不直接写入 Redis/DB;由 WS pipeline / ConversationHistoryStore 落库并同步缓存。 + + ``get_embedding_provider_fn`` / ``llm_provider`` 供测试或脚本注入;默认使用全局依赖。 """ - def __init__(self): - self.profile_agent = ProfileAgent() + def __init__( + self, + *, + get_embedding_provider_fn: Callable[[], "EmbeddingProvider"] | None = None, + llm_provider: "LLMProvider | None" = None, + ): + self._get_embedding_provider_fn = ( + get_embedding_provider_fn or get_embedding_provider + ) + self.profile_agent = ProfileAgent(llm_provider=llm_provider) self.interview_agent = InterviewAgent() + self.memory_prompt_adapter = MemoryPromptAdapter() async def process_user_message( self, @@ -272,12 +288,16 @@ class ChatOrchestrator: background_voice = infer_background_voice(user.occupation) occupation = user.occupation or "" - from app.features.memory.chat_memory_injection import slice_interview_memory - memory_bundle, mem_trace = await _fetch_interview_memory_bundle( - db, user_id, normalized_user_message + db, + user_id, + normalized_user_message, + get_embedding_provider_fn=self._get_embedding_provider_fn, + ) + mem_slices = self.memory_prompt_adapter.slice_for_interview( + memory_bundle, + normalized_user_message, ) - mem_slices = slice_interview_memory(memory_bundle, normalized_user_message) # 场景关键词仅作为 focus planner 的辅助输入,不直接拼进记忆块,避免抢过用户明确的关系/身份线索 scene_cues_for_planner = extract_scene_cues(normalized_user_message) diff --git a/api/app/agents/chat/profile_agent.py b/api/app/agents/chat/profile_agent.py index 23d9263..86b4328 100644 --- a/api/app/agents/chat/profile_agent.py +++ b/api/app/agents/chat/profile_agent.py @@ -24,19 +24,36 @@ from app.core.agent_logging import agent_span, log_agent_payload, log_agent_summ from app.core.config import settings from app.core.dependencies import get_llm_provider from app.core.llm_call import allm_json_call +from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger +from app.ports.llm import LLMProvider logger = get_logger(__name__) def _get_langchain_llm(): try: - provider = get_llm_provider() - return getattr(provider, "langchain_llm", None) + return LlmGateway().langchain_llm_for(LlmUseCase("chat.profile")) except Exception: return None +def _langchain_messages_to_port(messages: List[Any]) -> list[dict]: + """LangChain message 列表 → ``LLMProvider.complete`` 的 ``role/content`` 结构。""" + out: list[dict] = [] + for m in messages: + if isinstance(m, SystemMessage): + out.append({"role": "system", "content": str(m.content)}) + elif isinstance(m, HumanMessage): + out.append({"role": "user", "content": str(m.content)}) + elif isinstance(m, AIMessage): + out.append({"role": "assistant", "content": str(m.content)}) + else: + c = getattr(m, "content", None) + out.append({"role": "user", "content": str(c) if c is not None else ""}) + return out + + def _message_contents_char_count(messages: List[Any]) -> int: n = 0 for m in messages: @@ -49,9 +66,15 @@ def _message_contents_char_count(messages: List[Any]) -> int: class ProfileAgent: """用户资料收集 Specialist Agent""" - def __init__(self): + def __init__(self, llm_provider: LLMProvider | None = None): + self._llm_provider = llm_provider self.llm = _get_langchain_llm() + def _provider(self) -> LLMProvider: + if self._llm_provider is not None: + return self._llm_provider + return get_llm_provider() + async def _invoke_chat( self, messages: List[Any], @@ -60,20 +83,21 @@ class ProfileAgent: conversation_id: Optional[str], agent_name: str, ) -> str: - chat_llm = self.llm.bind(max_tokens=max_tokens) + port_messages = _langchain_messages_to_port(messages) llm_t0 = time.perf_counter() with agent_span( logger, f"{agent_name}.llm", conversation_id=conversation_id or "" ): - response = await chat_llm.ainvoke(messages) + response_text = await self._provider().complete( + port_messages, + max_tokens=max_tokens, + ) logger.info( "event=chat_llm_done agent={} response_latency_ms={:.2f}", agent_name, (time.perf_counter() - llm_t0) * 1000, ) - return ( - response.content if hasattr(response, "content") else str(response) - ) or "" + return response_text or "" async def _segments_from_response( self, diff --git a/api/app/agents/chat/prompt_layers.py b/api/app/agents/chat/prompt_layers.py index 30fa4d9..be3e188 100644 --- a/api/app/agents/chat/prompt_layers.py +++ b/api/app/agents/chat/prompt_layers.py @@ -25,7 +25,6 @@ from app.agents.chat.background_voice import ( get_background_voice_tone_hint, ) from app.agents.chat.occupation_context import get_occupation_chat_hint -from app.agents.chat.output_rules import chat_output_rules from app.agents.chat.personas import ( get_interview_persona_tone_hint, normalize_interview_persona, @@ -35,7 +34,6 @@ from app.agents.stage_constants import CHAT_STAGES, STAGE_DISPLAY_ZH from app.agents.state_schema import KnownFact, PersonaThread from app.agents.style_profiles import ChatStyleProfile - # ============================================================================= # Context 层:状态与素材(纯数据视图,不立行为规则) # ============================================================================= @@ -213,7 +211,7 @@ def build_behavior_policy_block() -> str: "- 若用户直接追问**你的**身世、籍贯、童年、感情或家庭,必须守住这条边界:明确你没有这些真实经历,再把话题轻轻带回用户;**绝不能**把「用户信息」「已确认事实」「人物主线」或「记忆线索」里的内容拿来冒充助手自己的资料(例如不能把用户的成长地答成「我是上海人」)。但这些上下文仍可继续用来服务回答,只能以**明确归因**方式转回用户(如「你刚提到上海」「你之前说过那段童年」)。\n" "\n## 身份与语气\n" "- 你们是**平等聊天**:底色暖、有安全感;**不是**冷冰冰盘问或庭审式追问。仍须避免**晚会串联腔、播报腔**(如「那么接下来」「让我们回到」)——好的主持人**自然勾回话题**,不靠节目硬切。\n" - "- **主持人职责(与温情并存)**:你心里守着**回忆口述这条主线**。用户若只给寒暄、天气、泛泛忙累、纯近况而**几乎没有人生叙事实质**:最多**一两句**并肩承接,随后**必须**用**一条**带锚的开放式问题,把话头带回「当前阶段 / 还可聊的方向 / 已确认事实或人物主线 /(若有)一条极短记忆线索」之一;像朋友**绕着弯把话头勾回来**,**禁止**长时间停在纯日常闲聊里空转。**不要把「今天过得怎样」「最近好吗」当默认整轮主线**。\n" + "- **主持人职责(与温情并存)**:你心里守着**回忆口述这条主线**。用户若只给寒暄、天气、泛泛忙累、纯近况而**几乎没有人生叙事实质**:通常最多**一两句**并肩承接,并参考顶部「本轮编排指令」决定是否用带锚的开放式问题,把话头带回「当前阶段 / 还可聊的方向 / 已确认事实或人物主线 /(若有)一条极短记忆线索」之一;像朋友**绕着弯把话头勾回来**,避免长时间停在纯日常闲聊里空转。**不要把「今天过得怎样」「最近好吗」当默认整轮主线**。\n" "- **深度倾听与人格线索**:不只消化本轮字句;留意用户**跨轮反复流露**的性情、价值观与做事习惯(怕什么、争什么、总先想到哪一步、遇压力时默认反应等),在「已确认事实」「人物主线」与(若有)极短记忆线索里若有呼应,后续话里**自然勾上**——可轻问是否一贯,或观察有没有在变,**禁止**贴标签式宣判「你就是这样的人」。\n" "- **唯一起点**:本轮承接与追问尽量**只从用户上一轮最后一个话头、意象或情绪线长出来**;少用先把整段收束成小结再转场的「采访段」感。\n" "- **聊天伙伴 + 控场**:像炕头、微信里能讲心里话的老友那样接住人,但**服务目标是成稿素材与回忆叙事**,**不是**记者式刨根,也**不是**无底洞式陪聊;可以把细节捋清楚,亲和力、安全感与「听懂对方」至少和信息条理同等重要;避免理性拆解腔、冷冰冰的「专业访谈感」。\n" @@ -246,8 +244,8 @@ def build_reply_strategy_block() -> str: "- **先抓重点**:承接与追问优先对齐顶部「本轮承接重点」与**用户原词**(人名、关系、面子、身份、场景);若二者冲突,以顶部为准。\n" "- **追问与承接**:每轮由**你自己判断**该先接住、轻声并肩,还是带着锚往下挖;按情绪与画面自然取舍。\n" "- **情绪与大纲**:外显情绪很重或用户在溃堤式宣泄时,多承接、少搜集;**不要**把「写得长」或「带点感慨」误当成必须整轮不问。\n" - "- **追问义务回正**:若你方已连续两轮**完全无问句**(无句末问号也无隐性探询),而用户仍在展开叙事,**短承接后须带回一条**带锚的开放式问;本条与「情绪优先」冲突时,**以顶部指令为准**。\n" - "- **纯跑题**:若用户几乎只有寒暄/天气而无人生实质,短承接后仍须**勾回回忆叙事**(见「身份与语气」里的主持人职责)。\n" + "- **追问节奏校准**:若你方已连续两轮**完全无问句**(无句末问号也无隐性探询),而用户仍在展开叙事,把它视为需要校准节奏的信号;具体是否追问、问几问,仍以顶部「本轮编排指令」为准。\n" + "- **纯跑题**:若用户几乎只有寒暄/天气而无人生实质,把它视为需要回到回忆叙事主线的信号;具体回法见顶部「本轮编排指令」与「身份与语气」里的主持人职责。\n" "- **大纲**:每次只撬一个叙述槽;从大纲借问题时,把抽象词换成对方嘴里出现过的具体词。\n" "- **跟随—沉浸**:长段后可极短并肩画面或体感,须贴着对方物象;共情用泛指,**禁止**助手自传式亲历。\n" "- **承接**:钉住对方上一句里的名词、动词或比喻;少用「听起来你…」式判语。\n" diff --git a/api/app/agents/image_prompt/orchestrator.py b/api/app/agents/image_prompt/orchestrator.py index ba618e9..89ecb57 100644 --- a/api/app/agents/image_prompt/orchestrator.py +++ b/api/app/agents/image_prompt/orchestrator.py @@ -74,7 +74,7 @@ class ImagePromptOrchestrator: def get_image_prompt_orchestrator() -> ImagePromptOrchestrator: """Celery / 后台任务入口:统一装配 LLM 与 MemoirImageSettings。""" - from app.core.dependencies import get_llm_provider + from app.core.llm_gateway import LlmGateway, LlmUseCase - llm = getattr(get_llm_provider(), "langchain_llm", None) + llm = LlmGateway().langchain_llm_for(LlmUseCase("image_prompt")) return ImagePromptOrchestrator(llm=llm, settings=MemoirImageSettings.from_env()) diff --git a/api/app/agents/memoir/orchestrator.py b/api/app/agents/memoir/orchestrator.py index 9e20101..2ecbe20 100644 --- a/api/app/agents/memoir/orchestrator.py +++ b/api/app/agents/memoir/orchestrator.py @@ -49,11 +49,18 @@ class MemoirOrchestrator: 回忆录生成编排器。 遍历 segments → ExtractionAgent → ClassificationAgent → 按 category 聚合 → 调用 process_category 生成叙事并持久化。 + + 可注入 ``extraction_agent`` / ``classification_agent`` 以便测试替身。 """ - def __init__(self) -> None: - self.extraction_agent = ExtractionAgent() - self.classification_agent = ClassificationAgent() + def __init__( + self, + *, + extraction_agent: ExtractionAgent | None = None, + classification_agent: ClassificationAgent | None = None, + ) -> None: + self.extraction_agent = extraction_agent or ExtractionAgent() + self.classification_agent = classification_agent or ClassificationAgent() def prepare_batches( self, diff --git a/api/app/core/llm_gateway.py b/api/app/core/llm_gateway.py new file mode 100644 index 0000000..2b10b88 --- /dev/null +++ b/api/app/core/llm_gateway.py @@ -0,0 +1,102 @@ +"""Use-case oriented LLM gateway. + +This is a small compatibility layer over the existing provider and JSON helper +functions. It gives new code a stable place to request model capabilities while +older agents continue to use LangChain directly during the transition. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, TypeVar + +from pydantic import BaseModel + +from app.core.dependencies import get_llm_provider, get_llm_provider_fast +from app.core.llm_call import allm_json_call, llm_json_call + +T = TypeVar("T", bound=BaseModel) + + +@dataclass(frozen=True) +class LlmUseCase: + name: str + fast: bool = False + max_tokens: int | None = None + temperature: float | None = None + model: str | None = None + + +class LlmGateway: + """Facade for text and JSON LLM calls.""" + + def provider_for(self, use_case: LlmUseCase | None = None): + if use_case and use_case.fast: + return get_llm_provider_fast() + return get_llm_provider() + + def langchain_llm_for(self, use_case: LlmUseCase | None = None) -> Any | None: + provider = self.provider_for(use_case) + return getattr(provider, "langchain_llm", None) + + async def chat_text( + self, + messages: list[dict], + *, + use_case: LlmUseCase | None = None, + temperature: float | None = None, + model: str | None = None, + max_tokens: int | None = None, + ) -> str: + provider = self.provider_for(use_case) + return await provider.complete( + messages, + temperature=( + temperature + if temperature is not None + else (use_case.temperature if use_case else 0.7) + ), + model=model if model is not None else (use_case.model if use_case else None), + max_tokens=( + max_tokens + if max_tokens is not None + else (use_case.max_tokens if use_case else None) + ), + ) + + async def json_object( + self, + prompt: str, + schema: type[T], + *, + use_case: LlmUseCase, + fallback_factory: Callable[[], T] | None = None, + ) -> T: + return await allm_json_call( + self.langchain_llm_for(use_case), + prompt, + schema, + max_tokens=use_case.max_tokens or 1024, + agent=use_case.name, + fallback_factory=fallback_factory, + ) + + def sync_json_object( + self, + prompt: str, + schema: type[T], + *, + use_case: LlmUseCase, + fallback_factory: Callable[[], T] | None = None, + ) -> T: + return llm_json_call( + self.langchain_llm_for(use_case), + prompt, + schema, + max_tokens=use_case.max_tokens or 1024, + agent=use_case.name, + fallback_factory=fallback_factory, + ) + + +__all__ = ["LlmGateway", "LlmUseCase"] diff --git a/api/app/features/conversation/chat_turn.py b/api/app/features/conversation/chat_turn.py new file mode 100644 index 0000000..ce4359f --- /dev/null +++ b/api/app/features/conversation/chat_turn.py @@ -0,0 +1,116 @@ +"""Conversation chat turn boundary. + +This module gives the WebSocket pipeline a small, explicit contract for one +user turn. It deliberately keeps the existing ``ChatOrchestrator`` behavior +intact while making the runtime inputs/outputs visible and testable. +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass +from datetime import datetime +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.chat import ChatOrchestrator +from app.agents.chat.agent_turn import AgentChatTurn + + +@dataclass(frozen=True) +class ChatTurnInput: + """Transport-neutral input for a single user chat turn.""" + + conversation_id: str + user_message: str + is_from_voice: bool = False + voice_session_id: str | None = None + user_message_timestamp: datetime | None = None + audio_duration_seconds: int | None = None + force_skip_tts: bool = False + + +@dataclass(frozen=True) +class ChatTurnContext: + """Runtime dependencies needed to execute a turn.""" + + db: AsyncSession + user: Any | None + conversation: Any | None + apply_extracted_profile_fn: Callable[..., Any] + get_missing_profile_fields_fn: Callable[[Any], list[str]] + get_filled_profile_fields_fn: Callable[[Any], dict[str, Any]] + + +@dataclass(frozen=True) +class ChatTurnDecision: + """Observable decision metadata for the chat runtime boundary.""" + + engine: str = "ChatOrchestrator" + route_hint: str = "auto" + force_skip_tts: bool = False + + +@dataclass(frozen=True) +class ChatTurnResult: + """Stable result shape consumed by conversation persistence and delivery.""" + + messages: list[str] + skip_tts: bool + memory_retrieval_trace: dict[str, Any] | None = None + interview_state_meta: dict[str, Any] | None = None + decision: ChatTurnDecision = ChatTurnDecision() + + @classmethod + def from_agent_turn( + cls, + turn: AgentChatTurn, + *, + decision: ChatTurnDecision, + ) -> "ChatTurnResult": + return cls( + messages=list(turn.messages or []), + skip_tts=bool(turn.skip_tts or decision.force_skip_tts), + memory_retrieval_trace=turn.memory_retrieval_trace, + interview_state_meta=turn.interview_state_meta, + decision=decision, + ) + + +class ChatTurnService: + """Executes one chat turn behind an explicit internal contract.""" + + def __init__(self, orchestrator: ChatOrchestrator | None = None) -> None: + self._orchestrator = orchestrator or ChatOrchestrator() + + async def process_turn( + self, + turn_input: ChatTurnInput, + context: ChatTurnContext, + ) -> ChatTurnResult: + decision = ChatTurnDecision(force_skip_tts=turn_input.force_skip_tts) + turn = await self._orchestrator.process_user_message( + conversation_id=turn_input.conversation_id, + user_message=turn_input.user_message, + user=context.user, + conversation=context.conversation, + is_from_voice=turn_input.is_from_voice, + voice_session_id=turn_input.voice_session_id, + db=context.db, + apply_extracted_profile_fn=context.apply_extracted_profile_fn, + get_missing_profile_fields_fn=context.get_missing_profile_fields_fn, + get_filled_profile_fields_fn=context.get_filled_profile_fields_fn, + user_message_timestamp=turn_input.user_message_timestamp, + audio_duration_seconds=turn_input.audio_duration_seconds, + ) + return ChatTurnResult.from_agent_turn(turn, decision=decision) + + +__all__ = [ + "ChatTurnContext", + "ChatTurnDecision", + "ChatTurnInput", + "ChatTurnResult", + "ChatTurnService", +] diff --git a/api/app/features/conversation/ws/pipeline.py b/api/app/features/conversation/ws/pipeline.py index f198011..aa028db 100644 --- a/api/app/features/conversation/ws/pipeline.py +++ b/api/app/features/conversation/ws/pipeline.py @@ -23,6 +23,11 @@ from app.core.config import settings from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider +from app.features.conversation.chat_turn import ( + ChatTurnContext, + ChatTurnInput, + ChatTurnService, +) from app.features.conversation.history_store import ( AI_RESPONSE_SEGMENT_JOIN, ConversationHistoryStore, @@ -37,6 +42,7 @@ from app.features.conversation.ws.profile_collector import ( get_missing_profile_fields, ) from app.features.memoir.background_runner import BackgroundTaskRunner +from app.features.memoir.ingest_scheduler import MemoirIngestScheduler from app.features.user.models import User from app.ports.asr import ASRTranscriptionError @@ -134,7 +140,9 @@ async def _send_tts_audio( # ── Agent 实例(从 ConnectionManager 移出) ───────────────────── chat_orchestrator = ChatOrchestrator() -background_runner = BackgroundTaskRunner() +chat_turn_service = ChatTurnService(chat_orchestrator) +_background_runner = BackgroundTaskRunner() +memoir_ingest_scheduler = MemoirIngestScheduler(_background_runner) # ── 分段流状态 ────────────────────────────────────────────────── @@ -573,7 +581,7 @@ async def process_audio_segment( user_message_timestamp = _mark_conversation_active(conversation) await db.commit() await db.refresh(segment) - await background_runner.queue_message( + await memoir_ingest_scheduler.queue_segment( conversation.user_id, segment.id, text_char_count=len((transcript_text or "").strip()), @@ -655,19 +663,24 @@ async def process_user_message( voice_session_id = _voice_session_id_from_audio_url(segment.audio_url) audio_dur = getattr(segment, "audio_duration_seconds", None) t_pipeline = time.perf_counter() - turn = await chat_orchestrator.process_user_message( - conversation_id=conversation_id, - user_message=user_message, - user=user, - conversation=conversation, - is_from_voice=is_from_voice, - voice_session_id=voice_session_id, - db=db, - apply_extracted_profile_fn=apply_extracted_profile, - get_missing_profile_fields_fn=get_missing_profile_fields, - get_filled_profile_fields_fn=get_filled_profile_fields, - user_message_timestamp=user_message_timestamp, - audio_duration_seconds=audio_dur, + turn = await chat_turn_service.process_turn( + ChatTurnInput( + conversation_id=conversation_id, + user_message=user_message, + is_from_voice=is_from_voice, + voice_session_id=voice_session_id, + user_message_timestamp=user_message_timestamp, + audio_duration_seconds=audio_dur, + force_skip_tts=force_skip_tts, + ), + ChatTurnContext( + db=db, + user=user, + conversation=conversation, + apply_extracted_profile_fn=apply_extracted_profile, + get_missing_profile_fields_fn=get_missing_profile_fields, + get_filled_profile_fields_fn=get_filled_profile_fields, + ), ) if agent_summary_enabled(): logger.info( @@ -682,7 +695,7 @@ async def process_user_message( turn.skip_tts, ) responses = turn.messages - skip_tts = bool(turn.skip_tts or force_skip_tts) + skip_tts = bool(turn.skip_tts) segment.agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses) _mark_conversation_active(conversation) @@ -696,7 +709,7 @@ async def process_user_message( audio_duration_seconds=audio_dur, tts_audio_urls=None, segment_id=segment.id, - memory_retrieval_trace=getattr(turn, "memory_retrieval_trace", None), + memory_retrieval_trace=turn.memory_retrieval_trace, ) if not turn_ids: logger.warning( @@ -823,7 +836,7 @@ async def process_conversation_segments( """ 对话结束时:把本对话仍待 Phase1 的段落交给回忆录管线。 - 经 `BackgroundTaskRunner.flush_pending` 将内存防抖 batch 与当前查询到的 + 经 `MemoirIngestScheduler.flush_pending` 将内存防抖 batch 与当前查询到的 `topic_category IS NULL` 段 ID 合并、去重后**单次**提交 `process_memoir_phase1`, 并在 flush 末尾触发待叙事 Phase2 派发;避免会话结束路径与 debounce flush 双发 Phase1。 @@ -842,7 +855,10 @@ async def process_conversation_segments( segments = result.scalars().all() if not segments: - await background_runner.flush_pending(conversation.user_id) + await memoir_ingest_scheduler.flush_pending( + conversation.user_id, + trigger="conversation_end", + ) return user = await db.get(User, conversation.user_id) @@ -854,13 +870,18 @@ async def process_conversation_segments( logger.info( f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}" ) - await background_runner.flush_pending(conversation.user_id) + await memoir_ingest_scheduler.flush_pending( + conversation.user_id, + trigger="conversation_end", + ) return segment_ids = [seg.id for seg in segments] try: - await background_runner.flush_pending( - conversation.user_id, extra_segment_ids=segment_ids + await memoir_ingest_scheduler.flush_pending( + conversation.user_id, + extra_segment_ids=segment_ids, + trigger="conversation_end", ) logger.info( "对话结束,合并批内 segment 与 DB 待分类段,单次提交 Phase1: " diff --git a/api/app/features/conversation/ws/router.py b/api/app/features/conversation/ws/router.py index 8253441..aa34e1e 100644 --- a/api/app/features/conversation/ws/router.py +++ b/api/app/features/conversation/ws/router.py @@ -24,11 +24,11 @@ from app.features.conversation.ws.message_types import MessageType from app.features.conversation.ws.pipeline import ( _delayed_listening_feedback, _voice_session_id_from_client_segment_id, - background_runner, bump_tts_cancel_epoch, chat_orchestrator, cleanup_segment_states, get_or_create_segment_state, + memoir_ingest_scheduler, process_audio_segment, process_conversation_segments, process_user_message, @@ -304,7 +304,7 @@ async def websocket_endpoint( text_message, ) user_message_timestamp = conversation.last_message_at - await background_runner.queue_message( + await memoir_ingest_scheduler.queue_segment( conversation.user_id, segment.id, text_char_count=len(text_message.strip()), @@ -565,7 +565,7 @@ async def websocket_endpoint( ) ) user_message_timestamp = conversation.last_message_at - await background_runner.queue_message( + await memoir_ingest_scheduler.queue_segment( conversation.user_id, segment.id, text_char_count=len((asr_text or "").strip()), diff --git a/api/app/features/evaluation/memoir_readiness_service.py b/api/app/features/evaluation/memoir_readiness_service.py index a072d53..7a87294 100644 --- a/api/app/features/evaluation/memoir_readiness_service.py +++ b/api/app/features/evaluation/memoir_readiness_service.py @@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.db import utc_now from app.features.conversation.models import Conversation, Segment -from app.features.conversation.ws.pipeline import background_runner +from app.features.conversation.ws.pipeline import memoir_ingest_scheduler from app.features.evaluation.errors import ( EvaluationBadRequestError, EvaluationNotFoundError, @@ -126,8 +126,10 @@ class MemoirReadinessService: elapsed_ms=None, ) t0 = time.perf_counter() - task_id = await background_runner.flush_pending( - uid, extra_segment_ids=segment_ids + _, task_id = await memoir_ingest_scheduler.flush_pending( + uid, + extra_segment_ids=segment_ids, + trigger="manual_flush", ) elapsed_ms = max(0, int((time.perf_counter() - t0) * 1000)) submitted_at = await record_phase1_job_submitted( diff --git a/api/app/features/evaluation/replay_service.py b/api/app/features/evaluation/replay_service.py index 7051cb0..5dc3d57 100644 --- a/api/app/features/evaluation/replay_service.py +++ b/api/app/features/evaluation/replay_service.py @@ -17,7 +17,7 @@ from app.features.auth import repo as auth_repo from app.features.conversation.models import Conversation from app.features.conversation.service import ConversationService from app.features.conversation.ws.pipeline import ( - background_runner, + memoir_ingest_scheduler, process_user_message, ) from app.features.evaluation.errors import ( @@ -160,10 +160,11 @@ class ReplayConversationService: segment_ids.append(segment.id) ts = segment.created_at or conv.last_message_at if not skip_memoir: - await background_runner.queue_message( + await memoir_ingest_scheduler.queue_segment( conv.user_id, segment.id, text_char_count=len(text), + trigger="evaluation_replay", ) await process_user_message( conversation_id=cid, @@ -178,7 +179,10 @@ class ReplayConversationService: count += 1 if flush_memoir_after and conv.user_id and (not skip_memoir): - await background_runner.flush_pending(conv.user_id) + await memoir_ingest_scheduler.flush_pending( + conv.user_id, + trigger="evaluation_replay", + ) logger.info( "eval replay done conversation_id={} turns={} flush={} skip_memoir={} skip_tts={}", diff --git a/api/app/features/memoir/ingest_scheduler.py b/api/app/features/memoir/ingest_scheduler.py new file mode 100644 index 0000000..6921eb8 --- /dev/null +++ b/api/app/features/memoir/ingest_scheduler.py @@ -0,0 +1,77 @@ +"""Memoir ingest scheduling boundary. + +The real batching logic still lives in ``BackgroundTaskRunner``. This adapter +keeps conversation code from depending on that implementation directly. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal, Sequence + +from app.features.memoir.background_runner import BackgroundTaskRunner + +MemoirTrigger = Literal[ + "turn", + "conversation_end", + "manual_flush", + "evaluation_replay", +] + + +@dataclass(frozen=True) +class MemoirPhasePlan: + """A visible plan for submitting segments into the memoir pipeline.""" + + user_id: str + segment_ids: tuple[str, ...] + trigger: MemoirTrigger + + +class MemoirIngestScheduler: + """Small facade over debounce batching and Phase2 flush dispatch.""" + + def __init__(self, runner: BackgroundTaskRunner | None = None) -> None: + self._runner = runner or BackgroundTaskRunner() + + @property + def runner(self) -> BackgroundTaskRunner: + """Compatibility escape hatch for existing tests and eval utilities.""" + + return self._runner + + async def queue_segment( + self, + user_id: str, + segment_id: str, + *, + text_char_count: int = 0, + trigger: MemoirTrigger = "turn", + ) -> MemoirPhasePlan: + await self._runner.queue_message( + user_id, + segment_id, + text_char_count=text_char_count, + ) + return MemoirPhasePlan( + user_id=user_id, + segment_ids=(segment_id,), + trigger=trigger, + ) + + async def flush_pending( + self, + user_id: str, + *, + extra_segment_ids: Sequence[str] | None = None, + trigger: MemoirTrigger = "manual_flush", + ) -> tuple[MemoirPhasePlan, str | None]: + ids = tuple(str(x) for x in (extra_segment_ids or ()) if str(x).strip()) + task_id = await self._runner.flush_pending( + user_id, + extra_segment_ids=list(ids), + ) + return MemoirPhasePlan(user_id=user_id, segment_ids=ids, trigger=trigger), task_id + + +__all__ = ["MemoirIngestScheduler", "MemoirPhasePlan", "MemoirTrigger"] diff --git a/api/app/features/memory/enrichment.py b/api/app/features/memory/enrichment.py index e639ef2..15690ea 100644 --- a/api/app/features/memory/enrichment.py +++ b/api/app/features/memory/enrichment.py @@ -14,6 +14,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.core.langchain_llm import ainvoke_json_object, invoke_json_object +from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger from app.features.memory.enrichment_pipeline import ( dedupe_key, @@ -45,9 +46,9 @@ def _lineage_snapshot_from_source(source: MemorySource | None) -> dict | None: def _resolve_llm_sync() -> Any | None: try: - from app.core.dependencies import get_llm_provider_fast - - return get_llm_provider_fast().langchain_llm + return LlmGateway().langchain_llm_for( + LlmUseCase("memory.enrichment_sync", fast=True) + ) except Exception as e: logger.warning("memory enrichment 无法获取 LLM: {}", e) return None @@ -150,7 +151,8 @@ def enrich_memory_after_ingest_sync( chunk_ids = [c.id for c in chunks] chunk_texts = [c.content for c in chunks] numbered = "\n\n".join( - f"[chunk_id={cid}]\n{txt}" for cid, txt in zip(chunk_ids, chunk_texts) + f"[chunk_id={cid}]\n{txt}" + for cid, txt in zip(chunk_ids, chunk_texts, strict=False) ) narrator_label = (narrator_name or "").strip() or "叙述者" @@ -224,7 +226,8 @@ async def enrich_memory_after_ingest_async( chunk_ids = [c.id for c in chunks] chunk_texts = [c.content for c in chunks] numbered = "\n\n".join( - f"[chunk_id={cid}]\n{txt}" for cid, txt in zip(chunk_ids, chunk_texts) + f"[chunk_id={cid}]\n{txt}" + for cid, txt in zip(chunk_ids, chunk_texts, strict=False) ) narrator_label = (narrator_name or "").strip() or "叙述者" diff --git a/api/app/features/memory/enrichment_scheduler.py b/api/app/features/memory/enrichment_scheduler.py new file mode 100644 index 0000000..d1ef750 --- /dev/null +++ b/api/app/features/memory/enrichment_scheduler.py @@ -0,0 +1,50 @@ +"""Memory enrichment scheduling boundary.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class MemoryEnrichmentRequest: + user_id: str + source_id: str + memoir_correlation_id: str | None = None + + +class MemoryEnrichmentScheduler: + """Adapter around the Celery enrichment task name and queue policy.""" + + def schedule(self, request: MemoryEnrichmentRequest) -> str | None: + from app.tasks.memory_enrichment_tasks import schedule_memory_enrichment + + return schedule_memory_enrichment( + request.user_id, + request.source_id, + memoir_correlation_id=request.memoir_correlation_id, + ) + + def schedule_many( + self, + user_id: str, + source_ids: list[str], + *, + memoir_correlation_id: str | None = None, + ) -> list[str]: + task_ids: list[str] = [] + for source_id in source_ids: + if not source_id: + continue + task_id = self.schedule( + MemoryEnrichmentRequest( + user_id=user_id, + source_id=source_id, + memoir_correlation_id=memoir_correlation_id, + ) + ) + if task_id: + task_ids.append(task_id) + return task_ids + + +__all__ = ["MemoryEnrichmentRequest", "MemoryEnrichmentScheduler"] diff --git a/api/app/features/memory/extractor.py b/api/app/features/memory/extractor.py index 0f78b91..4845661 100644 --- a/api/app/features/memory/extractor.py +++ b/api/app/features/memory/extractor.py @@ -5,6 +5,7 @@ from __future__ import annotations from typing import Any from app.core.langchain_llm import ainvoke_json_object, invoke_json_object +from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger from app.features.memory.llm_schemas import ( FactsExtractionPayload, @@ -101,10 +102,11 @@ async def extract_facts_from_transcript_async( async def extract_facts(chunk_text: str, *, user_id: str) -> list[dict]: """兼容旧接口:单块文本(无 chunk id 时传空 source_chunk_id)。""" from app.core.db import AsyncSessionLocal - from app.core.dependencies import get_llm_provider_fast from app.features.user.models import User - llm = get_llm_provider_fast().langchain_llm + llm = LlmGateway().langchain_llm_for( + LlmUseCase("memory.extract_facts.compat", fast=True) + ) narrator_name: str | None = None try: async with AsyncSessionLocal() as db: diff --git a/api/app/features/memory/ingest_service.py b/api/app/features/memory/ingest_service.py new file mode 100644 index 0000000..1b12467 --- /dev/null +++ b/api/app/features/memory/ingest_service.py @@ -0,0 +1,110 @@ +"""Memory ingest service boundary.""" + +from __future__ import annotations + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import settings +from app.core.logging import get_logger +from app.features.conversation.lineage_schemas import ( + primary_user_message_id_from_lineage, +) +from app.features.memory.chunker import chunk_transcript +from app.features.memory.enrichment_scheduler import ( + MemoryEnrichmentRequest, + MemoryEnrichmentScheduler, +) +from app.features.memory.repo import ( + create_chunk, + create_source, + update_chunk_embedding, +) +from app.ports.embedding import EmbeddingProvider + +logger = get_logger(__name__) + + +class MemoryIngestService: + """Creates memory sources/chunks and schedules post-commit enrichment.""" + + def __init__( + self, + db: AsyncSession, + *, + embedding_provider: EmbeddingProvider | None = None, + enrichment_scheduler: MemoryEnrichmentScheduler | None = None, + ) -> None: + self._db = db + self._embedding = embedding_provider + self._enrichment_scheduler = enrichment_scheduler or MemoryEnrichmentScheduler() + + async def ingest_transcript( + self, + user_id: str, + conversation_id: str, + transcript: str, + *, + lineage_json: dict | None = None, + ) -> str: + if not transcript or not transcript.strip(): + raise ValueError("transcript cannot be empty") + + primary_mid = ( + primary_user_message_id_from_lineage(lineage_json) if lineage_json else None + ) + source = await create_source( + self._db, + user_id=user_id, + source_type="transcript", + raw_text=transcript.strip(), + conversation_id=conversation_id, + lineage_json=lineage_json, + primary_user_message_id=primary_mid, + ) + + chunk_records: list[tuple[str, str]] = [] + for i, content in enumerate(chunk_transcript(transcript.strip())): + chunk = await create_chunk( + self._db, + source_id=source.id, + user_id=user_id, + content=content, + chunk_index=i, + ) + chunk_records.append((chunk.id, content)) + + await self._db.flush() + + vectors_written = 0 + if self._embedding and chunk_records: + texts = [content for _, content in chunk_records] + embeddings = await self._embedding.embed_texts(texts) + for (chunk_id, _), emb in zip( + chunk_records, embeddings, strict=False + ): + if emb: + vectors_written += 1 + await update_chunk_embedding(self._db, chunk_id, emb) + + await self._db.commit() + emb_ok = self._embedding.is_available() if self._embedding else False + enrichment_task_id = self._enrichment_scheduler.schedule( + MemoryEnrichmentRequest(user_id=user_id, source_id=source.id) + ) + + logger.info( + "event=memory_ingest_done user_id={} conversation_id={} source_id={} " + "chunks={} vectors_written={} embedding_available={} enrichment_enabled={} enrichment_task_id={}", + user_id, + conversation_id, + source.id, + len(chunk_records), + vectors_written, + emb_ok, + settings.memory_enrichment_enabled, + enrichment_task_id, + ) + return source.id + + +__all__ = ["MemoryIngestService"] diff --git a/api/app/features/memory/prompt_adapter.py b/api/app/features/memory/prompt_adapter.py new file mode 100644 index 0000000..cd69012 --- /dev/null +++ b/api/app/features/memory/prompt_adapter.py @@ -0,0 +1,26 @@ +"""Memory-to-prompt adapter boundary.""" + +from __future__ import annotations + +from typing import Any, Mapping + +from app.features.memory.chat_memory_injection import ( + InterviewMemorySlices, + slice_interview_memory, +) +from app.features.memory.runtime_types import MemoryEvidenceBundle + + +class MemoryPromptAdapter: + """Converts retrieved evidence into prompt-specific slices.""" + + def slice_for_interview( + self, + evidence: MemoryEvidenceBundle | Mapping[str, Any] | None, + user_message: str, + ) -> InterviewMemorySlices: + raw = evidence.raw if isinstance(evidence, MemoryEvidenceBundle) else evidence + return slice_interview_memory(dict(raw or {}), user_message) + + +__all__ = ["MemoryPromptAdapter"] diff --git a/api/app/features/memory/retrieval_service.py b/api/app/features/memory/retrieval_service.py new file mode 100644 index 0000000..e78a863 --- /dev/null +++ b/api/app/features/memory/retrieval_service.py @@ -0,0 +1,55 @@ +"""Memory retrieval service boundary.""" + +from __future__ import annotations + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.logging import get_logger +from app.features.memory.retriever import HybridRetriever +from app.features.memory.schemas import EvidenceBundle +from app.ports.embedding import EmbeddingProvider + +logger = get_logger(__name__) + + +class MemoryRetrievalService: + """Retrieves typed evidence bundles for downstream consumers.""" + + def __init__( + self, + db: AsyncSession, + *, + embedding_provider: EmbeddingProvider | None = None, + ) -> None: + self._db = db + self._embedding = embedding_provider + + async def retrieve( + self, + user_id: str, + query: str, + *, + top_k: int = 10, + ) -> EvidenceBundle: + retriever = HybridRetriever(self._db, embedding_provider=self._embedding) + raw = await retriever.retrieve(user_id=user_id, query=query, top_k=top_k) + bundle = EvidenceBundle.model_validate(raw) + bd = bundle.model_dump() + vec_ok = self._embedding.is_available() if self._embedding else False + logger.info( + "event=memory_retrieve_done user_id={} query_len={} top_k={} " + "chunks={} facts={} summaries={} timeline={} stories={} vector_ok={}", + user_id, + len((query or "").strip()), + top_k, + len(bd.get("relevant_chunks") or []), + len(bd.get("relevant_facts") or []), + len(bd.get("relevant_summaries") or []), + len(bd.get("timeline_hints") or []), + len(bd.get("relevant_stories") or []), + vec_ok, + ) + return bundle + + +__all__ = ["MemoryRetrievalService"] diff --git a/api/app/features/memory/runtime_types.py b/api/app/features/memory/runtime_types.py new file mode 100644 index 0000000..28c7569 --- /dev/null +++ b/api/app/features/memory/runtime_types.py @@ -0,0 +1,24 @@ +"""Runtime DTOs for memory consumers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Mapping + + +@dataclass(frozen=True) +class MemoryEvidenceBundle: + """Transport-neutral memory evidence payload used by chat and memoir adapters.""" + + raw: dict[str, Any] + + @classmethod + def from_mapping(cls, value: Mapping[str, Any] | None) -> "MemoryEvidenceBundle": + return cls(raw=dict(value or {})) + + @property + def has_any(self) -> bool: + return any(bool(self.raw.get(key)) for key in self.raw.keys()) + + +__all__ = ["MemoryEvidenceBundle"] diff --git a/api/app/features/memory/service.py b/api/app/features/memory/service.py index 0477d91..67f2b37 100644 --- a/api/app/features/memory/service.py +++ b/api/app/features/memory/service.py @@ -15,18 +15,14 @@ from app.core.logging import get_logger from app.features.conversation.lineage_schemas import ( primary_user_message_id_from_lineage, ) -from app.features.memory.chunker import chunk_transcript +from app.features.memory.enrichment_scheduler import MemoryEnrichmentScheduler +from app.features.memory.ingest_service import MemoryIngestService from app.features.memory.repo import ( - create_chunk, create_curation_action, - create_source, set_chunk_excluded, set_memory_fact_status, - update_chunk_embedding, -) -from app.features.conversation.lineage_schemas import ( - primary_user_message_id_from_lineage, ) +from app.features.memory.retrieval_service import MemoryRetrievalService from app.features.memory.schemas import EvidenceBundle from app.ports.embedding import EmbeddingProvider @@ -56,101 +52,20 @@ class MemoryService: Creates MemorySource, chunks, populates embedding. Returns source_id. """ - if not transcript or not transcript.strip(): - raise ValueError("transcript cannot be empty") - - primary_mid = ( - primary_user_message_id_from_lineage(lineage_json) if lineage_json else None - ) - source = await create_source( - self._db, - user_id=user_id, - source_type="transcript", - raw_text=transcript.strip(), - conversation_id=conversation_id, - lineage_json=lineage_json, - primary_user_message_id=primary_mid, - ) - - chunks_text = chunk_transcript(transcript.strip()) - chunk_records = [] - for i, content in enumerate(chunks_text): - chunk = await create_chunk( - self._db, - source_id=source.id, - user_id=user_id, - content=content, - chunk_index=i, - ) - chunk_records.append((chunk.id, content)) - - await self._db.flush() - - from app.core.config import settings - - vectors_written = 0 - # Embedding: 若有 provider 则写入 - if self._embedding and chunk_records: - texts = [c for _, c in chunk_records] - embeddings = await self._embedding.embed_texts(texts) - for (chunk_id, _), emb in zip(chunk_records, embeddings): - if emb: - vectors_written += 1 - await update_chunk_embedding(self._db, chunk_id, emb) - - await self._db.commit() - emb_ok = self._embedding.is_available() if self._embedding else False - enrichment_task_id: str | None = None - try: - from app.tasks.memory_enrichment_tasks import schedule_memory_enrichment - - enrichment_task_id = schedule_memory_enrichment( - user_id, source.id, memoir_correlation_id=None - ) - except Exception as e: - logger.warning( - "memory enrichment 派发跳过: {} exc_type={}", e, type(e).__name__ - ) - - logger.info( - "event=memory_ingest_done user_id={} conversation_id={} source_id={} " - "chunks={} vectors_written={} embedding_available={} enrichment_enabled={} enrichment_task_id={}", + service = MemoryIngestService(self._db, embedding_provider=self._embedding) + return await service.ingest_transcript( user_id, conversation_id, - source.id, - len(chunk_records), - vectors_written, - emb_ok, - settings.memory_enrichment_enabled, - enrichment_task_id, + transcript, + lineage_json=lineage_json, ) - return source.id async def retrieve( self, user_id: str, query: str, *, top_k: int = 10 ) -> EvidenceBundle: """Retrieve relevant evidence. 委托 HybridRetriever。""" - from app.features.memory.retriever import HybridRetriever - - retriever = HybridRetriever(self._db, embedding_provider=self._embedding) - raw = await retriever.retrieve(user_id=user_id, query=query, top_k=top_k) - bundle = EvidenceBundle.model_validate(raw) - bd = bundle.model_dump() - vec_ok = self._embedding.is_available() if self._embedding else False - logger.info( - "event=memory_retrieve_done user_id={} query_len={} top_k={} " - "chunks={} facts={} summaries={} timeline={} stories={} vector_ok={}", - user_id, - len((query or "").strip()), - top_k, - len(bd.get("relevant_chunks") or []), - len(bd.get("relevant_facts") or []), - len(bd.get("relevant_summaries") or []), - len(bd.get("timeline_hints") or []), - len(bd.get("relevant_stories") or []), - vec_ok, - ) - return bundle + service = MemoryRetrievalService(self._db, embedding_provider=self._embedding) + return await service.retrieve(user_id, query, top_k=top_k) async def exclude_chunk( self, user_id: str, chunk_id: str, *, reason: str = "" @@ -292,7 +207,9 @@ def ingest_transcript_sync( if chunk_records and embedding_provider is not None: texts = [content for _, content in chunk_records] embeddings = embedding_provider.embed_texts_sync(texts) - for (chunk_id, _), emb in zip(chunk_records, embeddings): + for (chunk_id, _), emb in zip( + chunk_records, embeddings, strict=False + ): if emb: vectors_written += 1 update_chunk_embedding_sync(session, chunk_id, emb) @@ -405,7 +322,9 @@ def ingest_transcripts_batch_sync( if all_chunk_records and embedding_provider is not None: texts = [content for _, content in all_chunk_records] embeddings = embedding_provider.embed_texts_sync(texts) - for (chunk_id, _), emb in zip(all_chunk_records, embeddings): + for (chunk_id, _), emb in zip( + all_chunk_records, embeddings, strict=False + ): if emb: vectors_written += 1 update_chunk_embedding_sync(session, chunk_id, emb) @@ -438,10 +357,8 @@ def schedule_enrichment_for_sources( memoir_correlation_id: str | None = None, ) -> None: """After successful ingest commit, enqueue LLM enrichment for each source (memory_idle queue).""" - from app.tasks.memory_enrichment_tasks import schedule_memory_enrichment - - for sid in source_ids: - if sid: - schedule_memory_enrichment( - user_id, sid, memoir_correlation_id=memoir_correlation_id - ) + MemoryEnrichmentScheduler().schedule_many( + user_id, + source_ids, + memoir_correlation_id=memoir_correlation_id, + ) diff --git a/api/app/features/memory/timeline.py b/api/app/features/memory/timeline.py index 42bee8e..177ad4d 100644 --- a/api/app/features/memory/timeline.py +++ b/api/app/features/memory/timeline.py @@ -6,6 +6,7 @@ import json from typing import Any from app.core.langchain_llm import ainvoke_json_object, invoke_json_object +from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger from app.features.memory.llm_schemas import ( TimelineEventsPayload, @@ -70,7 +71,7 @@ async def build_timeline_events_from_facts_async( async def build_timeline_events(facts: list[dict]) -> list[dict]: """兼容旧接口。""" - from app.core.dependencies import get_llm_provider_fast - - llm = get_llm_provider_fast().langchain_llm + llm = LlmGateway().langchain_llm_for( + LlmUseCase("memory.timeline_events.compat", fast=True) + ) return await build_timeline_events_from_facts_async(llm, facts) diff --git a/api/app/ports/llm.py b/api/app/ports/llm.py index b4f0409..8e436cc 100644 --- a/api/app/ports/llm.py +++ b/api/app/ports/llm.py @@ -12,8 +12,13 @@ class LLMProvider(Protocol): *, temperature: float = 0.7, model: str | None = None, + max_tokens: int | None = None, ) -> str: - """Single-turn completion, returns full response text.""" + """Single-turn completion, returns full response text. + + ``max_tokens`` when set is passed to the underlying chat API (adapter-specific). + """ + ... def stream( @@ -22,6 +27,8 @@ class LLMProvider(Protocol): *, temperature: float = 0.7, model: str | None = None, + max_tokens: int | None = None, ) -> AsyncIterator[str]: """Streaming completion, yields text chunks (async generator).""" + ... diff --git a/api/app/tasks/memoir_quality_pass_tasks.py b/api/app/tasks/memoir_quality_pass_tasks.py index 555174d..25955e8 100644 --- a/api/app/tasks/memoir_quality_pass_tasks.py +++ b/api/app/tasks/memoir_quality_pass_tasks.py @@ -18,7 +18,7 @@ from sqlalchemy.orm import Session from app.agents.memoir.narrative_agent import NarrativeAgent from app.core.config import settings from app.core.db import get_sync_db -from app.core.dependencies import get_llm_provider +from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger from app.core.memoir_pipeline_progress import merge_pipeline_run from app.features.memoir.models import Chapter @@ -30,7 +30,7 @@ logger = get_logger(__name__) def _get_llm(): try: - return getattr(get_llm_provider(), "langchain_llm", None) + return LlmGateway().langchain_llm_for(LlmUseCase("memoir_quality_pass")) except Exception: return None diff --git a/api/app/tasks/memoir_tasks.py b/api/app/tasks/memoir_tasks.py index a7eedc6..295e951 100644 --- a/api/app/tasks/memoir_tasks.py +++ b/api/app/tasks/memoir_tasks.py @@ -27,7 +27,7 @@ from app.core.chapter_pipeline_lock import ( ) from app.core.config import settings from app.core.db import get_sync_db -from app.core.dependencies import get_llm_provider, get_llm_provider_fast +from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger from app.core.memoir_pipeline_progress import ( init_pipeline_run_from_phase1, @@ -129,7 +129,7 @@ def _run_post_pipeline_commit( def _get_llm(): """Celery 任务内获取 LangChain LLM(通过 port)""" try: - return getattr(get_llm_provider(), "langchain_llm", None) + return LlmGateway().langchain_llm_for(LlmUseCase("memoir_tasks")) except Exception: return None @@ -137,7 +137,9 @@ def _get_llm(): def _get_llm_fast(): """分类 / 抽取等快档位任务(与叙事、路由默认模型可分离)。""" try: - return getattr(get_llm_provider_fast(), "langchain_llm", None) + return LlmGateway().langchain_llm_for( + LlmUseCase("memoir_tasks.fast", fast=True) + ) except Exception: return None diff --git a/api/app/tasks/story_title_tasks.py b/api/app/tasks/story_title_tasks.py index ec1166d..cad63b0 100644 --- a/api/app/tasks/story_title_tasks.py +++ b/api/app/tasks/story_title_tasks.py @@ -5,7 +5,7 @@ import time from celery import shared_task from app.core.db import get_sync_db -from app.core.dependencies import get_llm_provider +from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger logger = get_logger(__name__) @@ -62,7 +62,7 @@ def generate_story_title_after_create( ) return {"status": "skip_user_modified"} - llm = getattr(get_llm_provider(), "langchain_llm", None) + llm = LlmGateway().langchain_llm_for(LlmUseCase("story_title")) if not llm: ms = (time.perf_counter() - t0) * 1000 logger.info( diff --git a/api/docs/ai-touchpoints.md b/api/docs/ai-touchpoints.md new file mode 100644 index 0000000..a3002e6 --- /dev/null +++ b/api/docs/ai-touchpoints.md @@ -0,0 +1,79 @@ +# AI touchpoints (generated) + +Regenerate: `uv run python api/scripts/ai_touchpoints_scan.py --markdown api/docs/ai-touchpoints.md` + +| File | Tags | +|------|------| +| `api/app/adapters/asr/tencent_asr.py` | `ports_ai` | +| `api/app/adapters/asr/whisper_local.py` | `ports_ai` | +| `api/app/adapters/embedding/zhipu.py` | `embedding` | +| `api/app/adapters/llm/deepseek.py` | `json_llm_helpers`, `langchain`, `ports_ai` | +| `api/app/adapters/llm/deepseek_eval_judge.py` | `langchain` | +| `api/app/adapters/llm/zhipu_eval_judge.py` | `langchain` | +| `api/app/agents/__init__.py` | `agents_layer` | +| `api/app/agents/chat/__init__.py` | `agents_layer` | +| `api/app/agents/chat/helpers.py` | `langchain` | +| `api/app/agents/chat/interview_agent.py` | `agents_layer`, `langchain`, `llm_provider` | +| `api/app/agents/chat/interview_state_hints.py` | `agents_layer`, `langchain` | +| `api/app/agents/chat/interview_turn_plan.py` | `agents_layer` | +| `api/app/agents/chat/occupation_context.py` | `agents_layer` | +| `api/app/agents/chat/orchestrator.py` | `agents_layer`, `embedding`, `langchain`, `llm_provider`, `memory_ai` | +| `api/app/agents/chat/profile_agent.py` | `agents_layer`, `json_llm_helpers`, `langchain`, `llm_call_module`, `llm_provider` | +| `api/app/agents/chat/prompt_context.py` | `agents_layer` | +| `api/app/agents/chat/prompt_layers.py` | `agents_layer` | +| `api/app/agents/chat/prompts.py` | `agents_layer` | +| `api/app/agents/chat/prompts_conversation.py` | `agents_layer` | +| `api/app/agents/chat/prompts_profile.py` | `agents_layer` | +| `api/app/agents/chat/reply_planner.py` | `agents_layer`, `json_llm_helpers`, `langchain` | +| `api/app/agents/chat/slot_question_bank.py` | `agents_layer` | +| `api/app/agents/chat/stage_detection.py` | `agents_layer`, `json_llm_helpers`, `llm_call_module` | +| `api/app/agents/chat/stage_prompts.py` | `agents_layer` | +| `api/app/agents/image_prompt/__init__.py` | `agents_layer` | +| `api/app/agents/image_prompt/orchestrator.py` | `agents_layer`, `langchain`, `llm_provider` | +| `api/app/agents/memoir/__init__.py` | `agents_layer` | +| `api/app/agents/memoir/batch_phase1_prep.py` | `agents_layer`, `json_llm_helpers`, `llm_call_module` | +| `api/app/agents/memoir/classification_agent.py` | `agents_layer`, `json_llm_helpers`, `llm_call_module` | +| `api/app/agents/memoir/extraction_agent.py` | `agents_layer`, `json_llm_helpers`, `llm_call_module` | +| `api/app/agents/memoir/fidelity_check_agent.py` | `agents_layer`, `json_llm_helpers`, `llm_call_module` | +| `api/app/agents/memoir/narrative_agent.py` | `agents_layer`, `json_llm_helpers`, `langchain`, `llm_call_module` | +| `api/app/agents/memoir/orchestrator.py` | `agents_layer` | +| `api/app/agents/memoir/prompts.py` | `agents_layer`, `json_llm_helpers` | +| `api/app/agents/memoir/story_route_agent.py` | `agents_layer`, `json_llm_helpers`, `llm_call_module` | +| `api/app/agents/state_schema.py` | `agents_layer` | +| `api/app/core/config.py` | `json_llm_helpers`, `memory_ai` | +| `api/app/core/dependencies.py` | `embedding`, `llm_provider`, `ports_ai` | +| `api/app/core/langchain_llm.py` | `json_llm_helpers`, `langchain`, `llm_provider` | +| `api/app/core/llm_call.py` | `json_llm_helpers`, `langchain` | +| `api/app/core/text_normalize.py` | `json_llm_helpers`, `langchain` | +| `api/app/features/conversation/ws/pipeline.py` | `agents_layer`, `ports_ai` | +| `api/app/features/conversation/ws/profile_collector.py` | `agents_layer` | +| `api/app/features/conversation/ws/router.py` | `agents_layer` | +| `api/app/features/evaluation/judge_service.py` | `json_llm_helpers`, `llm_call_module` | +| `api/app/features/memoir/_interview_meta_store.py` | `agents_layer` | +| `api/app/features/memoir/deps.py` | `memory_ai` | +| `api/app/features/memoir/memoir_images/prompting.py` | `agents_layer`, `json_llm_helpers`, `langchain` | +| `api/app/features/memoir/service.py` | `memory_ai` | +| `api/app/features/memoir/state_service.py` | `agents_layer` | +| `api/app/features/memoir/story_pipeline_sync.py` | `agents_layer`, `embedding` | +| `api/app/features/memory/curation.py` | `memory_ai` | +| `api/app/features/memory/deps.py` | `embedding`, `memory_ai` | +| `api/app/features/memory/enrichment.py` | `json_llm_helpers`, `langchain`, `llm_provider`, `memory_ai` | +| `api/app/features/memory/evidence.py` | `embedding`, `memory_ai`, `ports_ai` | +| `api/app/features/memory/evidence_format.py` | `memory_ai` | +| `api/app/features/memory/extractor.py` | `json_llm_helpers`, `langchain`, `llm_provider` | +| `api/app/features/memory/llm_schemas.py` | `json_llm_helpers` | +| `api/app/features/memory/repo.py` | `embedding`, `memory_ai`, `ports_ai` | +| `api/app/features/memory/retriever.py` | `embedding`, `memory_ai`, `ports_ai` | +| `api/app/features/memory/router.py` | `memory_ai` | +| `api/app/features/memory/schemas.py` | `memory_ai` | +| `api/app/features/memory/service.py` | `embedding`, `memory_ai`, `ports_ai` | +| `api/app/features/memory/summarizer.py` | `json_llm_helpers`, `langchain` | +| `api/app/features/memory/timeline.py` | `json_llm_helpers`, `langchain`, `llm_provider` | +| `api/app/ports/embedding.py` | `embedding` | +| `api/app/ports/llm.py` | `ports_ai` | +| `api/app/tasks/chapter_cover_tasks.py` | `agents_layer` | +| `api/app/tasks/memoir_quality_pass_tasks.py` | `agents_layer`, `langchain`, `llm_provider` | +| `api/app/tasks/memoir_tasks.py` | `agents_layer`, `langchain`, `llm_provider` | +| `api/app/tasks/memory_enrichment_tasks.py` | `memory_ai` | +| `api/app/tasks/story_image_tasks.py` | `agents_layer` | +| `api/app/tasks/story_title_tasks.py` | `agents_layer`, `langchain`, `llm_provider` | diff --git a/api/docs/memory-retrieval.md b/api/docs/memory-retrieval.md index 99bba9e..a6cf338 100644 --- a/api/docs/memory-retrieval.md +++ b/api/docs/memory-retrieval.md @@ -14,6 +14,20 @@ - 未配置 `ZHIPU_API_KEY`(或 provider `_client` 为空)时,chunk 检索为空列表,仍会返回 facts/timeline/summaries/stories(按 query ILIKE)。 - 日志:`HybridRetriever` / `retrieve_evidence_bundle_sync` 在无 provider 或空向量时会打 warning。 +## 行为矩阵(async / sync 契约) + +以下行为应对齐;变更 `evidence.py` 时须同时检视 `HybridRetriever` + `retrieve_evidence_bundle_sync`,并跑 `tests/test_memory_evidence.py` 中的双路径用例。 + +| 条件 | 同步 `retrieve_evidence_bundle_sync` | 异步 `retrieve_evidence_bundle_async` | +|------|--------------------------------------|---------------------------------------| +| query 空白 | `memory_evidence_empty_query_include_rolling=false` → 与 `EMPTY_EVIDENCE_BUNDLE` 同键、全空列表 | 同上 | +| query 空白 | `memory_evidence_empty_query_include_rolling=true` → 无 chunks;rolling 摘要(若有)+ 最近 facts / timeline;`relevant_stories` 为空 | 同上(`_empty_query_bundle_*` 对称实现) | +| query 非空 | 本函数内 `embedding_provider.embed_text_sync` → `search_chunks_vector_sync`;再并行拉取元数据 | chunks 由调用方预计算(`HybridRetriever` 中 `search_chunks_vector`);本函数只 `fetch_evidence_metadata_async` 合并 | +| 无 embedding | warning;chunks 为空;元数据仍按 ILIKE 等返回 | async 路径若上游无向量则 `merged_chunk_dicts=[]`;元数据仍返回 | +| 输出形状 | `{"relevant_chunks", "relevant_summaries", "relevant_facts", "timeline_hints", "relevant_stories"}` chunk 项为 `id, content, chunk_index`(不含 distance) | 非空 query 下 `relevant_chunks` 等于传入的 `merged_chunk_dicts`(已由检索层剥掉 distance) | + +Facts 状态过滤(如 `confirmed` / 排除 `stale`)与 ILIKE fallback 由 `repo` 查询实现;两条路径共用同一套 sync/async repo 函数族,语义以 `evidence.py` 调用为准。 + ## 空 query - 默认:`relevant_*` 均为空(与历史行为一致)。 diff --git a/api/scripts/ai_touchpoints_scan.py b/api/scripts/ai_touchpoints_scan.py new file mode 100644 index 0000000..5203bb2 --- /dev/null +++ b/api/scripts/ai_touchpoints_scan.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +"""Scan api/app for AI-related symbols (LLM, embedding, langchain JSON helpers). + +Run from repo root: + uv run python api/scripts/ai_touchpoints_scan.py + uv run python api/scripts/ai_touchpoints_scan.py --markdown api/docs/ai-touchpoints.md + +Default prints sorted unique file paths to stdout. +""" + +from __future__ import annotations + +import argparse +import re +import sys +from pathlib import Path + +# Line must match at least one pattern to count as a touchpoint for that file. +PATTERNS: list[tuple[str, re.Pattern[str]]] = [ + ("llm_provider", re.compile(r"get_llm_provider(_fast)?\b")), + ("embedding", re.compile(r"get_embedding_provider\b|EmbeddingProvider\b")), + ("json_llm_helpers", re.compile(r"invoke_json_object|ainvoke_json_object|allm_json_call|llm_json_call\b")), + ("llm_call_module", re.compile(r"from app\.core\.llm_call|import app\.core\.llm_call")), + ("langchain", re.compile(r"\blangchain_|from langchain|import langchain")), + ("ports_ai", re.compile(r"from app\.ports\.(llm|embedding|asr|tts)\b|LLMProvider\b")), + ("agents_layer", re.compile(r"from app\.agents\.|import app\.agents\.")), + ("memory_ai", re.compile(r"MemoryService\b|HybridRetriever\b|retrieve_evidence_bundle_|schedule_memory_enrichment\b")), +] + + +def app_root_from_script() -> Path: + return Path(__file__).resolve().parent.parent + + +def iter_python_files(root: Path) -> list[Path]: + base = root / "app" + if not base.is_dir(): + raise SystemExit(f"expected package dir: {base}") + return sorted(p for p in base.rglob("*.py") if p.is_file()) + + +def scan_files(files: list[Path]) -> dict[Path, list[str]]: + hits: dict[Path, list[str]] = {} + for path in files: + try: + text = path.read_text(encoding="utf-8") + except OSError: + continue + tags: list[str] = [] + for tag, rx in PATTERNS: + if rx.search(text): + tags.append(tag) + if tags: + hits[path] = sorted(set(tags)) + return hits + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--root", + type=Path, + default=None, + help="api/ directory (default: parent of this script)", + ) + parser.add_argument( + "--markdown", + type=Path, + default=None, + help="if set, write markdown report to this path", + ) + args = parser.parse_args() + root = args.root or app_root_from_script() + files = iter_python_files(root) + hits = scan_files(files) + + lines = [f"{len(hits)} files with AI touchpoints under {root / 'app'}"] + for path in sorted(hits.keys()): + rel = path.relative_to(root.parent) if path.is_relative_to(root.parent) else path + lines.append(f"{rel}\t{','.join(hits[path])}") + + report = "\n".join(lines) + "\n" + sys.stdout.write(report) + + if args.markdown: + md_lines = [ + "# AI touchpoints (generated)", + "", + "Regenerate: `uv run python api/scripts/ai_touchpoints_scan.py --markdown api/docs/ai-touchpoints.md`", + "", + "| File | Tags |", + "|------|------|", + ] + for path in sorted(hits.keys()): + rel = path.relative_to(root.parent) + tags = ", ".join(f"`{t}`" for t in hits[path]) + md_lines.append(f"| `{rel}` | {tags} |") + md_lines.append("") + args.markdown.parent.mkdir(parents=True, exist_ok=True) + args.markdown.write_text("\n".join(md_lines), encoding="utf-8") + print(f"Wrote {args.markdown}", file=sys.stderr) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/api/tests/test_chat_turn_service.py b/api/tests/test_chat_turn_service.py new file mode 100644 index 0000000..8ff5bcb --- /dev/null +++ b/api/tests/test_chat_turn_service.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from app.agents.chat.agent_turn import AgentChatTurn +from app.features.conversation.chat_turn import ( + ChatTurnContext, + ChatTurnInput, + ChatTurnService, +) + + +class _FakeOrchestrator: + def __init__(self) -> None: + self.calls: list[dict] = [] + + async def process_user_message(self, **kwargs): + self.calls.append(kwargs) + return AgentChatTurn( + messages=["第一泡", "第二泡"], + skip_tts=False, + memory_retrieval_trace={"chunks": 1}, + interview_state_meta={"recent_questions": ["你当时在哪里?"]}, + ) + + +@pytest.mark.asyncio +async def test_chat_turn_service_exposes_one_turn_contract() -> None: + orchestrator = _FakeOrchestrator() + service = ChatTurnService(orchestrator=orchestrator) + ts = datetime(2026, 4, 29, tzinfo=timezone.utc) + + result = await service.process_turn( + ChatTurnInput( + conversation_id="conv-1", + user_message="我小时候住在河边。", + is_from_voice=True, + voice_session_id="voice-1", + user_message_timestamp=ts, + audio_duration_seconds=12, + force_skip_tts=True, + ), + ChatTurnContext( + db=object(), + user=object(), + conversation=object(), + apply_extracted_profile_fn=lambda *args, **kwargs: None, + get_missing_profile_fields_fn=lambda user: [], + get_filled_profile_fields_fn=lambda user: {}, + ), + ) + + assert result.messages == ["第一泡", "第二泡"] + assert result.skip_tts is True + assert result.memory_retrieval_trace == {"chunks": 1} + assert result.interview_state_meta == { + "recent_questions": ["你当时在哪里?"] + } + assert result.decision.force_skip_tts is True + + assert len(orchestrator.calls) == 1 + call = orchestrator.calls[0] + assert call["conversation_id"] == "conv-1" + assert call["user_message"] == "我小时候住在河边。" + assert call["is_from_voice"] is True + assert call["voice_session_id"] == "voice-1" + assert call["user_message_timestamp"] is ts + assert call["audio_duration_seconds"] == 12 diff --git a/api/tests/test_interview_prompts.py b/api/tests/test_interview_prompts.py index 221f163..e5e9101 100644 --- a/api/tests/test_interview_prompts.py +++ b/api/tests/test_interview_prompts.py @@ -2,6 +2,7 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from app.agents.chat.helpers import format_history_string from app.agents.chat.interview_state_hints import ( AUTOBIOGRAPHICAL_BOUNDARY_FALLBACK_ZH, DUPLICATE_QUESTION_GUARD_FALLBACK_ZH, @@ -10,20 +11,19 @@ from app.agents.chat.interview_state_hints import ( extract_scene_cues, segments_are_only_duplicate_guard_fallback, ) +from app.agents.chat.output_rules import chat_output_rules +from app.agents.chat.personas import normalize_interview_persona +from app.agents.chat.prompts_conversation import ( + get_guided_conversation_prompt, + get_opening_prompt, +) +from app.agents.chat.slot_question_bank import SLOT_QUESTION_OUTLINES from app.agents.state_schema import ( KnownFact, MemoirStateSchema, PersonaThread, default_slots, ) -from app.agents.chat.helpers import format_history_string -from app.agents.chat.personas import normalize_interview_persona -from app.agents.chat.output_rules import chat_output_rules -from app.agents.chat.prompts_conversation import ( - get_guided_conversation_prompt, - get_opening_prompt, -) -from app.agents.chat.slot_question_bank import SLOT_QUESTION_OUTLINES def test_guided_prompt_does_not_embed_raw_user_message_in_system_text(): @@ -132,6 +132,21 @@ def test_guided_prompt_host_tone_and_context_forward(): assert "行为" in p and "影响" in p +def test_guided_prompt_leaves_turn_level_question_contract_to_turn_plan() -> None: + p = get_guided_conversation_prompt( + current_stage="career", + empty_slots=["job"], + filled_slots={}, + detected_user_stage="career", + user_profile_context="", + persona="default", + ) + assert "随后**必须**用**一条**" not in p + assert "短承接后须带回一条" not in p + assert "仍须**勾回回忆叙事**" not in p + assert "具体问几问、是否必须追问,见顶部" in p + + def test_education_and_family_change_outlines_differ(): edu = SLOT_QUESTION_OUTLINES[("education", "change")] fam = SLOT_QUESTION_OUTLINES[("family", "change")] diff --git a/api/tests/test_llm_gateway.py b/api/tests/test_llm_gateway.py new file mode 100644 index 0000000..b47c5aa --- /dev/null +++ b/api/tests/test_llm_gateway.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import pytest + +from app.core.llm_gateway import LlmGateway, LlmUseCase + + +class _FakeProvider: + def __init__(self, name: str) -> None: + self.name = name + self.langchain_llm = f"lc-{name}" + self.complete_calls: list[dict] = [] + + async def complete(self, messages, **kwargs) -> str: + self.complete_calls.append({"messages": messages, **kwargs}) + return f"ok-{self.name}" + + +def test_llm_gateway_selects_default_or_fast_provider( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from app.core import llm_gateway as gateway_mod + + default = _FakeProvider("default") + fast = _FakeProvider("fast") + monkeypatch.setattr(gateway_mod, "get_llm_provider", lambda: default) + monkeypatch.setattr(gateway_mod, "get_llm_provider_fast", lambda: fast) + + gateway = LlmGateway() + + assert gateway.langchain_llm_for() == "lc-default" + assert gateway.langchain_llm_for(LlmUseCase("memory", fast=True)) == "lc-fast" + + +@pytest.mark.asyncio +async def test_llm_gateway_chat_text_applies_use_case_defaults( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from app.core import llm_gateway as gateway_mod + + provider = _FakeProvider("default") + monkeypatch.setattr(gateway_mod, "get_llm_provider", lambda: provider) + + text = await LlmGateway().chat_text( + [{"role": "user", "content": "hi"}], + use_case=LlmUseCase( + "chat", + max_tokens=99, + temperature=0.2, + model="model-a", + ), + ) + + assert text == "ok-default" + assert provider.complete_calls == [ + { + "messages": [{"role": "user", "content": "hi"}], + "temperature": 0.2, + "model": "model-a", + "max_tokens": 99, + } + ] diff --git a/api/tests/test_memoir_ingest_scheduler.py b/api/tests/test_memoir_ingest_scheduler.py new file mode 100644 index 0000000..4b08bab --- /dev/null +++ b/api/tests/test_memoir_ingest_scheduler.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import pytest + +from app.features.memoir.ingest_scheduler import MemoirIngestScheduler + + +class _FakeRunner: + def __init__(self) -> None: + self.queued: list[tuple[str, str, int]] = [] + self.flushed: list[tuple[str, list[str]]] = [] + + async def queue_message( + self, + user_id: str, + segment_id: str, + *, + text_char_count: int = 0, + ) -> None: + self.queued.append((user_id, segment_id, text_char_count)) + + async def flush_pending( + self, + user_id: str, + *, + extra_segment_ids: list[str] | None = None, + ) -> str: + self.flushed.append((user_id, list(extra_segment_ids or []))) + return "task-1" + + +@pytest.mark.asyncio +async def test_queue_segment_returns_visible_phase_plan() -> None: + runner = _FakeRunner() + scheduler = MemoirIngestScheduler(runner=runner) + + plan = await scheduler.queue_segment( + "user-1", + "seg-1", + text_char_count=42, + trigger="evaluation_replay", + ) + + assert runner.queued == [("user-1", "seg-1", 42)] + assert plan.user_id == "user-1" + assert plan.segment_ids == ("seg-1",) + assert plan.trigger == "evaluation_replay" + + +@pytest.mark.asyncio +async def test_flush_pending_returns_plan_and_task_id() -> None: + runner = _FakeRunner() + scheduler = MemoirIngestScheduler(runner=runner) + + plan, task_id = await scheduler.flush_pending( + "user-1", + extra_segment_ids=["seg-1", "seg-2"], + trigger="conversation_end", + ) + + assert runner.flushed == [("user-1", ["seg-1", "seg-2"])] + assert task_id == "task-1" + assert plan.user_id == "user-1" + assert plan.segment_ids == ("seg-1", "seg-2") + assert plan.trigger == "conversation_end" diff --git a/api/tests/test_memory_boundaries.py b/api/tests/test_memory_boundaries.py new file mode 100644 index 0000000..fb21fe7 --- /dev/null +++ b/api/tests/test_memory_boundaries.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from app.features.memory.prompt_adapter import MemoryPromptAdapter +from app.features.memory.runtime_types import MemoryEvidenceBundle + + +def test_memory_evidence_bundle_and_prompt_adapter_contract() -> None: + evidence = MemoryEvidenceBundle.from_mapping( + { + "relevant_chunks": [ + {"id": "c1", "content": "我小时候在河边长大,夏天常去玩水。"}, + ], + "relevant_summaries": [], + "relevant_facts": [], + "timeline_hints": [], + "relevant_stories": [], + } + ) + + slices = MemoryPromptAdapter().slice_for_interview( + evidence, + "那条河一到夏天就特别热闹,我现在都记得。", + ) + + assert evidence.has_any is True + assert slices.had_retrieval is True + assert "用户曾说" in slices.prompt_excerpt + assert slices.anchor_source.startswith("用户曾说") + + +@pytest.mark.asyncio +async def test_memory_retrieval_service_delegates_to_retriever( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from app.features.memory import retrieval_service as retrieval_mod + from app.features.memory.retrieval_service import MemoryRetrievalService + + calls: list[dict] = [] + + class FakeRetriever: + def __init__(self, db, *, embedding_provider=None) -> None: + calls.append({"db": db, "embedding_provider": embedding_provider}) + + async def retrieve(self, *, user_id: str, query: str, top_k: int) -> dict: + calls.append({"user_id": user_id, "query": query, "top_k": top_k}) + return { + "relevant_chunks": [{"id": "c1", "content": "chunk"}], + "relevant_summaries": [], + "relevant_facts": [], + "timeline_hints": [], + "relevant_stories": [], + } + + class FakeEmbedding: + def is_available(self) -> bool: + return True + + db = object() + embedding = FakeEmbedding() + monkeypatch.setattr(retrieval_mod, "HybridRetriever", FakeRetriever) + + bundle = await MemoryRetrievalService( + db, + embedding_provider=embedding, + ).retrieve("user-1", "hello", top_k=3) + + assert calls == [ + {"db": db, "embedding_provider": embedding}, + {"user_id": "user-1", "query": "hello", "top_k": 3}, + ] + assert bundle.relevant_chunks == [{"id": "c1", "content": "chunk"}] + + +@pytest.mark.asyncio +async def test_memory_ingest_service_commits_before_enrichment( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from app.features.memory import ingest_service as ingest_mod + from app.features.memory.ingest_service import MemoryIngestService + + events: list[tuple] = [] + + @dataclass + class FakeRow: + id: str + + class FakeDb: + async def flush(self) -> None: + events.append(("flush",)) + + async def commit(self) -> None: + events.append(("commit",)) + + class FakeEmbedding: + async def embed_texts(self, texts: list[str]) -> list[list[float]]: + events.append(("embed_texts", tuple(texts))) + return [[1.0], [2.0]] + + def is_available(self) -> bool: + return True + + class FakeScheduler: + def schedule(self, request) -> str: + events.append(("schedule", request.user_id, request.source_id)) + return "enrich-1" + + async def fake_create_source(db, **kwargs): + events.append(("create_source", kwargs["user_id"], kwargs["conversation_id"])) + return FakeRow("source-1") + + async def fake_create_chunk(db, **kwargs): + events.append(("create_chunk", kwargs["chunk_index"], kwargs["content"])) + return FakeRow(f"chunk-{kwargs['chunk_index']}") + + async def fake_update_chunk_embedding(db, chunk_id, emb): + events.append(("update_embedding", chunk_id, tuple(emb))) + + monkeypatch.setattr(ingest_mod, "chunk_transcript", lambda text: ["a", "b"]) + monkeypatch.setattr(ingest_mod, "create_source", fake_create_source) + monkeypatch.setattr(ingest_mod, "create_chunk", fake_create_chunk) + monkeypatch.setattr( + ingest_mod, + "update_chunk_embedding", + fake_update_chunk_embedding, + ) + + source_id = await MemoryIngestService( + FakeDb(), + embedding_provider=FakeEmbedding(), + enrichment_scheduler=FakeScheduler(), + ).ingest_transcript("user-1", "conv-1", "hello") + + assert source_id == "source-1" + assert events.index(("commit",)) < events.index( + ("schedule", "user-1", "source-1") + ) + assert ("embed_texts", ("a", "b")) in events + assert ("update_embedding", "chunk-0", (1.0,)) in events + assert ("update_embedding", "chunk-1", (2.0,)) in events diff --git a/api/tests/test_memory_evidence.py b/api/tests/test_memory_evidence.py index d86947e..045601c 100644 --- a/api/tests/test_memory_evidence.py +++ b/api/tests/test_memory_evidence.py @@ -2,6 +2,7 @@ import pytest +from app.core.config import settings from app.features.memory import evidence as evidence_mod from app.features.memory.evidence_format import format_evidence_chunks_for_chat_prompt from app.features.memory.evidence import ( @@ -9,6 +10,7 @@ from app.features.memory.evidence import ( _facts_to_dicts, _stories_to_dicts, _timeline_to_dicts, + retrieve_evidence_bundle_async, retrieve_evidence_bundle_sync, ) from app.features.memory.schemas import EvidenceBundle @@ -190,3 +192,69 @@ def test_slice_interview_memory_suppresses_long_new_topic(): s = slice_interview_memory(evidence, long_msg) assert s.prompt_excerpt == "" assert s.anchor_source == "" + + +async def test_retrieve_evidence_bundle_async_non_empty_merges_precomputed_chunks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """非空 query:异步路径以 merged_chunk_dicts 为主,元数据来自 fetch_evidence_metadata_async。""" + meta = { + "relevant_facts": [ + { + "id": "f1", + "fact_type": "bio", + "subject": "s", + "predicate": "p", + "object_json": {}, + } + ], + "timeline_hints": [], + "relevant_summaries": [ + { + "id": "s1", + "summary_type": "session", + "content": "sum", + "source_chunk_ids": [], + } + ], + "relevant_stories": [], + } + + async def fake_fetch_meta(db, user_id, q, top_k): + assert user_id == "u1" + assert q == "hello" + assert top_k == 7 + return meta + + monkeypatch.setattr(evidence_mod, "fetch_evidence_metadata_async", fake_fetch_meta) + merged = [{"id": "c1", "content": "chunk body", "chunk_index": 0}] + out = await retrieve_evidence_bundle_async( + object(), + "u1", + " hello ", + top_k=7, + merged_chunk_dicts=merged, + ) + assert out == {"relevant_chunks": merged, **meta} + + +async def test_empty_query_evidence_bundle_async_and_sync_aligned_when_rolling_off( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(settings, "memory_evidence_empty_query_include_rolling", False) + out_a = await retrieve_evidence_bundle_async( + object(), + "u1", + " ", + top_k=10, + merged_chunk_dicts=[], + ) + assert out_a == dict(EMPTY_EVIDENCE_BUNDLE) + out_s = retrieve_evidence_bundle_sync( + session=object(), + user_id="u1", + query="", + top_k=10, + embedding_provider=None, + ) + assert out_s == dict(EMPTY_EVIDENCE_BUNDLE)