From c6e07ce5ca317c2b5c778d103a62f6adc863206e Mon Sep 17 00:00:00 2001 From: Kevin Date: Thu, 19 Mar 2026 14:36:14 +0800 Subject: [PATCH] =?UTF-8?q?chore/=20=E5=88=A0=E9=99=A4=E6=97=A0=E7=94=A8?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + .husky/pre-commit | 5 +- api/__init__.py | 4 - .../versions/0002_create_initial_schema.py | 1 + api/app/adapters/asr/whisper_local.py | 10 +- api/app/adapters/image_gen/liblib.py | 6 +- api/app/adapters/image_gen/liblib_provider.py | 59 +- api/app/adapters/llm/deepseek.py | 8 +- api/app/adapters/sms/tencent.py | 6 +- api/app/adapters/storage/tencent_cos.py | 4 +- api/app/adapters/tts/tencent_tts.py | 4 +- api/app/agents/__init__.py | 1 + api/app/agents/chat/__init__.py | 1 + api/app/agents/chat/conversation_agent.py | 5 +- api/app/agents/chat/helpers.py | 5 +- api/app/agents/chat/interview_agent.py | 105 +- api/app/agents/chat/orchestrator.py | 1 + api/app/agents/chat/profile_agent.py | 29 +- api/app/agents/chat/prompts.py | 1 + api/app/agents/chat/prompts_conversation.py | 66 +- api/app/agents/chat/prompts_profile.py | 13 +- api/app/agents/image_prompt/__init__.py | 1 + api/app/agents/image_prompt/orchestrator.py | 1 + api/app/agents/image_prompt/prompt_agent.py | 1 + api/app/agents/memoir/__init__.py | 1 + api/app/agents/memoir/classification_agent.py | 5 +- api/app/agents/memoir/extraction_agent.py | 9 +- api/app/agents/memoir/memory_agent.py | 9 +- api/app/agents/memoir/narrative_agent.py | 1 + api/app/agents/memoir/orchestrator.py | 5 +- api/app/agents/memoir/placeholder_agent.py | 1 + api/app/agents/memoir/processor.py | 5 +- api/app/agents/memoir/prompts.py | 48 +- api/app/agents/state_schema.py | 7 +- api/app/core/db.py | 1 + api/app/core/logging.py | 4 +- api/app/core/pagination.py | 1 + api/app/core/redis.py | 15 +- api/app/core/task_tracker.py | 5 +- api/app/features/auth/repo.py | 4 +- api/app/features/auth/schemas.py | 16 +- api/app/features/auth/service.py | 16 +- api/app/features/content/router.py | 1 + api/app/features/conversation/repo.py | 16 +- api/app/features/conversation/router.py | 1 + api/app/features/conversation/schemas.py | 1 - api/app/features/conversation/service.py | 45 +- .../conversation/ws/connection_manager.py | 1 + .../features/conversation/ws/message_types.py | 2 + api/app/features/conversation/ws/pipeline.py | 228 ++-- .../conversation/ws/profile_collector.py | 6 +- .../features/conversation/ws/quota_guard.py | 1 + api/app/features/conversation/ws/router.py | 416 ++++-- api/app/features/memoir/helpers.py | 21 +- .../features/memoir/memoir_images/parser.py | 4 +- .../memoir/memoir_images/prompting.py | 20 +- .../features/memoir/memoir_images/provider.py | 1 + .../features/memoir/memoir_images/schema.py | 10 +- .../memoir/memoir_images/serializers.py | 1 + .../features/memoir/memoir_images/settings.py | 8 +- .../features/memoir/memoir_images/storage.py | 10 +- api/app/features/memoir/pdf_service.py | 1 + api/app/features/memoir/repo.py | 7 +- api/app/features/memoir/router.py | 1 + api/app/features/memoir/service.py | 47 +- api/app/features/memoir/state_service.py | 22 +- api/app/features/memory/chunker.py | 4 +- api/app/features/memory/models.py | 20 +- api/app/features/memory/retriever.py | 4 +- api/app/features/memory/schemas.py | 2 +- api/app/features/memory/service.py | 5 +- api/app/features/payment/alipay_client.py | 8 +- api/app/features/payment/order_service.py | 67 +- api/app/features/payment/payment_config.py | 1 + api/app/features/payment/payment_facade.py | 5 +- api/app/features/payment/repo.py | 4 +- api/app/features/payment/router.py | 8 +- api/app/features/payment/schemas.py | 1 + api/app/features/payment/service.py | 1 + api/app/features/payment/wechat_client.py | 37 +- api/app/features/plan/router.py | 1 + api/app/features/plan/service.py | 37 +- api/app/features/quota/router.py | 1 + api/app/features/quota/service.py | 6 +- api/app/features/tasks/deps.py | 1 + api/app/features/tasks/router.py | 3 + api/app/features/tasks/service.py | 1 + api/app/features/user/models.py | 5 +- api/app/features/user/router.py | 4 +- api/app/features/user/schemas.py | 2 + api/app/main.py | 11 +- api/app/tasks/__init__.py | 1 + api/app/tasks/celery_app.py | 1 + api/app/tasks/memoir_tasks.py | 292 +++-- api/database/models.py | 255 ---- api/main.py | 1 + api/migrations_legacy/README.md | 9 - .../add_chapter_is_active.sql | 8 - .../add_chapter_sections.sql | 42 - .../add_device_info_column.sql | 19 - .../add_memoir_images_table.sql | 32 - api/migrations_legacy/add_orders_table.sql | 26 - .../add_section_image_id_fk.sql | 18 - .../add_sms_verification.sql | 49 - .../add_user_profile_fields.sql | 7 - .../add_users_subscription_columns.sql | 26 - .../fix_chapter_order_index.sql | 15 - .../fix_chapter_order_index_v2.sql | 19 - .../fix_memoir_images_order_index.sql | 9 - .../sync_schema_to_models.sql | 141 -- api/routers/chapters.py | 305 ----- api/routers/conversations.py | 327 ----- api/routers/websocket.py | 1148 ----------------- api/scripts/__init__.py | 0 api/scripts/migrate_chapters_to_sections.py | 98 -- api/scripts/reprocess_user_memoir.py | 686 ---------- api/scripts/run_chapter_sections_migration.py | 155 --- api/scripts/run_memoir_images_migration.py | 196 --- api/tests/conftest.py | 1 + api/tests/test_chapters_router_images.py | 53 +- api/tests/test_conversation.py | 195 +-- .../test_conversation_messages_history.py | 13 +- ...est_generate_chapter_images_persistence.py | 8 +- .../test_generate_chapter_images_task.py | 163 ++- api/tests/test_memoir_image_bootstrap.py | 72 +- api/tests/test_memoir_image_parser.py | 24 +- api/tests/test_memoir_image_prompting.py | 5 +- api/tests/test_memoir_image_provider.py | 24 +- api/tests/test_memoir_image_settings.py | 4 +- api/tests/test_memoir_image_storage.py | 16 +- api/tests/test_memory_prompts_inject.py | 1 + api/tests/test_pdf_service_images.py | 12 +- ...t_process_memoir_segments_image_enqueue.py | 21 +- api/tests/test_sms_verification.py | 224 ++-- api/tests/test_websocket_baseline.py | 299 ++++- 135 files changed, 2111 insertions(+), 4510 deletions(-) delete mode 100644 api/__init__.py delete mode 100644 api/database/models.py delete mode 100644 api/migrations_legacy/README.md delete mode 100644 api/migrations_legacy/add_chapter_is_active.sql delete mode 100644 api/migrations_legacy/add_chapter_sections.sql delete mode 100644 api/migrations_legacy/add_device_info_column.sql delete mode 100644 api/migrations_legacy/add_memoir_images_table.sql delete mode 100644 api/migrations_legacy/add_orders_table.sql delete mode 100644 api/migrations_legacy/add_section_image_id_fk.sql delete mode 100644 api/migrations_legacy/add_sms_verification.sql delete mode 100644 api/migrations_legacy/add_user_profile_fields.sql delete mode 100644 api/migrations_legacy/add_users_subscription_columns.sql delete mode 100644 api/migrations_legacy/fix_chapter_order_index.sql delete mode 100644 api/migrations_legacy/fix_chapter_order_index_v2.sql delete mode 100644 api/migrations_legacy/fix_memoir_images_order_index.sql delete mode 100644 api/migrations_legacy/sync_schema_to_models.sql delete mode 100644 api/routers/chapters.py delete mode 100644 api/routers/conversations.py delete mode 100644 api/routers/websocket.py delete mode 100644 api/scripts/__init__.py delete mode 100644 api/scripts/migrate_chapters_to_sections.py delete mode 100644 api/scripts/reprocess_user_memoir.py delete mode 100644 api/scripts/run_chapter_sections_migration.py delete mode 100644 api/scripts/run_memoir_images_migration.py diff --git a/.gitignore b/.gitignore index f923fe7..396cf56 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ app-expo/ # Python __pycache__/ +.pytest_cache/ *.py[cod] *$py.class *.so diff --git a/.husky/pre-commit b/.husky/pre-commit index 3932801..19afd46 100755 --- a/.husky/pre-commit +++ b/.husky/pre-commit @@ -1 +1,4 @@ -npm run format +npm run format && git add -u app-expo/ + +# Format Python files in api/ with ruff +cd api && uv run ruff format . && cd .. && git add -u api/ diff --git a/api/__init__.py b/api/__init__.py deleted file mode 100644 index b0250c3..0000000 --- a/api/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -""" -Life Echo API -""" - diff --git a/api/alembic/versions/0002_create_initial_schema.py b/api/alembic/versions/0002_create_initial_schema.py index a1f1902..58350fd 100644 --- a/api/alembic/versions/0002_create_initial_schema.py +++ b/api/alembic/versions/0002_create_initial_schema.py @@ -8,6 +8,7 @@ Revises: 0001_baseline Create Date: 2026-03-18 """ + from typing import Sequence, Union from alembic import op diff --git a/api/app/adapters/asr/whisper_local.py b/api/app/adapters/asr/whisper_local.py index 3f5397d..680ae4b 100644 --- a/api/app/adapters/asr/whisper_local.py +++ b/api/app/adapters/asr/whisper_local.py @@ -7,7 +7,14 @@ import tempfile logger = get_logger(__name__) _DEFAULT_CACHE_DIR = os.path.normpath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "models", "whisper") + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "..", + "..", + "models", + "whisper", + ) ) @@ -36,6 +43,7 @@ class WhisperASRProvider: if device == "auto": try: import torch # type: ignore[import-untyped] + device = "cuda" if torch.cuda.is_available() else "cpu" except ImportError: device = "cpu" diff --git a/api/app/adapters/image_gen/liblib.py b/api/app/adapters/image_gen/liblib.py index 774ef27..0973de1 100644 --- a/api/app/adapters/image_gen/liblib.py +++ b/api/app/adapters/image_gen/liblib.py @@ -48,7 +48,11 @@ class LiblibImageGenerator: try: job = {"job_id": task_id} result = self._provider.poll_until_complete(job, self._poll_interval, 1) - status = TaskStatus.COMPLETED if result.get("status") == "completed" else TaskStatus.PROCESSING + status = ( + TaskStatus.COMPLETED + if result.get("status") == "completed" + else TaskStatus.PROCESSING + ) return ImageResult( status=status, task_id=task_id, diff --git a/api/app/adapters/image_gen/liblib_provider.py b/api/app/adapters/image_gen/liblib_provider.py index 8a8dcfa..b5a2690 100644 --- a/api/app/adapters/image_gen/liblib_provider.py +++ b/api/app/adapters/image_gen/liblib_provider.py @@ -2,6 +2,7 @@ Liblib 图生 SDK 封装,位于 adapters 层;实现细节不暴露给 feature。 Feature 通过 port ImageGenerator 使用,本模块仅被 app.adapters.image_gen.liblib 使用。 """ + import base64 import hmac import logging @@ -61,8 +62,12 @@ class LiblibImageProvider: self.http_client = http_client or httpx.Client(timeout=120) self.access_key = access_key or (settings.liblib_access_key or "") self.secret_key = secret_key or (settings.liblib_secret_key or "") - self.base_url = (base_url or settings.liblib_base_url or "https://openapi.liblibai.cloud").rstrip("/") - self.template_uuid = template_uuid or (settings.liblib_template_uuid or DEFAULT_LIBLIB_TEMPLATE_UUID) + self.base_url = ( + base_url or settings.liblib_base_url or "https://openapi.liblibai.cloud" + ).rstrip("/") + self.template_uuid = template_uuid or ( + settings.liblib_template_uuid or DEFAULT_LIBLIB_TEMPLATE_UUID + ) self.allowed_download_hosts = _build_allowed_download_hosts( self.base_url, allowed_download_hosts=allowed_download_hosts, @@ -133,7 +138,9 @@ class LiblibImageProvider: response.raise_for_status() data = response.json() if data.get("code") != 0: - raise RuntimeError(f"Liblib status query failed: {data.get('msg', data)}") + raise RuntimeError( + f"Liblib status query failed: {data.get('msg', data)}" + ) result = data.get("data", {}) status = result.get("generateStatus") if status == 5: @@ -144,17 +151,28 @@ class LiblibImageProvider: "image_url": images[0]["imageUrl"], "job_id": job["job_id"], } - raise RuntimeError(f"Liblib returned success but no images for {job['job_id']}") + raise RuntimeError( + f"Liblib returned success but no images for {job['job_id']}" + ) if status == 6: - raise RuntimeError(f"Liblib generation failed: {result.get('generateMsg', 'unknown')}") + raise RuntimeError( + f"Liblib generation failed: {result.get('generateMsg', 'unknown')}" + ) if status == 7: - raise TimeoutError(f"Liblib returned undocumented status 7 for {job['job_id']}") + raise TimeoutError( + f"Liblib returned undocumented status 7 for {job['job_id']}" + ) logger.debug( "Liblib poll attempt %d/%d, status=%s, job=%s", - attempt + 1, max_attempts, status, job["job_id"], + attempt + 1, + max_attempts, + status, + job["job_id"], ) time.sleep(poll_interval_seconds) - raise TimeoutError(f"Liblib image generation timed out after {max_attempts} attempts for {job['job_id']}") + raise TimeoutError( + f"Liblib image generation timed out after {max_attempts} attempts for {job['job_id']}" + ) def download_image(self, job: dict) -> bytes: image_url = job["image_url"] @@ -183,9 +201,13 @@ class _LiblibAuthRedactionFilter(logging.Filter): record.msg = _redact_sensitive_query_values(record.msg) if record.args: if isinstance(record.args, dict): - record.args = {k: _redact_sensitive_query_values(v) for k, v in record.args.items()} + record.args = { + k: _redact_sensitive_query_values(v) for k, v in record.args.items() + } else: - record.args = tuple(_redact_sensitive_query_values(v) for v in record.args) + record.args = tuple( + _redact_sensitive_query_values(v) for v in record.args + ) return True @@ -196,7 +218,13 @@ def _redact_sensitive_query_values(value): def _install_http_log_redaction() -> None: - for logger_name in ("httpx", "httpcore", "httpcore.connection", "httpcore.http11", "httpcore.proxy"): + for logger_name in ( + "httpx", + "httpcore", + "httpcore.connection", + "httpcore.http11", + "httpcore.proxy", + ): target_logger = get_logger(logger_name) if getattr(target_logger, "_liblib_auth_redaction_installed", False): continue @@ -219,7 +247,10 @@ def _build_allowed_download_hosts( default_hosts: set[str] = set() if base_hostname: default_hosts.add(base_hostname) - if base_hostname.endswith(".liblibai.cloud") or base_hostname == "liblibai.cloud": + if ( + base_hostname.endswith(".liblibai.cloud") + or base_hostname == "liblibai.cloud" + ): default_hosts.add("liblibai.cloud") default_hosts.add("liblib.cloud") return tuple(sorted(default_hosts.union(configured_hosts))) @@ -230,7 +261,9 @@ def _validate_download_url(image_url: str, allowed_hosts: tuple[str, ...]) -> No hostname = (parsed.hostname or "").lower() if parsed.scheme != "https" or not hostname: raise ValueError(f"Unsupported image download URL: {image_url}") - if not any(_hostname_matches(hostname, allowed_host) for allowed_host in allowed_hosts): + if not any( + _hostname_matches(hostname, allowed_host) for allowed_host in allowed_hosts + ): raise ValueError(f"Image download host is not allowed: {hostname}") diff --git a/api/app/adapters/llm/deepseek.py b/api/app/adapters/llm/deepseek.py index 75d4658..99c133e 100644 --- a/api/app/adapters/llm/deepseek.py +++ b/api/app/adapters/llm/deepseek.py @@ -75,7 +75,13 @@ class DeepSeekLLMProvider: def _to_langchain_messages(messages: list[dict]) -> list: from langchain_core.messages import AIMessage, HumanMessage, SystemMessage - mapping = {"system": SystemMessage, "human": HumanMessage, "user": HumanMessage, "ai": AIMessage, "assistant": AIMessage} + mapping = { + "system": SystemMessage, + "human": HumanMessage, + "user": HumanMessage, + "ai": AIMessage, + "assistant": AIMessage, + } result = [] for msg in messages: cls = mapping.get(msg.get("role", ""), HumanMessage) diff --git a/api/app/adapters/sms/tencent.py b/api/app/adapters/sms/tencent.py index e1b147a..88dcd5a 100644 --- a/api/app/adapters/sms/tencent.py +++ b/api/app/adapters/sms/tencent.py @@ -63,7 +63,11 @@ class TencentSmsSender: error_code = status.Code if status else "UNKNOWN" if "TemplateParamSetNotMatchApprovedTemplate" in error_code: continue - logger.error("SMS send failed: %s - %s", error_code, status.Message if status else "") + logger.error( + "SMS send failed: %s - %s", + error_code, + status.Message if status else "", + ) return False except TencentCloudSDKException as e: diff --git a/api/app/adapters/storage/tencent_cos.py b/api/app/adapters/storage/tencent_cos.py index d937828..31c4595 100644 --- a/api/app/adapters/storage/tencent_cos.py +++ b/api/app/adapters/storage/tencent_cos.py @@ -19,7 +19,9 @@ class TencentCosStorage: token: str = "", ): self._bucket = bucket - self._base_url = (base_url or f"https://{bucket}.cos.{region}.myqcloud.com").rstrip("/") + self._base_url = ( + base_url or f"https://{bucket}.cos.{region}.myqcloud.com" + ).rstrip("/") config = CosConfig( Region=region, SecretId=secret_id, diff --git a/api/app/adapters/tts/tencent_tts.py b/api/app/adapters/tts/tencent_tts.py index 443c05a..0b0a6b5 100644 --- a/api/app/adapters/tts/tencent_tts.py +++ b/api/app/adapters/tts/tencent_tts.py @@ -134,9 +134,7 @@ class TencentTTSProvider: results: list[bytes] = [] for chunk in chunks: - audio = await asyncio.to_thread( - self._synthesize_sync, chunk, voice_type - ) + audio = await asyncio.to_thread(self._synthesize_sync, chunk, voice_type) if not audio: return b"" results.append(audio) diff --git a/api/app/agents/__init__.py b/api/app/agents/__init__.py index 839f422..37da606 100644 --- a/api/app/agents/__init__.py +++ b/api/app/agents/__init__.py @@ -1,6 +1,7 @@ """ Agent 模块(按功能拆分:chat / memoir / image_prompt) """ + from app.agents.chat import ( ChatOrchestrator, ConversationAgent, diff --git a/api/app/agents/chat/__init__.py b/api/app/agents/chat/__init__.py index 32305dd..4510af6 100644 --- a/api/app/agents/chat/__init__.py +++ b/api/app/agents/chat/__init__.py @@ -1,4 +1,5 @@ """聊天模块:AI 回复用户(ProfileAgent + InterviewAgent + ChatOrchestrator)""" + from app.agents.chat.conversation_agent import ConversationAgent from app.agents.chat.orchestrator import ChatOrchestrator from app.agents.chat.profile_agent import ProfileAgent diff --git a/api/app/agents/chat/conversation_agent.py b/api/app/agents/chat/conversation_agent.py index 0d0d121..dbf3d67 100644 --- a/api/app/agents/chat/conversation_agent.py +++ b/api/app/agents/chat/conversation_agent.py @@ -2,6 +2,7 @@ 对话 Agent:Facade,内部委托 ChatOrchestrator + ProfileAgent + InterviewAgent 保留原有对外 API,供 router 等调用方兼容使用 """ + from datetime import datetime from typing import Any, Dict, List, Optional @@ -119,7 +120,9 @@ class ConversationAgent: ) return responses[0] if responses else "" - def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage: + def detect_stage( + self, conversation_id: str, user_message: str + ) -> ConversationStage: """根据关键词检测用户阶段(兼容 API)""" detected = self._orchestrator.detect_user_stage(user_message) if detected == "childhood": diff --git a/api/app/agents/chat/helpers.py b/api/app/agents/chat/helpers.py index db2c666..ec6f6d9 100644 --- a/api/app/agents/chat/helpers.py +++ b/api/app/agents/chat/helpers.py @@ -1,4 +1,5 @@ """聊天 Agent 共享工具:历史获取、格式化、存储""" + from datetime import datetime from typing import Any, List @@ -45,5 +46,7 @@ async def save_message( content, message_type=message_type, voice_session_id=voice_session_id, - timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp, + timestamp=timestamp.isoformat() + if isinstance(timestamp, datetime) + else timestamp, ) diff --git a/api/app/agents/chat/interview_agent.py b/api/app/agents/chat/interview_agent.py index 8d51fee..7070f7a 100644 --- a/api/app/agents/chat/interview_agent.py +++ b/api/app/agents/chat/interview_agent.py @@ -2,6 +2,7 @@ InterviewAgent:正式访谈 Specialist 负责状态感知回复、开场白,不负责 Redis 持久化(由 Orchestrator 统一处理) """ + from typing import Any, List from app.core.dependencies import get_llm_provider @@ -36,11 +37,81 @@ class InterviewAgent: """根据关键词检测用户正在谈论的人生阶段""" message = user_message.lower() stage_keywords = { - "childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"], - "education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"], - "career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"], - "family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"], - "belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"], + "childhood": [ + "童年", + "小时候", + "出生", + "家乡", + "小镇", + "爸妈", + "父亲", + "母亲", + "爷爷", + "奶奶", + "外公", + "外婆", + "幼儿园", + ], + "education": [ + "上学", + "学校", + "老师", + "同学", + "教育", + "大学", + "高中", + "初中", + "小学", + "考试", + "毕业", + "读书", + "高考", + "课堂", + ], + "career": [ + "工作", + "职业", + "事业", + "公司", + "同事", + "创业", + "升职", + "跳槽", + "老板", + "行业", + "项目", + "加班", + "薪水", + "面试", + ], + "family": [ + "伴侣", + "孩子", + "家庭", + "家人", + "结婚", + "爱人", + "老婆", + "老公", + "丈夫", + "妻子", + "儿子", + "女儿", + "婚礼", + "恋爱", + ], + "belief": [ + "信念", + "价值观", + "座右铭", + "坚持", + "原则", + "信仰", + "意义", + "感悟", + "遗憾", + "骄傲", + ], } for stage, keywords in stage_keywords.items(): if any(word in message for word in keywords): @@ -81,12 +152,16 @@ class InterviewAgent: ) -> List[str]: """生成状态感知的访谈回复,不持久化(由 Orchestrator 负责)""" if not self.llm: - return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"] + return [ + "抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。" + ] try: empty_slots = memoir_state.empty_slots_for_current_stage() filled_slots = { key: value.snippet - for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items() + for key, value in memoir_state.slots.get( + memoir_state.current_stage, {} + ).items() if value.snippet } detected_user_stage = self._detect_user_stage(user_message) @@ -110,8 +185,12 @@ class InterviewAgent: history_string = format_history_string(history_messages) full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" response = await self.llm.ainvoke(full_prompt) - response_text = response.content if hasattr(response, "content") else str(response) - messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + response_text = ( + response.content if hasattr(response, "content") else str(response) + ) + messages = [ + msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip() + ] return messages[:3] if messages else [response_text] except Exception as e: logger.error("生成回应失败: %s", e) @@ -138,8 +217,12 @@ class InterviewAgent: ) full_prompt = f"{prompt}\n\nAssistant:" response = await self.llm.ainvoke(full_prompt) - response_text = response.content if hasattr(response, "content") else str(response) - messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + response_text = ( + response.content if hasattr(response, "content") else str(response) + ) + messages = [ + msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip() + ] return messages[:2] if messages else [response_text] except Exception as e: logger.error("生成开场白失败: %s", e, exc_info=True) diff --git a/api/app/agents/chat/orchestrator.py b/api/app/agents/chat/orchestrator.py index 20be068..d3b67e1 100644 --- a/api/app/agents/chat/orchestrator.py +++ b/api/app/agents/chat/orchestrator.py @@ -2,6 +2,7 @@ ChatOrchestrator:AI 回复用户模块的编排层 负责路由(Profile vs Interview)、调用 Specialist Agent、统一 Redis 持久化与错误处理 """ + from datetime import datetime from typing import TYPE_CHECKING, List, Optional diff --git a/api/app/agents/chat/profile_agent.py b/api/app/agents/chat/profile_agent.py index ca85575..2ab5046 100644 --- a/api/app/agents/chat/profile_agent.py +++ b/api/app/agents/chat/profile_agent.py @@ -2,6 +2,7 @@ ProfileAgent:用户资料收集 Specialist 负责提取资料、资料追问、资料收集开场白,不负责 Redis 持久化(由 Orchestrator 统一处理) """ + import json from typing import Any, Dict, List, Optional @@ -47,7 +48,9 @@ class ProfileAgent: recent_dialogue = "" if conversation_id: history_messages = await get_history_messages(conversation_id) - recent = history_messages[-4:] if len(history_messages) > 4 else history_messages + recent = ( + history_messages[-4:] if len(history_messages) > 4 else history_messages + ) parts = [] for msg in recent: if isinstance(msg, HumanMessage): @@ -105,10 +108,16 @@ class ProfileAgent: ) history_messages = await get_history_messages(conversation_id) history_string = format_history_string(history_messages) - full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" + full_prompt = ( + f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" + ) response = await self.llm.ainvoke(full_prompt) - response_text = response.content if hasattr(response, "content") else str(response) - messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + response_text = ( + response.content if hasattr(response, "content") else str(response) + ) + messages = [ + msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip() + ] return messages[:3] if messages else [response_text] except Exception as e: logger.error("生成资料跟进回复失败: %s", e) @@ -129,9 +138,15 @@ class ProfileAgent: history_string = format_history_string(history_messages) full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt response = await self.llm.ainvoke(full_prompt) - response_text = response.content if hasattr(response, "content") else str(response) - messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + response_text = ( + response.content if hasattr(response, "content") else str(response) + ) + messages = [ + msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip() + ] return messages[:2] if messages else [response_text] except Exception as e: logger.error("生成资料收集开场白失败: %s", e) - return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"] + return [ + "你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?" + ] diff --git a/api/app/agents/chat/prompts.py b/api/app/agents/chat/prompts.py index 1dc9467..55d2144 100644 --- a/api/app/agents/chat/prompts.py +++ b/api/app/agents/chat/prompts.py @@ -1,6 +1,7 @@ """ Chat 模块提示词:用户资料收集 + 对话访谈 """ + # Profile prompts(用户资料收集) from app.agents.chat.prompts_profile import ( PROFILE_FIELD_NAMES, diff --git a/api/app/agents/chat/prompts_conversation.py b/api/app/agents/chat/prompts_conversation.py index 7e6ff59..da37716 100644 --- a/api/app/agents/chat/prompts_conversation.py +++ b/api/app/agents/chat/prompts_conversation.py @@ -1,6 +1,7 @@ """ 对话 Agent 提示词模板和访谈问题库 """ + from enum import Enum from typing import List, Dict import random @@ -8,12 +9,13 @@ import random class ConversationStage(str, Enum): """对话阶段枚举""" - CHILDHOOD = "childhood" # 童年 - EDUCATION = "education" # 教育 - CAREER = "career" # 事业 - FAMILY = "family" # 家庭 - BELIEFS = "beliefs" # 信念 - SUMMARY = "summary" # 人生总结 + + CHILDHOOD = "childhood" # 童年 + EDUCATION = "education" # 教育 + CAREER = "career" # 事业 + FAMILY = "family" # 家庭 + BELIEFS = "beliefs" # 信念 + SUMMARY = "summary" # 人生总结 # 访谈问题库 @@ -62,7 +64,11 @@ INTERVIEW_QUESTIONS: Dict[ConversationStage, List[str]] = { } -def get_system_prompt(current_stage: ConversationStage, covered_topics: List[str], user_latest_response: str) -> str: +def get_system_prompt( + current_stage: ConversationStage, + covered_topics: List[str], + user_latest_response: str, +) -> str: """ 生成对话 Agent 的系统提示词 @@ -179,8 +185,14 @@ def get_opening_prompt( "belief": "人生信念", } stage_name = stage_name_map.get(current_stage, current_stage) - topics_str = "、".join(empty_slots_readable) if empty_slots_readable else "人生故事、童年、经历等" - profile_section = f"\n## 用户基本信息\n{user_profile_context}\n" if user_profile_context else "" + topics_str = ( + "、".join(empty_slots_readable) + if empty_slots_readable + else "人生故事、童年、经历等" + ) + profile_section = ( + f"\n## 用户基本信息\n{user_profile_context}\n" if user_profile_context else "" + ) return f"""你是「岁月知己」,用户的老朋友。用户刚通过「打个招呼」进入空对话,**还没有发任何消息**,需要你先开口。 {profile_section} ## 当前建议话题({stage_name}) @@ -287,16 +299,24 @@ def get_guided_conversation_prompt( } current_stage_name = stage_name_map.get(current_stage, current_stage) - user_stage_name = stage_name_map.get(detected_user_stage, "") if detected_user_stage else "" + user_stage_name = ( + stage_name_map.get(detected_user_stage, "") if detected_user_stage else "" + ) user_jumped = detected_user_stage and detected_user_stage != current_stage empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots] - empty_slots_str = "、".join(empty_slots_readable) if empty_slots_readable else "已聊得很充分" + empty_slots_str = ( + "、".join(empty_slots_readable) if empty_slots_readable else "已聊得很充分" + ) filled_info = [] for key, value in filled_slots.items(): readable_key = SLOT_NAME_MAP.get(key, key) - filled_info.append(f"{readable_key}: {value[:50]}..." if len(value) > 50 else f"{readable_key}: {value}") + filled_info.append( + f"{readable_key}: {value[:50]}..." + if len(value) > 50 + else f"{readable_key}: {value}" + ) filled_slots_str = "\n".join(filled_info) if filled_info else "刚开始聊" progress_lines = [] @@ -317,7 +337,9 @@ def get_guided_conversation_prompt( progress_str = "\n".join(progress_lines) if progress_lines else "" filled_count = len(filled_slots) - should_switch_topic = same_topic_turns >= 3 or (filled_count >= 2 and same_topic_turns >= 2) + should_switch_topic = same_topic_turns >= 3 or ( + filled_count >= 2 and same_topic_turns >= 2 + ) should_lighten_mood = conversation_turn > 0 and conversation_turn % 5 == 0 should_try_new_stage = filled_count >= 3 and len(empty_slots) <= 2 @@ -343,9 +365,13 @@ def get_guided_conversation_prompt( if should_lighten_mood: dynamic_guidance += "\n- 聊了一会儿了,可以适当轻松一下,聊点有趣的" if should_switch_topic and empty_slots_readable: - dynamic_guidance += f"\n- 这个话题聊得差不多了,可以自然转到:{empty_slots_str}" + dynamic_guidance += ( + f"\n- 这个话题聊得差不多了,可以自然转到:{empty_slots_str}" + ) if should_try_new_stage and related_stages: - dynamic_guidance += f"\n- 如果自然的话,可以尝试聊聊相关的话题,比如{related_stages_str}" + dynamic_guidance += ( + f"\n- 如果自然的话,可以尝试聊聊相关的话题,比如{related_stages_str}" + ) uncovered_hint = "" if not user_jumped and uncovered_stages and should_try_new_stage: @@ -360,7 +386,9 @@ def get_guided_conversation_prompt( if user_profile_context: profile_section = f"\n## 用户基本信息\n{user_profile_context}\n" - active_stage = detected_user_stage if user_jumped and detected_user_stage else current_stage + active_stage = ( + detected_user_stage if user_jumped and detected_user_stage else current_stage + ) era_context = _build_era_context(active_stage, user_profile_context) prompt = f"""你是「岁月知己」,用户的老朋友,正在和他/她聊人生故事。{topic_desc}。 @@ -418,6 +446,10 @@ def get_guided_conversation_prompt( return prompt -def get_conversation_prompt(current_stage: ConversationStage, covered_topics: List[str], user_latest_response: str) -> str: +def get_conversation_prompt( + current_stage: ConversationStage, + covered_topics: List[str], + user_latest_response: str, +) -> str: """向后兼容的函数""" return get_system_prompt(current_stage, covered_topics, user_latest_response) diff --git a/api/app/agents/chat/prompts_profile.py b/api/app/agents/chat/prompts_profile.py index 5939aac..cc6d757 100644 --- a/api/app/agents/chat/prompts_profile.py +++ b/api/app/agents/chat/prompts_profile.py @@ -1,6 +1,7 @@ """ 用户基础资料收集提示词 """ + from typing import Dict, List, Optional @@ -14,7 +15,9 @@ PROFILE_FIELD_NAMES = { def get_profile_greeting_prompt(missing_fields: List[str], nickname: str = "") -> str: """生成初次见面、收集基础资料的引导提示词""" - missing_names = [PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES] + missing_names = [ + PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES + ] missing_str = "、".join(missing_names) name_part = f",{nickname}" if nickname else "" @@ -54,7 +57,9 @@ def get_profile_extraction_prompt( recent_dialogue: Optional[str] = None, ) -> str: """从用户回答中提取基础资料信息(可包含最近几轮对话,避免漏提)""" - missing_names = {f: PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES} + missing_names = { + f: PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES + } dialogue_section = "" if recent_dialogue and recent_dialogue.strip(): @@ -93,7 +98,9 @@ def get_profile_followup_prompt( nickname: str = "", ) -> str: """在收集资料过程中的跟进提问""" - missing_names = [PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES] + missing_names = [ + PROFILE_FIELD_NAMES[f] for f in missing_fields if f in PROFILE_FIELD_NAMES + ] missing_str = "、".join(missing_names) if missing_names else "无" filled_info = [] diff --git a/api/app/agents/image_prompt/__init__.py b/api/app/agents/image_prompt/__init__.py index 3d26b1b..a79b0c3 100644 --- a/api/app/agents/image_prompt/__init__.py +++ b/api/app/agents/image_prompt/__init__.py @@ -1,4 +1,5 @@ """图片提示词模块:ImagePromptOrchestrator + PromptGenerationAgent""" + from app.agents.image_prompt.orchestrator import ImagePromptOrchestrator from app.agents.image_prompt.prompt_agent import PromptGenerationAgent diff --git a/api/app/agents/image_prompt/orchestrator.py b/api/app/agents/image_prompt/orchestrator.py index 6a879d3..7effb5c 100644 --- a/api/app/agents/image_prompt/orchestrator.py +++ b/api/app/agents/image_prompt/orchestrator.py @@ -3,6 +3,7 @@ ImagePromptOrchestrator:图片提示词生成编排器。 根据调用方(封面/正文)选择 build_prompt 或 build_cover_prompt; 统一异常处理和回退;内部委托 PromptGenerationAgent。 """ + from __future__ import annotations from typing import Any, Optional diff --git a/api/app/agents/image_prompt/prompt_agent.py b/api/app/agents/image_prompt/prompt_agent.py index 5983fad..0c1d509 100644 --- a/api/app/agents/image_prompt/prompt_agent.py +++ b/api/app/agents/image_prompt/prompt_agent.py @@ -4,6 +4,7 @@ PromptGenerationAgent:生成回忆录配图的 image-generation prompt。 调用 LLM 或 fallback 生成 {prompt, style, size}。 底层委托 MemoirImagePromptService,保持对外接口兼容。 """ + from __future__ import annotations from typing import Any, Optional diff --git a/api/app/agents/memoir/__init__.py b/api/app/agents/memoir/__init__.py index 1de5720..cf7e9b5 100644 --- a/api/app/agents/memoir/__init__.py +++ b/api/app/agents/memoir/__init__.py @@ -1,4 +1,5 @@ """回忆录模块:MemoryAgent、BackgroundTaskRunner、MemoirOrchestrator、各 Specialist Agent""" + from app.agents.memoir.memory_agent import MemoryAgent from app.agents.memoir.processor import ( BackgroundTaskRunner, diff --git a/api/app/agents/memoir/classification_agent.py b/api/app/agents/memoir/classification_agent.py index 2ea203b..6b17297 100644 --- a/api/app/agents/memoir/classification_agent.py +++ b/api/app/agents/memoir/classification_agent.py @@ -2,6 +2,7 @@ ClassificationAgent:将内容分类到 8 个章节类别,或判定无价值返回 None。 对应现有逻辑:_classify_chapter_category """ + from __future__ import annotations from typing import Any, Optional @@ -63,7 +64,9 @@ class ClassificationAgent: response = llm.invoke(prompt) category = (response.content or "").strip().lower() if category == "none": - logger.info("LLM 判定内容无回忆录价值,跳过: %s...", (text or "")[:80]) + logger.info( + "LLM 判定内容无回忆录价值,跳过: %s...", (text or "")[:80] + ) return None if category in CHAPTER_CATEGORIES: return category diff --git a/api/app/agents/memoir/extraction_agent.py b/api/app/agents/memoir/extraction_agent.py index 027b9e8..90c2ae7 100644 --- a/api/app/agents/memoir/extraction_agent.py +++ b/api/app/agents/memoir/extraction_agent.py @@ -2,6 +2,7 @@ ExtractionAgent:从用户消息中提取 5-stage 状态与 slots。 对应现有逻辑:get_state_extraction_prompt + JSON 解析 """ + from __future__ import annotations import json @@ -19,6 +20,7 @@ logger = get_logger(__name__) @dataclass class ExtractionResult: """状态提取结果""" + detected_stage: str slots: Dict[str, str] @@ -41,7 +43,9 @@ class ExtractionAgent: extracted_slots: Dict[str, str] = {} if not llm: - return ExtractionResult(detected_stage=detected_stage, slots=extracted_slots) + return ExtractionResult( + detected_stage=detected_stage, slots=extracted_slots + ) try: prompt = get_state_extraction_prompt( @@ -61,8 +65,7 @@ class ExtractionAgent: detected_stage = parsed.get("detected_stage", detected_stage) raw_slots = parsed.get("slots", {}) or {} extracted_slots = { - k: v if isinstance(v, str) else str(v) - for k, v in raw_slots.items() + k: v if isinstance(v, str) else str(v) for k, v in raw_slots.items() } except (json.JSONDecodeError, Exception) as e: logger.warning("ExtractionAgent LLM 解析失败: %s", e) diff --git a/api/app/agents/memoir/memory_agent.py b/api/app/agents/memoir/memory_agent.py index fc82924..ff006a2 100644 --- a/api/app/agents/memoir/memory_agent.py +++ b/api/app/agents/memoir/memory_agent.py @@ -2,6 +2,7 @@ 回忆录整理 Agent:基于传记结构,将口语改写为书面语,归类到章节 支持异步调用 """ + import json from typing import Dict, List, Optional @@ -40,7 +41,9 @@ class MemoryAgent: try: prompt = get_chapter_classification_prompt(segments_text) response = await self.llm.ainvoke(prompt) - content = response.content if hasattr(response, "content") else str(response) + content = ( + response.content if hasattr(response, "content") else str(response) + ) category = content.strip().lower() if category in CHAPTER_CATEGORIES: return category @@ -70,7 +73,9 @@ class MemoryAgent: max_tokens=4096, ) response = await json_llm.ainvoke(prompt) - content = response.content if hasattr(response, "content") else str(response) + content = ( + response.content if hasattr(response, "content") else str(response) + ) content = content.strip() result = json.loads(extract_json_payload(content)) result["content"] = inject_image_placeholder_template( diff --git a/api/app/agents/memoir/narrative_agent.py b/api/app/agents/memoir/narrative_agent.py index fd82ff1..9bd95af 100644 --- a/api/app/agents/memoir/narrative_agent.py +++ b/api/app/agents/memoir/narrative_agent.py @@ -2,6 +2,7 @@ NarrativeAgent:生成创意标题和叙事改写。 对应现有逻辑:get_creative_title_prompt、get_narrative_prompt """ + from __future__ import annotations from typing import Any, Dict, Optional diff --git a/api/app/agents/memoir/orchestrator.py b/api/app/agents/memoir/orchestrator.py index 1870714..7e1d548 100644 --- a/api/app/agents/memoir/orchestrator.py +++ b/api/app/agents/memoir/orchestrator.py @@ -3,6 +3,7 @@ MemoirOrchestrator:按 segment 编排流水线,调用各 Specialist Agent。 负责:遍历 segments、按 category 聚合、调用 Specialist、更新 state; 持久化与章节生成由 process_category 回调完成。 """ + from __future__ import annotations from typing import Any, Callable, Dict, List, Set, Tuple @@ -39,9 +40,7 @@ class MemoirOrchestrator: user_profile: str = "", user_birth_year: Any = None, get_or_create_state: Callable[[], MemoirStateSchema], - update_slot: Callable[ - [str, str, str, List[str]], MemoirStateSchema - ], + update_slot: Callable[[str, str, str, List[str]], MemoirStateSchema], acquire_lock: Callable[[str], bool], release_lock: Callable[[str], None], process_category: Callable[ diff --git a/api/app/agents/memoir/placeholder_agent.py b/api/app/agents/memoir/placeholder_agent.py index 59e0b2c..36ed1b3 100644 --- a/api/app/agents/memoir/placeholder_agent.py +++ b/api/app/agents/memoir/placeholder_agent.py @@ -3,6 +3,7 @@ PlaceholderInjectAgent:对 narrative 做占位符模板注入。 对应现有逻辑:inject_image_placeholder_template 纯函数式,无 LLM 调用。 """ + from app.agents.memoir.prompts import inject_image_placeholder_template diff --git a/api/app/agents/memoir/processor.py b/api/app/agents/memoir/processor.py index 56b9155..994601d 100644 --- a/api/app/agents/memoir/processor.py +++ b/api/app/agents/memoir/processor.py @@ -2,6 +2,7 @@ 回忆录后台处理器:分析对话、更新状态、生成章节、创意标题 使用 Celery 进行后台任务处理 """ + from __future__ import annotations import json @@ -70,9 +71,7 @@ class ContentAnalyzer: async def analyze_message( self, user_message: str, current_state: MemoirStateSchema ) -> AnalysisResult: - detected_stage = self._detect_stage( - user_message, current_state.current_stage - ) + detected_stage = self._detect_stage(user_message, current_state.current_stage) extracted_slots: Dict[str, str] = {} emotion = "neutral" is_new_chapter = False diff --git a/api/app/agents/memoir/prompts.py b/api/app/agents/memoir/prompts.py index edccea9..36aa7b8 100644 --- a/api/app/agents/memoir/prompts.py +++ b/api/app/agents/memoir/prompts.py @@ -1,6 +1,7 @@ """ 回忆录整理 Agent 提示词模板 """ + import json import re from typing import Optional @@ -67,8 +68,13 @@ def inject_image_placeholder_template(content: str) -> str: if not inner: return match.group(0) if inner.startswith(IMAGE_PLACEHOLDER_TEMPLATE): - desc = inner[len(IMAGE_PLACEHOLDER_TEMPLATE):].lstrip("。").strip() - return "{{{{IMAGE:" + IMAGE_PLACEHOLDER_TEMPLATE + ("。" + desc if desc else "") + "}}}}" + desc = inner[len(IMAGE_PLACEHOLDER_TEMPLATE) :].lstrip("。").strip() + return ( + "{{{{IMAGE:" + + IMAGE_PLACEHOLDER_TEMPLATE + + ("。" + desc if desc else "") + + "}}}}" + ) return "{{{{IMAGE:" + IMAGE_PLACEHOLDER_TEMPLATE + "。" + inner + "}}}}" content = _IMAGE_PLACEHOLDER_ANY_BRACES_RE.sub(replace_one, content) @@ -146,10 +152,14 @@ def get_chapter_classification_prompt(segments_text: str) -> str: 如果对话内容中没有任何与人生经历相关的实质内容,返回 none。""" -def get_text_rewrite_prompt(segments_text: str, chapter_category: str, existing_content: str = "") -> str: +def get_text_rewrite_prompt( + segments_text: str, chapter_category: str, existing_content: str = "" +) -> str: """获取文本改写的提示词""" chapter_name = CHAPTER_CATEGORIES.get(chapter_category, chapter_category) - existing_section = f"\n\n已有章节内容:\n{existing_content}" if existing_content else "" + existing_section = ( + f"\n\n已有章节内容:\n{existing_content}" if existing_content else "" + ) return f"""{get_system_prompt()} 请将以下口语化的对话内容改写为书面语,归类到"{chapter_name}"章节。 @@ -181,7 +191,9 @@ def get_text_rewrite_prompt(segments_text: str, chapter_category: str, existing_ {{{{IMAGE:奶奶坐在院子里的藤椅上,手里摇着蒲扇}}}}""" -def get_state_extraction_prompt(user_message: str, current_stage: str, stage_slots: dict) -> str: +def get_state_extraction_prompt( + user_message: str, current_stage: str, stage_slots: dict +) -> str: """抽取结构化信息并判断阶段""" slot_keys = list(stage_slots.keys()) all_stage_slots = { @@ -296,9 +308,19 @@ def get_narrative_prompt( """将新对话改写为叙述(只输出新内容的改写,不重复已有内容)""" context_tail = "" if existing_content: - context_tail = existing_content[-300:] if len(existing_content) > 300 else existing_content - context_section = f"\n\n【衔接上下文(已有内容的末尾,仅供参考衔接,不要重复)】:\n{context_tail}" if context_tail else "" - archived_section = f"\n\n【已删除的该类别历史章节(仅供参考,请勿直接使用或重复)】:\n{archived_summaries}" if archived_summaries else "" + context_tail = ( + existing_content[-300:] if len(existing_content) > 300 else existing_content + ) + context_section = ( + f"\n\n【衔接上下文(已有内容的末尾,仅供参考衔接,不要重复)】:\n{context_tail}" + if context_tail + else "" + ) + archived_section = ( + f"\n\n【已删除的该类别历史章节(仅供参考,请勿直接使用或重复)】:\n{archived_summaries}" + if archived_summaries + else "" + ) profile_section = f"\n\n用户基本信息:\n{user_profile}" if user_profile else "" age_hint = _build_age_hint(stage, birth_year) @@ -366,8 +388,14 @@ def get_narrative_json_prompt( """将新对话改写为叙述,输出 JSON 格式(paragraphs: [{content, image_description}])""" context_tail = "" if existing_content: - context_tail = existing_content[-300:] if len(existing_content) > 300 else existing_content - context_section = f"\n\n【衔接上下文(已有内容的末尾,仅供参考衔接,不要重复)】:\n{context_tail}" if context_tail else "" + context_tail = ( + existing_content[-300:] if len(existing_content) > 300 else existing_content + ) + context_section = ( + f"\n\n【衔接上下文(已有内容的末尾,仅供参考衔接,不要重复)】:\n{context_tail}" + if context_tail + else "" + ) profile_section = f"\n\n用户基本信息:\n{user_profile}" if user_profile else "" age_hint = _build_age_hint(stage, birth_year) time_section = f"\n时间参考:{age_hint}" if age_hint else "" diff --git a/api/app/agents/state_schema.py b/api/app/agents/state_schema.py index de2bd40..8b6a3b4 100644 --- a/api/app/agents/state_schema.py +++ b/api/app/agents/state_schema.py @@ -1,6 +1,7 @@ """ 共享状态 Schema(对话 Agent 与后台 Agent 共用) """ + from __future__ import annotations from typing import Dict, List, Optional @@ -10,12 +11,14 @@ from pydantic import BaseModel, Field class SlotData(BaseModel): """Slot 数据结构""" + snippet: Optional[str] = None segment_ids: List[str] = Field(default_factory=list) class MemoirStateSchema(BaseModel): """回忆录状态""" + stage_order: List[str] current_stage: str covered_stages: List[str] @@ -38,9 +41,7 @@ class MemoirStateSchema(BaseModel): """获取指定阶段已填充的槽位及其内容""" stage_slots = self.slots.get(stage, {}) return { - key: value.snippet - for key, value in stage_slots.items() - if value.snippet + key: value.snippet for key, value in stage_slots.items() if value.snippet } def all_stages_coverage(self) -> Dict[str, Dict]: diff --git a/api/app/core/db.py b/api/app/core/db.py index 7860a52..e2fbb13 100644 --- a/api/app/core/db.py +++ b/api/app/core/db.py @@ -33,6 +33,7 @@ def utc_now(): # ── Database URL(纯 postgresql:// 时拼接为 postgresql+psycopg://)──────────────── + def ensure_psycopg_url(url: str) -> str: """若为 postgresql://... 则改为 postgresql+psycopg://...,否则原样返回。""" if url.startswith("postgresql://") and not url.startswith("postgresql+psycopg://"): diff --git a/api/app/core/logging.py b/api/app/core/logging.py index a6487df..e0f2b66 100644 --- a/api/app/core/logging.py +++ b/api/app/core/logging.py @@ -22,7 +22,9 @@ class InterceptHandler(logging.Handler): frame = frame.f_back depth += 1 - logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) def setup_logging() -> None: diff --git a/api/app/core/pagination.py b/api/app/core/pagination.py index 2e61e5e..5c1d449 100644 --- a/api/app/core/pagination.py +++ b/api/app/core/pagination.py @@ -2,6 +2,7 @@ 通用分页:offset/limit 参数与分页结果结构。 各 feature 的 router 按需使用,成功响应直接返回 Pydantic model,不强制包装。 """ + from typing import Generic, TypeVar from pydantic import BaseModel diff --git a/api/app/core/redis.py b/api/app/core/redis.py index e6543a9..87bd957 100644 --- a/api/app/core/redis.py +++ b/api/app/core/redis.py @@ -2,6 +2,7 @@ Redis 客户端与会话/缓存能力:供应用生命周期、会话历史、任务追踪等使用。 配置从 app.core.config.settings 读取,禁止业务层散落 os.getenv。 """ + import json from app.core.logging import get_logger from datetime import datetime, timezone @@ -47,7 +48,9 @@ class RedisService: def _conversation_key(self, conversation_id: str) -> str: return f"conversation:history:{conversation_id}" - async def get_conversation_history(self, conversation_id: str) -> List[Dict[str, Any]]: + async def get_conversation_history( + self, conversation_id: str + ) -> List[Dict[str, Any]]: try: client = await self.get_client() key = self._conversation_key(conversation_id) @@ -81,7 +84,9 @@ class RedisService: if voice_session_id: item["voiceSessionId"] = voice_session_id history.append(item) - await client.setex(key, self.session_ttl, json.dumps(history, ensure_ascii=False)) + await client.setex( + key, self.session_ttl, json.dumps(history, ensure_ascii=False) + ) return True except Exception as e: logger.error("添加消息失败: %s", e) @@ -110,7 +115,11 @@ class RedisService: async def set_cache(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: try: client = await self.get_client() - data = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value + data = ( + json.dumps(value, ensure_ascii=False) + if not isinstance(value, str) + else value + ) if ttl: await client.setex(key, ttl, data) else: diff --git a/api/app/core/task_tracker.py b/api/app/core/task_tracker.py index bca4e83..73c8262 100644 --- a/api/app/core/task_tracker.py +++ b/api/app/core/task_tracker.py @@ -1,6 +1,7 @@ """ 任务追踪服务:追踪 Celery 任务状态(从 services 迁入 core) """ + import json from app.core.logging import get_logger from datetime import datetime, timezone @@ -17,7 +18,9 @@ class TaskTracker: KEY_PREFIX = "task:user:" TASK_TTL = 3600 - async def add_task(self, user_id: str, task_id: str, task_type: str = "memoir") -> bool: + async def add_task( + self, user_id: str, task_id: str, task_type: str = "memoir" + ) -> bool: try: client = await redis_service.get_client() key = f"{self.KEY_PREFIX}{user_id}:tasks" diff --git a/api/app/features/auth/repo.py b/api/app/features/auth/repo.py index cbade8f..75a3193 100644 --- a/api/app/features/auth/repo.py +++ b/api/app/features/auth/repo.py @@ -51,7 +51,9 @@ async def create_refresh_token(token: RefreshToken, db: AsyncSession) -> None: # ── SMS verification code ───────────────────────────────────── -async def create_verification_code(record: SmsVerificationCode, db: AsyncSession) -> None: +async def create_verification_code( + record: SmsVerificationCode, db: AsyncSession +) -> None: db.add(record) diff --git a/api/app/features/auth/schemas.py b/api/app/features/auth/schemas.py index 58ab680..b940b49 100644 --- a/api/app/features/auth/schemas.py +++ b/api/app/features/auth/schemas.py @@ -39,7 +39,9 @@ class UserResponse(BaseModel): class SendSmsRequest(BaseModel): phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") - purpose: str = Field(..., description="用途:register/login/reset_password/change_phone") + purpose: str = Field( + ..., description="用途:register/login/reset_password/change_phone" + ) class SendSmsResponse(BaseModel): @@ -52,7 +54,9 @@ class SmsLoginRequest(BaseModel): phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") - nickname: Optional[str] = Field(None, max_length=50, description="昵称(注册时必填,登录时可选)") + nickname: Optional[str] = Field( + None, max_length=50, description="昵称(注册时必填,登录时可选)" + ) class SmsRegisterRequest(BaseModel): @@ -76,12 +80,16 @@ class ChangePasswordRequest(BaseModel): class ChangePhoneRequest(BaseModel): - new_phone: str = Field(..., min_length=11, max_length=11, description="新手机号(11位)") + new_phone: str = Field( + ..., min_length=11, max_length=11, description="新手机号(11位)" + ) code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") class UpdateNicknameRequest(BaseModel): - nickname: str = Field(..., min_length=1, max_length=50, description="昵称(1-50个字符)") + nickname: str = Field( + ..., min_length=1, max_length=50, description="昵称(1-50个字符)" + ) class AvatarUploadResponse(BaseModel): diff --git a/api/app/features/auth/service.py b/api/app/features/auth/service.py index 32ae254..f412350 100644 --- a/api/app/features/auth/service.py +++ b/api/app/features/auth/service.py @@ -162,9 +162,7 @@ class AuthService: device_info: str = "", ) -> dict: """Refresh access token. Returns {access_token, refresh_token}.""" - token_record = await repo.get_refresh_token_by_token( - refresh_token, self._db - ) + token_record = await repo.get_refresh_token_by_token(refresh_token, self._db) if not token_record: raise AuthError("无效的刷新令牌", "INVALID_TOKEN") @@ -186,9 +184,7 @@ class AuthService: async def logout(self, refresh_token: str, user_id: str) -> None: """Revoke a refresh token owned by the given user.""" - token_record = await repo.get_refresh_token_by_token( - refresh_token, self._db - ) + token_record = await repo.get_refresh_token_by_token(refresh_token, self._db) if token_record and token_record.user_id == user_id: token_record.is_revoked = True await self._db.commit() @@ -284,9 +280,7 @@ class AuthService: new_password: str, ) -> None: """Reset password via SMS code.""" - success, message = await self._verify_sms_code( - phone, code, "reset_password" - ) + success, message = await self._verify_sms_code(phone, code, "reset_password") if not success: raise AuthError(message, "INVALID_SMS_CODE") @@ -321,9 +315,7 @@ class AuthService: code: str, ) -> User: """Change phone number via SMS code. Returns updated user.""" - success, message = await self._verify_sms_code( - new_phone, code, "change_phone" - ) + success, message = await self._verify_sms_code(new_phone, code, "change_phone") if not success: raise AuthError(message, "INVALID_SMS_CODE") diff --git a/api/app/features/content/router.py b/api/app/features/content/router.py index be973b5..7cd8aa3 100644 --- a/api/app/features/content/router.py +++ b/api/app/features/content/router.py @@ -1,6 +1,7 @@ """ 静态内容路由:FAQ、法律文档、官网主页。 """ + from pathlib import Path from typing import List diff --git a/api/app/features/conversation/repo.py b/api/app/features/conversation/repo.py index 4b19600..a2e35c8 100644 --- a/api/app/features/conversation/repo.py +++ b/api/app/features/conversation/repo.py @@ -6,7 +6,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.features.conversation.models import Conversation, Segment -async def get_conversation(conversation_id: str, db: AsyncSession) -> Conversation | None: +async def get_conversation( + conversation_id: str, db: AsyncSession +) -> Conversation | None: return await db.get(Conversation, conversation_id) @@ -14,7 +16,9 @@ async def get_user_conversations(user_id: str, db: AsyncSession) -> list[Convers stmt = ( select(Conversation) .where(Conversation.user_id == user_id) - .order_by(func.coalesce(Conversation.last_message_at, Conversation.started_at).desc()) + .order_by( + func.coalesce(Conversation.last_message_at, Conversation.started_at).desc() + ) ) result = await db.execute(stmt) return list(result.scalars().all()) @@ -24,7 +28,9 @@ def add_conversation(conv: Conversation, db: AsyncSession) -> None: db.add(conv) -async def get_segments_for_conversation(conversation_id: str, db: AsyncSession) -> list[Segment]: +async def get_segments_for_conversation( + conversation_id: str, db: AsyncSession +) -> list[Segment]: stmt = ( select(Segment) .where(Segment.conversation_id == conversation_id) @@ -34,7 +40,9 @@ async def get_segments_for_conversation(conversation_id: str, db: AsyncSession) return list(result.scalars().all()) -async def get_segments_for_organize(conversation_id: str, db: AsyncSession) -> list[Segment]: +async def get_segments_for_organize( + conversation_id: str, db: AsyncSession +) -> list[Segment]: """Unprocessed segments first; if none, all segments.""" stmt = ( select(Segment) diff --git a/api/app/features/conversation/router.py b/api/app/features/conversation/router.py index a2c7ee6..9301ea0 100644 --- a/api/app/features/conversation/router.py +++ b/api/app/features/conversation/router.py @@ -1,6 +1,7 @@ """ 对话 feature — conversations 路由 """ + from app.core.logging import get_logger from fastapi import APIRouter, Depends, HTTPException diff --git a/api/app/features/conversation/schemas.py b/api/app/features/conversation/schemas.py index 20d8659..24e3749 100644 --- a/api/app/features/conversation/schemas.py +++ b/api/app/features/conversation/schemas.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel diff --git a/api/app/features/conversation/service.py b/api/app/features/conversation/service.py index 9133244..7f3c8b6 100644 --- a/api/app/features/conversation/service.py +++ b/api/app/features/conversation/service.py @@ -30,7 +30,8 @@ def _message_timestamp_ms(msg: dict, fallback: datetime | None) -> int: if isinstance(raw_timestamp, str): try: return int( - datetime.fromisoformat(raw_timestamp.replace("Z", "+00:00")).timestamp() * 1000 + datetime.fromisoformat(raw_timestamp.replace("Z", "+00:00")).timestamp() + * 1000 ) except ValueError: pass @@ -60,14 +61,16 @@ def _build_messages_from_history( if voice_session_id in seen_audio_sessions: continue seen_audio_sessions.add(voice_session_id) - messages.append({ - "id": f"{conversation_id}_msg_{idx}", - "conversationId": conversation_id, - "content": msg.get("content", ""), - "senderType": "user" if role == "human" else "assistant", - "timestamp": _message_timestamp_ms(msg, fallback_timestamp), - "messageType": message_type, - }) + messages.append( + { + "id": f"{conversation_id}_msg_{idx}", + "conversationId": conversation_id, + "content": msg.get("content", ""), + "senderType": "user" if role == "human" else "assistant", + "timestamp": _message_timestamp_ms(msg, fallback_timestamp), + "messageType": message_type, + } + ) return messages @@ -99,15 +102,17 @@ class ConversationService: except Exception: pass latest_message = history[-1].get("content", "")[:50] if history else None - result.append({ - "id": conv.id, - "title": (conv.summary or "")[:30] or "岁月知己", - "avatarUrl": None, - "latestMessagePreview": latest_message or conv.summary, - "latestMessageTime": _latest_message_time_ms(conv, history), - "unreadCount": 0, - "isDefaultAssistant": conv.summary is None, - }) + result.append( + { + "id": conv.id, + "title": (conv.summary or "")[:30] or "岁月知己", + "avatarUrl": None, + "latestMessagePreview": latest_message or conv.summary, + "latestMessageTime": _latest_message_time_ms(conv, history), + "unreadCount": 0, + "isDefaultAssistant": conv.summary is None, + } + ) return result async def create(self, user_id: str) -> dict: @@ -181,7 +186,9 @@ class ConversationService: except Exception: return [] - async def organize(self, conversation_id: str, user_id: str, subscription_type: str) -> dict: + async def organize( + self, conversation_id: str, user_id: str, subscription_type: str + ) -> dict: conv = await self.get_or_404(conversation_id, user_id) segments = await repo.get_segments_for_organize(conversation_id, self._db) if not segments: diff --git a/api/app/features/conversation/ws/connection_manager.py b/api/app/features/conversation/ws/connection_manager.py index 45fd325..7733575 100644 --- a/api/app/features/conversation/ws/connection_manager.py +++ b/api/app/features/conversation/ws/connection_manager.py @@ -1,4 +1,5 @@ """WebSocket 连接管理器:仅负责连接注册/注销和消息收发""" + from app.core.logging import get_logger from typing import Dict diff --git a/api/app/features/conversation/ws/message_types.py b/api/app/features/conversation/ws/message_types.py index 91a1cf8..8422cec 100644 --- a/api/app/features/conversation/ws/message_types.py +++ b/api/app/features/conversation/ws/message_types.py @@ -1,4 +1,5 @@ """WebSocket 消息类型定义""" + from enum import Enum LEGACY_VOICE_SESSION_ID = "legacy" @@ -6,6 +7,7 @@ LEGACY_VOICE_SESSION_ID = "legacy" class MessageType(str, Enum): """WebSocket 消息类型""" + CONNECT = "connect" RECORDING_STARTED = "recording_started" AUDIO_CHUNK = "audio_chunk" diff --git a/api/app/features/conversation/ws/pipeline.py b/api/app/features/conversation/ws/pipeline.py index 7076a7d..1622802 100644 --- a/api/app/features/conversation/ws/pipeline.py +++ b/api/app/features/conversation/ws/pipeline.py @@ -1,4 +1,5 @@ """核心消息处理管道:Agent 调用、ASR 转写、分段有序聚合""" + import asyncio import base64 from app.core.logging import get_logger @@ -19,7 +20,10 @@ from app.agents.memoir import BackgroundTaskRunner from app.core.db import AsyncSessionLocal from app.features.conversation.models import Conversation, Segment from app.features.conversation.ws.connection_manager import manager -from app.features.conversation.ws.message_types import LEGACY_VOICE_SESSION_ID, MessageType +from app.features.conversation.ws.message_types import ( + LEGACY_VOICE_SESSION_ID, + MessageType, +) from app.features.conversation.ws.profile_collector import ( apply_extracted_profile, get_filled_profile_fields, @@ -42,15 +46,18 @@ async def _send_tts_audio(conversation_id: str, text: str) -> None: "TTS skipped: synthesize returned empty. Check TTS config in .env" ) return - await manager.send_message(conversation_id, { - "type": MessageType.TTS_AUDIO, - "conversation_id": conversation_id, - "data": { - "audio_base64": base64.b64encode(audio_bytes).decode("utf-8"), - "format": settings.tts_codec, + await manager.send_message( + conversation_id, + { + "type": MessageType.TTS_AUDIO, + "conversation_id": conversation_id, + "data": { + "audio_base64": base64.b64encode(audio_bytes).decode("utf-8"), + "format": settings.tts_codec, + }, + "timestamp": datetime.now(timezone.utc).isoformat(), }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + ) except Exception as e: err_str = str(e) if "PkgExhausted" in err_str: @@ -61,6 +68,7 @@ async def _send_tts_audio(conversation_id: str, text: str) -> None: else: logger.error("TTS synthesize failed: %s", e) + # ── Agent 实例(从 ConnectionManager 移出) ───────────────────── conversation_agent = ConversationAgent() chat_orchestrator = ChatOrchestrator() @@ -70,6 +78,7 @@ background_runner = BackgroundTaskRunner() # ── 分段流状态 ────────────────────────────────────────────────── + @dataclass class SegmentStreamState: """会话内分段处理状态(用于并行 ASR + 有序聚合)""" @@ -136,11 +145,14 @@ def cleanup_segment_states(conversation_id: str) -> None: # ── 工具函数 ──────────────────────────────────────────────────── + def _utc_now() -> datetime: return datetime.now(timezone.utc) -def _mark_conversation_active(conversation: Conversation, at: Optional[datetime] = None) -> datetime: +def _mark_conversation_active( + conversation: Conversation, at: Optional[datetime] = None +) -> datetime: activity_time = at or _utc_now() conversation.last_message_at = activity_time return activity_time @@ -152,7 +164,9 @@ def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str: return LEGACY_VOICE_SESSION_ID -def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]: +def _voice_session_id_from_client_segment_id( + client_segment_id: Optional[str], +) -> Optional[str]: if not client_segment_id: return None session_id, separator, _ = client_segment_id.rpartition("-") @@ -171,11 +185,14 @@ def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int] prefix = "audio-segment:" if not audio_url or not audio_url.startswith(prefix): return None - payload = audio_url[len(prefix):] + payload = audio_url[len(prefix) :] voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":") try: if separator: - return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw)) + return ( + _normalize_voice_session_id(voice_session_id_raw), + int(segment_index_raw), + ) return (LEGACY_VOICE_SESSION_ID, int(payload)) except ValueError: return None @@ -201,14 +218,21 @@ async def _find_existing_segment_by_index( segment_index: int, ) -> Optional[Segment]: segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index) - stmt = select(Segment).where( - Segment.conversation_id == conversation_id, - Segment.audio_url == segment_audio_url, - ).order_by(Segment.created_at.desc()) + stmt = ( + select(Segment) + .where( + Segment.conversation_id == conversation_id, + Segment.audio_url == segment_audio_url, + ) + .order_by(Segment.created_at.desc()) + ) result = await db.execute(stmt) candidates = result.scalars().all() for item in candidates: - if item.conversation_id == conversation_id and item.audio_url == segment_audio_url: + if ( + item.conversation_id == conversation_id + and item.audio_url == segment_audio_url + ): return item return None @@ -252,16 +276,19 @@ async def _send_segment_transition_feedback( segment_index: int, ) -> None: """发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。""" - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": { - "text": LISTENING_FEEDBACK_TEXT, - "transition": True, - "segment_index": segment_index, + await manager.send_message( + conversation_id, + { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": { + "text": LISTENING_FEEDBACK_TEXT, + "transition": True, + "segment_index": segment_index, + }, + "timestamp": datetime.now(timezone.utc).isoformat(), }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + ) async def _delayed_listening_feedback( @@ -281,6 +308,7 @@ async def _delayed_listening_feedback( # ── 分段语音异步处理 ──────────────────────────────────────────── + async def process_audio_segment( conversation_id: str, user_id: str, @@ -298,18 +326,24 @@ async def process_audio_segment( conversation = await db.get(Conversation, conversation_id) user = await db.get(User, user_id) if not conversation: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "对话不存在,分段处理已取消"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": "对话不存在,分段处理已取消"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) return if not user: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "用户不存在,分段处理已取消"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": "用户不存在,分段处理已取消"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) return async with state.lock: @@ -320,14 +354,18 @@ async def process_audio_segment( ) if should_prime_state: - persisted_contiguous_index = await _get_persisted_contiguous_segment_index( - db=db, - conversation_id=conversation_id, - voice_session_id=voice_session_id, + persisted_contiguous_index = ( + await _get_persisted_contiguous_segment_index( + db=db, + conversation_id=conversation_id, + voice_session_id=voice_session_id, + ) ) if persisted_contiguous_index >= 0: async with state.lock: - state.consumed_index = max(state.consumed_index, persisted_contiguous_index) + state.consumed_index = max( + state.consumed_index, persisted_contiguous_index + ) try: audio_bytes = base64.b64decode(audio_base64) @@ -336,28 +374,34 @@ async def process_audio_segment( transcript_text = await get_asr_provider().transcribe( audio_bytes, format="m4a" ) - await manager.send_message(conversation_id, { - "type": MessageType.TRANSCRIPT, - "conversation_id": conversation_id, - "data": { - "text": transcript_text or "", - "audio_duration": audio_duration, - "voice_session_id": voice_session_id, - "segment_index": segment_index, - "is_last": is_last, - }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - - if _is_transcribe_failure(transcript_text): - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, + await manager.send_message( + conversation_id, + { + "type": MessageType.TRANSCRIPT, + "conversation_id": conversation_id, "data": { - "message": f"分段 {segment_index} 转写失败,请重试该片段", + "text": transcript_text or "", + "audio_duration": audio_duration, + "voice_session_id": voice_session_id, "segment_index": segment_index, + "is_last": is_last, }, "timestamp": datetime.now(timezone.utc).isoformat(), - }) + }, + ) + + if _is_transcribe_failure(transcript_text): + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": { + "message": f"分段 {segment_index} 转写失败,请重试该片段", + "segment_index": segment_index, + }, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) return existing_segment = await _find_existing_segment_by_index( @@ -391,7 +435,10 @@ async def process_audio_segment( ready_segments: List[Tuple[int, str, Segment]] = [] async with state.lock: state.processed_indices.add(segment_index) - state.buffered_transcripts[segment_index] = (transcript_text or "", segment) + state.buffered_transcripts[segment_index] = ( + transcript_text or "", + segment, + ) next_index = state.consumed_index + 1 while next_index in state.buffered_transcripts: @@ -408,7 +455,8 @@ async def process_audio_segment( segment=ordered_segment, db=db, user=user, - user_message_timestamp=ordered_segment.created_at or user_message_timestamp, + user_message_timestamp=ordered_segment.created_at + or user_message_timestamp, ) except Exception as e: @@ -416,14 +464,17 @@ async def process_audio_segment( f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}", exc_info=True, ) - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": { - "message": f"分段处理失败: {str(e)}", - "segment_index": segment_index, + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": { + "message": f"分段处理失败: {str(e)}", + "segment_index": segment_index, + }, + "timestamp": datetime.now(timezone.utc).isoformat(), }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + ) finally: async with state.lock: state.pending_indices.discard(segment_index) @@ -431,6 +482,7 @@ async def process_audio_segment( # ── 用户消息处理 ──────────────────────────────────────────────── + async def process_user_message( conversation_id: str, user_message: str, @@ -463,12 +515,19 @@ async def process_user_message( await db.commit() for i, response_text in enumerate(responses): - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": {"text": response_text, "index": i, "total": len(responses)}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": { + "text": response_text, + "index": i, + "total": len(responses), + }, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) await _send_tts_audio(conversation_id, response_text) if i < len(responses) - 1: await asyncio.sleep(0.5) @@ -477,17 +536,21 @@ async def process_user_message( logger.error(f"处理用户消息失败: {e}", exc_info=True) if conversation_id in manager.active_connections: try: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": f"生成回应失败: {str(e)}"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": f"生成回应失败: {str(e)}"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") # ── 对话结束处理 ──────────────────────────────────────────────── + async def process_conversation_segments( conversation_id: str, db: AsyncSession, quota_service: "QuotaService" ): @@ -528,8 +591,11 @@ async def process_conversation_segments( segment_ids = [seg.id for seg in segments] try: from app.tasks.memoir_tasks import process_memoir_segments + process_memoir_segments.delay(conversation.user_id, segment_ids) - logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}") + logger.info( + f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}" + ) except Exception as e: logger.error(f"提交 Celery 任务失败: {e}") diff --git a/api/app/features/conversation/ws/profile_collector.py b/api/app/features/conversation/ws/profile_collector.py index 6b178e9..2ad2376 100644 --- a/api/app/features/conversation/ws/profile_collector.py +++ b/api/app/features/conversation/ws/profile_collector.py @@ -1,4 +1,5 @@ """用户资料收集:缺失字段检测、提取与应用""" + from sqlalchemy.ext.asyncio import AsyncSession from app.features.user.models import User @@ -6,7 +7,10 @@ from app.features.user.models import User def get_missing_profile_fields(user: User) -> list: """检查用户缺失的资料字段""" - from app.agents.chat.prompts_profile import get_missing_profile_fields as _get_missing + from app.agents.chat.prompts_profile import ( + get_missing_profile_fields as _get_missing, + ) + return _get_missing( birth_year=user.birth_year, birth_place=user.birth_place, diff --git a/api/app/features/conversation/ws/quota_guard.py b/api/app/features/conversation/ws/quota_guard.py index 47983b3..3a22661 100644 --- a/api/app/features/conversation/ws/quota_guard.py +++ b/api/app/features/conversation/ws/quota_guard.py @@ -1,4 +1,5 @@ """WebSocket 配额检查:通过注入 QuotaService,不直接 import quota 内部函数。""" + from app.features.quota.service import QuotaService diff --git a/api/app/features/conversation/ws/router.py b/api/app/features/conversation/ws/router.py index 3621184..7463b10 100644 --- a/api/app/features/conversation/ws/router.py +++ b/api/app/features/conversation/ws/router.py @@ -2,6 +2,7 @@ WebSocket 路由:实时对话通信 仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块 """ + import asyncio from app.core.logging import get_logger import uuid @@ -57,23 +58,31 @@ async def websocket_endpoint( """ token = websocket.query_params.get("token") if not token: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌" + ) return payload = verify_token(token) if not payload or payload.get("type") != "access": - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌" + ) return user_id = payload.get("sub") if not user_id: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容" + ) return async with AsyncSessionLocal() as db: user = await db.get(User, user_id) if not user: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在" + ) return await manager.connect(websocket, conversation_id) @@ -81,12 +90,15 @@ async def websocket_endpoint( quota_service = QuotaService(db=db) try: - await manager.send_message(conversation_id, { - "type": MessageType.CONNECT, - "conversation_id": conversation_id, - "data": {"status": "connected"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.CONNECT, + "conversation_id": conversation_id, + "data": {"status": "connected"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) conversation = await db.get(Conversation, conversation_id) if not conversation: @@ -101,14 +113,19 @@ async def websocket_endpoint( else: if conversation.user_id != user_id: try: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "无权访问此对话"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": "无权访问此对话"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) except Exception: pass - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话" + ) return history = await redis_service.get_conversation_history(conversation_id) @@ -122,12 +139,19 @@ async def websocket_endpoint( nickname=user.nickname or "", ) for i, text in enumerate(greetings): - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": {"text": text, "index": i, "total": len(greetings)}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": { + "text": text, + "index": i, + "total": len(greetings), + }, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) if i < len(greetings) - 1: await asyncio.sleep(0.5) except Exception as e: @@ -141,18 +165,27 @@ async def websocket_endpoint( grew_up_place=user.grew_up_place, occupation=user.occupation, ) - opening_messages = await conversation_agent.generate_opening_message( - conversation_id=conversation_id, - memoir_state=state, - user_profile_context=user_profile_context, + opening_messages = ( + await conversation_agent.generate_opening_message( + conversation_id=conversation_id, + memoir_state=state, + user_profile_context=user_profile_context, + ) ) for i, text in enumerate(opening_messages): - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": {"text": text, "index": i, "total": len(opening_messages)}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": { + "text": text, + "index": i, + "total": len(opening_messages), + }, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) if i < len(opening_messages) - 1: await asyncio.sleep(0.5) except Exception as e: @@ -161,7 +194,9 @@ async def websocket_endpoint( while True: try: if websocket.application_state != WebSocketState.CONNECTED: - logger.info(f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}") + logger.info( + f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}" + ) break message = await websocket.receive_json() msg_type = message.get("type") @@ -170,13 +205,23 @@ async def websocket_endpoint( text_message = message.get("data", {}).get("text", "") if text_message: - can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type) + can_send, quota_msg = await check_ws_quota( + quota_service, user_id, user.subscription_type + ) if not can_send: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": { + "message": quota_msg, + "code": "QUOTA_EXCEEDED", + }, + "timestamp": datetime.now( + timezone.utc + ).isoformat(), + }, + ) continue segment = Segment( @@ -186,10 +231,14 @@ async def websocket_endpoint( processed=False, ) db.add(segment) - user_message_timestamp = _mark_conversation_active(conversation) + user_message_timestamp = _mark_conversation_active( + conversation + ) await db.commit() await db.refresh(segment) - await background_runner.queue_message(conversation.user_id, segment.id) + await background_runner.queue_message( + conversation.user_id, segment.id + ) await process_user_message( conversation_id=conversation_id, @@ -198,18 +247,24 @@ async def websocket_endpoint( segment=segment, db=db, user=user, - user_message_timestamp=segment.created_at or user_message_timestamp, + user_message_timestamp=segment.created_at + or user_message_timestamp, ) elif msg_type == MessageType.RECORDING_STARTED: data = message.get("data", {}) - voice_session_id = _normalize_voice_session_id(data.get("voice_session_id")) + voice_session_id = _normalize_voice_session_id( + data.get("voice_session_id") + ) segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) async with segment_state.lock: - if segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done(): + if ( + segment_state.listening_feedback_task is not None + and not segment_state.listening_feedback_task.done() + ): continue if segment_state.listening_feedback_sent: continue @@ -227,52 +282,74 @@ async def websocket_endpoint( segment_index_raw = data.get("segment_index") voice_session_id = _normalize_voice_session_id( data.get("voice_session_id") - or _voice_session_id_from_client_segment_id(data.get("client_segment_id")) + or _voice_session_id_from_client_segment_id( + data.get("client_segment_id") + ) ) is_last = bool(data.get("is_last", False)) audio_duration = int(data.get("duration", 0) or 0) if not audio_base64: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "缺少 audio_base64"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": "缺少 audio_base64"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) continue if segment_index_raw is None: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "缺少 segment_index"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": "缺少 segment_index"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) continue try: segment_index = int(segment_index_raw) except (TypeError, ValueError): - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "segment_index 必须为整数"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": "segment_index 必须为整数"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) continue if segment_index < 0: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "segment_index 不能为负数"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": "segment_index 不能为负数"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) continue - can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type) + can_send, quota_msg = await check_ws_quota( + quota_service, user_id, user.subscription_type + ) if not can_send: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": { + "message": quota_msg, + "code": "QUOTA_EXCEEDED", + }, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) continue segment_state = get_or_create_segment_state( @@ -323,13 +400,23 @@ async def websocket_endpoint( audio_duration = data.get("duration", 0) if audio_base64: - can_send, quota_msg = await check_ws_quota(quota_service, user_id, user.subscription_type) + can_send, quota_msg = await check_ws_quota( + quota_service, user_id, user.subscription_type + ) if not can_send: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": { + "message": quota_msg, + "code": "QUOTA_EXCEEDED", + }, + "timestamp": datetime.now( + timezone.utc + ).isoformat(), + }, + ) continue logger.info(f"收到音频消息,时长: {audio_duration}s") @@ -337,18 +424,25 @@ async def websocket_endpoint( try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) - transcript_text = await asr.transcribe(audio_bytes, "m4a") + transcript_text = await asr.transcribe( + audio_bytes, "m4a" + ) logger.info("ASR 转写结果: %s", transcript_text) - await manager.send_message(conversation_id, { - "type": MessageType.TRANSCRIPT, - "conversation_id": conversation_id, - "data": { - "text": transcript_text, - "audio_duration": audio_duration, + await manager.send_message( + conversation_id, + { + "type": MessageType.TRANSCRIPT, + "conversation_id": conversation_id, + "data": { + "text": transcript_text, + "audio_duration": audio_duration, + }, + "timestamp": datetime.now( + timezone.utc + ).isoformat(), }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + ) segment = Segment( id=str(uuid.uuid4()), @@ -358,12 +452,18 @@ async def websocket_endpoint( processed=False, ) db.add(segment) - user_message_timestamp = _mark_conversation_active(conversation) + user_message_timestamp = _mark_conversation_active( + conversation + ) await db.commit() await db.refresh(segment) - await background_runner.queue_message(conversation.user_id, segment.id) + await background_runner.queue_message( + conversation.user_id, segment.id + ) - if transcript_text and not transcript_text.startswith("转写失败"): + if transcript_text and not transcript_text.startswith( + "转写失败" + ): await process_user_message( conversation_id=conversation_id, user_message=transcript_text, @@ -371,99 +471,141 @@ async def websocket_endpoint( segment=segment, db=db, user=user, - user_message_timestamp=segment.created_at or user_message_timestamp, + user_message_timestamp=segment.created_at + or user_message_timestamp, ) else: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "语音转写失败,请重试或使用文字输入"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": { + "message": "语音转写失败,请重试或使用文字输入" + }, + "timestamp": datetime.now( + timezone.utc + ).isoformat(), + }, + ) except Exception as e: logger.error(f"处理音频消息失败: {e}", exc_info=True) - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": f"处理音频消息失败: {str(e)}"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": { + "message": f"处理音频消息失败: {str(e)}" + }, + "timestamp": datetime.now( + timezone.utc + ).isoformat(), + }, + ) elif msg_type == MessageType.TRANSCRIBE_ONLY: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") if not audio_base64: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "缺少 audio_base64"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": "缺少 audio_base64"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) continue try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) transcript_text = await asr.transcribe(audio_bytes, "m4a") - await manager.send_message(conversation_id, { - "type": MessageType.TRANSCRIPT, - "conversation_id": conversation_id, - "data": {"text": transcript_text or ""}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.TRANSCRIPT, + "conversation_id": conversation_id, + "data": {"text": transcript_text or ""}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) except Exception as e: logger.error(f"仅转写失败: {e}", exc_info=True) - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": f"转写失败: {str(e)}"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": f"转写失败: {str(e)}"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) elif msg_type == MessageType.END_CONVERSATION: conversation.status = "ended" conversation.ended_at = datetime.now(timezone.utc) await db.commit() - await process_conversation_segments(conversation_id, db, quota_service) + await process_conversation_segments( + conversation_id, db, quota_service + ) - await manager.send_message(conversation_id, { - "type": MessageType.END_CONVERSATION, - "conversation_id": conversation_id, - "data": {"status": "ended"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.END_CONVERSATION, + "conversation_id": conversation_id, + "data": {"status": "ended"}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) break except RuntimeError as e: error_msg = str(e) if ( "disconnect" in error_msg.lower() - or "Cannot call \"receive\"" in error_msg - or "accept" in error_msg.lower() and "not connected" in error_msg.lower() + or 'Cannot call "receive"' in error_msg + or "accept" in error_msg.lower() + and "not connected" in error_msg.lower() ): - logger.info(f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}") + logger.info( + f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}" + ) break else: logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True) if conversation_id in manager.active_connections: try: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": str(e)}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": str(e)}, + "timestamp": datetime.now( + timezone.utc + ).isoformat(), + }, + ) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") break except WebSocketDisconnect: - logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}") + logger.info( + f"WebSocket 断开连接: conversation_id={conversation_id}" + ) break except Exception as e: logger.error(f"处理消息时发生错误: {e}", exc_info=True) if conversation_id in manager.active_connections: try: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": str(e)}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) + await manager.send_message( + conversation_id, + { + "type": MessageType.ERROR, + "data": {"message": str(e)}, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") break diff --git a/api/app/features/memoir/helpers.py b/api/app/features/memoir/helpers.py index e872361..3867606 100644 --- a/api/app/features/memoir/helpers.py +++ b/api/app/features/memoir/helpers.py @@ -1,6 +1,7 @@ """ 回忆录序列化与图片归一化辅助(供 MemoirService 使用)。 """ + from app.core.logging import get_logger from app.core.config import settings @@ -46,7 +47,9 @@ def normalize_image_assets_for_api(images: list[dict] | None) -> list[dict]: except CosDownloadUrlError as exc: logger.warning( "章节图片签名失败: key=%s, retryable=%s, error=%s", - storage_key, exc.retryable, exc, + storage_key, + exc.retryable, + exc, ) asset = mark_image_delivery_unavailable(asset) except Exception as exc: @@ -112,20 +115,18 @@ def chapter_to_dict(ch: Chapter) -> dict: content, images_list = sections_to_content_and_images(ch) normalized_images = normalize_image_assets_for_api(images_list) cover = chapter_cover_to_dict(ch) - cover_normalized = ( - normalize_image_assets_for_api([cover])[0] if cover else None - ) + cover_normalized = normalize_image_assets_for_api([cover])[0] if cover else None sections_data = [] if getattr(ch, "sections", None): for s in sorted(ch.sections, key=lambda x: getattr(x, "order_index", 0)): sec_img = section_image_to_dict(s) - sec_img = ( - normalize_image_assets_for_api([sec_img])[0] if sec_img else None + sec_img = normalize_image_assets_for_api([sec_img])[0] if sec_img else None + sections_data.append( + { + "content": (getattr(s, "content", None) or "").strip(), + "image": sec_img, + } ) - sections_data.append({ - "content": (getattr(s, "content", None) or "").strip(), - "image": sec_img, - }) return { "id": ch.id, "title": ch.title, diff --git a/api/app/features/memoir/memoir_images/parser.py b/api/app/features/memoir/memoir_images/parser.py index 1fa316a..35cc1f4 100644 --- a/api/app/features/memoir/memoir_images/parser.py +++ b/api/app/features/memoir/memoir_images/parser.py @@ -82,7 +82,9 @@ def split_narrative_to_sections(narrative: str) -> list[dict[str, Any]]: content = narrative[start:end] if isinstance(content, str): content = content.strip() - sections.append({"content": content or "", "placeholder_info": placeholder_info}) + sections.append( + {"content": content or "", "placeholder_info": placeholder_info} + ) return sections diff --git a/api/app/features/memoir/memoir_images/prompting.py b/api/app/features/memoir/memoir_images/prompting.py index c759bda..7a9e38e 100644 --- a/api/app/features/memoir/memoir_images/prompting.py +++ b/api/app/features/memoir/memoir_images/prompting.py @@ -42,7 +42,9 @@ class MemoirImagePromptService: description: str, context_excerpt: str, ) -> dict[str, str]: - style = self.CATEGORY_STYLE_MAP.get(chapter_category, self.settings.default_style) + style = self.CATEGORY_STYLE_MAP.get( + chapter_category, self.settings.default_style + ) prompt_context = f"{chapter_category}: {chapter_title}" llm_input = { @@ -69,7 +71,9 @@ class MemoirImagePromptService: raw_response = response.content parsed = json.loads(extract_json_payload(raw_response)) return { - "prompt": _ensure_style_in_prompt(parsed["prompt"], parsed.get("style", style)), + "prompt": _ensure_style_in_prompt( + parsed["prompt"], parsed.get("style", style) + ), "style": parsed.get("style", style), "size": parsed.get("size", self.settings.default_size), "prompt_context": prompt_context, @@ -104,7 +108,9 @@ class MemoirImagePromptService: context_excerpt: str, ) -> dict[str, str]: """生成章节封面图的 image-generation prompt。""" - style = self.CATEGORY_STYLE_MAP.get(chapter_category, self.settings.default_style) + style = self.CATEGORY_STYLE_MAP.get( + chapter_category, self.settings.default_style + ) prompt_context = f"{chapter_category}: {chapter_title}" llm_input = { @@ -189,14 +195,18 @@ class MemoirImagePromptService: context_excerpt: str, style: str, ) -> str: - subject = self.CATEGORY_FALLBACK_SUBJECT_MAP.get(chapter_category, "memoir scene") + subject = self.CATEGORY_FALLBACK_SUBJECT_MAP.get( + chapter_category, "memoir scene" + ) if _contains_cjk(description) or _contains_cjk(context_excerpt): return ( f"A {style} illustration of a {subject}, emotionally resonant, cinematic composition, " "authentic everyday details, natural lighting, expressive environment, no text overlay." ) - details = ". ".join(part.strip() for part in (description, context_excerpt) if part.strip()) + details = ". ".join( + part.strip() for part in (description, context_excerpt) if part.strip() + ) if not details: details = "A personal life story scene with authentic emotional detail" return ( diff --git a/api/app/features/memoir/memoir_images/provider.py b/api/app/features/memoir/memoir_images/provider.py index 3cd73ac..745a2f3 100644 --- a/api/app/features/memoir/memoir_images/provider.py +++ b/api/app/features/memoir/memoir_images/provider.py @@ -3,6 +3,7 @@ LiblibImageProvider 已迁至 app.adapters.image_gen.liblib_provider; 此处仅 re-export,以便 memoir 与 tasks 的既有引用不中断。 Feature 应通过 port ImageGenerator + get_image_generator() 使用图生能力。 """ + from app.adapters.image_gen.liblib_provider import LiblibImageProvider # noqa: F401 __all__ = ["LiblibImageProvider"] diff --git a/api/app/features/memoir/memoir_images/schema.py b/api/app/features/memoir/memoir_images/schema.py index 6d4ea8f..55114d5 100644 --- a/api/app/features/memoir/memoir_images/schema.py +++ b/api/app/features/memoir/memoir_images/schema.py @@ -13,7 +13,9 @@ VALID_IMAGE_STATUSES = { IMAGE_STATUS_FAILED, } -_PLACEHOLDER_DESCRIPTION_RE = re.compile(r"\{\{\{\{IMAGE:(.*?)\}\}\}\}|\{\{IMAGE:(.*?)\}\}") +_PLACEHOLDER_DESCRIPTION_RE = re.compile( + r"\{\{\{\{IMAGE:(.*?)\}\}\}\}|\{\{IMAGE:(.*?)\}\}" +) def normalize_image_asset(asset: dict[str, Any] | None) -> dict[str, Any] | None: @@ -21,9 +23,9 @@ def normalize_image_asset(asset: dict[str, Any] | None) -> dict[str, Any] | None return None placeholder = _as_non_empty_string(asset.get("placeholder")) - description = _as_non_empty_string(asset.get("description")) or _extract_description_from_placeholder( - placeholder - ) + description = _as_non_empty_string( + asset.get("description") + ) or _extract_description_from_placeholder(placeholder) if not placeholder or not description: return None diff --git a/api/app/features/memoir/memoir_images/serializers.py b/api/app/features/memoir/memoir_images/serializers.py index 6df8d7e..f339aca 100644 --- a/api/app/features/memoir/memoir_images/serializers.py +++ b/api/app/features/memoir/memoir_images/serializers.py @@ -1,6 +1,7 @@ """ MemoirImage 模型与 API 用 dict 的互转(与 schema.normalize_image_asset 字段一致)。 """ + from datetime import datetime from typing import Any diff --git a/api/app/features/memoir/memoir_images/settings.py b/api/app/features/memoir/memoir_images/settings.py index 0b8484d..751ceaa 100644 --- a/api/app/features/memoir/memoir_images/settings.py +++ b/api/app/features/memoir/memoir_images/settings.py @@ -55,7 +55,9 @@ class MemoirImageSettings: base_max = max(self.max_per_chapter, 0) effective_cap = max(self.max_images_cap, base_max) safe_length = max(content_length, 0) - extra = safe_length // self.chars_per_extra_image if self.chars_per_extra_image > 0 else 0 + extra = ( + safe_length // self.chars_per_extra_image + if self.chars_per_extra_image > 0 + else 0 + ) return min(base_max + extra, effective_cap) - - diff --git a/api/app/features/memoir/memoir_images/storage.py b/api/app/features/memoir/memoir_images/storage.py index bad234a..e5aeb40 100644 --- a/api/app/features/memoir/memoir_images/storage.py +++ b/api/app/features/memoir/memoir_images/storage.py @@ -26,7 +26,9 @@ def normalize_cos_base_url(base_url: str, bucket: str, region: str) -> str: return candidate -def normalize_cos_url(url: str | None, bucket: str, region: str, base_url: str | None = None) -> str | None: +def normalize_cos_url( + url: str | None, bucket: str, region: str, base_url: str | None = None +) -> str | None: if not url: return url @@ -43,7 +45,11 @@ def normalize_cos_url(url: str | None, bucket: str, region: str, base_url: str | return url normalized_parsed = urlparse(normalized_base) - return urlunparse(parsed._replace(scheme=normalized_parsed.scheme, netloc=normalized_parsed.netloc)) + return urlunparse( + parsed._replace( + scheme=normalized_parsed.scheme, netloc=normalized_parsed.netloc + ) + ) def resolve_image_storage_key(image: dict | None) -> str | None: diff --git a/api/app/features/memoir/pdf_service.py b/api/app/features/memoir/pdf_service.py index ef1a307..c3c10cc 100644 --- a/api/app/features/memoir/pdf_service.py +++ b/api/app/features/memoir/pdf_service.py @@ -1,6 +1,7 @@ """ PDF 生成服务(从 services 迁入 memoir feature) """ + from app.core.logging import get_logger from io import BytesIO from typing import List diff --git a/api/app/features/memoir/repo.py b/api/app/features/memoir/repo.py index acfb26d..8e66363 100644 --- a/api/app/features/memoir/repo.py +++ b/api/app/features/memoir/repo.py @@ -8,7 +8,12 @@ from app.features.memoir.models import Book, Chapter, ChapterSection, MemoirStat async def get_current_book(user_id: str, db: AsyncSession) -> Book | None: - stmt = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()).limit(1) + stmt = ( + select(Book) + .where(Book.user_id == user_id) + .order_by(Book.updated_at.desc()) + .limit(1) + ) result = await db.execute(stmt) return result.scalar_one_or_none() diff --git a/api/app/features/memoir/router.py b/api/app/features/memoir/router.py index e1e0b2b..13a205e 100644 --- a/api/app/features/memoir/router.py +++ b/api/app/features/memoir/router.py @@ -1,6 +1,7 @@ """ 回忆录 feature — books / chapters / memoir-state 合并路由 """ + from app.core.logging import get_logger from typing import List, Optional diff --git a/api/app/features/memoir/service.py b/api/app/features/memoir/service.py index 8e28788..4bbe9db 100644 --- a/api/app/features/memoir/service.py +++ b/api/app/features/memoir/service.py @@ -151,20 +151,22 @@ class MemoirService: else: if is_new is True: continue - all_chapters.append({ - "id": f"placeholder_{category}", - "title": CHAPTER_CATEGORIES[category], - "content": "", - "order_index": STAGE_TO_ORDER.get(category, 999), - "status": "empty", - "category": category, - "images": [], - "cover_image": None, - "sections": [], - "updated_at": None, - "is_new": False, - "source_segments": [], - }) + all_chapters.append( + { + "id": f"placeholder_{category}", + "title": CHAPTER_CATEGORIES[category], + "content": "", + "order_index": STAGE_TO_ORDER.get(category, 999), + "status": "empty", + "category": category, + "images": [], + "cover_image": None, + "sections": [], + "updated_at": None, + "is_new": False, + "source_segments": [], + } + ) for ch in chapter_by_category.values(): await self._cleanup_unavailable_images(ch) all_chapters.append(chapter_to_dict(ch)) @@ -231,7 +233,9 @@ class MemoirService: if not ch.category or ch.status == "empty": continue sections = getattr(ch, "sections", None) or [] - section_image_count = sum(1 for s in sections if getattr(s, "image_id", None)) + section_image_count = sum( + 1 for s in sections if getattr(s, "image_id", None) + ) images = getattr(ch, "images", None) or [] cover_rec = next( (m for m in images if getattr(m, "section_id", None) is None), @@ -239,7 +243,10 @@ class MemoirService: ) if section_image_count <= 3: continue - if cover_rec and (getattr(cover_rec, "status") or "").strip() == "completed": + if ( + cover_rec + and (getattr(cover_rec, "status") or "").strip() == "completed" + ): continue if cover_rec is None: img_settings = MemoirImageSettings.from_env() @@ -281,16 +288,12 @@ class MemoirService: return {"triggered": triggered} async def mark_memoir_read(self, user_id: str) -> dict: - stmt = select(Chapter).where( - Chapter.user_id == user_id, Chapter.is_new == True - ) + stmt = select(Chapter).where(Chapter.user_id == user_id, Chapter.is_new == True) result = await self._db.execute(stmt) for chapter in result.scalars().all(): chapter.is_new = False stmt_book = ( - select(Book) - .where(Book.user_id == user_id) - .order_by(Book.updated_at.desc()) + select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()) ) result_book = await self._db.execute(stmt_book) book = result_book.scalar_one_or_none() diff --git a/api/app/features/memoir/state_service.py b/api/app/features/memoir/state_service.py index 4989cb2..1e3f9ad 100644 --- a/api/app/features/memoir/state_service.py +++ b/api/app/features/memoir/state_service.py @@ -2,6 +2,7 @@ 回忆录状态服务:get_or_create_state、update_slot、mark_stage_complete 等。 供 memoir service、conversation ws 使用;Celery 任务内使用同步版本(见 tasks/memoir_tasks)。 """ + import uuid from typing import Dict, List @@ -18,7 +19,9 @@ def _coerce_state(model: MemoirStateModel) -> MemoirStateSchema: "stage_order": model.stage_order or default_state().stage_order, "current_stage": model.current_stage, "covered_stages": model.covered_stages or [], - "slots": model.slots if isinstance(model.slots, dict) else default_state().slots, + "slots": model.slots + if isinstance(model.slots, dict) + else default_state().slots, } ) @@ -37,7 +40,10 @@ async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSche stage_order=default.stage_order, current_stage=default.current_stage, covered_stages=default.covered_stages, - slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()}, + slots={ + k: {sk: sv.model_dump() for sk, sv in v.items()} + for k, v in default.slots.items() + }, ) db.add(state) await db.commit() @@ -66,7 +72,9 @@ async def update_slot( existing = stage_slots.get(slot_name, {}) merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) - stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump() + stage_slots[slot_name] = SlotData( + snippet=snippet, segment_ids=merged_segment_ids + ).model_dump() slots[stage] = stage_slots state.slots = slots state.current_stage = state.current_stage or stage @@ -75,7 +83,9 @@ async def update_slot( return _coerce_state(state) -async def mark_stage_complete(user_id: str, stage: str, db: AsyncSession) -> MemoirStateSchema: +async def mark_stage_complete( + user_id: str, stage: str, db: AsyncSession +) -> MemoirStateSchema: stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) result = await db.execute(stmt) state = result.scalar_one_or_none() @@ -104,7 +114,9 @@ async def get_empty_slots(user_id: str, db: AsyncSession) -> List[str]: return state.empty_slots_for_current_stage() -async def switch_stage(user_id: str, new_stage: str, db: AsyncSession) -> MemoirStateSchema: +async def switch_stage( + user_id: str, new_stage: str, db: AsyncSession +) -> MemoirStateSchema: stmt = select(MemoirStateModel).where(MemoirStateModel.user_id == user_id) result = await db.execute(stmt) state = result.scalar_one_or_none() diff --git a/api/app/features/memory/chunker.py b/api/app/features/memory/chunker.py index 3d52214..6be02f7 100644 --- a/api/app/features/memory/chunker.py +++ b/api/app/features/memory/chunker.py @@ -1,6 +1,8 @@ """Transcript chunker — split raw text into retrieval-ready chunks (skeleton).""" -def chunk_transcript(text: str, *, max_tokens: int = 512, overlap: int = 64) -> list[str]: +def chunk_transcript( + text: str, *, max_tokens: int = 512, overlap: int = 64 +) -> list[str]: """Split transcript text into overlapping chunks.""" raise NotImplementedError diff --git a/api/app/features/memory/models.py b/api/app/features/memory/models.py index 97b1467..6593c78 100644 --- a/api/app/features/memory/models.py +++ b/api/app/features/memory/models.py @@ -30,13 +30,17 @@ class MemorySource(Base): status = Column(String, default="active") conversation_id = Column(String, ForeignKey("conversations.id"), nullable=True) created_at = Column(DateTime(timezone=True), default=utc_now) - chunks = relationship("MemoryChunk", back_populates="source", cascade="all, delete-orphan") + chunks = relationship( + "MemoryChunk", back_populates="source", cascade="all, delete-orphan" + ) class MemoryChunk(Base): __tablename__ = "memory_chunks" id = Column(String, primary_key=True) - source_id = Column(String, ForeignKey("memory_sources.id"), nullable=False, index=True) + source_id = Column( + String, ForeignKey("memory_sources.id"), nullable=False, index=True + ) user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True) content = Column(Text, nullable=False) # pgvector embedding — Alembic migration 负责 CREATE EXTENSION vector 及列类型 @@ -67,7 +71,9 @@ class MemoryFact(Base): __tablename__ = "memory_facts" id = Column(String, primary_key=True) user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True) - fact_type = Column(String, nullable=False) # person / event / relation / place / milestone + fact_type = Column( + String, nullable=False + ) # person / event / relation / place / milestone subject = Column(String, nullable=True) predicate = Column(String, nullable=True) object_json = Column(JSON, nullable=True) @@ -94,8 +100,12 @@ class MemoryCurationAction(Base): __tablename__ = "memory_curation_actions" id = Column(String, primary_key=True) user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True) - action_type = Column(String, nullable=False) # exclude / restore / correct / merge / confirm / reject - target_type = Column(String, nullable=False) # chunk / fact / summary / timeline_event + action_type = Column( + String, nullable=False + ) # exclude / restore / correct / merge / confirm / reject + target_type = Column( + String, nullable=False + ) # chunk / fact / summary / timeline_event target_id = Column(String, nullable=False) details = Column(JSON, nullable=True) created_at = Column(DateTime(timezone=True), default=utc_now) diff --git a/api/app/features/memory/retriever.py b/api/app/features/memory/retriever.py index 5799378..278bbfd 100644 --- a/api/app/features/memory/retriever.py +++ b/api/app/features/memory/retriever.py @@ -4,5 +4,7 @@ class HybridRetriever: """Phase 2+ implementation: combine FTS, vector, and metadata filter results.""" - async def retrieve(self, user_id: str, query: str, *, top_k: int = 10) -> list[dict]: + async def retrieve( + self, user_id: str, query: str, *, top_k: int = 10 + ) -> list[dict]: raise NotImplementedError diff --git a/api/app/features/memory/schemas.py b/api/app/features/memory/schemas.py index 681f72b..5e6e0c4 100644 --- a/api/app/features/memory/schemas.py +++ b/api/app/features/memory/schemas.py @@ -1,9 +1,9 @@ - from pydantic import BaseModel class EvidenceBundle(BaseModel): """MemoryService 产出的检索结果,供 conversation/memoir 消费。""" + relevant_chunks: list[dict] = [] relevant_summaries: list[dict] = [] relevant_facts: list[dict] = [] diff --git a/api/app/features/memory/service.py b/api/app/features/memory/service.py index 3acd8ad..3c451f4 100644 --- a/api/app/features/memory/service.py +++ b/api/app/features/memory/service.py @@ -2,6 +2,7 @@ MemoryService — conversation / memoir 的统一门面。 一期先实现基础接口签名,具体逻辑后续补充。 """ + from sqlalchemy.ext.asyncio import AsyncSession @@ -9,7 +10,9 @@ class MemoryService: def __init__(self, db: AsyncSession): self._db = db - async def ingest_transcript(self, user_id: str, conversation_id: str, transcript: str) -> str: + async def ingest_transcript( + self, user_id: str, conversation_id: str, transcript: str + ) -> str: """Ingest conversation transcript into memory. Returns source_id.""" raise NotImplementedError("Phase 2+ implementation") diff --git a/api/app/features/payment/alipay_client.py b/api/app/features/payment/alipay_client.py index 9678d28..6564e7e 100644 --- a/api/app/features/payment/alipay_client.py +++ b/api/app/features/payment/alipay_client.py @@ -1,6 +1,7 @@ """ 支付宝 OpenAPI 封装(从 payment 迁入 app) """ + from app.core.logging import get_logger from typing import Dict, Optional @@ -27,6 +28,7 @@ class AlipayClient: if self._client is None: try: from alipay import AliPay + self._client = AliPay( appid=self._config.app_id, app_notify_url=self._config.notify_url, @@ -114,7 +116,11 @@ class AlipayClient: trade_status=unified_status, total_amount=total_amount, ) - error_msg = result.get("sub_msg", result.get("msg", "未知错误")) if result else "空结果" + error_msg = ( + result.get("sub_msg", result.get("msg", "未知错误")) + if result + else "空结果" + ) raise PaymentQueryError(f"查询支付宝订单失败: {error_msg}") except PaymentQueryError: raise diff --git a/api/app/features/payment/order_service.py b/api/app/features/payment/order_service.py index 30a6da6..ecbeac8 100644 --- a/api/app/features/payment/order_service.py +++ b/api/app/features/payment/order_service.py @@ -1,6 +1,7 @@ """ 支付订单门面:持有 db + 底层 payment 客户端,提供 create_order / 回调 / 查询。 """ + import asyncio from app.core.logging import get_logger import time @@ -38,6 +39,7 @@ def _generate_order_no() -> str: def _get_legacy_payment_service(): from app.features.payment.deps import get_payment_service + return get_payment_service() @@ -65,13 +67,20 @@ class PaymentOrderService: if plan.price <= 0: raise HTTPException(status_code=400, detail="免费套餐无需支付") if payment_method not in ("wechat", "alipay"): - raise HTTPException(status_code=400, detail="不支持的支付方式,仅支持 wechat / alipay") + raise HTTPException( + status_code=400, detail="不支持的支付方式,仅支持 wechat / alipay" + ) client = _get_legacy_payment_service() if not client.is_method_available(payment_method): if payment_method == "alipay": - raise HTTPException(status_code=503, detail="支付宝支付接口正在开发中,暂时不可用") - raise HTTPException(status_code=503, detail=f"{payment_method} 支付暂不可用,请选择其他支付方式") + raise HTTPException( + status_code=503, detail="支付宝支付接口正在开发中,暂时不可用" + ) + raise HTTPException( + status_code=503, + detail=f"{payment_method} 支付暂不可用,请选择其他支付方式", + ) amount_fen = int(plan_price * 100) order_no = _generate_order_no() @@ -100,7 +109,9 @@ class PaymentOrderService: except asyncio.TimeoutError: order.status = "failed" await self._db.flush() - raise HTTPException(status_code=504, detail="微信支付初始化超时,请稍后重试。") + raise HTTPException( + status_code=504, detail="微信支付初始化超时,请稍后重试。" + ) except Exception as e: order.status = "failed" await self._db.flush() @@ -128,15 +139,24 @@ class PaymentOrderService: except PaymentError as e: order.status = "failed" await self._db.flush() - raise HTTPException(status_code=500, detail=f"创建支付订单失败: {e.message}") + raise HTTPException( + status_code=500, detail=f"创建支付订单失败: {e.message}" + ) except Exception as e: order.status = "failed" await self._db.flush() logger.exception("创建支付订单异常: %s", e) - raise HTTPException(status_code=500, detail=f"创建支付订单异常: {type(e).__name__}: {e!s}") + raise HTTPException( + status_code=500, detail=f"创建支付订单异常: {type(e).__name__}: {e!s}" + ) await self._db.commit() - logger.info("订单创建成功: order_no=%s, payment_method=%s, amount_fen=%s", order_no, payment_method, amount_fen) + logger.info( + "订单创建成功: order_no=%s, payment_method=%s, amount_fen=%s", + order_no, + payment_method, + amount_fen, + ) return CreateOrderResponse( order_id=order_no, payment_method=payment_method, @@ -157,18 +177,29 @@ class PaymentOrderService: order.status = "paid" order.trade_no = trade_no order.paid_at = now - user_result = await self._db.execute(select(User).where(User.id == order.user_id)) + user_result = await self._db.execute( + select(User).where(User.id == order.user_id) + ) user = user_result.scalar_one_or_none() if user: duration_days = SUBSCRIPTION_DURATION_DAYS.get(order.plan_id, 365) if user.subscription_expires_at and user.subscription_expires_at > now: - user.subscription_expires_at = user.subscription_expires_at + timedelta(days=duration_days) + user.subscription_expires_at = user.subscription_expires_at + timedelta( + days=duration_days + ) else: user.subscription_expires_at = now + timedelta(days=duration_days) user.subscription_type = order.plan_id - logger.info("用户 %s 订阅已升级为 %s,到期: %s", user.id, order.plan_id, user.subscription_expires_at) + logger.info( + "用户 %s 订阅已升级为 %s,到期: %s", + user.id, + order.plan_id, + user.subscription_expires_at, + ) await self._db.commit() - logger.info("支付成功处理完成: 订单 %s, 第三方交易号 %s", out_trade_no, trade_no) + logger.info( + "支付成功处理完成: 订单 %s, 第三方交易号 %s", out_trade_no, trade_no + ) async def handle_wechat_notify(self, headers: dict, body: str) -> dict: client = _get_legacy_payment_service() @@ -183,14 +214,20 @@ class PaymentOrderService: async def handle_alipay_notify(self, params: dict) -> str: client = _get_legacy_payment_service() notify_result = client.handle_alipay_notify(params=params) - if notify_result.success and notify_result.trade_status in ("TRADE_SUCCESS", "TRADE_FINISHED", "SUCCESS"): + if notify_result.success and notify_result.trade_status in ( + "TRADE_SUCCESS", + "TRADE_FINISHED", + "SUCCESS", + ): await self.handle_payment_success( notify_result.out_trade_no, notify_result.trade_no, ) return "success" - async def get_order_status(self, order_id: str, user_id: str) -> OrderStatusResponse: + async def get_order_status( + self, order_id: str, user_id: str + ) -> OrderStatusResponse: result = await self._db.execute( select(Order).where(Order.id == order_id, Order.user_id == user_id) ) @@ -212,7 +249,9 @@ class PaymentOrderService: async def list_orders(self, user_id: str) -> list[OrderListResponse]: result = await self._db.execute( - select(Order).where(Order.user_id == user_id).order_by(Order.created_at.desc()) + select(Order) + .where(Order.user_id == user_id) + .order_by(Order.created_at.desc()) ) orders = result.scalars().all() return [ diff --git a/api/app/features/payment/payment_config.py b/api/app/features/payment/payment_config.py index 1d68cb6..34cde0e 100644 --- a/api/app/features/payment/payment_config.py +++ b/api/app/features/payment/payment_config.py @@ -1,6 +1,7 @@ """ 支付模块配置(从 payment 迁入 app,从 app.core.config.settings 读取) """ + from app.core.logging import get_logger from dataclasses import dataclass, field diff --git a/api/app/features/payment/payment_facade.py b/api/app/features/payment/payment_facade.py index b91ebe0..a915cf7 100644 --- a/api/app/features/payment/payment_facade.py +++ b/api/app/features/payment/payment_facade.py @@ -1,6 +1,7 @@ """ 统一支付服务门面(从 payment 迁入 app) """ + from app.core.logging import get_logger from typing import Dict, Optional @@ -55,9 +56,7 @@ class PaymentService: subject=description, ) - def handle_wechat_notify( - self, headers: Dict[str, str], body: str - ) -> NotifyResult: + def handle_wechat_notify(self, headers: Dict[str, str], body: str) -> NotifyResult: return self.wechat_client.verify_notify(headers=headers, body=body) def handle_alipay_notify(self, params: Dict[str, str]) -> NotifyResult: diff --git a/api/app/features/payment/repo.py b/api/app/features/payment/repo.py index 2d0b415..2de9cda 100644 --- a/api/app/features/payment/repo.py +++ b/api/app/features/payment/repo.py @@ -11,7 +11,9 @@ async def get_order_by_id(order_id: str, db: AsyncSession) -> Order | None: async def get_orders_by_user(user_id: str, db: AsyncSession) -> list[Order]: - stmt = select(Order).where(Order.user_id == user_id).order_by(Order.created_at.desc()) + stmt = ( + select(Order).where(Order.user_id == user_id).order_by(Order.created_at.desc()) + ) result = await db.execute(stmt) return list(result.scalars().all()) diff --git a/api/app/features/payment/router.py b/api/app/features/payment/router.py index d3d5389..7296f87 100644 --- a/api/app/features/payment/router.py +++ b/api/app/features/payment/router.py @@ -37,7 +37,9 @@ async def create_order( service: PaymentOrderService = Depends(get_payment_order_service), plan_service: PlanService = Depends(get_plan_service), ): - plan = next((p for p in plan_service.get_plans_for_api() if p.id == request.plan_id), None) + plan = next( + (p for p in plan_service.get_plans_for_api() if p.id == request.plan_id), None + ) if plan is None: raise HTTPException(status_code=400, detail="无效的套餐 ID") return await service.create_order( @@ -59,7 +61,9 @@ async def wechat_notify( try: headers = dict(request.headers) body = await request.body() - return await service.handle_wechat_notify(headers=headers, body=body.decode("utf-8")) + return await service.handle_wechat_notify( + headers=headers, body=body.decode("utf-8") + ) except Exception as e: logger.exception("微信支付回调处理失败: %s", e) return {"code": "FAIL", "message": str(e)} diff --git a/api/app/features/payment/schemas.py b/api/app/features/payment/schemas.py index e1cab77..51823e4 100644 --- a/api/app/features/payment/schemas.py +++ b/api/app/features/payment/schemas.py @@ -1,4 +1,5 @@ """支付模块 Pydantic 模型定义(从 payment 迁入 app)""" + from typing import Any, Dict, Optional from pydantic import BaseModel diff --git a/api/app/features/payment/service.py b/api/app/features/payment/service.py index 95f253c..dce0c94 100644 --- a/api/app/features/payment/service.py +++ b/api/app/features/payment/service.py @@ -1,6 +1,7 @@ """ 支付 feature 对外暴露:统一门面与配置(实现已迁入 app) """ + from app.features.payment.payment_config import PaymentConfig from app.features.payment.payment_facade import PaymentService from app.features.payment.payment_exceptions import PaymentConfigError, PaymentError diff --git a/api/app/features/payment/wechat_client.py b/api/app/features/payment/wechat_client.py index 391f7fb..1589624 100644 --- a/api/app/features/payment/wechat_client.py +++ b/api/app/features/payment/wechat_client.py @@ -1,6 +1,7 @@ """ 微信支付 API v3 封装(从 payment 迁入 app) """ + import json from app.core.logging import get_logger import os @@ -40,9 +41,7 @@ def _resolve_key_path(key_path: str) -> str: try: # app/features/payment/wechat_client.py -> api/ api_dir = os.path.dirname( - os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - ) + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) abs_api = os.path.normpath(os.path.join(api_dir, key_path)) if os.path.isfile(abs_api): @@ -67,7 +66,10 @@ class WeChatPayClient: try: from wechatpayv3 import WeChatPay, WeChatPayType - if self._config.private_key_path and self._config.private_key_path.strip(): + if ( + self._config.private_key_path + and self._config.private_key_path.strip() + ): key_path = _resolve_key_path(self._config.private_key_path) with open(key_path, "r", encoding="utf-8") as f: private_key = f.read() @@ -92,11 +94,19 @@ class WeChatPayClient: notify_url=self._config.notify_url, ) if self._config.use_platform_public_key: - if self._config.platform_public_key_path and self._config.platform_public_key_path.strip(): - key_path = _resolve_key_path(self._config.platform_public_key_path) + if ( + self._config.platform_public_key_path + and self._config.platform_public_key_path.strip() + ): + key_path = _resolve_key_path( + self._config.platform_public_key_path + ) with open(key_path, "r", encoding="utf-8") as f: platform_pub_key = _normalize_pem_key(f.read()) - elif self._config.platform_public_key and self._config.platform_public_key.strip(): + elif ( + self._config.platform_public_key + and self._config.platform_public_key.strip() + ): platform_pub_key = _normalize_pem_key( self._config.platform_public_key.strip() ) @@ -105,9 +115,13 @@ class WeChatPayClient: "平台公钥模式需设置 WECHAT_PAY_PLATFORM_PUBLIC_KEY 或 WECHAT_PAY_PLATFORM_PUBLIC_KEY_PATH" ) if not platform_pub_key or "-----BEGIN" not in platform_pub_key: - raise PaymentConfigError("微信支付平台公钥格式错误,需为 PEM 格式") + raise PaymentConfigError( + "微信支付平台公钥格式错误,需为 PEM 格式" + ) kwargs["public_key"] = platform_pub_key - kwargs["public_key_id"] = self._config.platform_public_key_id.strip() + kwargs["public_key_id"] = ( + self._config.platform_public_key_id.strip() + ) else: kwargs["timeout"] = (10, 25) self._client = WeChatPay(**kwargs) @@ -118,9 +132,7 @@ class WeChatPayClient: if self._config.private_key_path else self._config.private_key_path ) - raise PaymentConfigError( - f"微信支付商户私钥文件不存在: {key_path}" - ) + raise PaymentConfigError(f"微信支付商户私钥文件不存在: {key_path}") except PaymentConfigError: raise except Exception as e: @@ -244,4 +256,5 @@ class WeChatPayClient: def _get_pay_type(self): from wechatpayv3 import WeChatPayType + return WeChatPayType.APP diff --git a/api/app/features/plan/router.py b/api/app/features/plan/router.py index d427552..0c5d8f0 100644 --- a/api/app/features/plan/router.py +++ b/api/app/features/plan/router.py @@ -1,6 +1,7 @@ """ 订阅计划路由。 """ + from typing import List from fastapi import APIRouter, Depends diff --git a/api/app/features/plan/service.py b/api/app/features/plan/service.py index 17f1fbc..2454664 100644 --- a/api/app/features/plan/service.py +++ b/api/app/features/plan/service.py @@ -11,28 +11,43 @@ ENABLE_TEST_PLAN = (settings.enable_test_plan or "").lower() in ("1", "true", "y AVAILABLE_PLANS = [ PlanResponse( - id="free", name="free", display_name="免费体验版", - price=0.0, currency="CNY", + id="free", + name="free", + display_name="免费体验版", + price=0.0, + currency="CNY", features=["500 轮对话", "无章节限制", "完整回忆录生成流程"], - max_conversations=500, is_popular=False, + max_conversations=500, + is_popular=False, ), PlanResponse( - id="pro", name="pro", display_name="Pro 版", - price=88.0, currency="CNY", + id="pro", + name="pro", + display_name="Pro 版", + price=88.0, + currency="CNY", features=["2000 轮对话", "无章节限制", "完整回忆录生成"], - max_conversations=2000, is_popular=True, + max_conversations=2000, + is_popular=True, ), PlanResponse( - id="pro_plus", name="pro_plus", display_name="Pro+ 版", - price=288.0, currency="CNY", + id="pro_plus", + name="pro_plus", + display_name="Pro+ 版", + price=288.0, + currency="CNY", features=["10000 轮对话", "无章节限制", "完整回忆录生成", "长期创作无忧"], - max_conversations=10000, is_popular=False, + max_conversations=10000, + is_popular=False, ), ] TEST_PLAN = PlanResponse( - id="test", name="test", display_name="一分钱测试版", - price=0.01, currency="CNY", + id="test", + name="test", + display_name="一分钱测试版", + price=0.01, + currency="CNY", features=["无限对话", "无限章节整理", "仅用于开发环境测试支付"], is_popular=False, ) diff --git a/api/app/features/quota/router.py b/api/app/features/quota/router.py index 3f4ce36..5485c62 100644 --- a/api/app/features/quota/router.py +++ b/api/app/features/quota/router.py @@ -1,6 +1,7 @@ """ 配额检查路由。 """ + from fastapi import APIRouter, Depends from app.core.dependencies import get_current_user diff --git a/api/app/features/quota/service.py b/api/app/features/quota/service.py index 927887b..02e6d1a 100644 --- a/api/app/features/quota/service.py +++ b/api/app/features/quota/service.py @@ -3,6 +3,7 @@ 「对话轮数」的定义:每条用户发出的消息(Segment 表的记录数)计为 1 轮。 """ + from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession @@ -72,7 +73,10 @@ def check_can_send_message( if max_conv is None: return True, "" if segment_count >= max_conv: - return False, f"对话轮数已用完({segment_count}/{max_conv}),请升级 Pro 或 Pro+ 继续使用" + return ( + False, + f"对话轮数已用完({segment_count}/{max_conv}),请升级 Pro 或 Pro+ 继续使用", + ) return True, "" diff --git a/api/app/features/tasks/deps.py b/api/app/features/tasks/deps.py index d2d7f0f..d5d3203 100644 --- a/api/app/features/tasks/deps.py +++ b/api/app/features/tasks/deps.py @@ -1,4 +1,5 @@ """Tasks feature 依赖:提供 get_tasks_service。""" + from fastapi import Depends from app.features.tasks.service import TasksService diff --git a/api/app/features/tasks/router.py b/api/app/features/tasks/router.py index 05610c9..99beb77 100644 --- a/api/app/features/tasks/router.py +++ b/api/app/features/tasks/router.py @@ -1,6 +1,7 @@ """ 任务状态 API 路由 """ + from typing import Dict, List from fastapi import APIRouter, Depends @@ -20,6 +21,7 @@ router = APIRouter( class TaskInfo(BaseModel): """任务信息""" + task_id: str task_type: str = "memoir" status: str @@ -30,6 +32,7 @@ class TaskInfo(BaseModel): class TasksStatusResponse(BaseModel): """任务状态汇总响应""" + total: int pending: int running: int diff --git a/api/app/features/tasks/service.py b/api/app/features/tasks/service.py index 400054f..3e92a94 100644 --- a/api/app/features/tasks/service.py +++ b/api/app/features/tasks/service.py @@ -1,6 +1,7 @@ """ 任务状态服务:对外提供任务状态查询与清理,委托给底层 task_tracker。 """ + from typing import Any from app.core.task_tracker import task_tracker diff --git a/api/app/features/user/models.py b/api/app/features/user/models.py index 826d8c7..88e4734 100644 --- a/api/app/features/user/models.py +++ b/api/app/features/user/models.py @@ -27,7 +27,10 @@ class User(Base): books = relationship("Book", back_populates="user") orders = relationship("Order", back_populates="user", cascade="all, delete-orphan") memoir_state = relationship( - "MemoirState", back_populates="user", uselist=False, cascade="all, delete-orphan" + "MemoirState", + back_populates="user", + uselist=False, + cascade="all, delete-orphan", ) refresh_tokens = relationship( "RefreshToken", back_populates="user", cascade="all, delete-orphan" diff --git a/api/app/features/user/router.py b/api/app/features/user/router.py index 619fc83..3c2ef1d 100644 --- a/api/app/features/user/router.py +++ b/api/app/features/user/router.py @@ -81,7 +81,9 @@ async def test_subscription( ) -@feedback_router.post("", response_model=FeedbackResponse, status_code=status.HTTP_201_CREATED) +@feedback_router.post( + "", response_model=FeedbackResponse, status_code=status.HTTP_201_CREATED +) async def submit_feedback( request: SubmitFeedbackRequest, current_user: User = Depends(get_current_user), diff --git a/api/app/features/user/schemas.py b/api/app/features/user/schemas.py index eb35f40..eb7a0e4 100644 --- a/api/app/features/user/schemas.py +++ b/api/app/features/user/schemas.py @@ -37,11 +37,13 @@ class TestSubscriptionResponse(BaseModel): class SubmitFeedbackRequest(BaseModel): """提交反馈请求""" + content: str = Field(..., min_length=1, max_length=2000, description="反馈内容") contact: Optional[str] = Field(None, max_length=100, description="联系方式(可选)") class FeedbackResponse(BaseModel): """反馈响应""" + id: str message: str diff --git a/api/app/main.py b/api/app/main.py index 01d73f0..486da20 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -1,6 +1,7 @@ """ FastAPI 应用入口(app 内主入口,符合架构计划) """ + from pathlib import Path from app.core.logging import setup_logging @@ -85,7 +86,9 @@ async def startup_event(): await asyncio.to_thread(_run_alembic_upgrade) logger.info("Alembic 迁移已就绪") except Exception as e: - logger.warning("Alembic 迁移失败(可能数据库未启动或 DATABASE_URL 未配置): %s", e) + logger.warning( + "Alembic 迁移失败(可能数据库未启动或 DATABASE_URL 未配置): %s", e + ) try: from app.core.redis import redis_service @@ -107,7 +110,11 @@ async def startup_event(): if asr_ready: from app.core.config import settings - name = "腾讯云一句话识别" if settings.asr_provider == "tencent" else "本地 Whisper" + name = ( + "腾讯云一句话识别" + if settings.asr_provider == "tencent" + else "本地 Whisper" + ) logger.info("ASR 服务已就绪(%s)", name) else: logger.warning("ASR 服务未就绪,语音转写将不可用") diff --git a/api/app/tasks/__init__.py b/api/app/tasks/__init__.py index 65dc44a..713270e 100644 --- a/api/app/tasks/__init__.py +++ b/api/app/tasks/__init__.py @@ -1,6 +1,7 @@ """ Celery 任务模块 """ + from .celery_app import celery_app from .memoir_tasks import process_memoir_segments, generate_chapter_images diff --git a/api/app/tasks/celery_app.py b/api/app/tasks/celery_app.py index 24bd810..9305783 100644 --- a/api/app/tasks/celery_app.py +++ b/api/app/tasks/celery_app.py @@ -3,6 +3,7 @@ Celery 应用配置 配置从 app.core.config.settings 读取。 Worker 启动时需聚合注册所有 feature 的 model,否则 User 等 relationship("Order", ...) 解析时会报找不到 Order。 """ + from celery import Celery from app.core.config import settings diff --git a/api/app/tasks/memoir_tasks.py b/api/app/tasks/memoir_tasks.py index f2b2a73..ac8b3cb 100644 --- a/api/app/tasks/memoir_tasks.py +++ b/api/app/tasks/memoir_tasks.py @@ -1,6 +1,7 @@ """ 回忆录处理 Celery 任务 """ + import json from app.core.logging import get_logger import uuid @@ -125,28 +126,30 @@ def _release_chapter_image_lock(chapter_id: str): r.delete(lock_key) -def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None): +def _update_task_status_sync( + user_id: str, task_id: str, status: str, result: Dict = None +): """同步更新任务状态(在 Celery 任务中使用)""" try: r = _get_redis_client(decode_responses=True) - + key = f"task:user:{user_id}:tasks" - + # 获取现有任务信息 data = r.hget(key, task_id) if data: task_info = json.loads(data) else: task_info = {"task_id": task_id} - + task_info["status"] = status task_info["updated_at"] = datetime.now(timezone.utc).isoformat() if result is not None: task_info["result"] = result - + r.hset(key, task_id, json.dumps(task_info)) r.expire(key, 3600) # 1小时过期 - + logger.info(f"任务状态已更新: task_id={task_id}, status={status}") except Exception as e: logger.error(f"更新任务状态失败: {e}") @@ -284,7 +287,16 @@ def _select_placeholders_for_effective_max( return [{**item, "index": index} for index, item in enumerate(selected)] -def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str, category: str, order_index: int, source_segments: list, user_id: str): +def _save_narrative_to_sections( + db: Session, + chapter, + narrative: str, + title: str, + category: str, + order_index: int, + source_segments: list, + user_id: str, +): """ 将带占位符的 narrative 拆成 chapter_sections 并写入;为每段占位符创建 pending 配图。 已有 section 与图片不删除,仅追加新内容。若无封面 MemoirImage 则创建 pending 封面(section_id=None)。 @@ -313,20 +325,25 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str .where(ChapterSection.chapter_id == chapter.id) .order_by(ChapterSection.order_index) ) - .scalars().all() + .scalars() + .all() ) if existing_sections: existing_content = "\n\n".join( - (s.content or "").strip() for s in existing_sections if (s.content or "").strip() + (s.content or "").strip() + for s in existing_sections + if (s.content or "").strip() ) if existing_content and narrative.startswith(existing_content): - new_part = narrative[len(existing_content):].lstrip() + new_part = narrative[len(existing_content) :].lstrip() else: new_part = (narrative or "").strip() if not new_part: chapter.title = title chapter.is_new = True - chapter.source_segments = list(set((chapter.source_segments or []) + (source_segments or []))) + chapter.source_segments = list( + set((chapter.source_segments or []) + (source_segments or [])) + ) return chapter narrative_to_parse = new_part order_base = max(s.order_index for s in existing_sections) + 1 @@ -335,7 +352,11 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str order_base = 0 img_settings = MemoirImageSettings.from_env() - prompt_service = MemoirImagePromptService(llm=None, settings=img_settings) if img_settings.enabled else None + prompt_service = ( + MemoirImagePromptService(llm=None, settings=img_settings) + if img_settings.enabled + else None + ) segments = parse_narrative_to_sections(narrative_to_parse) if not segments: @@ -349,12 +370,9 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str db.add(sec) db.flush() if img_settings.enabled: - stmt_cover = ( - select(MemoirImage) - .where( - MemoirImage.chapter_id == chapter.id, - MemoirImage.section_id.is_(None), - ) + stmt_cover = select(MemoirImage).where( + MemoirImage.chapter_id == chapter.id, + MemoirImage.section_id.is_(None), ) if not db.execute(stmt_cover).scalar_one_or_none(): cover_ph = { @@ -365,7 +383,11 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str cover_asset = build_initial_image_assets( [cover_ph], img_settings.provider, - prompt_service.CATEGORY_STYLE_MAP.get(category, img_settings.default_style) if prompt_service else img_settings.default_style, + prompt_service.CATEGORY_STYLE_MAP.get( + category, img_settings.default_style + ) + if prompt_service + else img_settings.default_style, img_settings.default_size, now_iso, )[0] @@ -374,7 +396,9 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str db.flush() chapter.title = title chapter.is_new = True - chapter.source_segments = list(set((chapter.source_segments or []) + (source_segments or []))) + chapter.source_segments = list( + set((chapter.source_segments or []) + (source_segments or [])) + ) return chapter def _should_have_image(seg: dict, order_idx: int) -> bool: @@ -393,7 +417,11 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str return ph content = (seg.get("content") or "").strip() desc = (content[:50] + "…") if len(content) > 50 else (content or "章节配图") - return {"placeholder": f"{{{{{{{{IMAGE:{desc}}}}}}}}}", "description": desc, "index": order_idx} + return { + "placeholder": f"{{{{{{{{IMAGE:{desc}}}}}}}}}", + "description": desc, + "index": order_idx, + } # 按顺序创建 section,每 3 个 section 对应 1 张配图 for i, seg in enumerate(segments): @@ -402,7 +430,13 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str image_asset = None if img_settings.enabled and _should_have_image(seg, order_idx): ph = _placeholder_for_segment(seg, order_idx) - style = prompt_service.CATEGORY_STYLE_MAP.get(category, img_settings.default_style) if prompt_service else img_settings.default_style + style = ( + prompt_service.CATEGORY_STYLE_MAP.get( + category, img_settings.default_style + ) + if prompt_service + else img_settings.default_style + ) image_asset = build_initial_image_assets( [ph], img_settings.provider, @@ -422,7 +456,9 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str db.flush() if image_asset: # 本段配图与当前 section 绑定,memoir_images.order_index = section.order_index + 1(封面 0 预留) - mi = _memoir_image_from_asset(chapter.id, sec.id, order_idx + 1, image_asset) + mi = _memoir_image_from_asset( + chapter.id, sec.id, order_idx + 1, image_asset + ) db.add(mi) db.flush() sec.image_id = mi.id @@ -430,12 +466,9 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str # 封面图:若无则创建 pending MemoirImage(section_id=None, order_index=0) if img_settings.enabled: - stmt_cover = ( - select(MemoirImage) - .where( - MemoirImage.chapter_id == chapter.id, - MemoirImage.section_id.is_(None), - ) + stmt_cover = select(MemoirImage).where( + MemoirImage.chapter_id == chapter.id, + MemoirImage.section_id.is_(None), ) existing_cover = db.execute(stmt_cover).scalar_one_or_none() if not existing_cover: @@ -447,7 +480,11 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str cover_asset = build_initial_image_assets( [cover_ph], img_settings.provider, - prompt_service.CATEGORY_STYLE_MAP.get(category, img_settings.default_style) if prompt_service else img_settings.default_style, + prompt_service.CATEGORY_STYLE_MAP.get( + category, img_settings.default_style + ) + if prompt_service + else img_settings.default_style, img_settings.default_size, now_iso, )[0] @@ -457,7 +494,9 @@ def _save_narrative_to_sections(db: Session, chapter, narrative: str, title: str chapter.title = title chapter.is_new = True - chapter.source_segments = list(set((chapter.source_segments or []) + (source_segments or []))) + chapter.source_segments = list( + set((chapter.source_segments or []) + (source_segments or [])) + ) return chapter @@ -465,7 +504,9 @@ def initialize_chapter_images(_chapter): """ 兼容旧调用:若章节已改为 sections 存储,则图片初始化已在 _save_narrative_to_sections 中完成,直接返回。 """ - logger.info("initialize_chapter_images: 已由 _save_narrative_to_sections 处理 section 配图,跳过") + logger.info( + "initialize_chapter_images: 已由 _save_narrative_to_sections 处理 section 配图,跳过" + ) return [] @@ -489,7 +530,9 @@ def _coerce_state(model: MemoirState) -> MemoirStateSchema: "stage_order": model.stage_order or default_state().stage_order, "current_stage": model.current_stage, "covered_stages": model.covered_stages or [], - "slots": model.slots if isinstance(model.slots, dict) else default_state().slots, + "slots": model.slots + if isinstance(model.slots, dict) + else default_state().slots, } ) @@ -509,7 +552,10 @@ def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema: stage_order=default.stage_order, current_stage=default.current_stage, covered_stages=default.covered_stages, - slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()}, + slots={ + k: {sk: sv.model_dump() for sk, sv in v.items()} + for k, v in default.slots.items() + }, ) db.add(state) db.commit() @@ -539,7 +585,9 @@ def _update_slot_sync( existing = stage_slots.get(slot_name, {}) merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) - stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump() + stage_slots[slot_name] = SlotData( + snippet=snippet, segment_ids=merged_segment_ids + ).model_dump() slots[stage] = stage_slots state.slots = slots state.current_stage = state.current_stage or stage @@ -552,17 +600,19 @@ def _update_slot_sync( def process_memoir_segments(self, user_id: str, segment_ids: List[str]): """ 处理回忆录段落的 Celery 任务 - + Args: user_id: 用户 ID segment_ids: 段落 ID 列表 """ task_id = self.request.id - logger.info(f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}") - + logger.info( + f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}" + ) + # 更新任务状态为 running _update_task_status_sync(user_id, task_id, "running") - + try: with get_sync_db() as db: chapters_to_enqueue: set[str] = set() @@ -570,11 +620,11 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): stmt = select(Segment).where(Segment.id.in_(segment_ids)) result = db.execute(stmt) segments = result.scalars().all() - + if not segments: logger.warning(f"未找到段落: {segment_ids}") return {"status": "no_segments"} - + # 获取用户状态和资料 state = _get_or_create_state_sync(user_id, db) llm = _get_llm() @@ -615,7 +665,9 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): Chapter.is_active == True, ) .options( - joinedload(Chapter.sections).joinedload(ChapterSection.image_record), + joinedload(Chapter.sections).joinedload( + ChapterSection.image_record + ), joinedload(Chapter.images), ) ) @@ -625,7 +677,9 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): slot_snippets = {} stage_slots = state.slots.get(chapter_category, {}) or {} for key, value in stage_slots.items(): - snip = getattr(value, "snippet", None) or (value.get("snippet") if isinstance(value, dict) else None) + snip = getattr(value, "snippet", None) or ( + value.get("snippet") if isinstance(value, dict) else None + ) if snip: slot_snippets[key] = snip @@ -633,7 +687,9 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): existing_content = "" if chapter and getattr(chapter, "sections", None): existing_content = "\n\n".join( - s.content for s in sorted(chapter.sections, key=lambda x: x.order_index) if (s.content or "").strip() + s.content + for s in sorted(chapter.sections, key=lambda x: x.order_index) + if (s.content or "").strip() ) narrative = combined_text @@ -662,7 +718,11 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): else: narrative = new_narrative - if existing_content and not _is_json_narrative(narrative) and len(narrative) < len(existing_content) * 0.8: + if ( + existing_content + and not _is_json_narrative(narrative) + and len(narrative) < len(existing_content) * 0.8 + ): logger.warning( "内容长度异常: existing=%d, new=%d, category=%s. 回退为追加模式", len(existing_content), @@ -693,7 +753,11 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): or _chapter_has_cover_to_generate(chapter) ) - stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()) + stmt_book = ( + select(Book) + .where(Book.user_id == user_id) + .order_by(Book.updated_at.desc()) + ) result_book = db.execute(stmt_book) book = result_book.scalar_one_or_none() if not book: @@ -721,19 +785,19 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): user_profile=user_profile, user_birth_year=user_birth_year, get_or_create_state=lambda: _get_or_create_state_sync(user_id, db), - update_slot=lambda stage, slot_name, snippet, seg_ids: _update_slot_sync( - user_id, stage, slot_name, snippet, seg_ids, db + update_slot=lambda stage, slot_name, snippet, seg_ids: ( + _update_slot_sync(user_id, stage, slot_name, snippet, seg_ids, db) ), acquire_lock=lambda stage: _acquire_chapter_lock(user_id, stage), release_lock=lambda stage: _release_chapter_lock(user_id, stage), process_category=_process_category, raise_retry=_raise_retry, ) - + # 标记段落为已处理 for seg in segments: seg.processed = True - + db.commit() for chapter_id in sorted(chapters_to_enqueue): @@ -741,21 +805,25 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): logger.info(f"派发章节补图任务: chapter={chapter_id}") generate_chapter_images.delay(chapter_id) except Exception as exc: - logger.warning(f"补图任务派发失败: chapter={chapter_id}, error={exc}") + logger.warning( + f"补图任务派发失败: chapter={chapter_id}, error={exc}" + ) logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}") - + # 更新任务状态为成功 - _update_task_status_sync(user_id, task_id, "success", {"processed": len(segments)}) - + _update_task_status_sync( + user_id, task_id, "success", {"processed": len(segments)} + ) + return {"status": "success", "processed": len(segments)} except Exception as e: logger.error(f"回忆录处理失败: {e}") - + # 更新任务状态为失败 _update_task_status_sync(user_id, task_id, "failure", {"error": str(e)}) - + # 重试 raise self.retry(exc=e) @@ -764,18 +832,18 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): def generate_chapter_content(self, user_id: str, stage: str, new_content: str): """ 单独生成章节内容的任务(用于实时更新) - + Args: user_id: 用户 ID stage: 阶段 new_content: 新内容 """ logger.info(f"生成章节内容: user_id={user_id}, stage={stage}") - + try: with get_sync_db() as db: llm = _get_llm() - + # 查找 active 章节并预加载 sections stmt = ( select(Chapter) @@ -791,7 +859,9 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str): existing_content = "" if chapter and getattr(chapter, "sections", None): existing_content = "\n\n".join( - s.content for s in sorted(chapter.sections, key=lambda x: x.order_index) if (s.content or "").strip() + s.content + for s in sorted(chapter.sections, key=lambda x: x.order_index) + if (s.content or "").strip() ) if llm: @@ -814,10 +884,18 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str): else: narrative = new_narrative else: - narrative = f"{existing_content}\n\n{new_content}" if existing_content else new_content + narrative = ( + f"{existing_content}\n\n{new_content}" + if existing_content + else new_content + ) # 安全检查:新内容不应比旧内容短(仅非 JSON 格式) - if existing_content and not _is_json_narrative(narrative) and len(narrative) < len(existing_content) * 0.8: + if ( + existing_content + and not _is_json_narrative(narrative) + and len(narrative) < len(existing_content) * 0.8 + ): logger.warning( f"内容长度异常: existing={len(existing_content)}, " f"new={len(narrative)}, stage={stage}. 回退为追加模式" @@ -841,14 +919,20 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str): db.commit() db.refresh(chapter) image_settings = MemoirImageSettings.from_env() - if image_settings.enabled and chapter and ( - _chapter_has_any_section_images_to_generate(chapter) - or _chapter_has_cover_to_generate(chapter) + if ( + image_settings.enabled + and chapter + and ( + _chapter_has_any_section_images_to_generate(chapter) + or _chapter_has_cover_to_generate(chapter) + ) ): try: generate_chapter_images.delay(chapter.id) except Exception as exc: - logger.warning("补图任务派发失败: chapter=%s, error=%s", chapter.id, exc) + logger.warning( + "补图任务派发失败: chapter=%s, error=%s", chapter.id, exc + ) return {"status": "success"} except Exception as e: @@ -873,7 +957,9 @@ def generate_chapter_images(self, chapter_id: str): select(Chapter) .where(Chapter.id == chapter_id) .options( - joinedload(Chapter.sections).joinedload(ChapterSection.image_record), + joinedload(Chapter.sections).joinedload( + ChapterSection.image_record + ), joinedload(Chapter.images), ) ) @@ -883,7 +969,9 @@ def generate_chapter_images(self, chapter_id: str): return {"status": "no_chapter"} sections = getattr(chapter, "sections", None) or [] sections_with_pending = [ - (idx, s) for idx, s in enumerate(sections) if _section_has_image_to_generate(s) + (idx, s) + for idx, s in enumerate(sections) + if _section_has_image_to_generate(s) ] cover_rec = _get_cover_memoir_image(chapter) cover_to_generate = ( @@ -894,7 +982,9 @@ def generate_chapter_images(self, chapter_id: str): else None ) if not sections_with_pending and not cover_to_generate: - logger.info("章节补图跳过: chapter=%s, reason=no_pending_images", chapter_id) + logger.info( + "章节补图跳过: chapter=%s, reason=no_pending_images", chapter_id + ) return {"status": "no_images"} settings = MemoirImageSettings.from_env() @@ -943,8 +1033,14 @@ def generate_chapter_images(self, chapter_id: str): _apply_item_to_memoir_image(cover_to_generate, current_item) db.commit() try: - sections_ordered = sorted(sections, key=lambda s: getattr(s, "order_index", 0)) - first_content = (sections_ordered[0].content or "").strip() if sections_ordered else "" + sections_ordered = sorted( + sections, key=lambda s: getattr(s, "order_index", 0) + ) + first_content = ( + (sections_ordered[0].content or "").strip() + if sections_ordered + else "" + ) context_excerpt = " ".join(first_content.split("\n")[:5])[:200] prompt_data = prompt_orchestrator.build_cover_prompt( chapter_title=chapter.title, @@ -961,9 +1057,13 @@ def generate_chapter_images(self, chapter_id: str): image_bytes = _normalize_image_bytes_for_storage( image_generator.download_image(result.image_url) ) - key = build_cos_key(chapter.user_id, chapter.id, "cover", prompt_data["prompt"]) + key = build_cos_key( + chapter.user_id, chapter.id, "cover", prompt_data["prompt"] + ) current_item["storage_key"] = key - current_item["url"] = storage.upload_bytes(image_bytes, key, "image/png") + current_item["url"] = storage.upload_bytes( + image_bytes, key, "image/png" + ) current_item["prompt"] = prompt_data["prompt"] current_item["style"] = prompt_data["style"] current_item["size"] = prompt_data["size"] @@ -982,7 +1082,11 @@ def generate_chapter_images(self, chapter_id: str): failure_msg = f"cover, error={exc}" if isinstance(exc, CosUploadError) and not exc.retryable: permanent_failures.append(failure_msg) - logger.error("封面图上传不可重试,清理: chapter=%s, %s", chapter_id, failure_msg) + logger.error( + "封面图上传不可重试,清理: chapter=%s, %s", + chapter_id, + failure_msg, + ) db.delete(cover_to_generate) db.commit() else: @@ -990,14 +1094,24 @@ def generate_chapter_images(self, chapter_id: str): current_item["status"] = IMAGE_STATUS_FAILED current_item["error"] = str(exc) current_item["retryable"] = True - current_item["updated_at"] = datetime.now(timezone.utc).isoformat() + current_item["updated_at"] = datetime.now( + timezone.utc + ).isoformat() retryable_failures.append(failure_msg) - logger.warning("封面图生成失败(可重试): chapter=%s, %s", chapter_id, failure_msg) + logger.warning( + "封面图生成失败(可重试): chapter=%s, %s", + chapter_id, + failure_msg, + ) _apply_item_to_memoir_image(cover_to_generate, current_item) db.commit() for sec_index, section in sections_with_pending: - item = memoir_image_to_dict(section.image_record) if section.image_record else {} + item = ( + memoir_image_to_dict(section.image_record) + if section.image_record + else {} + ) current_item = dict(item) if item else {} current_item.setdefault("placeholder", "") current_item.setdefault("description", "") @@ -1025,9 +1139,13 @@ def generate_chapter_images(self, chapter_id: str): image_bytes = _normalize_image_bytes_for_storage( image_generator.download_image(result.image_url) ) - key = build_cos_key(chapter.user_id, chapter.id, sec_index, prompt_data["prompt"]) + key = build_cos_key( + chapter.user_id, chapter.id, sec_index, prompt_data["prompt"] + ) current_item["storage_key"] = key - current_item["url"] = storage.upload_bytes(image_bytes, key, "image/png") + current_item["url"] = storage.upload_bytes( + image_bytes, key, "image/png" + ) current_item["prompt"] = prompt_data["prompt"] current_item["style"] = prompt_data["style"] current_item["size"] = prompt_data["size"] @@ -1047,7 +1165,11 @@ def generate_chapter_images(self, chapter_id: str): failure_msg = f"section_index={sec_index}, error={exc}" if isinstance(exc, CosUploadError) and not exc.retryable: permanent_failures.append(failure_msg) - logger.error("图片上传不可重试,清理配图: chapter=%s, %s", chapter_id, failure_msg) + logger.error( + "图片上传不可重试,清理配图: chapter=%s, %s", + chapter_id, + failure_msg, + ) mi = section.image_record section.image_id = None if mi: @@ -1058,8 +1180,14 @@ def generate_chapter_images(self, chapter_id: str): current_item["error"] = str(exc) current_item["retryable"] = True retryable_failures.append(failure_msg) - logger.warning("图片生成失败(可重试): chapter=%s, %s", chapter_id, failure_msg) - current_item["updated_at"] = datetime.now(timezone.utc).isoformat() + logger.warning( + "图片生成失败(可重试): chapter=%s, %s", + chapter_id, + failure_msg, + ) + current_item["updated_at"] = datetime.now( + timezone.utc + ).isoformat() _apply_item_to_memoir_image(section.image_record, current_item) db.commit() diff --git a/api/database/models.py b/api/database/models.py deleted file mode 100644 index 145e009..0000000 --- a/api/database/models.py +++ /dev/null @@ -1,255 +0,0 @@ -""" -数据库模型定义 -""" -from datetime import datetime, timezone -from typing import Optional, List -from sqlalchemy import Column, String, Integer, DateTime, Boolean, Text, ForeignKey, JSON -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import relationship - -Base = declarative_base() - - -def utc_now(): - """返回当前 UTC 时间(带时区信息)""" - return datetime.now(timezone.utc) - - -class User(Base): - """用户表""" - __tablename__ = "users" - - id = Column(String, primary_key=True) - phone = Column(String, unique=True, nullable=False, index=True) # 手机号(唯一,必填) - password_hash = Column(String, nullable=False) # 密码哈希 - email = Column(String, unique=True, nullable=True) # 邮箱(可选) - openid = Column(String, unique=True, nullable=True) # 微信 OpenID(可选) - nickname = Column(String, nullable=False) - avatar_url = Column(String, nullable=True) - subscription_type = Column(String, default="free") # free, premium - subscription_expires_at = Column(DateTime(timezone=True), nullable=True) # 订阅到期时间 - created_at = Column(DateTime(timezone=True), default=utc_now) - - birth_year = Column(Integer, nullable=True) - birth_place = Column(String, nullable=True) - grew_up_place = Column(String, nullable=True) - occupation = Column(String, nullable=True) - - # Relationships - conversations = relationship("Conversation", back_populates="user") - chapters = relationship("Chapter", back_populates="user") - books = relationship("Book", back_populates="user") - orders = relationship("Order", back_populates="user", cascade="all, delete-orphan") - memoir_state = relationship("MemoirState", back_populates="user", uselist=False, cascade="all, delete-orphan") - refresh_tokens = relationship("RefreshToken", back_populates="user", cascade="all, delete-orphan") - - -class Conversation(Base): - """对话会话表""" - __tablename__ = "conversations" - - id = Column(String, primary_key=True) - user_id = Column(String, ForeignKey("users.id"), nullable=False) - started_at = Column(DateTime(timezone=True), default=utc_now) - last_message_at = Column(DateTime(timezone=True), nullable=True) - ended_at = Column(DateTime(timezone=True), nullable=True) - duration_seconds = Column(Integer, default=0) - summary = Column(Text, nullable=True) - status = Column(String, default="active") # active, ended, processing - current_topic = Column(String, nullable=True) - conversation_stage = Column(String, nullable=True) # childhood, education, career, family, beliefs, summary - - # Relationships - user = relationship("User", back_populates="conversations") - segments = relationship("Segment", back_populates="conversation", cascade="all, delete-orphan") - - -class Segment(Base): - """对话段落表""" - __tablename__ = "segments" - - id = Column(String, primary_key=True) - conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False) - audio_url = Column(String, nullable=True) - transcript_text = Column(Text, nullable=False) - created_at = Column(DateTime(timezone=True), default=utc_now) - processed = Column(Boolean, default=False) - topic_category = Column(String, nullable=True) - agent_response = Column(Text, nullable=True) - - # Relationships - conversation = relationship("Conversation", back_populates="segments") - - -class Chapter(Base): - """章节表(正文与插图存于 chapter_sections)""" - __tablename__ = "chapters" - - id = Column(String, primary_key=True) - user_id = Column(String, ForeignKey("users.id"), nullable=False) - title = Column(String, nullable=False) - order_index = Column(Integer, nullable=False) - status = Column(String, default="draft") # draft, completed - cover_image = Column(JSON, nullable=True) # 章节封面图(单条图片元数据) - updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now) - category = Column(String, nullable=True) # 章节分类 - is_new = Column(Boolean, default=True) # 是否为新内容(未读) - is_active = Column(Boolean, default=True) # 是否启用(清除回忆后置为 False) - source_segments = Column(JSON, nullable=True) # 来源 segment IDs 列表 - - # Relationships - user = relationship("User", back_populates="chapters") - sections = relationship( - "ChapterSection", - back_populates="chapter", - order_by="ChapterSection.order_index", - cascade="all, delete-orphan", - ) - images = relationship( - "MemoirImage", - back_populates="chapter", - foreign_keys="MemoirImage.chapter_id", - cascade="all, delete-orphan", - ) - - -class ChapterSection(Base): - """章节段落表:一章多段,每段一段正文 + 可选一张图""" - __tablename__ = "chapter_sections" - - id = Column(String, primary_key=True) - chapter_id = Column(String, ForeignKey("chapters.id", ondelete="CASCADE"), nullable=False) - order_index = Column(Integer, nullable=False) - content = Column(Text, nullable=False) # 本段正文(无占位符) - image_id = Column(String, ForeignKey("memoir_images.id", ondelete="SET NULL"), nullable=True) # 关联 memoir_images.id - updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now) - - # Relationships - chapter = relationship("Chapter", back_populates="sections") - image_record = relationship( - "MemoirImage", - back_populates="section", - uselist=False, - foreign_keys="ChapterSection.image_id", - cascade="all, delete-orphan", - single_parent=True, - ) - - -class MemoirImage(Base): - """章节配图与封面:字段独立存储,section_id 为空表示章节封面,非空表示该 section 的配图""" - __tablename__ = "memoir_images" - - id = Column(String, primary_key=True) - chapter_id = Column(String, ForeignKey("chapters.id", ondelete="CASCADE"), nullable=False) - section_id = Column(String, ForeignKey("chapter_sections.id", ondelete="CASCADE"), nullable=True) - order_index = Column(Integer, nullable=False, default=0) - placeholder = Column(Text, nullable=True) - description = Column(Text, nullable=True) - status = Column(String, nullable=False, default="pending") - prompt = Column(Text, nullable=True) - url = Column(Text, nullable=True) - storage_key = Column(Text, nullable=True) - provider = Column(String, nullable=True) - style = Column(String, nullable=True) - size = Column(String, nullable=True) - error = Column(Text, nullable=True) - retryable = Column(Boolean, nullable=True) - created_at = Column(DateTime(timezone=True), nullable=True) - updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now) - - # Relationships - chapter = relationship("Chapter", back_populates="images") - section = relationship( - "ChapterSection", - back_populates="image_record", - foreign_keys="ChapterSection.image_id", - ) - - -class Book(Base): - """回忆录表""" - __tablename__ = "books" - - id = Column(String, primary_key=True) - user_id = Column(String, ForeignKey("users.id"), nullable=False) - title = Column(String, nullable=False) - total_pages = Column(Integer, default=0) - total_words = Column(Integer, default=0) - cover_image_url = Column(String, nullable=True) - updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now) - has_update = Column(Boolean, default=False) # 是否有新内容 - last_update_chapter_id = Column(String, nullable=True) # 最近更新的章节 ID - - # Relationships - user = relationship("User", back_populates="books") - - -class MemoirState(Base): - """回忆录状态表 - 对话 Agent 与后台 Agent 共享""" - __tablename__ = "memoir_states" - - id = Column(String, primary_key=True) - user_id = Column(String, ForeignKey("users.id"), unique=True, nullable=False) - stage_order = Column(JSON, default=list) # 阶段顺序 - current_stage = Column(String, default="childhood") # 当前阶段 - covered_stages = Column(JSON, default=list) # 已完成阶段列表 - slots = Column(JSON, nullable=False) # 各阶段 slot 信息 - updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now) - - # Relationships - user = relationship("User", back_populates="memoir_state") - - -class RefreshToken(Base): - """刷新令牌表""" - __tablename__ = "refresh_tokens" - - id = Column(String, primary_key=True) - user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True) - token = Column(String, unique=True, nullable=False, index=True) # 刷新令牌(唯一) - expires_at = Column(DateTime(timezone=True), nullable=False) # 过期时间(30天后) - created_at = Column(DateTime(timezone=True), default=utc_now) - is_revoked = Column(Boolean, default=False) # 是否已撤销 - device_info = Column(String, nullable=True) # 设备信息(用于全设备登出) - - # Relationships - user = relationship("User", back_populates="refresh_tokens") - - -class SmsVerificationCode(Base): - """短信验证码表""" - __tablename__ = "sms_verification_codes" - - id = Column(String, primary_key=True) - phone = Column(String, nullable=False, index=True) # 手机号 - code = Column(String, nullable=False) # 6位验证码 - purpose = Column(String, nullable=False) # register/login/reset_password/change_phone - is_used = Column(Boolean, default=False) # 是否已使用 - is_expired = Column(Boolean, default=False) # 是否已过期 - expires_at = Column(DateTime(timezone=True), nullable=False) # 过期时间(5分钟后) - created_at = Column(DateTime(timezone=True), default=utc_now) - verified_at = Column(DateTime(timezone=True), nullable=True) # 验证时间 - ip_address = Column(String, nullable=True) # 请求IP地址 - - -class Order(Base): - """支付订单表""" - __tablename__ = "orders" - - id = Column(String, primary_key=True) # 内部订单号 - user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True) - plan_id = Column(String, nullable=False) # 套餐 ID(free / premium) - plan_name = Column(String, nullable=False) # 套餐名称 - amount = Column(Integer, nullable=False) # 金额(单位:分) - currency = Column(String, default="CNY") - payment_method = Column(String, nullable=False) # wechat / alipay - status = Column(String, default="pending") # pending / paid / failed / cancelled / refunded - trade_no = Column(String, nullable=True, index=True) # 第三方交易号(微信/支付宝) - paid_at = Column(DateTime(timezone=True), nullable=True) # 支付完成时间 - created_at = Column(DateTime(timezone=True), default=utc_now) - expired_at = Column(DateTime(timezone=True), nullable=True) # 订单超时时间 - - # Relationships - user = relationship("User", back_populates="orders") - diff --git a/api/main.py b/api/main.py index 8bd992f..e1479a7 100644 --- a/api/main.py +++ b/api/main.py @@ -1,6 +1,7 @@ """ 入口模块:从 app.main 暴露 app,保证 uvicorn main:app(cwd=api)仍可用。 """ + from app.main import app __all__ = ["app"] diff --git a/api/migrations_legacy/README.md b/api/migrations_legacy/README.md deleted file mode 100644 index 604ce5f..0000000 --- a/api/migrations_legacy/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# 历史迁移(已归档) - -> **注意**:本目录下的 SQL 文件为架构重构前的历史迁移,**不再使用**。 -> 当前数据库 schema 由 **Alembic** 管理,请使用 `alembic upgrade head` 进行迁移。 - -## 说明 - -- 本目录保留仅作历史参考,部署时请勿执行这些 SQL。 -- 新部署或 schema 变更请使用:`uv run alembic upgrade head` diff --git a/api/migrations_legacy/add_chapter_is_active.sql b/api/migrations_legacy/add_chapter_is_active.sql deleted file mode 100644 index 97b878f..0000000 --- a/api/migrations_legacy/add_chapter_is_active.sql +++ /dev/null @@ -1,8 +0,0 @@ --- 为 chapters 表添加 is_active 字段 --- 用于支持"清除回忆"功能:将章节标记为 disabled 而非物理删除 --- 默认值为 TRUE,现有章节全部为 active - -ALTER TABLE chapters ADD COLUMN IF NOT EXISTS is_active BOOLEAN DEFAULT TRUE; - --- 确保现有数据全部为 active -UPDATE chapters SET is_active = TRUE WHERE is_active IS NULL; diff --git a/api/migrations_legacy/add_chapter_sections.sql b/api/migrations_legacy/add_chapter_sections.sql deleted file mode 100644 index b080c0e..0000000 --- a/api/migrations_legacy/add_chapter_sections.sql +++ /dev/null @@ -1,42 +0,0 @@ --- 章节拆分为 chapter_sections:每段正文 + 配图独立存储,chapters 只保留封面图 --- 执行顺序: 1) 本文件 2) python -m scripts.migrate_chapters_to_sections --- 执行方式: psql -U -d -f api/migrations/add_chapter_sections.sql - --- ========== 1. 新建 chapter_sections 表 ========== -CREATE TABLE IF NOT EXISTS chapter_sections ( - id VARCHAR NOT NULL PRIMARY KEY, - chapter_id VARCHAR NOT NULL REFERENCES chapters(id) ON DELETE CASCADE, - order_index INTEGER NOT NULL, - content TEXT NOT NULL, - image JSONB, - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() -); -CREATE INDEX IF NOT EXISTS ix_chapter_sections_chapter_id ON chapter_sections(chapter_id); -CREATE INDEX IF NOT EXISTS ix_chapter_sections_order ON chapter_sections(chapter_id, order_index); - --- 若 chapter_sections 已存在且已提前切换到 image_id 结构,重跑迁移时需要临时补回 image 列, --- 以便后续 Python 脚本先完成 JSON -> memoir_images 的数据搬迁。 -DO $$ -BEGIN - IF NOT EXISTS ( - SELECT 1 - FROM information_schema.columns - WHERE table_schema = 'public' - AND table_name = 'chapter_sections' - AND column_name = 'image' - ) THEN - ALTER TABLE chapter_sections ADD COLUMN image JSONB; - RAISE NOTICE '已补回 chapter_sections.image,供数据迁移使用'; - END IF; -END $$; - --- ========== 2. chapters 表增加 cover_image ========== -DO $$ -BEGIN - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'chapters' AND column_name = 'cover_image') THEN - ALTER TABLE chapters ADD COLUMN cover_image JSONB; - RAISE NOTICE '已添加 chapters.cover_image'; - END IF; -END $$; - --- ========== 3. 回填与删列由脚本 scripts.migrate_chapters_to_sections 完成 ========== diff --git a/api/migrations_legacy/add_device_info_column.sql b/api/migrations_legacy/add_device_info_column.sql deleted file mode 100644 index 4ab6e74..0000000 --- a/api/migrations_legacy/add_device_info_column.sql +++ /dev/null @@ -1,19 +0,0 @@ --- 添加 refresh_tokens.device_info 列的迁移脚本 --- 执行方式: psql -U postgres -d life_echo -f migrations/add_device_info_column.sql --- 或者在 psql 中执行: \i migrations/add_device_info_column.sql - --- 检查列是否存在,如果不存在则添加 -DO $$ -BEGIN - IF NOT EXISTS ( - SELECT 1 - FROM information_schema.columns - WHERE table_name = 'refresh_tokens' - AND column_name = 'device_info' - ) THEN - ALTER TABLE refresh_tokens ADD COLUMN device_info VARCHAR; - RAISE NOTICE '已添加 refresh_tokens.device_info 列'; - ELSE - RAISE NOTICE 'refresh_tokens.device_info 列已存在,跳过'; - END IF; -END $$; diff --git a/api/migrations_legacy/add_memoir_images_table.sql b/api/migrations_legacy/add_memoir_images_table.sql deleted file mode 100644 index e0bc2da..0000000 --- a/api/migrations_legacy/add_memoir_images_table.sql +++ /dev/null @@ -1,32 +0,0 @@ --- 图片独立表:将原先挤在 chapter.cover_image / chapter_sections.image 的 JSON 拆成独立列 --- 执行顺序: 1) 本文件 2) python -m api.scripts.run_memoir_images_migration --- 执行方式: psql -U -d -f api/migrations/add_memoir_images_table.sql - --- ========== 1. 新建 memoir_images 表 ========== -CREATE TABLE IF NOT EXISTS memoir_images ( - id VARCHAR NOT NULL PRIMARY KEY, - chapter_id VARCHAR NOT NULL REFERENCES chapters(id) ON DELETE CASCADE, - section_id VARCHAR NULL REFERENCES chapter_sections(id) ON DELETE CASCADE, - order_index INTEGER NOT NULL DEFAULT 0, - placeholder TEXT, - description TEXT, - status VARCHAR NOT NULL DEFAULT 'pending', - prompt TEXT, - url TEXT, - storage_key TEXT, - provider VARCHAR, - style VARCHAR, - size VARCHAR, - error TEXT, - retryable BOOLEAN, - created_at TIMESTAMP WITH TIME ZONE, - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() -); - -COMMENT ON TABLE memoir_images IS '章节配图与封面:section_id 为空表示章节封面,非空表示该 section 的配图'; -COMMENT ON COLUMN memoir_images.order_index IS '章节内唯一:封面=0,段落配图=1,2,3,...(对应 section.order_index+1)'; - -CREATE INDEX IF NOT EXISTS ix_memoir_images_chapter_id ON memoir_images(chapter_id); -CREATE INDEX IF NOT EXISTS ix_memoir_images_section_id ON memoir_images(section_id); -CREATE UNIQUE INDEX IF NOT EXISTS ix_memoir_images_chapter_cover ON memoir_images(chapter_id) WHERE section_id IS NULL; -CREATE UNIQUE INDEX IF NOT EXISTS ix_memoir_images_section_unique ON memoir_images(section_id) WHERE section_id IS NOT NULL; diff --git a/api/migrations_legacy/add_orders_table.sql b/api/migrations_legacy/add_orders_table.sql deleted file mode 100644 index 6d6b494..0000000 --- a/api/migrations_legacy/add_orders_table.sql +++ /dev/null @@ -1,26 +0,0 @@ --- 添加订单表和用户订阅到期时间字段 --- 执行时间: 2026-02 - --- 1. 为 users 表添加 subscription_expires_at 字段 -ALTER TABLE users ADD COLUMN IF NOT EXISTS subscription_expires_at TIMESTAMP WITH TIME ZONE DEFAULT NULL; - --- 2. 创建 orders 表 -CREATE TABLE IF NOT EXISTS orders ( - id VARCHAR NOT NULL PRIMARY KEY, - user_id VARCHAR NOT NULL REFERENCES users(id), - plan_id VARCHAR NOT NULL, - plan_name VARCHAR NOT NULL, - amount INTEGER NOT NULL, - currency VARCHAR DEFAULT 'CNY', - payment_method VARCHAR NOT NULL, - status VARCHAR DEFAULT 'pending', - trade_no VARCHAR, - paid_at TIMESTAMP WITH TIME ZONE, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - expired_at TIMESTAMP WITH TIME ZONE -); - --- 3. 创建索引 -CREATE INDEX IF NOT EXISTS ix_orders_user_id ON orders(user_id); -CREATE INDEX IF NOT EXISTS ix_orders_trade_no ON orders(trade_no); -CREATE INDEX IF NOT EXISTS ix_orders_status ON orders(status); diff --git a/api/migrations_legacy/add_section_image_id_fk.sql b/api/migrations_legacy/add_section_image_id_fk.sql deleted file mode 100644 index 114418f..0000000 --- a/api/migrations_legacy/add_section_image_id_fk.sql +++ /dev/null @@ -1,18 +0,0 @@ --- section 表用 image_id 关联 memoir_images,不再存 JSON --- 执行顺序:1) 本文件 2) 回填后执行 DROP 语句(或由脚本完成) --- 执行方式: psql -U -d -f api/migrations/add_section_image_id_fk.sql - --- 1. 添加外键列(可空,无默认) -ALTER TABLE chapter_sections -ADD COLUMN IF NOT EXISTS image_id VARCHAR REFERENCES memoir_images(id) ON DELETE SET NULL; - --- 2. 回填:已有 memoir_images 且 section_id 指向本行的,把其 id 写入本行 image_id -UPDATE chapter_sections cs -SET image_id = sub.id -FROM ( - SELECT id, section_id FROM memoir_images WHERE section_id IS NOT NULL -) sub -WHERE sub.section_id = cs.id AND cs.image_id IS NULL; - --- 3. 删除旧的 JSON 列 -ALTER TABLE chapter_sections DROP COLUMN IF EXISTS image; diff --git a/api/migrations_legacy/add_sms_verification.sql b/api/migrations_legacy/add_sms_verification.sql deleted file mode 100644 index 4311d59..0000000 --- a/api/migrations_legacy/add_sms_verification.sql +++ /dev/null @@ -1,49 +0,0 @@ --- 短信验证码功能数据库迁移脚本 --- 执行方式: psql -U postgres -d life_echo -f migrations/add_sms_verification.sql - --- 1. 创建短信验证码表 -CREATE TABLE IF NOT EXISTS sms_verification_codes ( - id VARCHAR PRIMARY KEY, - phone VARCHAR NOT NULL, - code VARCHAR NOT NULL, - purpose VARCHAR NOT NULL, - is_used BOOLEAN DEFAULT FALSE, - is_expired BOOLEAN DEFAULT FALSE, - expires_at TIMESTAMP WITH TIME ZONE NOT NULL, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - verified_at TIMESTAMP WITH TIME ZONE, - ip_address VARCHAR -); - --- 2. 创建索引以提高查询性能 -CREATE INDEX IF NOT EXISTS idx_sms_phone ON sms_verification_codes(phone); -CREATE INDEX IF NOT EXISTS idx_sms_created_at ON sms_verification_codes(created_at); -CREATE INDEX IF NOT EXISTS idx_sms_purpose ON sms_verification_codes(purpose); -CREATE INDEX IF NOT EXISTS idx_sms_phone_purpose ON sms_verification_codes(phone, purpose); - --- 3. 扩展 refresh_tokens 表,添加设备信息字段 -ALTER TABLE refresh_tokens ADD COLUMN IF NOT EXISTS device_info VARCHAR; - --- 4. 添加注释 -COMMENT ON TABLE sms_verification_codes IS '短信验证码表'; -COMMENT ON COLUMN sms_verification_codes.id IS '主键ID'; -COMMENT ON COLUMN sms_verification_codes.phone IS '手机号'; -COMMENT ON COLUMN sms_verification_codes.code IS '6位验证码'; -COMMENT ON COLUMN sms_verification_codes.purpose IS '用途:register/login/reset_password/change_phone'; -COMMENT ON COLUMN sms_verification_codes.is_used IS '是否已使用'; -COMMENT ON COLUMN sms_verification_codes.is_expired IS '是否已过期'; -COMMENT ON COLUMN sms_verification_codes.expires_at IS '过期时间(5分钟后)'; -COMMENT ON COLUMN sms_verification_codes.created_at IS '创建时间'; -COMMENT ON COLUMN sms_verification_codes.verified_at IS '验证时间'; -COMMENT ON COLUMN sms_verification_codes.ip_address IS '请求IP地址'; - -COMMENT ON COLUMN refresh_tokens.device_info IS '设备信息(用于全设备登出)'; - --- 5. 显示迁移完成信息 -DO $$ -BEGIN - RAISE NOTICE '短信验证码功能迁移完成!'; - RAISE NOTICE '- 已创建 sms_verification_codes 表'; - RAISE NOTICE '- 已创建相关索引'; - RAISE NOTICE '- 已扩展 refresh_tokens 表'; -END $$; diff --git a/api/migrations_legacy/add_user_profile_fields.sql b/api/migrations_legacy/add_user_profile_fields.sql deleted file mode 100644 index 1472dcc..0000000 --- a/api/migrations_legacy/add_user_profile_fields.sql +++ /dev/null @@ -1,7 +0,0 @@ --- 添加用户基础资料字段(出生年份、出生地、成长地、职业) -SET lock_timeout = '5s'; -ALTER TABLE users ADD COLUMN IF NOT EXISTS birth_year INTEGER; -ALTER TABLE users ADD COLUMN IF NOT EXISTS birth_place VARCHAR; -ALTER TABLE users ADD COLUMN IF NOT EXISTS grew_up_place VARCHAR; -ALTER TABLE users ADD COLUMN IF NOT EXISTS occupation VARCHAR; -RESET lock_timeout; diff --git a/api/migrations_legacy/add_users_subscription_columns.sql b/api/migrations_legacy/add_users_subscription_columns.sql deleted file mode 100644 index dcf20ac..0000000 --- a/api/migrations_legacy/add_users_subscription_columns.sql +++ /dev/null @@ -1,26 +0,0 @@ --- 为 users 表添加订阅相关列(subscription_type, subscription_expires_at) --- 若列已存在则跳过,可重复执行。 --- 执行方式: psql -U -d -f api/migrations/add_users_subscription_columns.sql - -DO $$ -BEGIN - IF NOT EXISTS ( - SELECT 1 FROM information_schema.columns - WHERE table_schema = 'public' AND table_name = 'users' AND column_name = 'subscription_type' - ) THEN - ALTER TABLE users ADD COLUMN subscription_type VARCHAR DEFAULT 'free'; - RAISE NOTICE '已添加 users.subscription_type 列'; - ELSE - RAISE NOTICE 'users.subscription_type 列已存在,跳过'; - END IF; - - IF NOT EXISTS ( - SELECT 1 FROM information_schema.columns - WHERE table_schema = 'public' AND table_name = 'users' AND column_name = 'subscription_expires_at' - ) THEN - ALTER TABLE users ADD COLUMN subscription_expires_at TIMESTAMP WITH TIME ZONE DEFAULT NULL; - RAISE NOTICE '已添加 users.subscription_expires_at 列'; - ELSE - RAISE NOTICE 'users.subscription_expires_at 列已存在,跳过'; - END IF; -END $$; diff --git a/api/migrations_legacy/fix_chapter_order_index.sql b/api/migrations_legacy/fix_chapter_order_index.sql deleted file mode 100644 index 748e1da..0000000 --- a/api/migrations_legacy/fix_chapter_order_index.sql +++ /dev/null @@ -1,15 +0,0 @@ --- 修复章节 order_index 为 999 的问题 --- 原因:STAGE_KEYWORDS 使用简化阶段名(career, belief), --- 但 CHAPTER_ORDER 使用详细分类名(career_early, beliefs),导致查找失败回退到 999 - --- 根据 category 字段修复 order_index -UPDATE chapters SET order_index = 0 WHERE order_index = 999 AND category = 'childhood'; -UPDATE chapters SET order_index = 1 WHERE order_index = 999 AND category = 'education'; -UPDATE chapters SET order_index = 2 WHERE order_index = 999 AND category = 'career'; -UPDATE chapters SET order_index = 2 WHERE order_index = 999 AND category = 'career_early'; -UPDATE chapters SET order_index = 3 WHERE order_index = 999 AND category = 'career_achievement'; -UPDATE chapters SET order_index = 4 WHERE order_index = 999 AND category = 'career_challenge'; -UPDATE chapters SET order_index = 5 WHERE order_index = 999 AND category = 'family'; -UPDATE chapters SET order_index = 6 WHERE order_index = 999 AND category = 'belief'; -UPDATE chapters SET order_index = 6 WHERE order_index = 999 AND category = 'beliefs'; -UPDATE chapters SET order_index = 7 WHERE order_index = 999 AND category = 'summary'; diff --git a/api/migrations_legacy/fix_chapter_order_index_v2.sql b/api/migrations_legacy/fix_chapter_order_index_v2.sql deleted file mode 100644 index 59ca025..0000000 --- a/api/migrations_legacy/fix_chapter_order_index_v2.sql +++ /dev/null @@ -1,19 +0,0 @@ --- 修正章节排序索引,与 8 分类体系对齐 --- childhood=0, education=1, career_early=2, career_achievement=3, --- career_challenge=4, family=5, beliefs=6, summary=7 -UPDATE chapters SET order_index = 0 WHERE category = 'childhood' AND order_index != 0; -UPDATE chapters SET order_index = 1 WHERE category = 'education' AND order_index != 1; -UPDATE chapters SET order_index = 2 WHERE category = 'career_early' AND order_index != 2; -UPDATE chapters SET order_index = 3 WHERE category = 'career_achievement' AND order_index != 3; -UPDATE chapters SET order_index = 4 WHERE category = 'career_challenge' AND order_index != 4; -UPDATE chapters SET order_index = 5 WHERE category = 'family' AND order_index != 5; -UPDATE chapters SET order_index = 6 WHERE category IN ('belief', 'beliefs') AND order_index != 6; -UPDATE chapters SET order_index = 7 WHERE category = 'summary' AND order_index != 7; - --- 旧的 5-stage "career" 章节归入 career_early -UPDATE chapters SET category = 'career_early', order_index = 2 - WHERE category = 'career'; - --- 旧的 "belief" 统一为 "beliefs" -UPDATE chapters SET category = 'beliefs' - WHERE category = 'belief'; diff --git a/api/migrations_legacy/fix_memoir_images_order_index.sql b/api/migrations_legacy/fix_memoir_images_order_index.sql deleted file mode 100644 index 20507fb..0000000 --- a/api/migrations_legacy/fix_memoir_images_order_index.sql +++ /dev/null @@ -1,9 +0,0 @@ --- 修复 memoir_images 同一章节内 order_index 重复:封面=0,段落配图=1,2,3,...(section.order_index+1) --- 仅更新有 section_id 的段落配图,封面(section_id 为空)保持 0。 --- 执行方式: psql -U -d -f api/migrations/fix_memoir_images_order_index.sql - -UPDATE memoir_images mi -SET order_index = cs.order_index + 1 -FROM chapter_sections cs -WHERE mi.section_id IS NOT NULL - AND mi.section_id = cs.id; diff --git a/api/migrations_legacy/sync_schema_to_models.sql b/api/migrations_legacy/sync_schema_to_models.sql deleted file mode 100644 index 9ff15de..0000000 --- a/api/migrations_legacy/sync_schema_to_models.sql +++ /dev/null @@ -1,141 +0,0 @@ --- 数据库结构同步迁移脚本(与 api/database/models.py 保持一致) --- 幂等:可重复执行,已存在的表/列会跳过。 --- 执行方式: psql -U -d -f api/migrations/sync_schema_to_models.sql --- 执行时间: 2026-02 - --- ========== 1. users 表缺失列 ========== -DO $$ -BEGIN - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'users' AND column_name = 'subscription_type') THEN - ALTER TABLE users ADD COLUMN subscription_type VARCHAR DEFAULT 'free'; - RAISE NOTICE '已添加 users.subscription_type'; - END IF; - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'users' AND column_name = 'subscription_expires_at') THEN - ALTER TABLE users ADD COLUMN subscription_expires_at TIMESTAMP WITH TIME ZONE DEFAULT NULL; - RAISE NOTICE '已添加 users.subscription_expires_at'; - END IF; - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'users' AND column_name = 'openid') THEN - ALTER TABLE users ADD COLUMN openid VARCHAR UNIQUE; - RAISE NOTICE '已添加 users.openid'; - END IF; -END $$; - --- ========== 2. refresh_tokens 表缺失列 ========== -DO $$ -BEGIN - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'refresh_tokens' AND column_name = 'device_info') THEN - ALTER TABLE refresh_tokens ADD COLUMN device_info VARCHAR; - RAISE NOTICE '已添加 refresh_tokens.device_info'; - END IF; -END $$; - --- ========== 3. conversations 表缺失列 ========== -DO $$ -BEGIN - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'conversations' AND column_name = 'current_topic') THEN - ALTER TABLE conversations ADD COLUMN current_topic VARCHAR; - RAISE NOTICE '已添加 conversations.current_topic'; - END IF; - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'conversations' AND column_name = 'conversation_stage') THEN - ALTER TABLE conversations ADD COLUMN conversation_stage VARCHAR; - RAISE NOTICE '已添加 conversations.conversation_stage'; - END IF; - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'conversations' AND column_name = 'last_message_at') THEN - ALTER TABLE conversations ADD COLUMN last_message_at TIMESTAMP WITH TIME ZONE; - RAISE NOTICE '已添加 conversations.last_message_at'; - END IF; -END $$; - -UPDATE conversations -SET last_message_at = started_at -WHERE last_message_at IS NULL - AND started_at IS NOT NULL; - -CREATE INDEX IF NOT EXISTS ix_conversations_last_message_at ON conversations(last_message_at); - --- ========== 4. chapters 表缺失列 ========== -DO $$ -BEGIN - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'chapters' AND column_name = 'category') THEN - ALTER TABLE chapters ADD COLUMN category VARCHAR; - RAISE NOTICE '已添加 chapters.category'; - END IF; - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'chapters' AND column_name = 'is_new') THEN - ALTER TABLE chapters ADD COLUMN is_new BOOLEAN DEFAULT TRUE; - RAISE NOTICE '已添加 chapters.is_new'; - END IF; - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'chapters' AND column_name = 'source_segments') THEN - ALTER TABLE chapters ADD COLUMN source_segments JSONB; - RAISE NOTICE '已添加 chapters.source_segments'; - END IF; - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'chapters' AND column_name = 'images') THEN - ALTER TABLE chapters ADD COLUMN images JSONB; - RAISE NOTICE '已添加 chapters.images'; - END IF; -END $$; - --- ========== 5. books 表缺失列 ========== -DO $$ -BEGIN - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'books' AND column_name = 'has_update') THEN - ALTER TABLE books ADD COLUMN has_update BOOLEAN DEFAULT FALSE; - RAISE NOTICE '已添加 books.has_update'; - END IF; - IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_schema = 'public' AND table_name = 'books' AND column_name = 'last_update_chapter_id') THEN - ALTER TABLE books ADD COLUMN last_update_chapter_id VARCHAR; - RAISE NOTICE '已添加 books.last_update_chapter_id'; - END IF; -END $$; - --- ========== 6. orders 表(若无则创建) ========== -CREATE TABLE IF NOT EXISTS orders ( - id VARCHAR NOT NULL PRIMARY KEY, - user_id VARCHAR NOT NULL REFERENCES users(id), - plan_id VARCHAR NOT NULL, - plan_name VARCHAR NOT NULL, - amount INTEGER NOT NULL, - currency VARCHAR DEFAULT 'CNY', - payment_method VARCHAR NOT NULL, - status VARCHAR DEFAULT 'pending', - trade_no VARCHAR, - paid_at TIMESTAMP WITH TIME ZONE, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - expired_at TIMESTAMP WITH TIME ZONE -); -CREATE INDEX IF NOT EXISTS ix_orders_user_id ON orders(user_id); -CREATE INDEX IF NOT EXISTS ix_orders_trade_no ON orders(trade_no); -CREATE INDEX IF NOT EXISTS ix_orders_status ON orders(status); - --- ========== 7. sms_verification_codes 表(若无则创建) ========== -CREATE TABLE IF NOT EXISTS sms_verification_codes ( - id VARCHAR PRIMARY KEY, - phone VARCHAR NOT NULL, - code VARCHAR NOT NULL, - purpose VARCHAR NOT NULL, - is_used BOOLEAN DEFAULT FALSE, - is_expired BOOLEAN DEFAULT FALSE, - expires_at TIMESTAMP WITH TIME ZONE NOT NULL, - created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - verified_at TIMESTAMP WITH TIME ZONE, - ip_address VARCHAR -); -CREATE INDEX IF NOT EXISTS idx_sms_phone ON sms_verification_codes(phone); -CREATE INDEX IF NOT EXISTS idx_sms_created_at ON sms_verification_codes(created_at); -CREATE INDEX IF NOT EXISTS idx_sms_purpose ON sms_verification_codes(purpose); -CREATE INDEX IF NOT EXISTS idx_sms_phone_purpose ON sms_verification_codes(phone, purpose); - --- ========== 8. memoir_states 表(若无则创建,供 create_all 未执行环境使用) ========== -CREATE TABLE IF NOT EXISTS memoir_states ( - id VARCHAR NOT NULL PRIMARY KEY, - user_id VARCHAR NOT NULL UNIQUE REFERENCES users(id), - stage_order JSONB DEFAULT '[]'::jsonb, - current_stage VARCHAR DEFAULT 'childhood', - covered_stages JSONB DEFAULT '[]'::jsonb, - slots JSONB NOT NULL, - updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() -); - -DO $$ -BEGIN - RAISE NOTICE 'sync_schema_to_models 迁移执行完成'; -END $$; diff --git a/api/routers/chapters.py b/api/routers/chapters.py deleted file mode 100644 index c8663d6..0000000 --- a/api/routers/chapters.py +++ /dev/null @@ -1,305 +0,0 @@ -""" -章节相关 API 路由 -""" -import logging -import os -from typing import List, Optional - -from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload - -from database import get_async_db -from database.models import Chapter as ChapterModel, ChapterSection -from database.models import User as UserModel -from middleware.auth import get_current_user -from app.agents.memoir.prompts import CHAPTER_CATEGORIES, CHAPTER_ORDER, STAGE_TO_ORDER -from services.memoir_images.schema import ( - completed_image_assets, - IMAGE_STATUS_COMPLETED, - IMAGE_STATUS_FAILED, - normalize_image_assets, -) -from services.memoir_images.serializers import memoir_image_to_dict -from services.memoir_images.settings import MemoirImageSettings -from services.memoir_images.storage import ( - CosDownloadUrlError, - TencentCosStorageService, - mark_image_delivery_unavailable, - normalize_cos_url, - resolve_image_storage_key, -) - -router = APIRouter(prefix="/api/chapters", tags=["chapters"]) -logger = logging.getLogger(__name__) - - -def _normalize_image_assets(images: list[dict] | None) -> list[dict]: - bucket = os.getenv("TENCENT_COS_BUCKET", "") - region = os.getenv("TENCENT_COS_REGION", "") - base_url = os.getenv("TENCENT_COS_BASE_URL", "") - storage = TencentCosStorageService.from_env() - settings = MemoirImageSettings.from_env() - source_assets = normalize_image_assets(images) - if not settings.enabled: - source_assets = completed_image_assets(source_assets) - normalized_assets: list[dict] = [] - - for item in source_assets: - asset = dict(item) - normalized_url = normalize_cos_url( - asset.get("url"), - bucket=bucket, - region=region, - base_url=base_url, - ) - storage_key = resolve_image_storage_key(asset) - if asset.get("status") == IMAGE_STATUS_COMPLETED and storage_key: - try: - asset["url"] = storage.get_download_url(storage_key) - except CosDownloadUrlError as exc: - logger.warning( - "章节图片签名失败: key=%s, retryable=%s, request_id=%s, error=%s", - storage_key, exc.retryable, exc.request_id, exc, - ) - asset = mark_image_delivery_unavailable(asset) - except Exception as exc: - logger.warning("章节图片签名失败: key=%s, error=%s", storage_key, exc) - asset = mark_image_delivery_unavailable(asset) - else: - asset["url"] = normalized_url - asset.pop("storage_key", None) - normalized_assets.append(asset) - - return normalized_assets - - -def _is_image_permanently_unavailable(rec) -> bool: - """配图是否应清理:失败不可恢复,或 completed 但无 url/storage_key(损坏数据)""" - if not rec: - return False - status = getattr(rec, "status", None) or "" - retryable = getattr(rec, "retryable", None) - url = getattr(rec, "url", None) - storage_key = getattr(rec, "storage_key", None) - if status == IMAGE_STATUS_FAILED and retryable is False: - return True - if status == IMAGE_STATUS_COMPLETED and not url and not storage_key: - return True - return False - - -async def _cleanup_permanently_unavailable_images(ch: ChapterModel, db: AsyncSession) -> None: - """清理章节中永久不可用的配图:section.image_id 置空,删除 memoir_images 记录""" - sections = getattr(ch, "sections", None) or [] - cleaned = False - for s in sections: - rec = getattr(s, "image_record", None) - if rec and _is_image_permanently_unavailable(rec): - logger.info("清理不可用配图: chapter=%s, section=%s", ch.id, s.id) - s.image_id = None - await db.delete(rec) - cleaned = True - if cleaned: - await db.commit() - await db.refresh(ch) - - -def _section_image_to_dict(section) -> dict | None: - """从 section.image_id 关联的 memoir_images(image_record)取配图。""" - if getattr(section, "image_record", None): - return memoir_image_to_dict(section.image_record) - return None - - -def _chapter_cover_to_dict(ch) -> dict | None: - """优先从 memoir_images 表(section_id 为空的一条)取封面,否则回退到 chapter.cover_image JSON。""" - images = getattr(ch, "images", None) or [] - for m in images: - if getattr(m, "section_id", None) is None: - return memoir_image_to_dict(m) - if getattr(ch, "cover_image", None) and isinstance(ch.cover_image, dict): - return ch.cover_image - return None - - -def _sections_to_content_and_images(ch): - """ - 从 chapter.sections 按 order_index 顺序拼出 content 与 images,保证每段文字与配图一一对应。 - 客户端依赖 content 中的占位符(与 images 中每项的 placeholder 一致)来切分正文并插入图片。 - """ - sections = getattr(ch, "sections", None) or [] - ordered = sorted(sections, key=lambda s: getattr(s, "order_index", 0)) - parts = [] - images = [] - for s in ordered: - text = (s.content or "").strip() - if text: - parts.append(text) - img = _section_image_to_dict(s) - if img: - images.append(img) - placeholder = (img.get("placeholder") or "").strip() - if placeholder: - parts.append(placeholder) - content = "\n\n".join(parts) if parts else "" - return content, images - - -def _chapter_to_dict(ch: ChapterModel) -> dict: - content, images_list = _sections_to_content_and_images(ch) - normalized_images = _normalize_image_assets(images_list) - cover = _chapter_cover_to_dict(ch) - cover_normalized = _normalize_image_assets([cover] if cover else [])[0] if cover else None - sections_data = [] - if getattr(ch, "sections", None): - for s in sorted(ch.sections, key=lambda x: getattr(x, "order_index", 0)): - sec_img = _section_image_to_dict(s) - sec_img = _normalize_image_assets([sec_img] if sec_img else [])[0] if sec_img else None - sections_data.append({"content": (s.content or "").strip(), "image": sec_img}) - return { - "id": ch.id, - "title": ch.title, - "content": content, - "order_index": ch.order_index, - "status": ch.status, - "category": ch.category, - "images": normalized_images, - "cover_image": cover_normalized, - "sections": sections_data, - "updated_at": ch.updated_at.isoformat() if ch.updated_at else None, - "is_new": ch.is_new, - "source_segments": ch.source_segments or [], - } - - -@router.get("", response_model=List[dict]) -async def get_chapters( - current_user: UserModel = Depends(get_current_user), - is_new: Optional[bool] = Query(None, description="仅返回未读章节"), - db: AsyncSession = Depends(get_async_db) -): - """ - 获取用户所有章节(需要认证,仅返回 active 章节)。 - 始终返回全部 8 个预定义类别,没有内容的类别用占位符返回。 - """ - stmt = ( - select(ChapterModel) - .where( - ChapterModel.user_id == current_user.id, - ChapterModel.is_active == True, - ) - .options( - joinedload(ChapterModel.sections), - joinedload(ChapterModel.images), - joinedload(ChapterModel.sections).joinedload(ChapterSection.image_record), - ) - .order_by(ChapterModel.order_index) - ) - if is_new is True: - stmt = stmt.where(ChapterModel.is_new == True) - result = await db.execute(stmt) - chapters = result.unique().scalars().all() - - chapter_by_category: dict[str, ChapterModel] = {} - for ch in chapters: - if ch.category and ch.category not in chapter_by_category: - chapter_by_category[ch.category] = ch - - all_chapters: List[dict] = [] - for category in CHAPTER_ORDER: - ch = chapter_by_category.pop(category, None) - if ch: - await _cleanup_permanently_unavailable_images(ch, db) - all_chapters.append(_chapter_to_dict(ch)) - else: - if is_new is True: - continue - all_chapters.append({ - "id": f"placeholder_{category}", - "title": CHAPTER_CATEGORIES[category], - "content": "", - "order_index": STAGE_TO_ORDER.get(category, 999), - "status": "empty", - "category": category, - "images": [], - "cover_image": None, - "sections": [], - "updated_at": None, - "is_new": False, - "source_segments": [], - }) - - for ch in chapter_by_category.values(): - await _cleanup_permanently_unavailable_images(ch, db) - all_chapters.append(_chapter_to_dict(ch)) - - return all_chapters - - -@router.get("/{chapter_id}", response_model=dict) -async def get_chapter( - chapter_id: str, - current_user: UserModel = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) -): - """获取章节详情(需要认证,只能访问自己的章节)""" - stmt = ( - select(ChapterModel) - .where(ChapterModel.id == chapter_id) - .options( - joinedload(ChapterModel.sections), - joinedload(ChapterModel.images), - joinedload(ChapterModel.sections).joinedload(ChapterSection.image_record), - ) - ) - result = await db.execute(stmt) - chapter = result.unique().scalar_one_or_none() - if not chapter: - raise HTTPException(status_code=404, detail="Chapter not found") - if chapter.user_id != current_user.id: - raise HTTPException(status_code=403, detail="无权访问此章节") - await _cleanup_permanently_unavailable_images(chapter, db) - return _chapter_to_dict(chapter) - - -@router.delete("/{chapter_id}") -async def disable_chapter( - chapter_id: str, - current_user: UserModel = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) -): - """清除章节(将章节标记为 disabled,需要认证,只能操作自己的章节)""" - chapter = await db.get(ChapterModel, chapter_id) - if not chapter: - raise HTTPException(status_code=404, detail="Chapter not found") - - # 验证用户权限 - if chapter.user_id != current_user.id: - raise HTTPException(status_code=403, detail="无权操作此章节") - - # 将章节标记为 disabled(不物理删除) - chapter.is_active = False - await db.commit() - - return {"status": "ok", "message": "章节已清除"} - - -@router.post("/{chapter_id}/regenerate") -async def regenerate_chapter( - chapter_id: str, - current_user: UserModel = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) -): - """重新整理章节(需要认证,只能操作自己的章节)""" - chapter = await db.get(ChapterModel, chapter_id) - if not chapter: - raise HTTPException(status_code=404, detail="Chapter not found") - - # 验证用户权限 - if chapter.user_id != current_user.id: - raise HTTPException(status_code=403, detail="无权操作此章节") - - # TODO: 实现重新整理逻辑 - return {"status": "ok", "message": "Chapter regeneration triggered"} diff --git a/api/routers/conversations.py b/api/routers/conversations.py deleted file mode 100644 index 21a49d0..0000000 --- a/api/routers/conversations.py +++ /dev/null @@ -1,327 +0,0 @@ -""" -对话相关 API 路由 -""" -from datetime import datetime, timezone -from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query, Body -from pydantic import BaseModel -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import func, select -import uuid - -from database import get_async_db, Conversation, Segment, User -from database.models import Conversation as ConversationModel, Segment as SegmentModel -from middleware.auth import get_current_user -from database.models import User as UserModel - -router = APIRouter(prefix="/api/conversations", tags=["conversations"]) - - -def _datetime_to_timestamp_ms(value: datetime | None) -> int: - if value is None: - return int(datetime.now(timezone.utc).timestamp() * 1000) - if value.tzinfo is None: - value = value.replace(tzinfo=timezone.utc) - return int(value.timestamp() * 1000) - - -def _message_timestamp_ms(msg: dict, fallback: datetime | None) -> int: - raw_timestamp = msg.get("timestamp") - if isinstance(raw_timestamp, (int, float)): - return int(raw_timestamp) - if isinstance(raw_timestamp, str): - try: - return int(datetime.fromisoformat(raw_timestamp.replace("Z", "+00:00")).timestamp() * 1000) - except ValueError: - pass - return _datetime_to_timestamp_ms(fallback) - - -def _latest_message_time_ms(conversation: ConversationModel, history: list[dict]) -> int: - if conversation.last_message_at: - return _datetime_to_timestamp_ms(conversation.last_message_at) - if history: - return _message_timestamp_ms(history[-1], conversation.started_at) - return _datetime_to_timestamp_ms(conversation.started_at) - - -def _build_messages_from_history( - conversation_id: str, - history: list[dict], - fallback_timestamp: datetime | None, -) -> list[dict]: - messages: list[dict] = [] - seen_audio_sessions: set[str] = set() - - for idx, msg in enumerate(history): - role = msg.get("role") - message_type = msg.get("messageType", "text") - voice_session_id = msg.get("voiceSessionId") - if role == "human" and message_type == "audio" and voice_session_id: - if voice_session_id in seen_audio_sessions: - continue - seen_audio_sessions.add(voice_session_id) - - messages.append( - { - "id": f"{conversation_id}_msg_{idx}", - "conversationId": conversation_id, - "content": msg.get("content", ""), - "senderType": "user" if role == "human" else "assistant", - "timestamp": _message_timestamp_ms(msg, fallback_timestamp), - "messageType": message_type, - } - ) - - return messages - - -@router.get("") -async def get_conversations( - current_user: UserModel = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) -): - """获取当前用户的所有对话列表(需要认证)""" - stmt = select(ConversationModel).where( - ConversationModel.user_id == current_user.id - ).order_by(func.coalesce(ConversationModel.last_message_at, ConversationModel.started_at).desc()) - result = await db.execute(stmt) - conversations = result.scalars().all() - - # 转换为列表项格式 - from services.redis_service import redis_service - conversation_list = [] - for conv in conversations: - # 从Redis获取最新消息预览 - latest_message = None - history: list[dict] = [] - try: - history = await redis_service.get_conversation_history(conv.id) - if history: - latest_message = history[-1].get("content", "")[:50] # 取前50个字符 - except Exception: - pass - - conversation_list.append({ - "id": conv.id, - "title": conv.summary[:30] if conv.summary else "岁月知己", # 使用summary作为标题,如果没有则使用默认标题 - "avatarUrl": None, - "latestMessagePreview": latest_message or conv.summary, - "latestMessageTime": _latest_message_time_ms(conv, history), - "unreadCount": 0, - "isDefaultAssistant": conv.summary is None # 如果没有summary,则认为是默认助手 - }) - - return conversation_list - - -@router.post("") -async def create_conversation( - current_user: UserModel = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) -): - """创建新对话(需要认证)。对话轮数在每次发送消息时校验。""" - conversation = ConversationModel( - id=str(uuid.uuid4()), - user_id=current_user.id, - started_at=datetime.now(timezone.utc), - status="active" - ) - db.add(conversation) - await db.commit() - await db.refresh(conversation) - - return { - "id": conversation.id, - "user_id": conversation.user_id, - "started_at": conversation.started_at.isoformat(), - "status": conversation.status - } - - -@router.get("/{conversation_id}") -async def get_conversation( - conversation_id: str, - current_user: UserModel = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) -): - """获取对话详情(需要认证,只能访问自己的对话)""" - conversation = await db.get(ConversationModel, conversation_id) - if not conversation: - raise HTTPException(status_code=404, detail="Conversation not found") - - # 验证用户权限 - if conversation.user_id != current_user.id: - raise HTTPException(status_code=403, detail="无权访问此对话") - - return { - "id": conversation.id, - "user_id": conversation.user_id, - "started_at": conversation.started_at.isoformat(), - "ended_at": conversation.ended_at.isoformat() if conversation.ended_at else None, - "duration_seconds": conversation.duration_seconds, - "summary": conversation.summary, - "status": conversation.status, - "current_topic": conversation.current_topic, - "conversation_stage": conversation.conversation_stage - } - - -@router.post("/{conversation_id}/end") -async def end_conversation( - conversation_id: str, - current_user: UserModel = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) -): - """结束对话(需要认证,只能结束自己的对话)""" - conversation = await db.get(ConversationModel, conversation_id) - if not conversation: - raise HTTPException(status_code=404, detail="Conversation not found") - - # 验证用户权限 - if conversation.user_id != current_user.id: - raise HTTPException(status_code=403, detail="无权操作此对话") - - conversation.status = "ended" - conversation.ended_at = datetime.now(timezone.utc) - - if conversation.started_at: - duration = (conversation.ended_at - conversation.started_at).total_seconds() - conversation.duration_seconds = int(duration) - - await db.commit() - - return { - "id": conversation.id, - "status": conversation.status, - "ended_at": conversation.ended_at.isoformat(), - "duration_seconds": conversation.duration_seconds - } - - -@router.delete("/{conversation_id}") -async def delete_conversation( - conversation_id: str, - current_user: UserModel = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) -): - """删除对话(需要认证,只能删除自己的对话)""" - conversation = await db.get(ConversationModel, conversation_id) - if not conversation: - raise HTTPException(status_code=404, detail="Conversation not found") - - # 验证用户权限 - if conversation.user_id != current_user.id: - raise HTTPException(status_code=403, detail="无权删除此对话") - - # 删除Redis中的对话历史 - from services.redis_service import redis_service - try: - await redis_service.clear_conversation_history(conversation_id) - except: - pass - - # 删除数据库中的对话(级联删除segments) - await db.delete(conversation) - await db.commit() - - return {"message": "对话已删除"} - - -@router.get("/{conversation_id}/messages") -async def get_messages( - conversation_id: str, - current_user: UserModel = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) -): - """获取对话的消息列表(需要认证,只能访问自己的对话)""" - # 验证对话存在且属于当前用户 - conversation = await db.get(ConversationModel, conversation_id) - if not conversation: - raise HTTPException(status_code=404, detail="Conversation not found") - - if conversation.user_id != current_user.id: - raise HTTPException(status_code=403, detail="无权访问此对话") - - # 从Redis获取消息历史 - from services.redis_service import redis_service - try: - history = await redis_service.get_conversation_history(conversation_id) - return _build_messages_from_history( - conversation_id=conversation_id, - history=history, - fallback_timestamp=conversation.started_at, - ) - except Exception: - # 如果Redis中没有数据,返回空列表 - return [] - - -@router.post("/{conversation_id}/organize") -async def organize_conversation( - conversation_id: str, - current_user: UserModel = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) -): - """ - 整理对话内容成章节(需要认证,只能整理自己的对话) - - 手动触发对话整理,将对话中的内容整理成回忆录章节 - """ - import logging - logger = logging.getLogger(__name__) - - # 验证对话存在且属于当前用户 - conversation = await db.get(ConversationModel, conversation_id) - if not conversation: - raise HTTPException(status_code=404, detail="Conversation not found") - - if conversation.user_id != current_user.id: - raise HTTPException(status_code=403, detail="无权操作此对话") - - # 获取所有未处理的段落 - stmt = select(SegmentModel).where( - SegmentModel.conversation_id == conversation_id, - SegmentModel.processed == False - ) - result = await db.execute(stmt) - segments = result.scalars().all() - - if not segments: - # 如果没有未处理的段落,尝试处理所有段落 - stmt = select(SegmentModel).where( - SegmentModel.conversation_id == conversation_id - ) - result = await db.execute(stmt) - segments = result.scalars().all() - - if not segments: - raise HTTPException(status_code=400, detail="该对话没有可整理的内容") - - # 免费版仅允许 1 个章节整理,Pro/Pro+ 无限制 - from routers.quota import get_chapter_count, check_can_submit_organize - chapter_count = await get_chapter_count(current_user.id, db) - can_submit, quota_message = check_can_submit_organize( - current_user.subscription_type, chapter_count - ) - if not can_submit: - raise HTTPException(status_code=403, detail=quota_message) - - # 提交到Celery任务处理 - try: - from routers.websocket import manager - from tasks.memoir_tasks import process_memoir_segments - - segment_ids = [seg.id for seg in segments] - process_memoir_segments.delay(conversation.user_id, segment_ids) - logger.info(f"手动触发对话整理: conversation_id={conversation_id}, segments={len(segment_ids)}") - - return { - "message": "对话整理任务已提交", - "conversation_id": conversation_id, - "segments_count": len(segment_ids) - } - except Exception as e: - logger.error(f"提交整理任务失败: {e}") - raise HTTPException(status_code=500, detail=f"提交整理任务失败: {str(e)}") - diff --git a/api/routers/websocket.py b/api/routers/websocket.py deleted file mode 100644 index 7b49a75..0000000 --- a/api/routers/websocket.py +++ /dev/null @@ -1,1148 +0,0 @@ -""" -WebSocket 路由:实时对话通信 -支持异步 Agent 调用和 Redis 会话存储 -""" -import asyncio -import logging -import uuid -from dataclasses import dataclass, field -from datetime import datetime, timezone -from enum import Enum -from typing import Dict, List, Optional, Set, Tuple - -from fastapi import WebSocket, WebSocketDisconnect, HTTPException, status -from starlette.websockets import WebSocketState -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents import ConversationAgent, MemoryAgent -from app.agents.memoir import BackgroundTaskRunner -from database import get_async_db -from database.models import Conversation, Segment -from database.models import User as UserModel -from services.auth_service import verify_token -from services.memoir_state_service import get_or_create_state -from services import asr_service, redis_service -from app.agents.chat.prompts_profile import format_user_profile_context - -logger = logging.getLogger(__name__) -LEGACY_VOICE_SESSION_ID = "legacy" - - -class MessageType(str, Enum): - """WebSocket 消息类型""" - CONNECT = "connect" - RECORDING_STARTED = "recording_started" # 客户端开始录音,用于服务端 5s 后发「我在认真听」 - AUDIO_CHUNK = "audio_chunk" - AUDIO_SEGMENT = "audio_segment" # 分段语音消息(长语音持续上传) - AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音) - TRANSCRIBE_ONLY = "transcribe_only" # 仅转写,不落库、不触发 Agent,只返回转写结果 - TEXT = "text" # 文本消息 - TRANSCRIPT = "transcript" # 语音转文字结果 - AGENT_RESPONSE = "agent_response" - TTS_AUDIO = "tts_audio" - END_CONVERSATION = "end_conversation" - MEMOIR_UPDATE = "memoir_update" - ERROR = "error" - - -# 连接管理 -class ConnectionManager: - """WebSocket 连接管理器""" - - def __init__(self): - self.active_connections: Dict[str, WebSocket] = {} - self.segment_states: Dict[Tuple[str, str], "SegmentStreamState"] = {} - # ConversationAgent 现在是无状态的(会话存储在 Redis),可以复用 - self.conversation_agent = ConversationAgent() - self.memory_agent = MemoryAgent() - self.background_runner = BackgroundTaskRunner() - - async def connect(self, websocket: WebSocket, conversation_id: str): - """建立连接""" - await websocket.accept() - self.active_connections[conversation_id] = websocket - - async def disconnect(self, conversation_id: str): - """断开连接""" - if conversation_id in self.active_connections: - del self.active_connections[conversation_id] - stale_keys = [ - key - for key, state in self.segment_states.items() - if key[0] == conversation_id and not state.active_tasks - ] - for key in stale_keys: - self.segment_states.pop(key, None) - # 清除 Redis 中的会话记忆(可选,也可以保留用于恢复) - # await self.conversation_agent.clear_memory(conversation_id) - - def get_or_create_segment_state( - self, - conversation_id: str, - voice_session_id: str, - ) -> "SegmentStreamState": - state_key = (conversation_id, voice_session_id) - if state_key not in self.segment_states: - self.segment_states[state_key] = SegmentStreamState() - return self.segment_states[state_key] - - def register_segment_task( - self, - conversation_id: str, - voice_session_id: str, - task: asyncio.Task, - ) -> None: - state_key = (conversation_id, voice_session_id) - state = self.get_or_create_segment_state(conversation_id, voice_session_id) - state.active_tasks.add(task) - - def _cleanup(done_task: asyncio.Task) -> None: - state.active_tasks.discard(done_task) - if not state.active_tasks and conversation_id not in self.active_connections: - self.segment_states.pop(state_key, None) - if done_task.cancelled(): - return - exc = done_task.exception() - if exc: - logger.error( - "分段处理任务异常 " - f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}", - exc_info=True, - ) - - task.add_done_callback(_cleanup) - - async def send_message(self, conversation_id: str, message: dict): - """发送消息""" - if conversation_id in self.active_connections: - websocket = self.active_connections[conversation_id] - try: - # 尝试发送消息,如果连接已关闭会抛出异常 - await websocket.send_json(message) - except (RuntimeError, Exception) as e: - logger.warning(f"发送消息失败 (conversation_id={conversation_id}): {e}") - # 如果发送失败,从连接列表中移除 - if conversation_id in self.active_connections: - del self.active_connections[conversation_id] - - async def receive_message(self, conversation_id: str) -> dict: - """接收消息""" - if conversation_id in self.active_connections: - websocket = self.active_connections[conversation_id] - return await websocket.receive_json() - raise HTTPException(status_code=404, detail="Connection not found") - - -manager = ConnectionManager() - - -@dataclass -class SegmentStreamState: - """会话内分段处理状态(用于并行 ASR + 有序聚合)""" - - lock: asyncio.Lock = field(default_factory=asyncio.Lock) - pending_indices: Set[int] = field(default_factory=set) - processed_indices: Set[int] = field(default_factory=set) - buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict) - consumed_index: int = -1 - active_tasks: Set[asyncio.Task] = field(default_factory=set) - # 录音开始约 5s 后只发一次「我在认真听」;若用户提前结束录音则取消待发 - listening_feedback_sent: bool = False - listening_feedback_task: Optional[asyncio.Task] = None - - -def _utc_now() -> datetime: - return datetime.now(timezone.utc) - - -def _mark_conversation_active(conversation: Conversation, at: Optional[datetime] = None) -> datetime: - activity_time = at or _utc_now() - conversation.last_message_at = activity_time - return activity_time - - -def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str: - if voice_session_id: - return str(voice_session_id) - return LEGACY_VOICE_SESSION_ID - - -def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]: - if not client_segment_id: - return None - session_id, separator, _ = client_segment_id.rpartition("-") - if separator and session_id: - return session_id - return None - - -def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str: - """构建分段语音的幂等标识(conversation_id + voice_session_id + segment_index)。""" - return f"audio-segment:{voice_session_id}:{segment_index}" - - -def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]: - """从 audio_url 中解析 voice_session_id 与 segment_index。兼容旧格式 audio-segment:{index}。""" - prefix = "audio-segment:" - if not audio_url or not audio_url.startswith(prefix): - return None - payload = audio_url[len(prefix):] - voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":") - try: - if separator: - return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw)) - return (LEGACY_VOICE_SESSION_ID, int(payload)) - except ValueError: - return None - - -def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]: - scope = _extract_segment_scope(audio_url) - if scope: - return scope[0] - return None - - -def _is_transcribe_failure(transcript_text: Optional[str]) -> bool: - if not transcript_text: - return True - return transcript_text.startswith("转写失败") - - -async def _find_existing_segment_by_index( - db: AsyncSession, - conversation_id: str, - voice_session_id: str, - segment_index: int, -) -> Optional[Segment]: - """ - 按 conversation + voice_session_id + segment_index 查找已落库分段。 - 说明:测试桩的 execute() 不会真正执行 where,所以这里做一次 Python 侧过滤,兼容真实 DB 和单测桩。 - """ - segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index) - stmt = select(Segment).where( - Segment.conversation_id == conversation_id, - Segment.audio_url == segment_audio_url, - ).order_by(Segment.created_at.desc()) - result = await db.execute(stmt) - candidates = result.scalars().all() - for item in candidates: - if item.conversation_id == conversation_id and item.audio_url == segment_audio_url: - return item - return None - - -async def _get_persisted_contiguous_segment_index( - db: AsyncSession, - conversation_id: str, - voice_session_id: str, -) -> int: - """读取数据库中当前 voice session 已连续落库的最大 segment_index,用于重连恢复。""" - stmt = select(Segment).where(Segment.conversation_id == conversation_id) - result = await db.execute(stmt) - candidates = result.scalars().all() - - persisted_indices: Set[int] = set() - for item in candidates: - if item.conversation_id != conversation_id: - continue - segment_scope = _extract_segment_scope(item.audio_url) - if not segment_scope: - continue - item_voice_session_id, item_index = segment_scope - if item_voice_session_id != voice_session_id: - continue - persisted_indices.add(item_index) - - contiguous_index = -1 - while contiguous_index + 1 in persisted_indices: - contiguous_index += 1 - return contiguous_index - - -LISTENING_FEEDBACK_DELAY_SEC = 5.0 -LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。" - - -async def _send_segment_transition_feedback( - conversation_id: str, - segment_index: int, - manager: ConnectionManager, -) -> None: - """发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。""" - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": { - "text": LISTENING_FEEDBACK_TEXT, - "transition": True, - "segment_index": segment_index, - }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - - -async def _delayed_listening_feedback( - conversation_id: str, - voice_session_id: str, - manager: ConnectionManager, -) -> None: - """录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。""" - await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC) - state = manager.get_or_create_segment_state(conversation_id, voice_session_id) - async with state.lock: - if state.listening_feedback_sent: - return - state.listening_feedback_sent = True - state.listening_feedback_task = None - await _send_segment_transition_feedback(conversation_id, 0, manager) - - -async def _process_audio_segment_async( - conversation_id: str, - user_id: str, - voice_session_id: str, - segment_index: int, - audio_base64: str, - audio_duration: int, - is_last: bool, - manager: ConnectionManager, -) -> None: - """分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。""" - state = manager.get_or_create_segment_state(conversation_id, voice_session_id) - - try: - # 每个分段任务使用独立 DB Session,避免与主循环共享同一 AsyncSession 导致并发冲突。 - async for db in get_async_db(): - conversation = await db.get(Conversation, conversation_id) - user = await db.get(UserModel, user_id) - if not conversation: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "对话不存在,分段处理已取消"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - return - if not user: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "用户不存在,分段处理已取消"}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - return - - async with state.lock: - should_prime_state = ( - state.consumed_index < 0 - and not state.processed_indices - and not state.buffered_transcripts - ) - - if should_prime_state: - persisted_contiguous_index = await _get_persisted_contiguous_segment_index( - db=db, - conversation_id=conversation_id, - voice_session_id=voice_session_id, - ) - if persisted_contiguous_index >= 0: - async with state.lock: - state.consumed_index = max(state.consumed_index, persisted_contiguous_index) - - transcript_text = await asr_service.transcribe(audio_base64) - await manager.send_message(conversation_id, { - "type": MessageType.TRANSCRIPT, - "conversation_id": conversation_id, - "data": { - "text": transcript_text or "", - "audio_duration": audio_duration, - "voice_session_id": voice_session_id, - "segment_index": segment_index, - "is_last": is_last, - }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - - if _is_transcribe_failure(transcript_text): - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": { - "message": f"分段 {segment_index} 转写失败,请重试该片段", - "segment_index": segment_index, - }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - return - - existing_segment = await _find_existing_segment_by_index( - db=db, - conversation_id=conversation_id, - voice_session_id=voice_session_id, - segment_index=segment_index, - ) - if existing_segment: - # 该分段已成功入库,视为重传:不重复入库、不重复触发 Agent。 - async with state.lock: - state.processed_indices.add(segment_index) - logger.info( - "分段已存在,按幂等处理跳过: " - f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}" - ) - return - else: - segment = Segment( - id=str(uuid.uuid4()), - conversation_id=conversation_id, - transcript_text=transcript_text or "", - audio_url=_build_segment_audio_url(voice_session_id, segment_index), - processed=False, - ) - db.add(segment) - user_message_timestamp = _mark_conversation_active(conversation) - await db.commit() - await db.refresh(segment) - await manager.background_runner.queue_message(conversation.user_id, segment.id) - - ready_segments: List[Tuple[int, str, Segment]] = [] - async with state.lock: - state.processed_indices.add(segment_index) - state.buffered_transcripts[segment_index] = (transcript_text or "", segment) - - next_index = state.consumed_index + 1 - while next_index in state.buffered_transcripts: - text, seg = state.buffered_transcripts.pop(next_index) - ready_segments.append((next_index, text, seg)) - state.consumed_index = next_index - next_index += 1 - - # 仅当前缀分段连续时才触发 Agent,保证增量上下文顺序正确。 - for _, ordered_text, ordered_segment in ready_segments: - await process_user_message( - conversation_id=conversation_id, - user_message=ordered_text, - conversation=conversation, - segment=ordered_segment, - db=db, - manager=manager, - user=user, - user_message_timestamp=ordered_segment.created_at or user_message_timestamp, - ) - - break - - except Exception as e: - logger.error( - f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}", - exc_info=True, - ) - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": { - "message": f"分段处理失败: {str(e)}", - "segment_index": segment_index, - }, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - finally: - async with state.lock: - state.pending_indices.discard(segment_index) - - -async def websocket_endpoint( - websocket: WebSocket, - conversation_id: str -): - """ - WebSocket 端点:处理实时对话 - - Args: - websocket: WebSocket 连接 - conversation_id: 对话 ID - """ - # 从查询参数获取token - token = websocket.query_params.get("token") - if not token: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌") - return - - # 验证JWT令牌 - payload = verify_token(token) - if not payload or payload.get("type") != "access": - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌") - return - - user_id = payload.get("sub") - if not user_id: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容") - return - - # 验证用户是否存在 - async for db in get_async_db(): - user = await db.get(UserModel, user_id) - if not user: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在") - return - - await manager.connect(websocket, conversation_id) - - try: - # 发送连接确认 - await manager.send_message(conversation_id, { - "type": MessageType.CONNECT, - "conversation_id": conversation_id, - "data": {"status": "connected"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - - # 从数据库获取对话信息 - conversation = await db.get(Conversation, conversation_id) - if not conversation: - # 如果对话不存在,创建新对话 - conversation = Conversation( - id=conversation_id, - user_id=user_id, - started_at=datetime.now(timezone.utc), - status="active" - ) - db.add(conversation) - await db.commit() - else: - # 验证用户权限:只能访问自己的对话 - if conversation.user_id != user_id: - try: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "无权访问此对话"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - except Exception: - pass # 如果发送失败,直接关闭连接 - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话") - return - - - # 首次连接时检查:若 Redis 已有历史(用户曾进入过此对话),不再发送开场白,避免重复/自问自答 - history = await redis_service.get_conversation_history(conversation_id) - if not history: - # 空对话:发送开场白(资料收集或正式访谈) - missing_profile = _get_missing_profile_fields(user) - if missing_profile: - try: - greetings = await manager.conversation_agent.generate_profile_greeting( - conversation_id=conversation_id, - missing_fields=missing_profile, - nickname=user.nickname or "", - ) - import asyncio as _asyncio_greet - for i, text in enumerate(greetings): - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": {"text": text, "index": i, "total": len(greetings)}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - if i < len(greetings) - 1: - await _asyncio_greet.sleep(0.5) - except Exception as e: - logger.error(f"发送资料收集开场白失败: {e}", exc_info=True) - else: - # 资料已完整:AI 先开口提问 - try: - state = await get_or_create_state(user_id, db) - user_profile_context = format_user_profile_context( - birth_year=user.birth_year, - birth_place=user.birth_place, - grew_up_place=user.grew_up_place, - occupation=user.occupation, - ) - opening_messages = await manager.conversation_agent.generate_opening_message( - conversation_id=conversation_id, - memoir_state=state, - user_profile_context=user_profile_context, - ) - import asyncio as _asyncio_open - for i, text in enumerate(opening_messages): - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": {"text": text, "index": i, "total": len(opening_messages)}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - if i < len(opening_messages) - 1: - await _asyncio_open.sleep(0.5) - except Exception as e: - logger.error(f"发送空对话开场白失败: {e}", exc_info=True) - - # 主循环:处理消息 - while True: - try: - if websocket.application_state != WebSocketState.CONNECTED: - logger.info(f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}") - break - message = await websocket.receive_json() - msg_type = message.get("type") - - if msg_type == MessageType.TEXT: - # 处理文本消息 - text_message = message.get("data", {}).get("text", "") - - if text_message: - # 校验对话轮数配额 - from routers.quota import get_segment_count, check_can_send_message - seg_count = await get_segment_count(user_id, db) - can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count) - if not can_send: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - continue - - # 保存段落到数据库 - segment = Segment( - id=str(uuid.uuid4()), - conversation_id=conversation_id, - transcript_text=text_message, - processed=False - ) - db.add(segment) - user_message_timestamp = _mark_conversation_active(conversation) - await db.commit() - await db.refresh(segment) - await manager.background_runner.queue_message(conversation.user_id, segment.id) - - # Agent 生成回应 - await process_user_message( - conversation_id=conversation_id, - user_message=text_message, - conversation=conversation, - segment=segment, - db=db, - manager=manager, - user=user, - user_message_timestamp=segment.created_at or user_message_timestamp, - ) - - elif msg_type == MessageType.RECORDING_STARTED: - # 用户点击开始录音:启动 5s 定时器,到时发一次「我在认真听」 - data = message.get("data", {}) - voice_session_id = _normalize_voice_session_id(data.get("voice_session_id")) - segment_state = manager.get_or_create_segment_state( - conversation_id, - voice_session_id, - ) - async with segment_state.lock: - if segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done(): - continue # 本会话已有待发任务,不重复 - if segment_state.listening_feedback_sent: - continue - delayed_task = asyncio.create_task( - _delayed_listening_feedback( - conversation_id=conversation_id, - voice_session_id=voice_session_id, - manager=manager, - ) - ) - segment_state.listening_feedback_task = delayed_task - - elif msg_type == MessageType.AUDIO_SEGMENT: - # 处理分段语音消息(长语音持续上传) - data = message.get("data", {}) - audio_base64 = data.get("audio_base64", "") - segment_index_raw = data.get("segment_index") - voice_session_id = _normalize_voice_session_id( - data.get("voice_session_id") - or _voice_session_id_from_client_segment_id(data.get("client_segment_id")) - ) - is_last = bool(data.get("is_last", False)) - audio_duration = int(data.get("duration", 0) or 0) - - if not audio_base64: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "缺少 audio_base64"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - continue - - if segment_index_raw is None: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "缺少 segment_index"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - continue - - try: - segment_index = int(segment_index_raw) - except (TypeError, ValueError): - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "segment_index 必须为整数"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - continue - - if segment_index < 0: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "segment_index 不能为负数"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - continue - - # 校验对话轮数配额(分段也计入对话轮次) - from routers.quota import get_segment_count, check_can_send_message - seg_count = await get_segment_count(user_id, db) - can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count) - if not can_send: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - continue - - segment_state = manager.get_or_create_segment_state( - conversation_id, - voice_session_id, - ) - should_process = False - async with segment_state.lock: - already_seen = ( - segment_index in segment_state.pending_indices - or segment_index in segment_state.processed_indices - or segment_index <= segment_state.consumed_index - ) - if not already_seen: - segment_state.pending_indices.add(segment_index) - should_process = True - - if not should_process: - logger.info( - "收到重复分段,跳过处理: " - f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}" - ) - continue - - # 若本段是用户结束录音的最后一段,取消尚未发出的「我在认真听」,避免结束后再说 - if is_last: - async with segment_state.lock: - t = segment_state.listening_feedback_task - segment_state.listening_feedback_task = None - if t is not None and not t.done(): - t.cancel() - - task = asyncio.create_task( - _process_audio_segment_async( - conversation_id=conversation_id, - user_id=user_id, - voice_session_id=voice_session_id, - segment_index=segment_index, - audio_base64=audio_base64, - audio_duration=audio_duration, - is_last=is_last, - manager=manager, - ) - ) - manager.register_segment_task(conversation_id, voice_session_id, task) - - elif msg_type == MessageType.AUDIO_MESSAGE: - # 处理完整音频消息(类似微信语音) - data = message.get("data", {}) - audio_base64 = data.get("audio_base64", "") - audio_duration = data.get("duration", 0) - - if audio_base64: - # 校验对话轮数配额 - from routers.quota import get_segment_count, check_can_send_message - seg_count = await get_segment_count(user_id, db) - can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count) - if not can_send: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - continue - - logger.info(f"收到音频消息,时长: {audio_duration}s") - - try: - # 1. ASR 转写 - transcript_text = await asr_service.transcribe(audio_base64) - logger.info(f"ASR 转写结果: {transcript_text}") - - # 2. 发送转写结果给客户端 - await manager.send_message(conversation_id, { - "type": MessageType.TRANSCRIPT, - "conversation_id": conversation_id, - "data": { - "text": transcript_text, - "audio_duration": audio_duration - }, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - - # 3. 保存段落到数据库(包含转写文本和音频信息) - segment = Segment( - id=str(uuid.uuid4()), - conversation_id=conversation_id, - transcript_text=transcript_text, - audio_url=f"audio:{audio_duration}s", # 简化存储,标记为音频消息 - processed=False - ) - db.add(segment) - user_message_timestamp = _mark_conversation_active(conversation) - await db.commit() - await db.refresh(segment) - await manager.background_runner.queue_message(conversation.user_id, segment.id) - - # 4. Agent 生成回应(基于转写文本) - if transcript_text and not transcript_text.startswith("转写失败"): - await process_user_message( - conversation_id=conversation_id, - user_message=transcript_text, - conversation=conversation, - segment=segment, - db=db, - manager=manager, - user=user, - user_message_timestamp=segment.created_at or user_message_timestamp, - ) - else: - # 转写失败,发送错误消息 - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "语音转写失败,请重试或使用文字输入"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - - except Exception as e: - logger.error(f"处理音频消息失败: {e}", exc_info=True) - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": f"处理音频消息失败: {str(e)}"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - - elif msg_type == MessageType.TRANSCRIBE_ONLY: - # 仅转写:不落库、不触发 Agent,只把识别结果返回给客户端 - data = message.get("data", {}) - audio_base64 = data.get("audio_base64", "") - if not audio_base64: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": "缺少 audio_base64"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - continue - try: - transcript_text = await asr_service.transcribe(audio_base64) - await manager.send_message(conversation_id, { - "type": MessageType.TRANSCRIPT, - "conversation_id": conversation_id, - "data": {"text": transcript_text or ""}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - except Exception as e: - logger.error(f"仅转写失败: {e}", exc_info=True) - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": f"转写失败: {str(e)}"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - - elif msg_type == MessageType.END_CONVERSATION: - # 结束对话 - conversation.status = "ended" - conversation.ended_at = datetime.now(timezone.utc) - await db.commit() - - # 触发整理 Agent - await process_conversation_segments(conversation_id, db) - - await manager.send_message(conversation_id, { - "type": MessageType.END_CONVERSATION, - "conversation_id": conversation_id, - "data": {"status": "ended"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - break - - except RuntimeError as e: - # 检查是否是断开连接或未连接状态(如 accept 前/后连接被关闭) - error_msg = str(e) - if ( - "disconnect" in error_msg.lower() - or "Cannot call \"receive\"" in error_msg - or "accept" in error_msg.lower() and "not connected" in error_msg.lower() - ): - logger.info(f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}") - break - else: - logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True) - # 只在连接仍然活跃时发送错误消息 - if conversation_id in manager.active_connections: - try: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": str(e)}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - except Exception as send_error: - logger.warning(f"发送错误消息失败: {send_error}") - break - except WebSocketDisconnect: - logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}") - break - except Exception as e: - logger.error(f"处理消息时发生错误: {e}", exc_info=True) - # 只在连接仍然活跃时发送错误消息 - if conversation_id in manager.active_connections: - try: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": str(e)}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - except Exception as send_error: - logger.warning(f"发送错误消息失败: {send_error}") - break - - except WebSocketDisconnect: - logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}") - await manager.disconnect(conversation_id) - except Exception as e: - logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True) - await manager.disconnect(conversation_id) - finally: - # 确保清理连接 - await manager.disconnect(conversation_id) - - -def _get_missing_profile_fields(user: UserModel) -> list: - """检查用户缺失的资料字段""" - from app.agents.chat.prompts_profile import get_missing_profile_fields - return get_missing_profile_fields( - birth_year=user.birth_year, - birth_place=user.birth_place, - grew_up_place=user.grew_up_place, - occupation=user.occupation, - ) - - -def _get_filled_profile_fields(user: UserModel) -> dict: - """获取用户已有的资料字段(中文展示)""" - from app.agents.chat.prompts_profile import PROFILE_FIELD_NAMES - filled = {} - if user.birth_year: - filled["birth_year"] = str(user.birth_year) - if user.birth_place: - filled["birth_place"] = user.birth_place - if user.grew_up_place: - filled["grew_up_place"] = user.grew_up_place - if user.occupation: - filled["occupation"] = user.occupation - return filled - - -async def _apply_extracted_profile(user: UserModel, extracted: dict, db: AsyncSession): - """将提取到的资料信息保存到用户模型""" - changed = False - if "birth_year" in extracted and not user.birth_year: - user.birth_year = extracted["birth_year"] - changed = True - if "birth_place" in extracted and not user.birth_place: - user.birth_place = extracted["birth_place"] - changed = True - if "grew_up_place" in extracted and not user.grew_up_place: - user.grew_up_place = extracted["grew_up_place"] - changed = True - if "occupation" in extracted and not user.occupation: - user.occupation = extracted["occupation"] - changed = True - if changed: - await db.commit() - await db.refresh(user) - - -async def process_user_message( - conversation_id: str, - user_message: str, - conversation: Conversation, - segment: Segment, - db: AsyncSession, - manager: ConnectionManager, - user: UserModel = None, - user_message_timestamp: Optional[datetime] = None, -) -> None: - """ - 处理用户消息,生成Agent回应(异步版本) - 支持资料收集模式和正式访谈模式 - """ - import asyncio as _asyncio - - agent = manager.conversation_agent - - # --- 资料收集模式 --- - if user: - missing = _get_missing_profile_fields(user) - if missing: - try: - extracted = await agent.extract_profile_from_message( - user_message, missing, conversation_id=conversation_id - ) - if extracted: - await _apply_extracted_profile(user, extracted, db) - - remaining = _get_missing_profile_fields(user) - filled = _get_filled_profile_fields(user) - is_from_voice = bool(segment.audio_url) - responses = await agent.generate_profile_followup( - conversation_id=conversation_id, - user_message=user_message, - missing_fields=remaining, - filled_fields=filled, - nickname=user.nickname or "", - is_from_voice=is_from_voice, - voice_session_id=_voice_session_id_from_audio_url(segment.audio_url), - user_message_timestamp=user_message_timestamp, - ) - - segment.agent_response = "\n\n".join(responses) - _mark_conversation_active(conversation) - await db.commit() - - for i, response_text in enumerate(responses): - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": {"text": response_text, "index": i, "total": len(responses)}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - if i < len(responses) - 1: - await _asyncio.sleep(0.5) - return - except Exception as e: - logger.error(f"资料收集处理失败: {e}", exc_info=True) - - # --- 正式访谈模式 --- - state = await get_or_create_state(conversation.user_id, db) - - if conversation.conversation_stage != state.current_stage: - conversation.conversation_stage = state.current_stage - await db.commit() - - stmt_segments = select(Segment).where( - Segment.conversation_id == conversation_id - ).order_by(Segment.created_at) - result_segments = await db.execute(stmt_segments) - previous_segments = result_segments.scalars().all() - covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category] - - # 构建用户资料上下文 - user_profile_context = "" - if user: - from app.agents.chat.prompts_profile import format_user_profile_context - user_profile_context = format_user_profile_context( - birth_year=user.birth_year, - birth_place=user.birth_place, - grew_up_place=user.grew_up_place, - occupation=user.occupation, - ) - - try: - is_from_voice = bool(segment.audio_url) - responses = await agent.generate_response_with_state( - conversation_id=conversation_id, - user_message=user_message, - memoir_state=state, - user_profile_context=user_profile_context, - is_from_voice=is_from_voice, - voice_session_id=_voice_session_id_from_audio_url(segment.audio_url), - user_message_timestamp=user_message_timestamp, - ) - - segment.agent_response = "\n\n".join(responses) - _mark_conversation_active(conversation) - await db.commit() - - for i, response_text in enumerate(responses): - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": {"text": response_text, "index": i, "total": len(responses)}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - if i < len(responses) - 1: - await _asyncio.sleep(0.5) - - except Exception as e: - logger.error(f"处理用户消息失败: {e}", exc_info=True) - if conversation_id in manager.active_connections: - try: - await manager.send_message(conversation_id, { - "type": MessageType.ERROR, - "data": {"message": f"生成回应失败: {str(e)}"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - except Exception as send_error: - logger.warning(f"发送错误消息失败: {send_error}") - - -async def process_conversation_segments(conversation_id: str, db: AsyncSession): - """ - 处理对话段落,生成章节(对话结束时调用) - - 注意:大部分处理已通过 Celery 任务增量完成 - 这里立即提交所有待处理的段落到 Celery - - Args: - conversation_id: 对话 ID - db: 数据库会话 - """ - # 获取对话信息 - conversation = await db.get(Conversation, conversation_id) - if not conversation: - return - - # 获取所有未处理的段落 - stmt = select(Segment).where( - Segment.conversation_id == conversation_id, - Segment.processed == False - ) - result = await db.execute(stmt) - segments = result.scalars().all() - - if not segments: - # 没有未处理的段落,直接 flush 待处理任务 - await manager.background_runner.flush_pending(conversation.user_id) - return - - # 免费版仅允许 1 个章节整理,提交前校验 - from database.models import User as UserModel - from routers.quota import get_chapter_count, check_can_submit_organize - user = await db.get(UserModel, conversation.user_id) - if user: - chapter_count = await get_chapter_count(user.id, db) - can_submit, _ = check_can_submit_organize(user.subscription_type, chapter_count) - if not can_submit: - logger.info( - f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}" - ) - await manager.background_runner.flush_pending(conversation.user_id) - return - - # 将未处理的段落直接提交到 Celery(不通过去抖) - segment_ids = [seg.id for seg in segments] - try: - from tasks.memoir_tasks import process_memoir_segments - process_memoir_segments.delay(conversation.user_id, segment_ids) - logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}") - except Exception as e: - logger.error(f"提交 Celery 任务失败: {e}") - - # 同时 flush 任何待处理的任务 - await manager.background_runner.flush_pending(conversation.user_id) diff --git a/api/scripts/__init__.py b/api/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/api/scripts/migrate_chapters_to_sections.py b/api/scripts/migrate_chapters_to_sections.py deleted file mode 100644 index 8e2451e..0000000 --- a/api/scripts/migrate_chapters_to_sections.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -将 chapters 的 content + images 迁移到 chapter_sections,并删除 chapters.content / chapters.images。 - -前置:已执行 api/migrations/add_chapter_sections.sql(创建 chapter_sections 表、chapters.cover_image 列)。 - -用法(在 api 目录下): - python -m scripts.migrate_chapters_to_sections -""" -import json -import os -import sys -import uuid - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from sqlalchemy import text -from app.core.db import sync_engine as engine -from app.features.memoir.memoir_images.parser import split_narrative_to_sections -from app.core.logging import get_logger, setup_logging - -setup_logging() -logger = get_logger(__name__) - - -def run(): - with engine.connect() as conn: - # 检查是否还存在 content 列(若已删则跳过) - r = conn.execute(text(""" - SELECT column_name FROM information_schema.columns - WHERE table_schema = 'public' AND table_name = 'chapters' AND column_name = 'content' - """)) - if r.fetchone() is None: - logger.info("chapters.content 已不存在,跳过迁移") - return - - # 读取所有有 content 的章节(原始列) - rows = conn.execute(text(""" - SELECT id, content, images FROM chapters WHERE content IS NOT NULL AND trim(content) != '' - """)).fetchall() - - for row in rows: - ch_id, content, images_raw = row[0], row[1], row[2] - images = json.loads(images_raw) if isinstance(images_raw, str) else (images_raw or []) - if not isinstance(images, list): - images = [] - - sections = split_narrative_to_sections(content or "") - if not sections: - # 无占位符:整段为一条 section,无图 - section_id = str(uuid.uuid4()).replace("-", "")[:32] - conn.execute(text(""" - INSERT INTO chapter_sections (id, chapter_id, order_index, content, image, updated_at) - VALUES (:id, :ch_id, 0, :content, NULL, NOW()) - """), {"id": section_id, "ch_id": ch_id, "content": (content or "").strip()}) - conn.commit() - logger.info("章节 %s: 1 条 section(无图)", ch_id) - continue - - first_cover = None - img_index = 0 - for order_idx, seg in enumerate(sections): - section_id = str(uuid.uuid4()).replace("-", "")[:32] - seg_content = seg.get("content") or "" - ph = seg.get("placeholder_info") - image_json = None - if ph is not None and img_index < len(images): - image_json = json.dumps(images[img_index]) if isinstance(images[img_index], dict) else None - if first_cover is None and image_json: - first_cover = image_json - img_index += 1 - - conn.execute(text(""" - INSERT INTO chapter_sections (id, chapter_id, order_index, content, image, updated_at) - VALUES (:id, :ch_id, :ord, :content, :img::jsonb, NOW()) - """), { - "id": section_id, - "ch_id": ch_id, - "ord": order_idx, - "content": seg_content, - "img": image_json, - }) - if first_cover: - conn.execute( - text("UPDATE chapters SET cover_image = :img::jsonb WHERE id = :id"), - {"img": first_cover, "id": ch_id}, - ) - conn.commit() - logger.info("章节 %s: %d 条 sections", ch_id, len(sections)) - - # 删除 chapters.content 和 chapters.images - conn.execute(text("ALTER TABLE chapters DROP COLUMN IF EXISTS content")) - conn.execute(text("ALTER TABLE chapters DROP COLUMN IF EXISTS images")) - conn.commit() - logger.info("已删除 chapters.content 与 chapters.images") - - -if __name__ == "__main__": - run() diff --git a/api/scripts/reprocess_user_memoir.py b/api/scripts/reprocess_user_memoir.py deleted file mode 100644 index 6a459be..0000000 --- a/api/scripts/reprocess_user_memoir.py +++ /dev/null @@ -1,686 +0,0 @@ -""" -重新整理用户历史对话为回忆录章节(远程预览 + 确认后写入) - -用法: - cd api - - # 第一步:预览(只读远程 DB,本地生成新章节,输出对比 Markdown) - python -m scripts.reprocess_user_memoir preview --phone 13800138000 - - # 第二步:确认后写入远程 DB - python -m scripts.reprocess_user_memoir apply --phone 13800138000 - -流程: - preview: - 1. SSH 隧道连接远程 PostgreSQL - 2. 读取用户现有章节 + 所有历史对话段落 - 3. 本地调用 LLM 生成新章节(不写入远程 DB) - 4. 输出对比 Markdown 表格 + 保存结果到 JSON 文件 - - apply: - 1. 读取上次 preview 保存的 JSON 文件 - 2. SSH 隧道连接远程 PostgreSQL - 3. 旧章节 is_active=False,写入新章节 -""" -import argparse -import json -import os -import sys -import uuid -import time -from datetime import datetime, timezone -from typing import Dict, List, Optional -from dataclasses import dataclass, field, asdict - -# 确保 api/ 目录在 sys.path 中 -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -# 配置由 app.core.config.settings 统一加载 - -import socket -import subprocess -import signal - -from sqlalchemy import create_engine, select -from sqlalchemy.orm import sessionmaker, Session - -from app.core.db import Base -from app.features.conversation.models import Conversation, Segment -from app.features.memoir.models import Book, Chapter, ChapterSection, MemoirState -from app.features.user.models import User -from app.core.dependencies import get_llm_provider -from app.agents.state_schema import MemoirStateSchema, SlotData, default_state -from app.agents.memoir.prompts import ( - get_creative_title_prompt, - get_narrative_prompt, - get_state_extraction_prompt, - inject_image_placeholder_template, - STAGE_TO_ORDER, -) -from app.features.memoir.memoir_images.json_payload import extract_json_payload -from app.features.memoir.memoir_images.parser import split_narrative_to_sections -from app.core.logging import get_logger, setup_logging - -setup_logging() -logger = get_logger(__name__) - -# ── SSH / DB 配置 ────────────────────────────────────────────── - -SSH_HOST = "1.15.29.57" -SSH_PORT = 22 -SSH_USER = "root" -SSH_KEY_PATH = os.path.join( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))), - "..", "certs", "key.crt", -) - -REMOTE_PG_HOST = "127.0.0.1" -REMOTE_PG_PORT = 5432 -PG_USER = "postgres" -PG_PASSWORD = "postgres" -PG_DATABASE = "life_echo" - -OUTPUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output") - -# ── 关键字阶段检测 ──────────────────────────────────────────── - -STAGE_KEYWORDS = { - "childhood": ["童年", "小时候", "出生", "家乡", "小镇"], - "education": ["上学", "学校", "老师", "同学", "教育", "大学"], - "career": ["工作", "职业", "事业", "公司", "同事", "创业"], - "family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"], - "belief": ["信念", "价值观", "座右铭", "坚持", "原则"], -} - - -def _detect_stage(text: str, fallback: str) -> str: - msg = text.lower() - for stage, keywords in STAGE_KEYWORDS.items(): - if any(w in msg for w in keywords): - return stage - return fallback - - -# ── SSH 隧道 + DB 会话 ──────────────────────────────────────── - - -class SshTunnel: - """用 ssh -L 子进程建立隧道,兼容所有 paramiko 版本""" - - def __init__(self, local_port: int = 15432): - self.local_port = local_port - self._proc: Optional[subprocess.Popen] = None - - def start(self): - key_path = os.path.normpath(SSH_KEY_PATH) - cmd = [ - "ssh", "-N", "-L", - f"{self.local_port}:{REMOTE_PG_HOST}:{REMOTE_PG_PORT}", - "-i", key_path, - "-p", str(SSH_PORT), - "-o", "StrictHostKeyChecking=no", - "-o", "ExitOnForwardFailure=yes", - "-o", "BatchMode=yes", - f"{SSH_USER}@{SSH_HOST}", - ] - logger.info(f"SSH 隧道: {SSH_USER}@{SSH_HOST}:{SSH_PORT} -> 127.0.0.1:{self.local_port}, key={key_path}") - self._proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE) - # 等待隧道端口可连接(最多 15 秒) - for attempt in range(30): - if self._proc.poll() is not None: - err = self._proc.stderr.read().decode() if self._proc.stderr else "" - raise RuntimeError(f"SSH 隧道进程已退出: {err}") - try: - sock = socket.create_connection(("127.0.0.1", self.local_port), timeout=1) - sock.close() - logger.info(f"SSH 隧道已建立, 本地端口: {self.local_port} (耗时 {attempt * 0.5:.1f}s)") - return - except (ConnectionRefusedError, OSError): - time.sleep(0.5) - # 超时 - err = "" - if self._proc.poll() is not None and self._proc.stderr: - err = self._proc.stderr.read().decode() - raise RuntimeError(f"SSH 隧道端口 {self.local_port} 超时未就绪: {err}") - - def stop(self): - if self._proc and self._proc.poll() is None: - self._proc.send_signal(signal.SIGTERM) - self._proc.wait(timeout=5) - logger.info("SSH 隧道已关闭") - - @property - def local_bind_port(self) -> int: - return self.local_port - - -def open_ssh_tunnel() -> SshTunnel: - tunnel = SshTunnel() - tunnel.start() - return tunnel - - -def make_session(tunnel: SshTunnel) -> Session: - url = ( - f"postgresql://{PG_USER}:{PG_PASSWORD}" - f"@127.0.0.1:{tunnel.local_bind_port}/{PG_DATABASE}" - ) - engine = create_engine(url, pool_size=2, max_overflow=2) - return sessionmaker(bind=engine)() - - -# ── 数据结构:保存生成结果 ──────────────────────────────────── - - -@dataclass -class GeneratedChapter: - category: str - title: str - content: str - order_index: int - source_segment_ids: List[str] = field(default_factory=list) - - -@dataclass -class PreviewResult: - user_id: str - phone: str - nickname: str - generated_at: str - old_chapters: List[dict] = field(default_factory=list) # {category, title, content_len, content_preview} - new_chapters: List[dict] = field(default_factory=list) # same shape + full content - - -# ── 核心:本地生成章节 ──────────────────────────────────────── - - -def extract_slots_with_llm(llm, text: str, current_stage: str, stage_slots: dict): - try: - prompt = get_state_extraction_prompt( - user_message=text, - current_stage=current_stage, - stage_slots=stage_slots, - ) - json_llm = llm.bind( - model_kwargs={"response_format": {"type": "json_object"}}, - max_tokens=1024, - ) - response = json_llm.invoke(prompt) - parsed = json.loads(extract_json_payload(response.content.strip())) - return parsed.get("detected_stage", current_stage), parsed.get("slots", {}) or {} - except Exception as e: - logger.warning(f"LLM slot 提取失败: {e}") - return current_stage, {} - - -def generate_chapters_in_memory( - segments: list, # list of (id, transcript_text) - llm, - batch_size: int, - skip_llm_slots: bool, -) -> List[GeneratedChapter]: - """纯内存生成章节,不写任何 DB""" - state = default_state() - - # 1. 阶段检测 & slot 提取(内存 state) - stage_to_segments: Dict[str, list] = {} - - for idx, (seg_id, text) in enumerate(segments, 1): - if not text or not text.strip(): - continue - - detected_stage = _detect_stage(text, state.current_stage) - - if not skip_llm_slots: - try: - detected_stage, extracted_slots = extract_slots_with_llm( - llm, text, state.current_stage, state.slots.get(detected_stage, {}) - ) - # 内存更新 state slots - for slot_name, snippet in extracted_slots.items(): - stage_slots = state.slots.get(detected_stage, {}) - stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=[seg_id]) - state.slots[detected_stage] = stage_slots - state.current_stage = detected_stage - except Exception as e: - logger.warning(f"段落 {idx} slot 提取失败: {e}") - - stage_to_segments.setdefault(detected_stage, []).append((seg_id, text)) - - if idx % 20 == 0: - logger.info(f"阶段检测进度: {idx}/{len(segments)}") - - for stage, segs in stage_to_segments.items(): - logger.info(f"阶段 [{stage}]: {len(segs)} 条段落") - - # 2. 按阶段分批生成 - results: List[GeneratedChapter] = [] - - for stage, seg_list in stage_to_segments.items(): - title = f"{stage} 回忆" - existing_content = "" - all_source_ids: List[str] = [] - - slot_snippets = { - key: value.snippet - for key, value in (state.slots.get(stage, {}) or {}).items() - if value.snippet - } - - for i in range(0, len(seg_list), batch_size): - batch = seg_list[i : i + batch_size] - batch_num = i // batch_size + 1 - total_batches = (len(seg_list) + batch_size - 1) // batch_size - logger.info(f"[{stage}] 处理第 {batch_num}/{total_batches} 批 ({len(batch)} 条)") - - combined_text = "\n\n".join(text for _, text in batch) - source_ids = [sid for sid, _ in batch] - all_source_ids.extend(source_ids) - narrative = combined_text # fallback - - try: - if not existing_content: - # 第一批 → 生成标题 - title_prompt = get_creative_title_prompt( - stage=stage, emotion="neutral", slots=slot_snippets - ) - title_response = llm.invoke(title_prompt) - title = title_response.content.strip().strip('"') - logger.info(f"[{stage}] 生成标题: {title}") - - narrative_prompt = get_narrative_prompt( - stage=stage, - slots=slot_snippets, - new_content=combined_text, - existing_content=existing_content, - ) - narrative_response = llm.invoke(narrative_prompt) - new_narrative = narrative_response.content.strip() - - if existing_content: - narrative = f"{existing_content}\n\n{new_narrative}" - else: - narrative = new_narrative - except Exception as e: - logger.warning(f"[{stage}] LLM 生成失败: {e}") - if existing_content: - narrative = f"{existing_content}\n\n{combined_text}" - - # 安全检查 - if existing_content and len(narrative) < len(existing_content) * 0.8: - logger.warning(f"[{stage}] 内容长度异常, 回退追加模式") - narrative = f"{existing_content}\n\n{combined_text}" - - existing_content = narrative - - logger.info(f"[{stage}] 批次 {batch_num} 完成, 累计长度: {len(existing_content)} 字") - if i + batch_size < len(seg_list): - time.sleep(1) - - # 入库前:占位符位置用正则匹配后拼上固定模板 - content_to_save = inject_image_placeholder_template(existing_content) - results.append(GeneratedChapter( - category=stage, - title=title, - content=content_to_save, - order_index=STAGE_TO_ORDER.get(stage, 999), - source_segment_ids=all_source_ids, - )) - - return results - - -# ── preview 命令 ────────────────────────────────────────────── - - -def cmd_preview(phone: str, batch_size: int, skip_llm_slots: bool): - # LLM - llm = getattr(get_llm_provider(), "langchain_llm", None) - if not llm: - logger.error("LLM 未配置,请检查 .env 中的 DEEPSEEK_API_KEY") - sys.exit(1) - logger.info("LLM 就绪") - - tunnel = open_ssh_tunnel() - try: - db = make_session(tunnel) - try: - # 找用户 - user = db.execute(select(User).where(User.phone == phone)).scalar_one_or_none() - if not user: - logger.error(f"未找到手机号 {phone} 的用户") - sys.exit(1) - user_id = user.id - nickname = user.nickname - logger.info(f"用户: {nickname} (id={user_id})") - - # 读取现有 active 章节(含 sections,正文从 sections 拼接) - from sqlalchemy.orm import joinedload - old_chapters = ( - db.execute( - select(Chapter) - .where(Chapter.user_id == user_id, Chapter.is_active == True) - .options(joinedload(Chapter.sections)) - .order_by(Chapter.order_index) - ) - .unique() - .scalars() - .all() - ) - old_chapter_data = [] - for ch in old_chapters: - content = "" - if getattr(ch, "sections", None): - content = "\n\n".join( - (s.content or "").strip() - for s in sorted(ch.sections, key=lambda x: x.order_index) - if (s.content or "").strip() - ) - content_len = len(content) - content_preview = (content[:200] + "…") if content_len > 200 else content - old_chapter_data.append({ - "category": ch.category, - "title": ch.title, - "content_len": content_len, - "content_preview": content_preview, - }) - logger.info(f"现有章节: {len(old_chapters)} 个") - - # 读取所有段落 - segments_raw = ( - db.execute( - select(Segment.id, Segment.transcript_text) - .join(Conversation, Segment.conversation_id == Conversation.id) - .where(Conversation.user_id == user_id) - .order_by(Segment.created_at.asc()) - ) - .all() - ) - logger.info(f"历史段落: {len(segments_raw)} 条") - if not segments_raw: - logger.warning("没有对话段落,无需处理") - return - - finally: - db.close() - finally: - tunnel.stop() - - # 在本地生成新章节(不需要 DB) - seg_tuples = [(row[0], row[1]) for row in segments_raw] - new_chapters = generate_chapters_in_memory(seg_tuples, llm, batch_size, skip_llm_slots) - - # 构建对比结果 - new_chapter_data = [] - for ch in new_chapters: - new_chapter_data.append({ - "category": ch.category, - "title": ch.title, - "content_len": len(ch.content), - "content_preview": (ch.content[:200] + "…") if len(ch.content) > 200 else ch.content, - "content": ch.content, - "order_index": ch.order_index, - "source_segment_ids": ch.source_segment_ids, - }) - - result = PreviewResult( - user_id=user_id, - phone=phone, - nickname=nickname, - generated_at=datetime.now(timezone.utc).isoformat(), - old_chapters=old_chapter_data, - new_chapters=new_chapter_data, - ) - - # 保存 JSON - os.makedirs(OUTPUT_DIR, exist_ok=True) - json_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.json") - with open(json_path, "w", encoding="utf-8") as f: - json.dump(asdict(result), f, ensure_ascii=False, indent=2) - logger.info(f"预览结果已保存: {json_path}") - - # 输出 Markdown - md_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.md") - md_lines = _build_comparison_markdown(result) - with open(md_path, "w", encoding="utf-8") as f: - f.write(md_lines) - logger.info(f"对比 Markdown 已保存: {md_path}") - - # 同时打印到终端 - print("\n" + md_lines) - - -def _build_comparison_markdown(result: PreviewResult) -> str: - lines = [] - lines.append(f"# 回忆录重整对比 — {result.nickname} ({result.phone})") - lines.append(f"\n生成时间: {result.generated_at}\n") - - # 总览表格 - lines.append("## 总览\n") - lines.append("| 阶段 | 旧标题 | 旧字数 | 新标题 | 新字数 | 变化 |") - lines.append("|------|--------|--------|--------|--------|------|") - - old_map = {ch["category"]: ch for ch in result.old_chapters} - new_map = {ch["category"]: ch for ch in result.new_chapters} - all_stages = list(dict.fromkeys( - [ch["category"] for ch in result.old_chapters] - + [ch["category"] for ch in result.new_chapters] - )) - - total_old = 0 - total_new = 0 - for stage in all_stages: - old = old_map.get(stage) - new = new_map.get(stage) - old_title = old["title"] if old else "—" - old_len = old["content_len"] if old else 0 - new_title = new["title"] if new else "—" - new_len = new["content_len"] if new else 0 - total_old += old_len - total_new += new_len - diff = new_len - old_len - diff_str = f"+{diff}" if diff >= 0 else str(diff) - lines.append(f"| {stage} | {old_title} | {old_len} | {new_title} | {new_len} | {diff_str} |") - - diff_total = total_new - total_old - diff_total_str = f"+{diff_total}" if diff_total >= 0 else str(diff_total) - lines.append(f"| **合计** | | **{total_old}** | | **{total_new}** | **{diff_total_str}** |") - - # 各章节详细对比 - lines.append("\n---\n") - lines.append("## 各章节详细对比\n") - - for stage in all_stages: - old = old_map.get(stage) - new = new_map.get(stage) - lines.append(f"### {stage}\n") - - lines.append("**旧内容预览:**\n") - if old: - lines.append(f"> {old['content_preview']}\n") - else: - lines.append("> (无)\n") - - lines.append("**新内容预览:**\n") - if new: - lines.append(f"> {new['content_preview']}\n") - else: - lines.append("> (无)\n") - - # 新章节完整内容 - lines.append("\n---\n") - lines.append("## 新章节完整内容\n") - for ch in result.new_chapters: - lines.append(f"### {ch['title']} ({ch['category']}, {ch['content_len']} 字)\n") - lines.append(ch["content"]) - lines.append("\n") - - return "\n".join(lines) - - -# ── apply 命令 ──────────────────────────────────────────────── - - -def cmd_apply(phone: str): - json_path = os.path.join(OUTPUT_DIR, f"preview_{phone}.json") - if not os.path.exists(json_path): - logger.error(f"未找到预览文件: {json_path}") - logger.error("请先运行 preview 命令") - sys.exit(1) - - with open(json_path, "r", encoding="utf-8") as f: - data = json.load(f) - - user_id = data["user_id"] - new_chapters = data["new_chapters"] - logger.info(f"将写入 {len(new_chapters)} 个新章节到用户 {data['nickname']} ({user_id})") - - # 确认 - answer = input("\n确认写入远程数据库? (yes/no): ").strip().lower() - if answer != "yes": - logger.info("已取消") - return - - tunnel = open_ssh_tunnel() - try: - db = make_session(tunnel) - try: - # 1. 旧章节 → inactive - old_active = ( - db.execute( - select(Chapter).where( - Chapter.user_id == user_id, Chapter.is_active == True - ) - ) - .scalars() - .all() - ) - for ch in old_active: - ch.is_active = False - logger.info(f"已将 {len(old_active)} 个旧章节标记为 inactive") - - # 2. 删除旧 MemoirState - old_state = db.execute( - select(MemoirState).where(MemoirState.user_id == user_id) - ).scalar_one_or_none() - if old_state: - db.delete(old_state) - logger.info("已删除旧 MemoirState") - - # 3. 创建新 MemoirState - ds = default_state() - db.add(MemoirState( - id=str(uuid.uuid4()), - user_id=user_id, - stage_order=ds.stage_order, - current_stage=ds.current_stage, - covered_stages=ds.covered_stages, - slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in ds.slots.items()}, - )) - - # 4. 插入新章节(无 content/images;正文与配图写入 chapter_sections) - last_chapter_id = None - for ch_data in new_chapters: - ch_id = str(uuid.uuid4()) - chapter = Chapter( - id=ch_id, - user_id=user_id, - title=ch_data["title"], - order_index=ch_data["order_index"], - status="completed", - category=ch_data["category"], - cover_image=None, - is_new=True, - source_segments=ch_data.get("source_segment_ids", []), - ) - db.add(chapter) - db.flush() - content = ch_data.get("content") or "" - sections = split_narrative_to_sections(content) - if not sections: - db.add(ChapterSection( - id=str(uuid.uuid4()).replace("-", "")[:32], - chapter_id=ch_id, - order_index=0, - content=content.strip(), - image=None, - )) - else: - for order_idx, seg in enumerate(sections): - db.add(ChapterSection( - id=str(uuid.uuid4()).replace("-", "")[:32], - chapter_id=ch_id, - order_index=order_idx, - content=(seg.get("content") or "").strip(), - image=None, - )) - last_chapter_id = ch_id - logger.info(f" 新建章节: [{ch_data['category']}] {ch_data['title']} — {ch_data['content_len']} 字") - - # 5. 更新 Book - book = db.execute( - select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()) - ).scalar_one_or_none() - if not book: - book = Book( - id=str(uuid.uuid4()), - user_id=user_id, - title="我的回忆录", - total_pages=0, - total_words=0, - cover_image_url=None, - ) - db.add(book) - book.has_update = True - if last_chapter_id: - book.last_update_chapter_id = last_chapter_id - - # 6. 标记所有段落为已处理 - segs = ( - db.execute( - select(Segment) - .join(Conversation, Segment.conversation_id == Conversation.id) - .where(Conversation.user_id == user_id) - ) - .scalars() - .all() - ) - for seg in segs: - seg.processed = True - - db.commit() - logger.info("远程数据库写入完成!") - - finally: - db.close() - finally: - tunnel.stop() - - -# ── CLI 入口 ────────────────────────────────────────────────── - - -def main(): - parser = argparse.ArgumentParser(description="重新整理用户历史对话为回忆录章节(远程预览+写入)") - sub = parser.add_subparsers(dest="command", required=True) - - # preview - p_preview = sub.add_parser("preview", help="预览:读取远程 DB,本地生成新章节,输出对比") - p_preview.add_argument("--phone", required=True, help="用户手机号") - p_preview.add_argument("--batch-size", type=int, default=5, help="每批段落数(默认 5)") - p_preview.add_argument("--skip-llm-slots", action="store_true", help="跳过 LLM slot 提取") - - # apply - p_apply = sub.add_parser("apply", help="写入:将 preview 结果写入远程 DB") - p_apply.add_argument("--phone", required=True, help="用户手机号(需与 preview 一致)") - - args = parser.parse_args() - - if args.command == "preview": - cmd_preview(phone=args.phone, batch_size=args.batch_size, skip_llm_slots=args.skip_llm_slots) - elif args.command == "apply": - cmd_apply(phone=args.phone) - - -if __name__ == "__main__": - main() diff --git a/api/scripts/run_chapter_sections_migration.py b/api/scripts/run_chapter_sections_migration.py deleted file mode 100644 index 0973d17..0000000 --- a/api/scripts/run_chapter_sections_migration.py +++ /dev/null @@ -1,155 +0,0 @@ -""" -一键执行 chapter_sections 迁移:先执行 SQL 建表/加列,再回填数据并删列。 - -依赖:.env 中 DATABASE_URL,以及 python-dotenv。 -用法(在 api 目录下): - python -m scripts.run_chapter_sections_migration -""" -import json -import os -import sys -import uuid -from pathlib import Path - -# 配置由 app.core.config 加载;DB 使用 psycopg 同步驱动 -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -os.chdir(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from sqlalchemy import create_engine, text -from sqlalchemy.engine import Engine -from app.core.config import settings -from app.core.db import ensure_psycopg_url -from app.core.logging import get_logger, setup_logging - -setup_logging() -logger = get_logger(__name__) - - -def get_engine() -> Engine: - url = (settings.migration_database_url or "").strip() or settings.database_url - return create_engine(ensure_psycopg_url(url), pool_pre_ping=True) - - -def run_sql_migration(engine: Engine): - sql_path = Path(__file__).parent.parent / "migrations" / "add_chapter_sections.sql" - sql = sql_path.read_text(encoding="utf-8") - # 按 DO $$ ... $$; 与普通 ; 拆分,避免把 PL/pgSQL 块拆碎 - stmts = [] - rest = sql - while rest: - rest = rest.lstrip() - if rest.startswith("--"): - rest = rest[rest.find("\n") + 1:] if "\n" in rest else "" - continue - if rest.upper().startswith("DO "): - # 找到 $$; 或 $$ ; - i = rest.find("$$") - if i == -1: - break - j = rest.find("$$", i + 2) - if j == -1: - break - stmts.append(rest[: j + 2].strip() + ";") - rest = rest[j + 2:].lstrip().lstrip(";").lstrip() - continue - idx = rest.find(";") - if idx == -1: - break - part = rest[: idx].strip() - rest = rest[idx + 1:] - if part and not part.startswith("--"): - stmts.append(part + ";") - with engine.begin() as conn: - for i, s in enumerate(stmts): - try: - conn.execute(text(s)) - logger.info(" SQL %s OK", i + 1) - except Exception as e: - if "already exists" in str(e).lower(): - logger.info(" SQL %s (已存在)", i + 1) - continue - raise - logger.info("1/2 SQL 迁移完成") - - -def run_data_migration(engine: Engine): - from app.features.memoir.memoir_images.parser import split_narrative_to_sections - - with engine.connect() as conn: - r = conn.execute(text(""" - SELECT column_name FROM information_schema.columns - WHERE table_schema = 'public' AND table_name = 'chapters' AND column_name = 'content' - """)) - if r.fetchone() is None: - logger.info("chapters.content 已不存在,跳过数据迁移") - return - - rows = conn.execute(text(""" - SELECT id, content, images FROM chapters WHERE content IS NOT NULL AND trim(content) != '' - """)).fetchall() - - for row in rows: - ch_id, content, images_raw = row[0], row[1], row[2] - if isinstance(images_raw, str): - try: - images = json.loads(images_raw) - except Exception: - images = [] - else: - images = images_raw if isinstance(images_raw, list) else [] - - sections = split_narrative_to_sections(content or "") - if not sections: - section_id = str(uuid.uuid4()).replace("-", "")[:32] - conn.execute(text(""" - INSERT INTO chapter_sections (id, chapter_id, order_index, content, image, updated_at) - VALUES (:id, :ch_id, 0, :content, NULL, NOW()) - """), {"id": section_id, "ch_id": ch_id, "content": (content or "").strip()}) - conn.commit() - logger.info("章节 %s: 1 条 section(无图)", ch_id) - continue - - first_cover = None - img_index = 0 - for order_idx, seg in enumerate(sections): - section_id = str(uuid.uuid4()).replace("-", "")[:32] - seg_content = seg.get("content") or "" - ph = seg.get("placeholder_info") - image_json = None - if ph is not None and img_index < len(images): - image_json = json.dumps(images[img_index]) if isinstance(images[img_index], dict) else None - if first_cover is None and image_json: - first_cover = image_json - img_index += 1 - - conn.execute(text(""" - INSERT INTO chapter_sections (id, chapter_id, order_index, content, image, updated_at) - VALUES (:id, :ch_id, :ord, :content, CAST(:img AS jsonb), NOW()) - """), { - "id": section_id, - "ch_id": ch_id, - "ord": order_idx, - "content": seg_content, - "img": image_json, - }) - if first_cover: - conn.execute( - text("UPDATE chapters SET cover_image = CAST(:img AS jsonb) WHERE id = :id"), - {"img": first_cover, "id": ch_id}, - ) - conn.commit() - logger.info("章节 %s: %d 条 sections", ch_id, len(sections)) - - conn.execute(text("ALTER TABLE chapters DROP COLUMN IF EXISTS content")) - conn.execute(text("ALTER TABLE chapters DROP COLUMN IF EXISTS images")) - conn.commit() - logger.info("已删除 chapters.content 与 chapters.images") - logger.info("2/2 数据迁移完成") - - -if __name__ == "__main__": - logger.info("开始 chapter_sections 迁移…") - engine = get_engine() - run_sql_migration(engine) - run_data_migration(engine) - logger.info("迁移全部完成") diff --git a/api/scripts/run_memoir_images_migration.py b/api/scripts/run_memoir_images_migration.py deleted file mode 100644 index cef869b..0000000 --- a/api/scripts/run_memoir_images_migration.py +++ /dev/null @@ -1,196 +0,0 @@ -""" -将 chapters.cover_image 与 chapter_sections.image 的 JSON 数据迁移到 memoir_images 表(字段独立列)。 - -前置:先执行 api/migrations/add_memoir_images_table.sql 建表。 -用法(在项目根目录或 api 目录下): - python -m api.scripts.run_memoir_images_migration -""" -import json -import os -import sys -import uuid - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -os.chdir(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from sqlalchemy import create_engine, text -from sqlalchemy.engine import Engine -from urllib.parse import urlsplit - -from app.core.config import settings -from app.core.db import ensure_psycopg_url -from app.core.logging import get_logger, setup_logging - -setup_logging() -logger = get_logger(__name__) - - -def get_engine() -> Engine: - url = (settings.migration_database_url or "").strip() or settings.database_url - return create_engine(ensure_psycopg_url(url), pool_pre_ping=True) - - -def _row_from_image_json(img: dict | None, chapter_id: str, section_id: str | None, order_index: int) -> dict | None: - if not img or not isinstance(img, dict): - return None - placeholder = (img.get("placeholder") or "").strip() - description = (img.get("description") or "").strip() - if not placeholder and not description: - return None - if not placeholder: - placeholder = f'{{{{IMAGE:{description}}}}}' - created = img.get("created_at") - updated = img.get("updated_at") - if isinstance(created, str) and created: - try: - from datetime import datetime - created = datetime.fromisoformat(created.replace("Z", "+00:00")) - except Exception: - created = None - if isinstance(updated, str) and updated: - try: - from datetime import datetime - updated = datetime.fromisoformat(updated.replace("Z", "+00:00")) - except Exception: - updated = None - return { - "id": str(uuid.uuid4()).replace("-", "")[:32], - "chapter_id": chapter_id, - "section_id": section_id, - "order_index": order_index, - "placeholder": placeholder or None, - "description": description or None, - "status": (img.get("status") or "pending").strip() or "pending", - "prompt": img.get("prompt") or None, - "url": img.get("url") or None, - "storage_key": img.get("storage_key") or None, - "provider": img.get("provider") or None, - "style": img.get("style") or None, - "size": img.get("size") or None, - "error": img.get("error") or None, - "retryable": img.get("retryable") if img.get("retryable") is not None else None, - "created_at": created, - "updated_at": updated, - } - - -def run_sql_migration(engine: Engine): - from pathlib import Path - sql_path = Path(__file__).parent.parent / "migrations" / "add_memoir_images_table.sql" - if not sql_path.exists(): - logger.warning("未找到 %s,请先执行该 SQL 建表", sql_path) - return - sql = sql_path.read_text(encoding="utf-8") - stmts = [] - rest = sql - while rest: - rest = rest.lstrip() - if rest.startswith("--"): - rest = rest[rest.find("\n") + 1:] if "\n" in rest else "" - continue - if rest.upper().startswith("DO "): - i = rest.find("$$") - if i == -1: - break - j = rest.find("$$", i + 2) - if j == -1: - break - stmts.append(rest[: j + 2].strip() + ";") - rest = rest[j + 2:].lstrip().lstrip(";").lstrip() - continue - idx = rest.find(";") - if idx == -1: - break - part = rest[: idx].strip() - rest = rest[idx + 1:] - if part and not part.startswith("--"): - stmts.append(part + ";") - with engine.begin() as conn: - for i, s in enumerate(stmts): - try: - conn.execute(text(s)) - logger.info(" SQL %s OK", i + 1) - except Exception as e: - if "already exists" in str(e).lower(): - logger.info(" SQL %s (已存在)", i + 1) - continue - raise - logger.info("1/2 SQL 迁移完成") - - -def run_data_migration(engine: Engine): - ins = text(""" - INSERT INTO memoir_images ( - id, chapter_id, section_id, order_index, - placeholder, description, status, prompt, url, storage_key, - provider, style, size, error, retryable, created_at, updated_at - ) VALUES ( - :id, :chapter_id, :section_id, :order_index, - :placeholder, :description, :status, :prompt, :url, :storage_key, - :provider, :style, :size, :error, :retryable, :created_at, :updated_at - ) - """) - with engine.connect() as conn: - r = conn.execute(text(""" - SELECT id, cover_image FROM chapters WHERE cover_image IS NOT NULL - """)) - cover_count = 0 - for row in r: - ch_id, cover = row[0], row[1] - if isinstance(cover, str): - try: - cover = json.loads(cover) - except Exception: - cover = None - if not cover or not isinstance(cover, dict): - continue - exists = conn.execute( - text("SELECT 1 FROM memoir_images WHERE chapter_id = :ch_id AND section_id IS NULL"), - {"ch_id": ch_id}, - ).fetchone() - if exists: - continue - row_data = _row_from_image_json(cover, ch_id, None, 0) - if not row_data: - continue - conn.execute(ins, {**row_data, "updated_at": row_data.get("updated_at")}) - conn.commit() - cover_count += 1 - logger.info("封面图迁移: %d 条", cover_count) - - with engine.connect() as conn: - r = conn.execute(text(""" - SELECT id, chapter_id, order_index, image FROM chapter_sections WHERE image IS NOT NULL - """)) - sec_count = 0 - for row in r: - sec_id, ch_id, ord_idx, img = row[0], row[1], row[2], row[3] - if isinstance(img, str): - try: - img = json.loads(img) - except Exception: - img = None - if not img or not isinstance(img, dict): - continue - exists = conn.execute( - text("SELECT 1 FROM memoir_images WHERE section_id = :sec_id"), - {"sec_id": sec_id}, - ).fetchone() - if exists: - continue - row_data = _row_from_image_json(img, ch_id, sec_id, ord_idx + 1) - if not row_data: - continue - conn.execute(ins, {**row_data, "updated_at": row_data.get("updated_at")}) - conn.commit() - sec_count += 1 - logger.info("段落配图迁移: %d 条", sec_count) - logger.info("2/2 数据迁移完成") - - -if __name__ == "__main__": - logger.info("开始 memoir_images 迁移…") - engine = get_engine() - run_sql_migration(engine) - run_data_migration(engine) - logger.info("迁移全部完成") diff --git a/api/tests/conftest.py b/api/tests/conftest.py index f2b754b..817d428 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -1,4 +1,5 @@ """Pytest 配置:确保 api 目录在 path 中,并预先加载所有 model 以解析 relationship 字符串引用。""" + import sys from pathlib import Path diff --git a/api/tests/test_chapters_router_images.py b/api/tests/test_chapters_router_images.py index eb738e2..a58c00d 100644 --- a/api/tests/test_chapters_router_images.py +++ b/api/tests/test_chapters_router_images.py @@ -66,7 +66,9 @@ class ChaptersRouterImagesTest(unittest.TestCase): ) def test_chapter_to_dict_returns_signed_image_urls_for_response(self, storage_cls): storage = Mock() - storage.get_download_url.return_value = "https://signed.example.com/memoirs/u1/c1/0-demo.png?sig=123" + storage.get_download_url.return_value = ( + "https://signed.example.com/memoirs/u1/c1/0-demo.png?sig=123" + ) storage_cls.from_settings.return_value = storage img0 = _image_stub( @@ -79,7 +81,11 @@ class ChaptersRouterImagesTest(unittest.TestCase): section_id=None, order_index=0, ) - sec = type("SectionStub", (), {"content": "", "order_index": 0, "image_record": img0, "image_id": None})() + sec = type( + "SectionStub", + (), + {"content": "", "order_index": 0, "image_record": img0, "image_id": None}, + )() chapter = _chapter_stub(images=[img0], sections=[sec]) payload = _chapter_to_dict(chapter) @@ -101,7 +107,9 @@ class ChaptersRouterImagesTest(unittest.TestCase): }, clear=False, ) - def test_chapter_to_dict_preserves_completed_asset_when_signing_fails(self, storage_cls): + def test_chapter_to_dict_preserves_completed_asset_when_signing_fails( + self, storage_cls + ): storage = Mock() storage.get_download_url.side_effect = CosDownloadUrlError( "cos unavailable", retryable=True, request_id="req-err" @@ -118,7 +126,11 @@ class ChaptersRouterImagesTest(unittest.TestCase): section_id=None, order_index=0, ) - sec = type("SectionStub", (), {"content": "", "order_index": 0, "image_record": img0, "image_id": None})() + sec = type( + "SectionStub", + (), + {"content": "", "order_index": 0, "image_record": img0, "image_id": None}, + )() chapter = _chapter_stub(images=[img0], sections=[sec]) payload = _chapter_to_dict(chapter) @@ -133,7 +145,9 @@ class ChaptersRouterImagesTest(unittest.TestCase): def test_chapter_to_dict_drops_malformed_image_assets(self, storage_cls): storage_cls.from_settings.return_value = Mock() # 无 sections 时 content/images 来自 _sections_to_content_and_images 得到 [];无有效封面(images 的 section_id 非空) - img = _image_stub(status="completed", placeholder="", description="", section_id="sec1") + img = _image_stub( + status="completed", placeholder="", description="", section_id="sec1" + ) chapter = _chapter_stub(images=[img], sections=[]) payload = _chapter_to_dict(chapter) @@ -142,9 +156,13 @@ class ChaptersRouterImagesTest(unittest.TestCase): @patch("app.features.memoir.router.MemoirImageSettings") @patch("app.features.memoir.router.TencentCosStorageService") - def test_chapter_to_dict_hides_non_completed_assets_when_feature_disabled(self, storage_cls, memoir_img_settings_cls): + def test_chapter_to_dict_hides_non_completed_assets_when_feature_disabled( + self, storage_cls, memoir_img_settings_cls + ): storage = Mock() - storage.get_download_url.return_value = "https://signed.example.com/0.png?sig=123" + storage.get_download_url.return_value = ( + "https://signed.example.com/0.png?sig=123" + ) storage_cls.from_settings.return_value = storage memoir_img_settings_cls.from_settings.return_value = Mock(enabled=False) @@ -158,7 +176,16 @@ class ChaptersRouterImagesTest(unittest.TestCase): order_index=0, section_id="s1", ) - sec = type("SectionStub", (), {"content": "", "order_index": 0, "image_record": img_completed, "image_id": None})() + sec = type( + "SectionStub", + (), + { + "content": "", + "order_index": 0, + "image_record": img_completed, + "image_id": None, + }, + )() chapter = _chapter_stub(images=[img_completed], sections=[sec]) payload = _chapter_to_dict(chapter) @@ -168,7 +195,9 @@ class ChaptersRouterImagesTest(unittest.TestCase): @patch("app.features.memoir.router.TencentCosStorageService") @patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False) - def test_chapter_to_dict_preserves_retryable_flag_for_failed_assets(self, storage_cls): + def test_chapter_to_dict_preserves_retryable_flag_for_failed_assets( + self, storage_cls + ): storage_cls.from_settings.return_value = Mock() img = _image_stub( @@ -180,7 +209,11 @@ class ChaptersRouterImagesTest(unittest.TestCase): retryable=False, order_index=0, ) - sec = type("SectionStub", (), {"content": "", "order_index": 0, "image_record": img, "image_id": None})() + sec = type( + "SectionStub", + (), + {"content": "", "order_index": 0, "image_record": img, "image_id": None}, + )() chapter = _chapter_stub(images=[img], sections=[sec]) payload = _chapter_to_dict(chapter) diff --git a/api/tests/test_conversation.py b/api/tests/test_conversation.py index 1ef1a09..14c748a 100644 --- a/api/tests/test_conversation.py +++ b/api/tests/test_conversation.py @@ -3,6 +3,7 @@ 多轮对话测试脚本 测试对话引导 Agent 和回忆录整理功能 """ + import asyncio import json import uuid @@ -26,7 +27,6 @@ CONVERSATION_MESSAGES = [ "奶奶家有个小院子,夏天的时候我们经常坐在院子里乘凉,她给我讲故事。", "那段时光真的很美好,我记得奶奶总是给我做红烧肉,那是我最爱吃的菜。", "小时候最开心的事就是过年,可以放鞭炮,还能收到压岁钱。", - # 教育阶段 "后来我去城里上学了,那是我第一次离开家,心里特别害怕。", "初中的时候遇到了一个很好的语文老师,她鼓励我多读书,对我影响很大。", @@ -36,12 +36,12 @@ CONVERSATION_MESSAGES = [ class ConversationTester: """对话测试器""" - + def __init__(self): self.token = None self.user_id = None self.conversation_id = str(uuid.uuid4()) - + async def register_or_login(self): """注册或登录用户""" async with httpx.AsyncClient(timeout=30.0) as client: @@ -52,10 +52,10 @@ class ConversationTester: json={ "phone": TEST_PHONE, "password": TEST_PASSWORD, - "nickname": TEST_NICKNAME - } + "nickname": TEST_NICKNAME, + }, ) - + if resp.status_code == 201: data = resp.json() self.token = data["access_token"] @@ -65,10 +65,7 @@ class ConversationTester: print(f"ℹ️ 用户已存在,尝试登录...") resp = await client.post( f"{BASE_URL}/api/auth/login", - json={ - "phone": TEST_PHONE, - "password": TEST_PASSWORD - } + json={"phone": TEST_PHONE, "password": TEST_PASSWORD}, ) if resp.status_code == 200: data = resp.json() @@ -78,98 +75,97 @@ class ConversationTester: raise Exception(f"登录失败: {resp.text}") else: raise Exception(f"注册失败: {resp.text}") - + print(f"🔑 Token: {self.token[:30]}...") - + async def get_memoir_state(self): """获取回忆录状态""" async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.get( f"{BASE_URL}/api/memoir-state", - headers={"Authorization": f"Bearer {self.token}"} + headers={"Authorization": f"Bearer {self.token}"}, ) if resp.status_code != 200: print(f" ⚠️ 状态API返回 {resp.status_code}: {resp.text[:200]}") return {"current_stage": "unknown", "covered_stages": [], "slots": {}} return resp.json() - + async def get_chapters(self): """获取章节列表""" async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.get( f"{BASE_URL}/api/chapters", - headers={"Authorization": f"Bearer {self.token}"} + headers={"Authorization": f"Bearer {self.token}"}, ) if resp.status_code != 200: print(f" ⚠️ 章节API返回 {resp.status_code}: {resp.text[:200]}") return [] return resp.json() - + async def get_book(self): """获取回忆录信息""" async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.get( f"{BASE_URL}/api/books/current", - headers={"Authorization": f"Bearer {self.token}"} + headers={"Authorization": f"Bearer {self.token}"}, ) if resp.status_code != 200: print(f" ⚠️ 回忆录API返回 {resp.status_code}: {resp.text[:200]}") return {"message": "获取失败"} return resp.json() - + async def get_tasks_status(self): """获取任务状态""" async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.get( f"{BASE_URL}/api/tasks/status", - headers={"Authorization": f"Bearer {self.token}"} + headers={"Authorization": f"Bearer {self.token}"}, ) if resp.status_code != 200: return {"total": 0, "all_completed": True, "tasks": []} return resp.json() - + async def clear_tasks(self): """清除任务记录""" async with httpx.AsyncClient(timeout=60.0) as client: await client.delete( f"{BASE_URL}/api/tasks/clear", - headers={"Authorization": f"Bearer {self.token}"} + headers={"Authorization": f"Bearer {self.token}"}, ) - + async def run_conversation(self): """运行多轮对话""" print(f"\n🔗 连接 WebSocket: {self.conversation_id}") - + ws_url = f"{WS_URL}/ws/conversation/{self.conversation_id}?token={self.token}" - + async with websockets.connect(ws_url) as ws: # 接收连接确认 msg = await ws.recv() data = json.loads(msg) print(f"✅ 连接成功: {data['type']}") - + # 多轮对话 for i, user_message in enumerate(CONVERSATION_MESSAGES, 1): - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print(f"📤 第 {i} 轮对话") - print(f"{'='*60}") + print(f"{'=' * 60}") print(f"👤 用户: {user_message}") - + # 发送消息 - await ws.send(json.dumps({ - "type": "text", - "data": {"text": user_message} - })) - + await ws.send( + json.dumps({"type": "text", "data": {"text": user_message}}) + ) + # 接收 Agent 回复(可能是多条消息) try: while True: msg = await asyncio.wait_for(ws.recv(), timeout=30) data = json.loads(msg) if data["type"] == "agent_response": - msg_data = data['data'] - total = msg_data.get('total', 1) - index = msg_data.get('index', 0) + msg_data = data["data"] + total = msg_data.get("total", 1) + index = msg_data.get("index", 0) print(f"🤖 Agent: {msg_data['text']}") # 如果是最后一条消息,退出循环 if index >= total - 1: @@ -181,55 +177,61 @@ class ConversationTester: break except asyncio.TimeoutError: print("⏰ 等待响应超时") - + # 短暂等待,模拟真实对话节奏 await asyncio.sleep(1) - + # 结束对话 - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("📭 结束对话") - print(f"{'='*60}") - - await ws.send(json.dumps({ - "type": "end_conversation", - "conversation_id": self.conversation_id - })) - + print(f"{'=' * 60}") + + await ws.send( + json.dumps( + { + "type": "end_conversation", + "conversation_id": self.conversation_id, + } + ) + ) + try: # 结束时会触发 process_conversation_segments,可能需要更长时间 msg = await asyncio.wait_for(ws.recv(), timeout=60) data = json.loads(msg) - if data['type'] == 'error': + if data["type"] == "error": print(f"❌ 结束对话错误: {data['data'].get('message', 'unknown')}") else: print(f"✅ 对话结束: {data['type']}") except asyncio.TimeoutError: print("⏰ 等待结束确认超时(但后台处理可能仍在进行)") - - async def wait_for_processing(self, max_wait_seconds: int = 300, check_interval: int = 3): + + async def wait_for_processing( + self, max_wait_seconds: int = 300, check_interval: int = 3 + ): """ 等待后台处理完成 通过查询 Celery 任务状态来判断处理是否完成 - + Args: max_wait_seconds: 最大等待时间(秒),默认 5 分钟 check_interval: 检查间隔(秒) - + Returns: 是否在超时前完成 """ print(f"\n⏳ 等待后台任务完成(最多 {max_wait_seconds} 秒)...") print(" 提示: 通过 Celery 任务状态 API 追踪任务进度") - + start_time = asyncio.get_event_loop().time() - + while True: elapsed = asyncio.get_event_loop().time() - start_time - + if elapsed >= max_wait_seconds: print(f"\n⚠️ 已等待 {max_wait_seconds} 秒,超时退出") return False - + # 检查任务状态 tasks_status = await self.get_tasks_status() total = tasks_status.get("total", 0) @@ -238,72 +240,76 @@ class ConversationTester: success = tasks_status.get("success", 0) failure = tasks_status.get("failure", 0) all_completed = tasks_status.get("all_completed", False) - + # 同时检查章节内容 chapters = await self.get_chapters() chapter_count = len(chapters) - total_content_length = sum(len(ch.get('content', '')) for ch in chapters) - + total_content_length = sum(len(ch.get("content", "")) for ch in chapters) + status_str = f"📊 总:{total} 等待:{pending} 运行:{running} 成功:{success} 失败:{failure}" content_str = f"📚 章节:{chapter_count} 内容:{total_content_length}字符" print(f" [{int(elapsed):3d}s] {status_str} | {content_str}") - + # 判断是否完成: # 1. 有任务且全部完成 # 2. 或者没有任务但有章节内容(兼容旧逻辑) if total > 0 and all_completed: print(f"\n✅ 所有任务已完成!共 {total} 个任务,等待 {int(elapsed)} 秒") return True - + # 如果没有任务记录,等待一会儿任务提交 if total == 0 and elapsed < 15: await asyncio.sleep(check_interval) continue - + # 如果长时间没有任务但有内容,也认为完成 if total == 0 and chapter_count > 0 and elapsed > 30: - print(f"\n✅ 无待处理任务,已有 {chapter_count} 个章节。等待 {int(elapsed)} 秒") + print( + f"\n✅ 无待处理任务,已有 {chapter_count} 个章节。等待 {int(elapsed)} 秒" + ) return True - + await asyncio.sleep(check_interval) - + async def check_results(self): """检查回忆录生成结果""" - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("📊 检查结果") - print(f"{'='*60}") - + print(f"{'=' * 60}") + # 等待后台处理完成(使用智能轮询) await self.wait_for_processing(max_wait_seconds=180, check_interval=5) - + # 获取回忆录状态 print("\n📋 回忆录状态:") state = await self.get_memoir_state() print(f" 当前阶段: {state.get('current_stage', 'N/A')}") print(f" 已完成阶段: {state.get('covered_stages', [])}") - + # 显示已填充的 slots - slots = state.get('slots', {}) + slots = state.get("slots", {}) for stage, stage_slots in slots.items(): - filled = [k for k, v in stage_slots.items() if v.get('snippet')] + filled = [k for k, v in stage_slots.items() if v.get("snippet")] if filled: print(f" {stage} 已填充: {filled}") for slot_name in filled: - snippet = stage_slots[slot_name].get('snippet', '') + snippet = stage_slots[slot_name].get("snippet", "") if snippet: print(f" - {slot_name}: {snippet[:50]}...") - + # 获取章节 print("\n📚 生成的章节:") chapters = await self.get_chapters() if chapters: for ch in chapters: is_new = "🆕" if ch.get("is_new") else "" - content_len = len(ch.get('content', '')) - print(f" {is_new} [{ch.get('category', 'N/A')}] {ch.get('title', 'N/A')} ({content_len} 字符)") + content_len = len(ch.get("content", "")) + print( + f" {is_new} [{ch.get('category', 'N/A')}] {ch.get('title', 'N/A')} ({content_len} 字符)" + ) else: print(" (暂无章节)") - + # 获取回忆录 print("\n📖 回忆录信息:") book = await self.get_book() @@ -313,25 +319,25 @@ class ConversationTester: print(f" 有更新: {'是' if book.get('has_update') else '否'}") else: print(f" {book.get('message', 'N/A')}") - + # 显示回忆录完整内容 if chapters: - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("📜 回忆录完整内容") - print(f"{'='*60}") + print(f"{'=' * 60}") for ch in chapters: - category = ch.get('category', 'N/A') - title = ch.get('title', '未命名章节') - content = ch.get('content', '') - - print(f"\n{'─'*60}") + category = ch.get("category", "N/A") + title = ch.get("title", "未命名章节") + content = ch.get("content", "") + + print(f"\n{'─' * 60}") print(f"【{title}】({category})") - print(f"{'─'*60}") + print(f"{'─' * 60}") if content: print(content) else: print("(暂无内容)") - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") async def main(): @@ -340,33 +346,34 @@ async def main(): print("🎭 Life Echo 多轮对话测试") print(f"⏰ 开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print("=" * 60) - + tester = ConversationTester() - + try: # 1. 注册/登录 await tester.register_or_login() - + # 2. 清除旧的任务记录 await tester.clear_tasks() print("\n🧹 已清除旧的任务记录") - + # 3. 查看初始状态 print("\n📋 初始回忆录状态:") state = await tester.get_memoir_state() print(f" 当前阶段: {state.get('current_stage', 'N/A')}") - + # 4. 运行多轮对话 await tester.run_conversation() - + # 5. 检查结果 await tester.check_results() - + except Exception as e: print(f"\n❌ 测试失败: {e}") import traceback + traceback.print_exc() - + print("\n" + "=" * 60) print(f"⏰ 结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print("=" * 60) diff --git a/api/tests/test_conversation_messages_history.py b/api/tests/test_conversation_messages_history.py index 5534ef3..aa9142c 100644 --- a/api/tests/test_conversation_messages_history.py +++ b/api/tests/test_conversation_messages_history.py @@ -44,7 +44,10 @@ class ConversationMessagesHistoryTest(unittest.TestCase): ) self.assertEqual( - [(msg["senderType"], msg["messageType"], msg["content"]) for msg in messages], + [ + (msg["senderType"], msg["messageType"], msg["content"]) + for msg in messages + ], [ ("user", "audio", "第一段"), ("assistant", "text", "继续说"), @@ -101,7 +104,9 @@ class ConversationMessagesHistoryTest(unittest.TestCase): } ] - latest_message_time = conversations_router._latest_message_time_ms(conversation, history) + latest_message_time = conversations_router._latest_message_time_ms( + conversation, history + ) self.assertEqual(latest_message_time, 1773489605000) @@ -112,6 +117,8 @@ class ConversationMessagesHistoryTest(unittest.TestCase): started_at=datetime(2026, 3, 14, 12, 0, 0, tzinfo=timezone.utc), ) - timestamp = conversations_router._message_timestamp_ms({}, conversation.started_at) + timestamp = conversations_router._message_timestamp_ms( + {}, conversation.started_at + ) self.assertEqual(timestamp, 1773489600000) diff --git a/api/tests/test_generate_chapter_images_persistence.py b/api/tests/test_generate_chapter_images_persistence.py index 6af5cd8..3e98163 100644 --- a/api/tests/test_generate_chapter_images_persistence.py +++ b/api/tests/test_generate_chapter_images_persistence.py @@ -77,7 +77,9 @@ class GenerateChapterImagesPersistenceTest(unittest.TestCase): ): chapter = _chapter_stub() db = Mock() - db.execute.return_value.unique.return_value.scalar_one_or_none.return_value = chapter + db.execute.return_value.unique.return_value.scalar_one_or_none.return_value = ( + chapter + ) get_sync_db_mock.return_value.__enter__.return_value = db get_sync_db_mock.return_value.__exit__.return_value = False @@ -102,5 +104,7 @@ class GenerateChapterImagesPersistenceTest(unittest.TestCase): record = chapter.sections[0].image_record self.assertEqual(record.status, "completed") - self.assertEqual(record.url, "https://cos.example.com/memoirs/user-1/chapter-1/0.png") + self.assertEqual( + record.url, "https://cos.example.com/memoirs/user-1/chapter-1/0.png" + ) self.assertEqual(record.prompt, "A serene southern China town") diff --git a/api/tests/test_generate_chapter_images_task.py b/api/tests/test_generate_chapter_images_task.py index e104abd..15a9c17 100644 --- a/api/tests/test_generate_chapter_images_task.py +++ b/api/tests/test_generate_chapter_images_task.py @@ -10,7 +10,11 @@ from app.tasks import memoir_tasks from app.tasks.memoir_tasks import generate_chapter_images -def _mock_image_generator(*, image_url: str = "https://provider.example.com/1.png", image_bytes: bytes | None = None): +def _mock_image_generator( + *, + image_url: str = "https://provider.example.com/1.png", + image_bytes: bytes | None = None, +): """构造满足 port ImageGenerator 的 mock:generate 返回 ImageResult,download_image 返回 bytes。""" if image_bytes is None: buf = BytesIO() @@ -81,7 +85,9 @@ def _chapter_with_sections(sections_data): def _bind_db_execute_to_chapter(db_mock, chapter): """让 db.execute(select(...)).unique().scalar_one_or_none() 返回 chapter。""" - db_mock.execute.return_value.unique.return_value.scalar_one_or_none.return_value = chapter + db_mock.execute.return_value.unique.return_value.scalar_one_or_none.return_value = ( + chapter + ) class GenerateChapterImagesTaskTest(unittest.TestCase): @@ -101,9 +107,20 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): get_sync_db_mock, redis_from_url, ): - chapter = _chapter_with_sections([ - {"content": "那条路我一直记得。", "image": {"index": 0, "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", "description": "南方小镇的青石板路", "status": "pending", "url": None}}, - ]) + chapter = _chapter_with_sections( + [ + { + "content": "那条路我一直记得。", + "image": { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + }, + }, + ] + ) db = Mock() _bind_db_execute_to_chapter(db, chapter) get_sync_db_mock.return_value.__enter__.return_value = db @@ -132,9 +149,20 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): storage_cls, get_sync_db_mock, ): - chapter = _chapter_with_sections([ - {"content": "那条路我一直记得。", "image": {"index": 0, "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", "description": "南方小镇的青石板路", "status": "pending", "url": None}}, - ]) + chapter = _chapter_with_sections( + [ + { + "content": "那条路我一直记得。", + "image": { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + }, + }, + ] + ) db = Mock() _bind_db_execute_to_chapter(db, chapter) get_sync_db_mock.return_value.__enter__.return_value = db @@ -145,17 +173,23 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): "size": "1024x1024", "prompt_context": "childhood: 童年的夏天", } - get_image_generator_mock.return_value.generate.side_effect = RuntimeError("transient provider error") + get_image_generator_mock.return_value.generate.side_effect = RuntimeError( + "transient provider error" + ) retry_error = RuntimeError("retry requested") - task_self = SimpleNamespace(request=SimpleNamespace(id="task-1"), retry=Mock(side_effect=retry_error)) + task_self = SimpleNamespace( + request=SimpleNamespace(id="task-1"), retry=Mock(side_effect=retry_error) + ) with self.assertRaises(RuntimeError) as ctx: generate_chapter_images.run.__func__(task_self, "chapter-1") self.assertIs(ctx.exception, retry_error) self.assertEqual(chapter.sections[0].image_record.status, "failed") - self.assertEqual(chapter.sections[0].image_record.error, "transient provider error") + self.assertEqual( + chapter.sections[0].image_record.error, "transient provider error" + ) task_self.retry.assert_called_once() storage_cls.from_env.return_value.upload_bytes.assert_not_called() @@ -174,9 +208,20 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): storage_cls, get_sync_db_mock, ): - chapter = _chapter_with_sections([ - {"content": "那条路我一直记得。", "image": {"index": 0, "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", "description": "南方小镇的青石板路", "status": "pending", "url": None}}, - ]) + chapter = _chapter_with_sections( + [ + { + "content": "那条路我一直记得。", + "image": { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + }, + }, + ] + ) db = Mock() _bind_db_execute_to_chapter(db, chapter) get_sync_db_mock.return_value.__enter__.return_value = db @@ -189,14 +234,24 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): } get_image_generator_mock.return_value = _mock_image_generator() storage_inst = storage_cls.from_env.return_value - storage_inst.upload_bytes.return_value = "https://cos.example.com/memoirs/u1/c1/0.png" + storage_inst.upload_bytes.return_value = ( + "https://cos.example.com/memoirs/u1/c1/0.png" + ) generate_chapter_images.run("chapter-1") self.assertEqual(chapter.sections[0].image_record.status, "completed") - self.assertEqual(chapter.sections[0].image_record.storage_key, "memoirs/user-1/chapter-1/0-7e1f860790.png") - self.assertEqual(chapter.sections[0].image_record.url, "https://cos.example.com/memoirs/u1/c1/0.png") - self.assertEqual(chapter.sections[0].image_record.prompt, "A serene southern China town") + self.assertEqual( + chapter.sections[0].image_record.storage_key, + "memoirs/user-1/chapter-1/0-7e1f860790.png", + ) + self.assertEqual( + chapter.sections[0].image_record.url, + "https://cos.example.com/memoirs/u1/c1/0.png", + ) + self.assertEqual( + chapter.sections[0].image_record.prompt, "A serene southern China town" + ) get_image_generator_mock.return_value.generate.assert_called_once() db.commit.assert_called() @@ -213,9 +268,20 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): storage_cls, get_sync_db_mock, ): - chapter = _chapter_with_sections([ - {"content": "那条路我一直记得。", "image": {"index": 0, "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", "description": "南方小镇的青石板路", "status": "pending", "url": None}}, - ]) + chapter = _chapter_with_sections( + [ + { + "content": "那条路我一直记得。", + "image": { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + }, + }, + ] + ) settings_from_env.return_value = SimpleNamespace( enabled=False, max_per_chapter=2, @@ -254,9 +320,20 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): storage_cls, get_sync_db_mock, ): - chapter = _chapter_with_sections([ - {"content": "那条路我一直记得。", "image": {"index": 0, "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", "description": "南方小镇的青石板路", "status": "pending", "url": None}}, - ]) + chapter = _chapter_with_sections( + [ + { + "content": "那条路我一直记得。", + "image": { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + }, + }, + ] + ) image_buffer = BytesIO() Image.new("RGB", (2, 1), color="white").save(image_buffer, format="JPEG") jpeg_bytes = image_buffer.getvalue() @@ -276,7 +353,9 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): image_bytes=jpeg_bytes, ) storage_inst = storage_cls.from_env.return_value - storage_inst.upload_bytes.return_value = "https://cos.example.com/memoirs/u1/c1/0.png" + storage_inst.upload_bytes.return_value = ( + "https://cos.example.com/memoirs/u1/c1/0.png" + ) generate_chapter_images.run("chapter-1") @@ -299,9 +378,20 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): storage_cls, get_sync_db_mock, ): - chapter = _chapter_with_sections([ - {"content": "那条路我一直记得。", "image": {"index": 0, "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", "description": "南方小镇的青石板路", "status": "pending", "url": None}}, - ]) + chapter = _chapter_with_sections( + [ + { + "content": "那条路我一直记得。", + "image": { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + }, + }, + ] + ) db = Mock() _bind_db_execute_to_chapter(db, chapter) get_sync_db_mock.return_value.__enter__.return_value = db @@ -342,9 +432,20 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): storage_cls, get_sync_db_mock, ): - chapter = _chapter_with_sections([ - {"content": "那条路我一直记得。", "image": {"index": 0, "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", "description": "南方小镇的青石板路", "status": "completed", "url": "https://cos.example.com/already-there.png"}}, - ]) + chapter = _chapter_with_sections( + [ + { + "content": "那条路我一直记得。", + "image": { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "completed", + "url": "https://cos.example.com/already-there.png", + }, + }, + ] + ) db = Mock() _bind_db_execute_to_chapter(db, chapter) get_sync_db_mock.return_value.__enter__.return_value = db diff --git a/api/tests/test_memoir_image_bootstrap.py b/api/tests/test_memoir_image_bootstrap.py index c84bca9..3d61b76 100644 --- a/api/tests/test_memoir_image_bootstrap.py +++ b/api/tests/test_memoir_image_bootstrap.py @@ -18,7 +18,9 @@ class MemoirImageBootstrapTest(unittest.TestCase): }, )() - with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "false"}, clear=False): + with unittest.mock.patch.dict( + os.environ, {"MEMOIR_IMAGE_ENABLED": "false"}, clear=False + ): assets = initialize_chapter_images(chapter) self.assertEqual(assets, []) @@ -35,41 +37,69 @@ class MemoirImageBootstrapTest(unittest.TestCase): }, )() - with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False): + with unittest.mock.patch.dict( + os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False + ): assets = initialize_chapter_images(chapter) self.assertEqual(assets, []) - def test_initialize_chapter_images_preserves_completed_assets_and_adds_only_new_placeholders(self): + def test_initialize_chapter_images_preserves_completed_assets_and_adds_only_new_placeholders( + self, + ): """图片初始化已迁移到 _save_narrative_to_sections;此处为兼容 no-op""" - chapter = type("ChapterStub", (), {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"})() + chapter = type( + "ChapterStub", + (), + {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"}, + )() - with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False): + with unittest.mock.patch.dict( + os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False + ): assets = initialize_chapter_images(chapter) self.assertEqual(assets, []) def test_initialize_chapter_images_accepts_double_brace_placeholders(self): """图片初始化已迁移到 _save_narrative_to_sections;此处为兼容 no-op""" - chapter = type("ChapterStub", (), {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"})() + chapter = type( + "ChapterStub", + (), + {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"}, + )() - with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False): + with unittest.mock.patch.dict( + os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False + ): assets = initialize_chapter_images(chapter) self.assertEqual(assets, []) def test_initialize_chapter_images_normalizes_invalid_existing_asset_status(self): """图片初始化已迁移到 _save_narrative_to_sections;此处为兼容 no-op""" - chapter = type("ChapterStub", (), {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"})() + chapter = type( + "ChapterStub", + (), + {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"}, + )() - with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False): + with unittest.mock.patch.dict( + os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False + ): assets = initialize_chapter_images(chapter) self.assertEqual(assets, []) - def test_initialize_chapter_images_preserves_existing_completed_assets_beyond_effective_max(self): + def test_initialize_chapter_images_preserves_existing_completed_assets_beyond_effective_max( + self, + ): """图片初始化已迁移到 _save_narrative_to_sections;此处为兼容 no-op""" - chapter = type("ChapterStub", (), {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"})() + chapter = type( + "ChapterStub", + (), + {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"}, + )() with unittest.mock.patch.dict( os.environ, @@ -82,18 +112,30 @@ class MemoirImageBootstrapTest(unittest.TestCase): def test_initialize_chapter_images_increases_limit_for_long_content(self): """图片初始化已迁移到 _save_narrative_to_sections;此处为兼容 no-op""" - chapter = type("ChapterStub", (), {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"})() + chapter = type( + "ChapterStub", + (), + {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"}, + )() - with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False): + with unittest.mock.patch.dict( + os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False + ): assets = initialize_chapter_images(chapter) self.assertEqual(assets, []) def test_initialize_chapter_images_caps_dynamic_limit_at_max_images_cap(self): """图片初始化已迁移到 _save_narrative_to_sections;此处为兼容 no-op""" - chapter = type("ChapterStub", (), {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"})() + chapter = type( + "ChapterStub", + (), + {"id": "chapter-1", "title": "童年的夏天", "category": "childhood"}, + )() - with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False): + with unittest.mock.patch.dict( + os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False + ): assets = initialize_chapter_images(chapter) self.assertEqual(assets, []) diff --git a/api/tests/test_memoir_image_parser.py b/api/tests/test_memoir_image_parser.py index 38ae532..92d8819 100644 --- a/api/tests/test_memoir_image_parser.py +++ b/api/tests/test_memoir_image_parser.py @@ -21,7 +21,9 @@ class MemoirImageParserTest(unittest.TestCase): self.assertEqual([item["index"] for item in items], [0, 1]) self.assertEqual(items[0]["description"], "南方小镇的青石板路") - self.assertEqual(items[1]["placeholder"], "{{{{IMAGE:奶奶坐在院子里的藤椅上}}}}") + self.assertEqual( + items[1]["placeholder"], "{{{{IMAGE:奶奶坐在院子里的藤椅上}}}}" + ) self.assertLess(items[0]["start_offset"], items[1]["start_offset"]) def test_build_initial_image_assets_marks_every_item_pending(self): @@ -52,20 +54,30 @@ class MemoirImageParserTest(unittest.TestCase): items = parse_image_placeholders(content, max_images=2) self.assertEqual(len(items), 1) - self.assertEqual(items[0]["placeholder"], "{{IMAGE:1938年初的上海弄堂口,冬日萧瑟}}") + self.assertEqual( + items[0]["placeholder"], "{{IMAGE:1938年初的上海弄堂口,冬日萧瑟}}" + ) self.assertEqual(items[0]["description"], "1938年初的上海弄堂口,冬日萧瑟") - def test_parse_narrative_json_returns_sections_with_content_and_placeholder_info(self): + def test_parse_narrative_json_returns_sections_with_content_and_placeholder_info( + self, + ): raw = '{"paragraphs": [{"content": "那年春天。", "image_description": "南方小镇的青石板路"}, {"content": "奶奶坐在藤椅上。", "image_description": "奶奶的藤椅"}]}' segments = parse_narrative_json(raw) self.assertEqual(len(segments), 2) self.assertEqual(segments[0]["content"], "那年春天。") - self.assertEqual(segments[0]["placeholder_info"]["description"], "南方小镇的青石板路") + self.assertEqual( + segments[0]["placeholder_info"]["description"], "南方小镇的青石板路" + ) self.assertEqual(segments[1]["content"], "奶奶坐在藤椅上。") self.assertEqual(segments[1]["placeholder_info"]["description"], "奶奶的藤椅") - def test_parse_narrative_to_sections_prefers_json_then_fallback_to_placeholder(self): - json_raw = '{"paragraphs": [{"content": "段落一", "image_description": "图一"}]}' + def test_parse_narrative_to_sections_prefers_json_then_fallback_to_placeholder( + self, + ): + json_raw = ( + '{"paragraphs": [{"content": "段落一", "image_description": "图一"}]}' + ) segments = parse_narrative_to_sections(json_raw) self.assertEqual(len(segments), 1) self.assertEqual(segments[0]["content"], "段落一") diff --git a/api/tests/test_memoir_image_prompting.py b/api/tests/test_memoir_image_prompting.py index b2e7141..e2774e1 100644 --- a/api/tests/test_memoir_image_prompting.py +++ b/api/tests/test_memoir_image_prompting.py @@ -60,7 +60,10 @@ class MemoirImagePromptingTest(unittest.TestCase): context_excerpt="梧桐树下很安静,夏天总有蝉鸣。", ) - self.assertEqual(result["prompt"], "A grandmother in a quiet courtyard, summer cicadas, soft watercolor") + self.assertEqual( + result["prompt"], + "A grandmother in a quiet courtyard, summer cicadas, soft watercolor", + ) self.assertEqual(result["style"], "watercolor") self.assertEqual(result["size"], "1024x1024") diff --git a/api/tests/test_memoir_image_provider.py b/api/tests/test_memoir_image_provider.py index 93796d3..72664fb 100644 --- a/api/tests/test_memoir_image_provider.py +++ b/api/tests/test_memoir_image_provider.py @@ -44,7 +44,9 @@ class SubmitGenerationTest(unittest.TestCase): url = call_kwargs.args[0] if call_kwargs.args else call_kwargs.kwargs["url"] params = call_kwargs.kwargs.get("params") - self.assertEqual(url, "https://openapi.liblibai.cloud/api/generate/webui/text2img/ultra") + self.assertEqual( + url, "https://openapi.liblibai.cloud/api/generate/webui/text2img/ultra" + ) self.assertNotIn("AccessKey=", url) self.assertEqual(params["AccessKey"], "test-ak") self.assertIn("Signature", params) @@ -62,7 +64,9 @@ class SubmitGenerationTest(unittest.TestCase): http_client.post.return_value = resp provider = _make_provider(http_client) - job = provider.submit_generation(prompt="a cat", size="1024x1024", style="watercolor") + job = provider.submit_generation( + prompt="a cat", size="1024x1024", style="watercolor" + ) self.assertEqual(job["status"], "processing") self.assertEqual(job["job_id"], "uuid-abc") @@ -86,7 +90,9 @@ class SubmitGenerationTest(unittest.TestCase): http_client.post.return_value = resp provider = _make_provider(http_client) - provider.submit_generation(prompt="a cat under the rain", size="1024x1024", style="watercolor") + provider.submit_generation( + prompt="a cat under the rain", size="1024x1024", style="watercolor" + ) call_kwargs = http_client.post.call_args body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") @@ -103,7 +109,9 @@ class SubmitGenerationTest(unittest.TestCase): provider = _make_provider(http_client) with self.assertRaises(RuntimeError): - provider.submit_generation(prompt="a cat", size="1024x1024", style="watercolor") + provider.submit_generation( + prompt="a cat", size="1024x1024", style="watercolor" + ) class PollUntilCompleteTest(unittest.TestCase): @@ -121,7 +129,9 @@ class PollUntilCompleteTest(unittest.TestCase): "code": 0, "data": { "generateStatus": 5, - "images": [{"imageUrl": "https://cdn.example.com/1.png", "auditStatus": 3}], + "images": [ + {"imageUrl": "https://cdn.example.com/1.png", "auditStatus": 3} + ], }, } success_resp.raise_for_status = Mock() @@ -208,7 +218,9 @@ class DownloadImageTest(unittest.TestCase): template_uuid="tpl-uuid", allowed_download_hosts=("cdn.example.com",), ) - payload = provider.download_image({"image_url": "https://cdn.example.com/1.png"}) + payload = provider.download_image( + {"image_url": "https://cdn.example.com/1.png"} + ) self.assertEqual(payload, b"png-bytes") diff --git a/api/tests/test_memoir_image_settings.py b/api/tests/test_memoir_image_settings.py index fef601a..abc62f5 100644 --- a/api/tests/test_memoir_image_settings.py +++ b/api/tests/test_memoir_image_settings.py @@ -37,6 +37,8 @@ class MemoirImageSettingsTest(unittest.TestCase): self.assertEqual(settings.liblib_template_uuid, DEFAULT_LIBLIB_TEMPLATE_UUID) def test_effective_max_images_never_drops_below_base_max_per_chapter(self): - settings = MemoirImageSettings(enabled=True, max_per_chapter=2, max_images_cap=1) + settings = MemoirImageSettings( + enabled=True, max_per_chapter=2, max_images_cap=1 + ) self.assertEqual(settings.effective_max_images(0), 2) diff --git a/api/tests/test_memoir_image_storage.py b/api/tests/test_memoir_image_storage.py index 4521797..2e27288 100644 --- a/api/tests/test_memoir_image_storage.py +++ b/api/tests/test_memoir_image_storage.py @@ -67,7 +67,9 @@ class MemoirImageStorageTest(unittest.TestCase): client.put_object.assert_called_once() @patch("app.features.memoir.memoir_images.storage.CosS3Client") - def test_upload_bytes_normalizes_duplicate_appid_suffix_in_base_url(self, client_cls): + def test_upload_bytes_normalizes_duplicate_appid_suffix_in_base_url( + self, client_cls + ): client = Mock() client_cls.return_value = client storage = TencentCosStorageService( @@ -105,7 +107,9 @@ class MemoirImageStorageTest(unittest.TestCase): @patch("app.features.memoir.memoir_images.storage.CosS3Client") def test_get_download_url_returns_presigned_download_url(self, client_cls): client = Mock() - client.get_presigned_download_url.return_value = "https://cos.example.com/0.png?q-sign-algorithm=sha1" + client.get_presigned_download_url.return_value = ( + "https://cos.example.com/0.png?q-sign-algorithm=sha1" + ) client_cls.return_value = client storage = TencentCosStorageService( secret_id="id", @@ -203,10 +207,14 @@ class MemoirImageStorageTest(unittest.TestCase): self.assertTrue(_is_retryable_cos_error(CosClientError("timeout"))) def test_is_retryable_returns_false_for_4xx_service_error(self): - self.assertFalse(_is_retryable_cos_error(CosServiceError("GET", "Forbidden", 403))) + self.assertFalse( + _is_retryable_cos_error(CosServiceError("GET", "Forbidden", 403)) + ) def test_is_retryable_returns_true_for_5xx_service_error(self): - self.assertTrue(_is_retryable_cos_error(CosServiceError("GET", "Internal", 500))) + self.assertTrue( + _is_retryable_cos_error(CosServiceError("GET", "Internal", 500)) + ) @patch("app.features.memoir.memoir_images.storage.CosS3Client") def test_cos_config_includes_scheme_and_token(self, client_cls): diff --git a/api/tests/test_memory_prompts_inject.py b/api/tests/test_memory_prompts_inject.py index 1d2e4f9..2f9b439 100644 --- a/api/tests/test_memory_prompts_inject.py +++ b/api/tests/test_memory_prompts_inject.py @@ -1,4 +1,5 @@ """测试 memory_prompts.inject_image_placeholder_template:占位符花括号统一为四层,避免多余花括号残留""" + import unittest from app.agents.memoir.prompts import ( diff --git a/api/tests/test_pdf_service_images.py b/api/tests/test_pdf_service_images.py index 86956e7..0ca9525 100644 --- a/api/tests/test_pdf_service_images.py +++ b/api/tests/test_pdf_service_images.py @@ -31,7 +31,9 @@ class PDFServiceImagesTest(unittest.IsolatedAsyncioTestCase): async_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) async_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) storage = MagicMock() - storage.get_download_url.return_value = "https://signed.example.com/0.png?sig=123" + storage.get_download_url.return_value = ( + "https://signed.example.com/0.png?sig=123" + ) storage_cls.from_env.return_value = storage reportlab_image_cls.return_value = MagicMock() @@ -82,7 +84,9 @@ class PDFServiceImagesTest(unittest.IsolatedAsyncioTestCase): async_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) async_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) storage = MagicMock() - storage.get_download_url.return_value = "https://signed.example.com/0.png?sig=123" + storage.get_download_url.return_value = ( + "https://signed.example.com/0.png?sig=123" + ) storage_cls.from_env.return_value = storage service = PDFService() @@ -109,7 +113,9 @@ class PDFServiceImagesTest(unittest.IsolatedAsyncioTestCase): self.assertGreater(len(pdf_bytes), 100) self.assertNotIn(b"IMAGE:", pdf_bytes) - mock_client.get.assert_called_once_with("https://signed.example.com/0.png?sig=123") + mock_client.get.assert_called_once_with( + "https://signed.example.com/0.png?sig=123" + ) @patch("app.features.memoir.pdf_service.httpx.AsyncClient") @patch("app.features.memoir.pdf_service.TencentCosStorageService") diff --git a/api/tests/test_process_memoir_segments_image_enqueue.py b/api/tests/test_process_memoir_segments_image_enqueue.py index 8d101b0..f19bc47 100644 --- a/api/tests/test_process_memoir_segments_image_enqueue.py +++ b/api/tests/test_process_memoir_segments_image_enqueue.py @@ -16,7 +16,10 @@ def _mock_get_sync_db(db): class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase): @patch("app.tasks.memoir_tasks._chapter_has_cover_to_generate", return_value=True) - @patch("app.tasks.memoir_tasks._chapter_has_any_section_images_to_generate", return_value=True) + @patch( + "app.tasks.memoir_tasks._chapter_has_any_section_images_to_generate", + return_value=True, + ) @patch("app.tasks.memoir_tasks._update_task_status_sync") @patch("app.tasks.memoir_tasks._release_chapter_lock") @patch("app.tasks.memoir_tasks._acquire_chapter_lock", return_value=True) @@ -55,8 +58,12 @@ class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase): max_attempts=20, liblib_template_uuid="tpl-uuid", ) - get_state_mock.return_value = SimpleNamespace(current_stage="childhood", slots={}) - update_slot_mock.return_value = SimpleNamespace(current_stage="childhood", slots={}) + get_state_mock.return_value = SimpleNamespace( + current_stage="childhood", slots={} + ) + update_slot_mock.return_value = SimpleNamespace( + current_stage="childhood", slots={} + ) llm = Mock() bound_llm = Mock() bound_llm.invoke.side_effect = [ @@ -70,7 +77,9 @@ class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase): } ```""" ), - SimpleNamespace(content='{"paragraphs":[{"content":"新的章节正文","image_description":"南方小镇的青石板路"}]}'), + SimpleNamespace( + content='{"paragraphs":[{"content":"新的章节正文","image_description":"南方小镇的青石板路"}]}' + ), ] llm.bind.return_value = bound_llm llm.invoke.side_effect = [ @@ -170,7 +179,9 @@ class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase): max_attempts=20, liblib_template_uuid="tpl-uuid", ) - get_state_mock.return_value = SimpleNamespace(current_stage="childhood", slots={}) + get_state_mock.return_value = SimpleNamespace( + current_stage="childhood", slots={} + ) segment = SimpleNamespace( id="segment-1", diff --git a/api/tests/test_sms_verification.py b/api/tests/test_sms_verification.py index b4eefc4..e8e4299 100755 --- a/api/tests/test_sms_verification.py +++ b/api/tests/test_sms_verification.py @@ -37,22 +37,23 @@ refresh_token: Optional[str] = None class Colors: """终端颜色""" - HEADER = '\033[95m' - OKBLUE = '\033[94m' - OKCYAN = '\033[96m' - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' + + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKCYAN = "\033[96m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" def print_header(text: str): """打印测试标题""" - print(f"\n{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}") + print(f"\n{Colors.HEADER}{Colors.BOLD}{'=' * 60}{Colors.ENDC}") print(f"{Colors.HEADER}{Colors.BOLD}{text}{Colors.ENDC}") - print(f"{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}\n") + print(f"{Colors.HEADER}{Colors.BOLD}{'=' * 60}{Colors.ENDC}\n") def print_success(text: str): @@ -80,11 +81,11 @@ def make_request( endpoint: str, data: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, - expected_status: int = 200 + expected_status: int = 200, ) -> Optional[Dict[str, Any]]: """发送HTTP请求""" url = f"{BASE_URL}{API_PREFIX}{endpoint}" - + try: if method.upper() == "GET": response = requests.get(url, headers=headers) @@ -97,9 +98,9 @@ def make_request( else: print_error(f"不支持的HTTP方法: {method}") return None - + print_info(f"{method.upper()} {endpoint} - Status: {response.status_code}") - + if response.status_code == expected_status: print_success(f"请求成功 (状态码: {response.status_code})") try: @@ -107,14 +108,18 @@ def make_request( except: return {"status": "success"} else: - print_error(f"请求失败 (期望: {expected_status}, 实际: {response.status_code})") + print_error( + f"请求失败 (期望: {expected_status}, 实际: {response.status_code})" + ) try: error_data = response.json() - print_error(f"错误信息: {json.dumps(error_data, ensure_ascii=False, indent=2)}") + print_error( + f"错误信息: {json.dumps(error_data, ensure_ascii=False, indent=2)}" + ) except: print_error(f"响应内容: {response.text}") return None - + except requests.exceptions.ConnectionError: print_error(f"连接失败: 无法连接到 {BASE_URL}") print_warning("请确保后端服务正在运行") @@ -127,41 +132,35 @@ def make_request( def test_send_verification_code(phone: str, purpose: str) -> bool: """测试发送验证码""" print_header(f"测试发送验证码 - {purpose}") - - data = { - "phone": phone, - "purpose": purpose - } - + + data = {"phone": phone, "purpose": purpose} + result = make_request("POST", "/auth/sms/send", data=data) - + if result: print_success(f"验证码已发送: {result.get('message', '')}") print_info(f"有效期: {result.get('expires_in', 0)} 秒") return True - + return False def test_rate_limit(phone: str) -> bool: """测试频率限制""" print_header("测试频率限制") - + # 第一次发送应该成功 if not test_send_verification_code(phone, "register"): return False - + print_info("等待1秒后再次发送...") time.sleep(1) - + # 第二次发送应该被限制 - data = { - "phone": phone, - "purpose": "register" - } - + data = {"phone": phone, "purpose": "register"} + result = make_request("POST", "/auth/sms/send", data=data, expected_status=429) - + if result is None: print_success("频率限制生效") return True @@ -173,165 +172,144 @@ def test_rate_limit(phone: str) -> bool: def test_register_with_sms(phone: str, code: str) -> bool: """测试验证码注册""" print_header("测试验证码注册") - + data = { "phone": phone, "code": code, "password": TEST_PASSWORD, "nickname": TEST_NICKNAME, - "email": TEST_EMAIL + "email": TEST_EMAIL, } - + result = make_request("POST", "/auth/register/sms", data=data, expected_status=201) - + if result: global access_token, refresh_token access_token = result.get("access_token") refresh_token = result.get("refresh_token") - + print_success("注册成功") print_info(f"Access Token: {access_token[:20]}...") print_info(f"Refresh Token: {refresh_token[:20]}...") return True - + return False def test_login_with_sms(phone: str, code: str) -> bool: """测试验证码登录""" print_header("测试验证码登录") - - data = { - "phone": phone, - "code": code - } - + + data = {"phone": phone, "code": code} + result = make_request("POST", "/auth/login/sms", data=data) - + if result: global access_token, refresh_token access_token = result.get("access_token") refresh_token = result.get("refresh_token") - + print_success("登录成功") print_info(f"Access Token: {access_token[:20]}...") return True - + return False def test_reset_password(phone: str, code: str, new_password: str) -> bool: """测试重置密码""" print_header("测试重置密码") - - data = { - "phone": phone, - "code": code, - "new_password": new_password - } - + + data = {"phone": phone, "code": code, "new_password": new_password} + result = make_request("POST", "/auth/password/reset", data=data) - + if result: print_success(f"密码重置成功: {result.get('message', '')}") return True - + return False def test_change_password(old_password: str, new_password: str) -> bool: """测试修改密码(已登录)""" print_header("测试修改密码") - + if not access_token: print_error("未登录,无法测试") return False - - data = { - "old_password": old_password, - "new_password": new_password - } - - headers = { - "Authorization": f"Bearer {access_token}" - } - + + data = {"old_password": old_password, "new_password": new_password} + + headers = {"Authorization": f"Bearer {access_token}"} + result = make_request("POST", "/auth/password/change", data=data, headers=headers) - + if result: print_success(f"密码修改成功: {result.get('message', '')}") return True - + return False def test_change_phone(new_phone: str, code: str) -> bool: """测试修改手机号""" print_header("测试修改手机号") - + if not access_token: print_error("未登录,无法测试") return False - - data = { - "new_phone": new_phone, - "code": code - } - - headers = { - "Authorization": f"Bearer {access_token}" - } - + + data = {"new_phone": new_phone, "code": code} + + headers = {"Authorization": f"Bearer {access_token}"} + result = make_request("POST", "/auth/phone/change", data=data, headers=headers) - + if result: print_success(f"手机号修改成功") print_info(f"新手机号: {result.get('phone', '')}") return True - + return False def test_logout_all() -> bool: """测试登出所有设备""" print_header("测试登出所有设备") - + if not access_token: print_error("未登录,无法测试") return False - - headers = { - "Authorization": f"Bearer {access_token}" - } - + + headers = {"Authorization": f"Bearer {access_token}"} + result = make_request("POST", "/auth/logout/all", headers=headers) - + if result: print_success(f"登出成功: {result.get('message', '')}") return True - + return False def test_get_current_user() -> bool: """测试获取当前用户信息""" print_header("测试获取当前用户信息") - + if not access_token: print_error("未登录,无法测试") return False - - headers = { - "Authorization": f"Bearer {access_token}" - } - + + headers = {"Authorization": f"Bearer {access_token}"} + result = make_request("GET", "/auth/me", headers=headers) - + if result: print_success("获取用户信息成功") print_info(f"用户信息: {json.dumps(result, ensure_ascii=False, indent=2)}") return True - + return False @@ -340,41 +318,43 @@ def interactive_test(): print_header("短信验证码功能交互式测试") print_info("此模式需要您手动输入收到的验证码") print_warning("请确保已配置腾讯云短信服务") - + phone = input(f"\n请输入测试手机号 (默认: {TEST_PHONE}): ").strip() or TEST_PHONE - + # 1. 测试发送验证码 if not test_send_verification_code(phone, "register"): print_error("发送验证码失败,测试终止") return - + code = input("\n请输入收到的验证码: ").strip() - + if not code or len(code) != 6: print_error("验证码格式错误") return - + # 2. 测试注册 if test_register_with_sms(phone, code): print_success("注册测试通过") - + # 3. 测试获取用户信息 test_get_current_user() - + # 4. 测试修改密码 - if input("\n是否测试修改密码? (y/n): ").lower() == 'y': + if input("\n是否测试修改密码? (y/n): ").lower() == "y": test_change_password(TEST_PASSWORD, NEW_PASSWORD) - + # 5. 测试修改手机号 - if input("\n是否测试修改手机号? (y/n): ").lower() == 'y': - new_phone = input(f"请输入新手机号 (默认: {NEW_PHONE}): ").strip() or NEW_PHONE - + if input("\n是否测试修改手机号? (y/n): ").lower() == "y": + new_phone = ( + input(f"请输入新手机号 (默认: {NEW_PHONE}): ").strip() or NEW_PHONE + ) + if test_send_verification_code(new_phone, "change_phone"): code = input("请输入收到的验证码: ").strip() test_change_phone(new_phone, code) - + # 6. 测试登出所有设备 - if input("\n是否测试登出所有设备? (y/n): ").lower() == 'y': + if input("\n是否测试登出所有设备? (y/n): ").lower() == "y": test_logout_all() @@ -382,14 +362,14 @@ def automated_test(): """自动化测试(需要mock验证码)""" print_header("短信验证码功能自动化测试") print_warning("此模式需要后端支持测试验证码(如:123456)") - + # 测试发送验证码 test_send_verification_code(TEST_PHONE, "register") - + # 等待一段时间 print_info("等待60秒以测试频率限制...") time.sleep(60) - + # 测试频率限制 test_rate_limit(TEST_PHONE) @@ -400,14 +380,14 @@ if __name__ == "__main__": print("短信验证码功能测试脚本") print("=" * 60) print(f"{Colors.ENDC}") - + print("\n请选择测试模式:") print("1. 交互式测试(需要真实短信验证码)") print("2. 自动化测试(需要测试验证码支持)") print("3. 仅测试API连接") - + choice = input("\n请输入选项 (1/2/3): ").strip() - + if choice == "1": interactive_test() elif choice == "2": @@ -421,5 +401,5 @@ if __name__ == "__main__": print_error("API连接失败") else: print_error("无效的选项") - + print(f"\n{Colors.BOLD}{Colors.OKBLUE}测试完成{Colors.ENDC}\n") diff --git a/api/tests/test_websocket_baseline.py b/api/tests/test_websocket_baseline.py index e0cc23e..c651213 100644 --- a/api/tests/test_websocket_baseline.py +++ b/api/tests/test_websocket_baseline.py @@ -89,7 +89,9 @@ class _FakeAsyncDB: async def execute(self, stmt): stmt_str = str(stmt) if "MemoirState" in stmt_str or "memoir_state" in stmt_str: - return _ExecuteResult([self.state_result] if self.state_result is not None else []) + return _ExecuteResult( + [self.state_result] if self.state_result is not None else [] + ) return _ExecuteResult(self.segments) @@ -116,7 +118,9 @@ class _FakeManager: self.active_connections.pop(conversation_id, None) async def send_message(self, conversation_id, message): - self.sent_messages.append({"conversation_id": conversation_id, "message": message}) + self.sent_messages.append( + {"conversation_id": conversation_id, "message": message} + ) def get_or_create_segment_state(self, conversation_id, voice_session_id): state_key = (conversation_id, voice_session_id) @@ -149,13 +153,16 @@ def _make_user(): def _db_provider(db): """返回可被 patch 到 get_async_db 的异步生成器(旧用法)。""" + async def _provider(): yield db + return _provider class _FakeSessionCM: """模拟 async with AsyncSessionLocal() as db 的上下文管理器。""" + def __init__(self, db): self._db = db @@ -168,8 +175,10 @@ class _FakeSessionCM: def _session_local_factory(fake_db): """返回可 patch 到 AsyncSessionLocal 的工厂,使 async with AsyncSessionLocal() as db 得到 fake_db。""" + def _factory(): return _FakeSessionCM(fake_db) + return _factory @@ -215,7 +224,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -228,13 +239,22 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.router.check_ws_quota", new=AsyncMock(return_value=(True, ""))) + patch( + "app.features.conversation.ws.router.check_ws_quota", + new=AsyncMock(return_value=(True, "")), + ) ) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) await ws_router.websocket_endpoint(fake_websocket, "conv-1") @@ -277,7 +297,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -290,13 +312,22 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.router.check_ws_quota", new=AsyncMock(return_value=(True, ""))) + patch( + "app.features.conversation.ws.router.check_ws_quota", + new=AsyncMock(return_value=(True, "")), + ) ) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) stack.enter_context( patch( @@ -364,7 +395,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -377,10 +410,16 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) stack.enter_context( patch( @@ -406,7 +445,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): transcribe_mock.assert_awaited_once_with(b"fake-audio-b64", "m4a") process_user_message_mock.assert_not_awaited() - self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 0) + self.assertEqual( + len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 0 + ) transcript_msgs = [ item["message"] @@ -435,7 +476,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -448,7 +491,11 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch.object( @@ -491,7 +538,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -504,7 +553,11 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( patch.object( @@ -558,7 +611,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -571,13 +626,22 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.router.check_ws_quota", new=AsyncMock(return_value=(True, ""))) + patch( + "app.features.conversation.ws.router.check_ws_quota", + new=AsyncMock(return_value=(True, "")), + ) ) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) stack.enter_context( patch( @@ -608,7 +672,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): if item["message"]["type"] == ws_router.MessageType.ERROR ] self.assertEqual(len(error_msgs), 1) - self.assertEqual(error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入") + self.assertEqual( + error_msgs[0]["data"]["message"], "语音转写失败,请重试或使用文字输入" + ) async def test_audio_segment_out_of_order_is_aggregated_by_segment_index(self): user = _make_user() @@ -642,9 +708,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) process_user_message_mock = AsyncMock() - transcribe_mock = AsyncMock( - side_effect=["这是第 1 段", "这是第 0 段"] - ) + transcribe_mock = AsyncMock(side_effect=["这是第 1 段", "这是第 0 段"]) with ExitStack() as stack: stack.enter_context( @@ -655,7 +719,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -668,13 +734,22 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.router.check_ws_quota", new=AsyncMock(return_value=(True, ""))) + patch( + "app.features.conversation.ws.router.check_ws_quota", + new=AsyncMock(return_value=(True, "")), + ) ) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) stack.enter_context( patch( @@ -701,10 +776,13 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): self.assertEqual(transcribe_mock.await_count, 2) ordered_messages = [ - call.kwargs["user_message"] for call in process_user_message_mock.await_args_list + call.kwargs["user_message"] + for call in process_user_message_mock.await_args_list ] self.assertEqual(ordered_messages, ["这是第 0 段", "这是第 1 段"]) - self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2) + self.assertEqual( + len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2 + ) transcript_msgs = [ item["message"] for item in fake_manager.sent_messages @@ -756,7 +834,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -769,13 +849,22 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.router.check_ws_quota", new=AsyncMock(return_value=(True, ""))) + patch( + "app.features.conversation.ws.router.check_ws_quota", + new=AsyncMock(return_value=(True, "")), + ) ) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) stack.enter_context( patch( @@ -802,9 +891,13 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): self.assertEqual(transcribe_mock.await_count, 1) process_user_message_mock.assert_awaited_once() - self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 1) + self.assertEqual( + len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 1 + ) - async def test_audio_segment_same_index_is_allowed_for_different_voice_sessions(self): + async def test_audio_segment_same_index_is_allowed_for_different_voice_sessions( + self, + ): user = _make_user() conversation = Conversation(id="conv-1", user_id=user.id, status="active") fake_db = _FakeAsyncDB(user=user, conversation=conversation) @@ -838,9 +931,7 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) process_user_message_mock = AsyncMock() - transcribe_mock = AsyncMock( - side_effect=["第一轮第 0 段", "第二轮第 0 段"] - ) + transcribe_mock = AsyncMock(side_effect=["第一轮第 0 段", "第二轮第 0 段"]) with ExitStack() as stack: stack.enter_context( @@ -851,7 +942,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -864,13 +957,22 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.router.check_ws_quota", new=AsyncMock(return_value=(True, ""))) + patch( + "app.features.conversation.ws.router.check_ws_quota", + new=AsyncMock(return_value=(True, "")), + ) ) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) stack.enter_context( patch( @@ -896,11 +998,14 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): await asyncio.sleep(0.05) ordered_messages = [ - call.kwargs["user_message"] for call in process_user_message_mock.await_args_list + call.kwargs["user_message"] + for call in process_user_message_mock.await_args_list ] self.assertEqual(ordered_messages, ["第一轮第 0 段", "第二轮第 0 段"]) self.assertEqual(transcribe_mock.await_count, 2) - self.assertEqual(len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2) + self.assertEqual( + len([obj for obj in fake_db.added if isinstance(obj, Segment)]), 2 + ) async def test_audio_segment_sends_transition_feedback_while_processing(self): user = _make_user() @@ -938,7 +1043,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -951,13 +1058,22 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.router.check_ws_quota", new=AsyncMock(return_value=(True, ""))) + patch( + "app.features.conversation.ws.router.check_ws_quota", + new=AsyncMock(return_value=(True, "")), + ) ) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) stack.enter_context( patch( @@ -999,7 +1115,10 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): fake_manager = _FakeManager() fake_websocket = _FakeWebSocket( messages=[ - {"type": "recording_started", "data": {"voice_session_id": "session-1"}}, + { + "type": "recording_started", + "data": {"voice_session_id": "session-1"}, + }, WebSocketDisconnect(), ] ) @@ -1013,7 +1132,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -1025,10 +1146,17 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.pipeline.LISTENING_FEEDBACK_DELAY_SEC", 0.05) + patch( + "app.features.conversation.ws.pipeline.LISTENING_FEEDBACK_DELAY_SEC", + 0.05, + ) ) await ws_router.websocket_endpoint(fake_websocket, "conv-1") @@ -1077,7 +1205,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -1090,13 +1220,22 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.router.check_ws_quota", new=AsyncMock(return_value=(True, ""))) + patch( + "app.features.conversation.ws.router.check_ws_quota", + new=AsyncMock(return_value=(True, "")), + ) ) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) stack.enter_context( patch( @@ -1130,7 +1269,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ] self.assertEqual(len(transition_msgs), 0) - async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment(self): + async def test_audio_segment_continues_after_reconnect_with_existing_previous_segment( + self, + ): user = _make_user() conversation = Conversation(id="conv-1", user_id=user.id, status="active") existing_segment = Segment( @@ -1175,7 +1316,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -1188,13 +1331,22 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.router.check_ws_quota", new=AsyncMock(return_value=(True, ""))) + patch( + "app.features.conversation.ws.router.check_ws_quota", + new=AsyncMock(return_value=(True, "")), + ) ) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) stack.enter_context( patch( @@ -1279,7 +1431,9 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): ) ) stack.enter_context( - patch.object(ws_router, "AsyncSessionLocal", _session_local_factory(fake_db)) + patch.object( + ws_router, "AsyncSessionLocal", _session_local_factory(fake_db) + ) ) stack.enter_context( patch( @@ -1292,13 +1446,22 @@ class WebSocketBaselineTest(unittest.IsolatedAsyncioTestCase): stack.enter_context( patch("app.features.conversation.ws.pipeline.manager", fake_manager) ) - stack.enter_context(patch.object(ws_router, "background_runner", fake_manager.background_runner)) + stack.enter_context( + patch.object( + ws_router, "background_runner", fake_manager.background_runner + ) + ) stack.enter_context(_redis_empty_history_patch()) stack.enter_context( - patch("app.features.conversation.ws.router.check_ws_quota", new=AsyncMock(return_value=(True, ""))) + patch( + "app.features.conversation.ws.router.check_ws_quota", + new=AsyncMock(return_value=(True, "")), + ) ) stack.enter_context( - patch.object(ws_router, "process_user_message", process_user_message_mock) + patch.object( + ws_router, "process_user_message", process_user_message_mock + ) ) stack.enter_context( patch(