修复一些已知问题

This commit is contained in:
Kevin
2026-03-20 17:25:42 +08:00
parent 8af37e5e8e
commit 70070216c4
16 changed files with 350 additions and 74 deletions

View File

@@ -0,0 +1,14 @@
"""一轮 AI 对话输出:分段文案 + 是否整轮跳过 TTS如失败兜底"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List
@dataclass(frozen=True)
class AgentChatTurn:
"""与 WebSocket pipeline 对齐messages 为气泡分段skip_tts 为 True 时不合成语音。"""
messages: List[str]
skip_tts: bool = False

View File

@@ -6,6 +6,7 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.agents.chat.agent_turn import AgentChatTurn
from app.agents.chat.orchestrator import ChatOrchestrator
from app.agents.chat.prompts_conversation import ConversationStage
from app.agents.state_schema import MemoirStateSchema
@@ -77,7 +78,7 @@ class ConversationAgent:
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
audio_duration_seconds: int | None = None,
) -> List[str]:
) -> AgentChatTurn:
"""委托 ChatOrchestrator/InterviewAgent 生成访谈回复"""
return await self._orchestrator.generate_response_with_state(
conversation_id=conversation_id,
@@ -116,13 +117,13 @@ class ConversationAgent:
state = default_state()
state.current_stage = (current_stage or ConversationStage.CHILDHOOD).value
state.covered_stages = covered_topics or []
responses = await self._orchestrator.generate_response_with_state(
turn = await self._orchestrator.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state,
user_profile_context="",
)
return responses[0] if responses else ""
return turn.messages[0] if turn.messages else ""
def detect_stage(
self, conversation_id: str, user_message: str

View File

@@ -5,6 +5,7 @@ InterviewAgent正式访谈 Specialist
from typing import Any, List
from app.agents.chat.agent_turn import AgentChatTurn
from app.core.dependencies import get_llm_provider
from app.core.logging import get_logger
@@ -18,6 +19,9 @@ from app.agents.state_schema import MemoirStateSchema
logger = get_logger(__name__)
# LLM 不可用或调用失败时对用户展示(不暴露异常细节、不触发 TTS
_FALLBACK_REPLY = "刚才网络不太稳,没接上。你可以再说一遍,或稍后再试。"
def _get_langchain_llm():
try:
@@ -149,12 +153,11 @@ class InterviewAgent:
user_message: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
) -> AgentChatTurn:
"""生成状态感知的访谈回复,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return [
"抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
]
logger.warning("InterviewAgent: LLM 未配置,返回兜底文案")
return AgentChatTurn(messages=[_FALLBACK_REPLY], skip_tts=True)
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
filled_slots = {
@@ -191,10 +194,11 @@ class InterviewAgent:
messages = [
msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()
]
return messages[:3] if messages else [response_text]
out = messages[:3] if messages else [response_text]
return AgentChatTurn(messages=out, skip_tts=False)
except Exception as e:
logger.error("生成回应失败: %s", e)
return [f"抱歉,生成回应时出现错误: {str(e)}"]
logger.error("生成回应失败: %s", e, exc_info=True)
return AgentChatTurn(messages=[_FALLBACK_REPLY], skip_tts=True)
async def generate_opening_message(
self,

View File

@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat.agent_turn import AgentChatTurn
from app.agents.chat.helpers import save_message
from app.agents.chat.interview_agent import InterviewAgent
from app.agents.chat.profile_agent import ProfileAgent
@@ -20,6 +21,10 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
_UNAUTH_TURN = AgentChatTurn(
messages=["暂时没法继续对话,请先登录后再试。"], skip_tts=True
)
class ChatOrchestrator:
"""
@@ -45,9 +50,9 @@ class ChatOrchestrator:
get_filled_profile_fields_fn,
user_message_timestamp: Optional[datetime] = None,
audio_duration_seconds: Optional[int] = None,
) -> List[str]:
) -> AgentChatTurn:
"""
处理用户消息,返回 AI 回复列表
处理用户消息,返回 AI 回复(分段 + 是否跳过 TTS
根据 missing_fields 路由到 ProfileAgent 或 InterviewAgent
统一写入 Redis。
"""
@@ -81,14 +86,14 @@ class ChatOrchestrator:
user_message_timestamp=user_message_timestamp,
audio_duration_seconds=audio_duration_seconds,
)
return responses
return AgentChatTurn(messages=responses, skip_tts=False)
except Exception as e:
logger.error(f"资料收集处理失败: {e}", exc_info=True)
# --- 正式访谈模式 ---
user_id = user.id if user else None
if not user_id:
return ["抱歉,无法识别用户。"]
return _UNAUTH_TURN
state = await get_or_create_state(user_id, db)
if conversation and conversation.conversation_stage != state.current_stage:
@@ -106,7 +111,7 @@ class ChatOrchestrator:
occupation=user.occupation,
)
responses = await self.interview_agent.generate_response_with_state(
turn = await self.interview_agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state,
@@ -115,13 +120,13 @@ class ChatOrchestrator:
await self._save_messages(
conversation_id=conversation_id,
user_message=user_message,
response_text="\n\n".join(responses),
response_text="\n\n".join(turn.messages),
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
audio_duration_seconds=audio_duration_seconds,
)
return responses
return turn
async def _save_messages(
self,
@@ -222,15 +227,15 @@ class ChatOrchestrator:
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
audio_duration_seconds: int | None = None,
) -> List[str]:
) -> AgentChatTurn:
"""委托 InterviewAgent 生成访谈回复,并写入 Redis"""
responses = await self.interview_agent.generate_response_with_state(
turn = await self.interview_agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
)
response_text = "\n\n".join(responses)
response_text = "\n\n".join(turn.messages)
await self._save_messages(
conversation_id=conversation_id,
user_message=user_message,
@@ -240,7 +245,7 @@ class ChatOrchestrator:
user_message_timestamp=user_message_timestamp,
audio_duration_seconds=audio_duration_seconds,
)
return responses
return turn
def detect_user_stage(self, user_message: str) -> str:
"""委托 InterviewAgent 检测用户阶段"""

View File

@@ -539,7 +539,7 @@ async def process_user_message(
is_from_voice = bool(segment.audio_url)
voice_session_id = _voice_session_id_from_audio_url(segment.audio_url)
audio_dur = getattr(segment, "audio_duration_seconds", None)
responses = await chat_orchestrator.process_user_message(
turn = await chat_orchestrator.process_user_message(
conversation_id=conversation_id,
user_message=user_message,
user=user,
@@ -553,6 +553,8 @@ async def process_user_message(
user_message_timestamp=user_message_timestamp,
audio_duration_seconds=audio_dur,
)
responses = turn.messages
skip_tts = turn.skip_tts
segment.agent_response = "\n\n".join(responses)
_mark_conversation_active(conversation)
@@ -574,12 +576,14 @@ async def process_user_message(
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
url = await _send_tts_audio(
conversation_id,
response_text,
chunk_index=i,
chunk_total=n,
)
url = None
if not skip_tts:
url = await _send_tts_audio(
conversation_id,
response_text,
chunk_index=i,
chunk_total=n,
)
if url:
tts_urls.append(url)
if i < n - 1:

View File

@@ -13,6 +13,7 @@ _PLACEHOLDER_RE = re.compile(
)
_ASSET_REF_RE = re.compile(r"!\[([^\]]*)\]\(asset://([a-zA-Z0-9_-]+)\)")
_BLANK_RUN_RE = re.compile(r"\n{3,}")
def strip_legacy_image_placeholders(text: str | None) -> str:
@@ -33,6 +34,19 @@ def collect_asset_ids_from_markdown(markdown: str) -> list[str]:
return [m.group(2) for m in _ASSET_REF_RE.finditer(markdown or "") if m.group(2)]
def strip_asset_image_refs_from_markdown(markdown: str | None) -> str:
"""Remove all `![...](asset://...)` references; collapse blank lines.
Used for story single-primary policy: new versions / backfill must not
accumulate multiple inline asset images.
"""
if not markdown or not str(markdown).strip():
return ""
text = _ASSET_REF_RE.sub("", markdown or "")
text = _BLANK_RUN_RE.sub("\n\n", text)
return text.strip()
def collect_asset_ids_for_chapter(chapter) -> set[str]:
"""章节正文 canonical、收录的各 story 正文、cover_asset_id 中的 asset id。"""
ids: set[str] = set()

View File

@@ -3,6 +3,9 @@ Story 图片回填 — 将 asset:// 引用追加到 markdown 末尾。
图片生成成功后,在正文最后插入 ![alt](asset://asset_id)。
alt 使用原始 prompt 短文prompt_brief而非模板拼接后的完整出图 prompt。
单主图策略Celery 任务在调用本函数前会先 strip 正文中已有 asset:// 插图,
避免与旧版本快照叠加多条引用。
"""

View File

@@ -11,6 +11,7 @@ from datetime import datetime, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logging import get_logger
from app.features.memoir.asset_resolver import strip_asset_image_refs_from_markdown
from app.features.memoir import repo as memoir_repo
from app.features.story.image_intent_extractor import extract_primary_image_intent
from app.features.story.repo import (
@@ -105,6 +106,7 @@ class StoryService:
canonical_markdown: str | None = None,
) -> str:
"""Create story, commit, return story_id."""
md = strip_asset_image_refs_from_markdown(canonical_markdown or "")
story = await create_story(
self._db,
user_id=user_id,
@@ -112,15 +114,15 @@ class StoryService:
stage=stage,
story_type=story_type,
summary=summary,
canonical_markdown=canonical_markdown or "",
canonical_markdown=md,
)
await self._db.flush()
if canonical_markdown:
if md.strip():
version = await create_story_version(
self._db,
story_id=story.id,
version_no=1,
markdown_snapshot=canonical_markdown,
markdown_snapshot=md,
actor_type="ai",
source_type="generate",
)
@@ -130,12 +132,12 @@ class StoryService:
self._db,
story=story,
version=version,
markdown=canonical_markdown,
markdown=md,
)
if canonical_markdown:
if md.strip():
await memoir_repo.mark_chapters_dirty_for_story(self._db, story.id)
await self._db.commit()
if canonical_markdown:
if md.strip():
from app.tasks.chapter_compose_tasks import recompose_chapters_for_story
from app.tasks.story_image_tasks import generate_story_image
@@ -163,13 +165,14 @@ class StoryService:
story = await get_story_by_id(self._db, story_id)
if not story:
raise ValueError(f"Story {story_id} not found")
md = strip_asset_image_refs_from_markdown(markdown_snapshot or "")
parent_id = story.current_version_id
version_no = (await count_story_versions(self._db, story_id)) + 1
version = await create_story_version(
self._db,
story_id=story_id,
version_no=version_no,
markdown_snapshot=markdown_snapshot,
markdown_snapshot=md,
actor_type=actor_type,
source_type=source_type,
parent_version_id=parent_id,
@@ -177,12 +180,12 @@ class StoryService:
)
version.change_summary = change_summary
story.current_version_id = version.id
story.canonical_markdown = markdown_snapshot
story.canonical_markdown = md
await _extract_and_store_image_intent(
self._db,
story=story,
version=version,
markdown=markdown_snapshot,
markdown=md,
)
await memoir_repo.mark_chapters_dirty_for_story(self._db, story_id)
await self._db.commit()

View File

@@ -14,6 +14,7 @@ from sqlalchemy.orm import Session, joinedload
from app.core.db import utc_now
from app.core.logging import get_logger
from app.features.memoir.asset_resolver import strip_asset_image_refs_from_markdown
from app.features.memoir.models import ChapterStoryLink
from app.features.memoir import repo as memoir_repo
from app.features.story.image_intent_extractor import extract_primary_image_intent
@@ -114,12 +115,13 @@ def create_story_with_version_sync(
canonical_markdown: str,
stage: str | None = None,
) -> Story:
md = strip_asset_image_refs_from_markdown(canonical_markdown or "")
story = Story(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
stage=stage,
canonical_markdown=canonical_markdown or "",
canonical_markdown=md,
)
session.add(story)
session.flush()
@@ -128,16 +130,16 @@ def create_story_with_version_sync(
id=vid,
story_id=story.id,
version_no=1,
markdown_snapshot=canonical_markdown or "",
markdown_snapshot=md,
actor_type="ai",
source_type="generate",
)
session.add(version)
session.flush()
story.current_version_id = vid
if (canonical_markdown or "").strip():
if md.strip():
_extract_and_store_image_intent_sync(
session, story=story, version=version, markdown=canonical_markdown
session, story=story, version=version, markdown=md
)
memoir_repo.mark_chapters_dirty_for_story_sync(session, story.id)
return story
@@ -154,6 +156,7 @@ def append_story_version_sync(
story = session.get(Story, story_id)
if not story:
raise ValueError(f"Story {story_id} not found")
md = strip_asset_image_refs_from_markdown(markdown_snapshot or "")
parent_id = story.current_version_id
version_no = count_story_versions_sync(session, story_id) + 1
vid = str(uuid.uuid4())
@@ -161,7 +164,7 @@ def append_story_version_sync(
id=vid,
story_id=story_id,
version_no=version_no,
markdown_snapshot=markdown_snapshot,
markdown_snapshot=md,
actor_type=actor_type,
source_type=source_type,
parent_version_id=parent_id,
@@ -169,9 +172,9 @@ def append_story_version_sync(
session.add(version)
session.flush()
story.current_version_id = vid
story.canonical_markdown = markdown_snapshot
story.canonical_markdown = md
_extract_and_store_image_intent_sync(
session, story=story, version=version, markdown=markdown_snapshot
session, story=story, version=version, markdown=md
)
memoir_repo.mark_chapters_dirty_for_story_sync(session, story_id)
return version

View File

@@ -18,6 +18,7 @@ from app.core.dependencies import get_image_generator
from app.core.logging import get_logger
from app.core.redis_lock import acquire_redis_lock, release_redis_lock
from app.features.asset.models import Asset
from app.features.memoir.asset_resolver import strip_asset_image_refs_from_markdown
from app.features.memoir.memoir_images.storage import TencentCosStorageService
from app.features.story.backfill import backfill_image_into_markdown
from app.features.story.models import Story, StoryImageIntent, StoryVersion
@@ -262,7 +263,7 @@ def generate_story_image(self, story_id: str):
db.commit()
return {"status": "success_no_snapshot", "asset_id": asset_id}
base_md = ver.markdown_snapshot or ""
base_md = strip_asset_image_refs_from_markdown(ver.markdown_snapshot or "")
alt_text = (getattr(intent_db, "prompt_brief", None) or "").strip()
if not alt_text:
alt_text = (getattr(intent_db, "caption", None) or "").strip()

View File

@@ -8,6 +8,7 @@ from app.features.memoir.asset_resolver import (
collect_asset_ids_from_markdown,
resolve_asset_refs_in_markdown,
split_markdown_by_asset_refs,
strip_asset_image_refs_from_markdown,
strip_legacy_image_placeholders,
)
from app.features.memoir.models import Chapter
@@ -53,6 +54,22 @@ class AssetResolverTest(unittest.TestCase):
ids = collect_asset_ids_for_chapter(ch)
self.assertEqual(ids, {"a1", "cov1"})
def test_strip_asset_image_refs_removes_all_and_collapses_blank_lines(self):
md = (
"第一段\n\n![a](asset://old-id-1)\n\n第二段\n\n\n"
"![b](asset://old-id-2)\n\n第三段"
)
out = strip_asset_image_refs_from_markdown(md)
self.assertNotIn("asset://", out)
self.assertIn("第一段", out)
self.assertIn("第二段", out)
self.assertIn("第三段", out)
self.assertNotIn("\n\n\n", out)
def test_strip_asset_image_refs_empty(self):
self.assertEqual(strip_asset_image_refs_from_markdown(""), "")
self.assertEqual(strip_asset_image_refs_from_markdown(" "), "")
def test_collect_asset_ids_includes_linked_story_markdown(self):
ch = SimpleNamespace(
canonical_markdown="",

View File

@@ -134,6 +134,103 @@ class GenerateStoryImageTaskTest(unittest.TestCase):
acquire_lock_mock.assert_called_once()
release_lock_mock.assert_called_once()
@patch("app.tasks.story_image_tasks.release_redis_lock")
@patch(
"app.tasks.story_image_tasks.acquire_redis_lock",
return_value=SimpleNamespace(key="lock:story-image:story-1"),
)
@patch("app.tasks.story_image_tasks._claim_story_image_intent_sync")
@patch("app.tasks.story_image_tasks.get_sync_db")
@patch("app.tasks.story_image_tasks.TencentCosStorageService")
@patch("app.tasks.story_image_tasks.get_image_generator")
@patch("app.features.memoir.memoir_images.settings.MemoirImageSettings.from_env")
@patch("app.tasks.story_image_tasks.uuid.uuid4")
def test_generate_story_image_strips_existing_asset_refs_before_backfill(
self,
uuid4_mock,
settings_from_env,
get_image_generator_mock,
storage_cls,
get_sync_db_mock,
claim_intent_mock,
acquire_lock_mock,
release_lock_mock,
):
uuid4_mock.side_effect = [
_FakeUUID("claim-token"),
_FakeUUID("new-asset-uuid"),
_FakeUUID("version-uuid"),
]
settings_from_env.return_value = SimpleNamespace(
provider="liblib",
default_style="watercolor",
default_size="1024x1024",
)
intent = SimpleNamespace(
id="intent-1",
prompt_brief="院子里的藤椅",
style_profile="watercolor",
story_version_id="ver-1",
caption="主插图",
status="processing",
)
story = SimpleNamespace(
id="story-1",
user_id="user-1",
title="童年的院子",
stage="childhood",
)
db_claim = Mock()
claim_intent_mock.return_value = (intent, story)
intent_db = SimpleNamespace(
id="intent-1",
story_version_id="ver-1",
caption="主插图",
prompt_brief="院子里的藤椅",
status="processing",
style_profile="watercolor",
claim_token="claim-token",
asset_id=None,
error=None,
updated_at=None,
)
story_db = SimpleNamespace(
id="story-1",
current_version_id="ver-1",
canonical_markdown="第一段\n\n第二段",
)
version_db = SimpleNamespace(
id="ver-1",
markdown_snapshot=("第一段\n\n![旧图](asset://old-stale-id)\n\n第二段"),
)
version_max_result = Mock()
version_max_result.scalar.return_value = 1
db_persist = Mock()
db_persist.get.side_effect = [intent_db, story_db, version_db]
db_persist.execute.return_value = version_max_result
get_sync_db_mock.side_effect = [_mock_db_cm(db_claim), _mock_db_cm(db_persist)]
generator = get_image_generator_mock.return_value
generator.generate.return_value = ImageResult(
status=TaskStatus.COMPLETED,
task_id="task-1",
image_url="https://provider.example.com/story.png",
)
generator.download_image.return_value = _png_bytes()
storage_cls.from_env.return_value.upload_bytes.return_value = (
"https://cos.example.com/stories/u1/s1.png"
)
result = generate_story_image.run("story-1")
self.assertEqual(result["status"], "success")
self.assertEqual(story_db.canonical_markdown.count("asset://"), 1)
self.assertIn("asset://new-asset-uuid", story_db.canonical_markdown)
self.assertNotIn("old-stale-id", story_db.canonical_markdown)
@patch("app.tasks.story_image_tasks.acquire_redis_lock", return_value=None)
@patch("app.tasks.story_image_tasks.get_sync_db")
@patch("app.tasks.story_image_tasks.get_image_generator")