From 53e0065e3e688a95008a14b3d6e2201169291ad8 Mon Sep 17 00:00:00 2001 From: Sully <101929462+Sullivansome@users.noreply.github.com> Date: Fri, 22 May 2026 13:44:50 +0800 Subject: [PATCH] =?UTF-8?q?refactor(api):=20TOML=20=E9=85=8D=E7=BD=AE=20SS?= =?UTF-8?q?OT=E3=80=81=E7=BB=9F=E4=B8=80=E9=94=99=E8=AF=AF=E5=A5=91?= =?UTF-8?q?=E7=BA=A6=E3=80=81Auth/=E4=BA=8B=E5=8A=A1=E5=8A=A0=E5=9B=BA?= =?UTF-8?q?=E4=B8=8E=E5=8F=AF=E8=A7=82=E6=B5=8B=E6=80=A7=20(#33)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 配置 SSOT(TOML + .env) 统一错误契约 Auth 与事务边界 Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client 可观测性(OpenTelemetry + LGTM) --- .cursor/rules/Backend-Develop-Guideline.mdc | 32 +- .cursor/rules/app-api-client.mdc | 21 + .cursor/rules/backend-testing-strategy.mdc | 13 + .github/workflows/android-release.yml | 2 +- README.md | 8 +- api/.agents/skills/celery-expert/SKILL.md | 630 +++++ .../.cursor-plugin/plugin.json | 14 + .../skills/redis-development/AGENTS.md | 2228 +++++++++++++++++ .../skills/redis-development/README.md | 124 + api/.agents/skills/redis-development/SKILL.md | 121 + .../skills/redis-development/assets/logo.png | Bin 0 -> 21458 bytes .../redis-development/rules/_contributing.md | 97 + .../redis-development/rules/_sections.md | 50 + .../redis-development/rules/_template.md | 52 + .../rules/cluster-hash-tags.md | 78 + .../rules/cluster-read-replicas.md | 55 + .../redis-development/rules/conn-blocking.md | 75 + .../rules/conn-client-cache.md | 70 + .../rules/conn-pipelining.md | 58 + .../redis-development/rules/conn-pooling.md | 71 + .../redis-development/rules/conn-timeouts.md | 41 + .../rules/data-choose-structure.md | 78 + .../rules/data-hash-field-expiry.md | 62 + .../redis-development/rules/data-incr.md | 76 + .../rules/data-key-naming.md | 62 + .../rules/data-transactions.md | 74 + .../rules/json-partial-updates.md | 49 + .../redis-development/rules/json-vs-hash.md | 105 + .../rules/observe-commands.md | 53 + .../rules/observe-metrics.md | 39 + .../redis-development/rules/ram-limits.md | 42 + .../skills/redis-development/rules/ram-ttl.md | 55 + .../redis-development/rules/rqe-dialect.md | 47 + .../rules/rqe-field-types.md | 81 + .../rules/rqe-index-creation.md | 73 + .../rules/rqe-index-management.md | 49 + .../rules/rqe-query-optimization.md | 49 + .../rules/rqe-skip-initial-scan.md | 82 + .../redis-development/rules/security-acls.md | 41 + .../redis-development/rules/security-auth.md | 78 + .../rules/security-network.md | 52 + .../rules/semantic-cache-best-practices.md | 72 + .../rules/semantic-cache-langcache-usage.md | 86 + .../rules/stream-choosing-pattern.md | 44 + .../rules/vector-algorithm-choice.md | 61 + .../rules/vector-hybrid-search.md | 52 + .../rules/vector-index-creation.md | 85 + .../rules/vector-rag-pattern.md | 52 + api/.env.example | 361 +-- api/.env.production | 281 +-- api/.env.staging | 178 +- api/README.md | 126 +- .../versions/0020_refresh_rt_lineage.py | 64 + .../versions/0021_memory_source_segment_id.py | 83 + api/app/adapters/image_gen/liblib_provider.py | 6 +- api/app/adapters/llm/deepseek_eval_judge.py | 14 +- api/app/adapters/llm/zhipu_eval_judge.py | 11 +- api/app/adapters/tts/tencent_tts.py | 3 +- api/app/agents/chat/interview_agent.py | 61 +- api/app/agents/chat/orchestrator.py | 25 +- api/app/agents/chat/personas.py | 4 +- api/app/agents/chat/profile_agent.py | 29 +- api/app/agents/chat/prompts_conversation.py | 7 +- api/app/agents/chat/stage_detection.py | 5 +- api/app/agents/image_prompt/orchestrator.py | 3 +- api/app/agents/memoir/batch_phase1_prep.py | 3 +- api/app/agents/memoir/classification_agent.py | 3 +- api/app/agents/memoir/extraction_agent.py | 3 +- api/app/agents/memoir/fidelity_check_agent.py | 7 +- api/app/agents/memoir/narrative_agent.py | 7 +- api/app/agents/memoir/orchestrator.py | 5 +- api/app/agents/memoir/story_route_agent.py | 8 +- api/app/agents/memoir/story_route_payload.py | 13 +- api/app/core/agent_logging.py | 17 +- api/app/core/alembic_startup.py | 7 +- api/app/core/app_config.py | 69 + api/app/core/app_config_loader.py | 50 + api/app/core/app_config_models.py | 317 +++ api/app/core/auth_deps.py | 53 + api/app/core/celery_broker_dev.py | 2 - api/app/core/config.py | 573 +---- api/app/core/cos_url_keys.py | 7 +- api/app/core/db.py | 70 +- api/app/core/dependencies.py | 119 +- api/app/core/deps_types.py | 21 + api/app/core/error_codes.py | 152 ++ api/app/core/errors.py | 172 +- api/app/core/llm_gateway.py | 17 +- api/app/core/logging.py | 7 +- api/app/core/memoir_pipeline_progress.py | 16 +- api/app/core/memory_compaction_schedule.py | 23 +- api/app/core/middleware.py | 49 +- api/app/core/openapi.py | 118 +- api/app/core/redis.py | 104 +- api/app/core/redis_lock.py | 24 +- api/app/core/redis_sync.py | 44 + api/app/core/redis_urls.py | 96 + api/app/core/runtime_constants.py | 13 + api/app/core/security.py | 3 +- api/app/core/task_tracker.py | 15 +- api/app/core/telemetry.py | 105 +- api/app/features/auth/deps.py | 11 +- api/app/features/auth/integrity.py | 42 + api/app/features/auth/models.py | 9 + api/app/features/auth/repo.py | 85 +- api/app/features/auth/router.py | 417 +-- api/app/features/auth/schemas.py | 50 +- api/app/features/auth/service.py | 590 ++++- api/app/features/content/router.py | 300 +-- api/app/features/conversation/constants.py | 5 + api/app/features/conversation/deps.py | 5 +- .../features/conversation/history_store.py | 133 +- .../features/conversation/input_normalize.py | 12 +- api/app/features/conversation/router.py | 68 +- api/app/features/conversation/schemas.py | 50 +- api/app/features/conversation/service.py | 76 +- .../conversation/ws/connection_manager.py | 5 +- api/app/features/conversation/ws/persist.py | 55 + api/app/features/conversation/ws/pipeline.py | 128 +- .../conversation/ws/profile_collector.py | 15 +- api/app/features/conversation/ws/router.py | 41 +- .../conversation/ws/topic_chips_push.py | 5 +- api/app/features/evaluation/constants.py | 5 + api/app/features/evaluation/deps.py | 23 +- api/app/features/evaluation/errors.py | 18 +- .../features/evaluation/eval_trace_format.py | 5 +- api/app/features/evaluation/gating_service.py | 3 +- api/app/features/evaluation/internal_auth.py | 18 +- .../evaluation/judge_manual_service.py | 27 +- api/app/features/evaluation/judge_service.py | 52 +- api/app/features/evaluation/replay_service.py | 15 +- api/app/features/evaluation/router.py | 208 +- api/app/features/memoir/background_runner.py | 6 +- api/app/features/memoir/constants.py | 5 + api/app/features/memoir/cover_eligibility.py | 3 +- api/app/features/memoir/deps.py | 5 +- api/app/features/memoir/helpers.py | 5 +- .../memoir/memoir_images/prompting.py | 17 +- .../features/memoir/memoir_images/settings.py | 59 +- .../features/memoir/memoir_images/storage.py | 14 +- api/app/features/memoir/oral_normalize.py | 9 +- api/app/features/memoir/router.py | 69 +- api/app/features/memoir/schemas.py | 115 +- api/app/features/memoir/service.py | 96 +- api/app/features/memoir/state_service.py | 106 +- .../features/memoir/story_pipeline_sync.py | 42 +- .../features/memory/chat_memory_injection.py | 7 +- api/app/features/memory/compaction_service.py | 17 +- api/app/features/memory/constants.py | 5 + api/app/features/memory/deps.py | 5 +- api/app/features/memory/enrichment.py | 5 +- api/app/features/memory/ingest_service.py | 127 +- api/app/features/memory/models.py | 1 + api/app/features/memory/repo.py | 31 + api/app/features/memory/router.py | 29 +- api/app/features/memory/service.py | 117 +- api/app/features/payment/deps.py | 5 +- api/app/features/payment/order_service.py | 129 +- api/app/features/payment/payment_config.py | 8 +- .../features/payment/payment_exceptions.py | 18 +- api/app/features/payment/router.py | 21 +- api/app/features/plan/catalog.py | 2 +- api/app/features/plan/deps.py | 11 +- api/app/features/plan/router.py | 27 +- api/app/features/plan/schemas.py | 23 +- api/app/features/plan/service.py | 18 +- api/app/features/quota/deps.py | 5 +- api/app/features/quota/router.py | 8 +- api/app/features/story/constants.py | 5 + api/app/features/story/deps.py | 5 +- api/app/features/story/post_commit.py | 29 +- api/app/features/story/service.py | 132 +- api/app/features/tasks/deps.py | 1 + api/app/features/tasks/router.py | 3 +- api/app/features/user/deps.py | 5 +- api/app/features/user/router.py | 48 +- api/app/features/user/schemas.py | 2 +- api/app/features/user/service.py | 57 +- api/app/internal_main.py | 115 +- api/app/main.py | 88 +- api/app/ports/llm.py | 4 + api/app/tasks/celery_app.py | 75 +- api/app/tasks/chapter_compose_tasks.py | 21 +- api/app/tasks/chapter_cover_enqueue.py | 6 +- api/app/tasks/chapter_cover_tasks.py | 118 +- api/app/tasks/memoir_quality_pass_tasks.py | 60 +- api/app/tasks/memoir_tasks.py | 343 ++- api/app/tasks/memory_compaction_tasks.py | 48 +- api/app/tasks/memory_enrichment_tasks.py | 14 +- api/app/tasks/story_image_tasks.py | 292 +-- api/app/tasks/story_title_tasks.py | 8 +- api/config/default.toml | 239 ++ api/config/development.toml | 21 + api/config/production.toml | 18 + api/config/staging.toml | 19 + api/deploy.sh | 11 +- api/development.sh | 105 +- api/docker-compose.dev.yml | 12 +- api/docker-compose.yml | 77 +- api/docs/configuration.md | 111 + api/docs/internal-eval.md | 6 +- api/docs/observability.md | 108 +- api/docs/本地开发环境配置.md | 16 +- api/docs/部署指南.md | 97 +- api/pyproject.toml | 4 + api/scripts/verify_observability_metrics.sh | 2 +- api/skills-lock.json | 17 + api/static/{home.html => home/index.html} | 0 api/static/legal/privacy.html | 141 ++ api/static/legal/terms.html | 139 + api/tests/conftest.py | 33 +- .../evaluation/test_internal_router_auth.py | 5 +- .../test_memoir_pipeline_run_router.py | 5 +- .../test_memoir_readiness_router.py | 5 +- api/tests/evaluation/test_replay_router.py | 5 +- api/tests/fixtures/config/merge/default.toml | 7 + api/tests/fixtures/config/merge/staging.toml | 5 + .../fixtures/config/minimal/default.toml | 7 + api/tests/support/__init__.py | 0 api/tests/support/auth_async_sqlite.py | 77 + api/tests/test_agent_logging.py | 17 +- api/tests/test_alembic_migration_policy.py | 8 +- api/tests/test_app_config_loader.py | 33 + api/tests/test_app_error_contract.py | 241 ++ api/tests/test_auth_refresh_http.py | 180 ++ api/tests/test_auth_refresh_rotation.py | 234 ++ .../test_auth_sms_login_nested_transaction.py | 103 + api/tests/test_auth_sms_rate_limit.py | 95 + .../test_auth_sms_verify_transactional.py | 371 +++ api/tests/test_avatar_preset_http.py | 29 +- api/tests/test_background_runner.py | 25 +- api/tests/test_chapter_cover_enqueue_redis.py | 76 + api/tests/test_chat_input_normalize.py | 35 +- api/tests/test_chat_stage_detection_gates.py | 3 +- api/tests/test_content_static_http.py | 37 + api/tests/test_conversation_history_list.py | 70 + .../test_conversation_history_turn_ids.py | 35 +- api/tests/test_cors_and_sms_http.py | 122 + api/tests/test_db_transactional.py | 147 ++ api/tests/test_default_toml_legacy_parity.py | 23 + .../test_dialogue_lineage_memory_ingest.py | 70 +- api/tests/test_error_code_registry.py | 120 + api/tests/test_eval_judge_llm_spec.py | 19 +- api/tests/test_fidelity_gate.py | 13 +- api/tests/test_history_store_transactional.py | 141 ++ api/tests/test_http_contract_errors.py | 24 +- api/tests/test_http_router_error_contract.py | 139 + api/tests/test_image_prompt_policy.py | 13 +- api/tests/test_infra_regressions.py | 15 +- api/tests/test_interview_turn_plan.py | 8 + api/tests/test_judge_service.py | 7 +- api/tests/test_main_app_smoke.py | 30 + .../test_memoir_phase1_ingest_idempotency.py | 250 ++ .../test_memoir_pipeline_optimization.py | 19 +- api/tests/test_memoir_pipeline_progress.py | 2 +- api/tests/test_memoir_route_defer.py | 11 +- api/tests/test_memoir_skip_story.py | 3 +- api/tests/test_memoir_two_phase.py | 7 +- api/tests/test_memory_compaction.py | 64 +- api/tests/test_memory_compaction_sweep.py | 30 + api/tests/test_memory_enrichment_baseline.py | 5 +- api/tests/test_mock_sms_login_http.py | 3 +- api/tests/test_openapi_error_response.py | 112 + api/tests/test_oral_normalize.py | 26 +- ..._pipeline_tts_cancel_emits_all_segments.py | 12 +- api/tests/test_recompose_retry_policy.py | 6 +- api/tests/test_redis_sync_client.py | 27 + api/tests/test_redis_urls.py | 46 + api/tests/test_settings_allowlist.py | 55 + ...st_sms_login_new_user_persists_language.py | 2 + .../test_state_service_batch_stage_policy.py | 19 +- api/tests/test_story_route_payload.py | 5 +- api/tests/test_task_tracker_ttl.py | 51 + api/tests/test_ws_pipeline_transactional.py | 237 ++ api/uv.lock | 62 + app-eval-web/src/api.ts | 21 +- app-eval-web/src/mainApi.ts | 26 +- app-eval-web/src/pages/LiveTesterPage.tsx | 11 +- app-eval-web/src/parseApiError.test.ts | 49 + app-eval-web/src/parseApiError.ts | 53 + app-expo/.env.example | 2 +- app-expo/src/app/(auth)/login.tsx | 28 +- app-expo/src/components/info-dialog.tsx | 30 +- app-expo/src/core/api/client.ts | 65 +- app-expo/src/core/api/parseApiError.ts | 49 + app-expo/src/core/api/types.ts | 4 +- app-expo/src/core/auth/refresh-lock.ts | 47 + app-expo/src/core/providers.tsx | 7 +- .../features/conversation/realtime-session.ts | 8 +- app-expo/src/i18n/generated/resources.ts | 10 +- app-expo/src/i18n/locales/en/auth.json | 3 +- app-expo/src/i18n/locales/en/profile.json | 2 +- app-expo/src/i18n/locales/zh/app.json | 2 +- app-expo/src/i18n/locales/zh/auth.json | 3 +- app-expo/src/i18n/locales/zh/profile.json | 2 +- app-expo/tests/core/api/parseApiError.test.ts | 55 + app-expo/tests/core/auth/refresh-lock.test.ts | 46 + assets/demo.html | 2 +- 298 files changed, 15247 insertions(+), 4344 deletions(-) create mode 100644 .cursor/rules/app-api-client.mdc create mode 100644 api/.agents/skills/celery-expert/SKILL.md create mode 100644 api/.agents/skills/redis-development/.cursor-plugin/plugin.json create mode 100644 api/.agents/skills/redis-development/AGENTS.md create mode 100644 api/.agents/skills/redis-development/README.md create mode 100644 api/.agents/skills/redis-development/SKILL.md create mode 100644 api/.agents/skills/redis-development/assets/logo.png create mode 100644 api/.agents/skills/redis-development/rules/_contributing.md create mode 100644 api/.agents/skills/redis-development/rules/_sections.md create mode 100644 api/.agents/skills/redis-development/rules/_template.md create mode 100644 api/.agents/skills/redis-development/rules/cluster-hash-tags.md create mode 100644 api/.agents/skills/redis-development/rules/cluster-read-replicas.md create mode 100644 api/.agents/skills/redis-development/rules/conn-blocking.md create mode 100644 api/.agents/skills/redis-development/rules/conn-client-cache.md create mode 100644 api/.agents/skills/redis-development/rules/conn-pipelining.md create mode 100644 api/.agents/skills/redis-development/rules/conn-pooling.md create mode 100644 api/.agents/skills/redis-development/rules/conn-timeouts.md create mode 100644 api/.agents/skills/redis-development/rules/data-choose-structure.md create mode 100644 api/.agents/skills/redis-development/rules/data-hash-field-expiry.md create mode 100644 api/.agents/skills/redis-development/rules/data-incr.md create mode 100644 api/.agents/skills/redis-development/rules/data-key-naming.md create mode 100644 api/.agents/skills/redis-development/rules/data-transactions.md create mode 100644 api/.agents/skills/redis-development/rules/json-partial-updates.md create mode 100644 api/.agents/skills/redis-development/rules/json-vs-hash.md create mode 100644 api/.agents/skills/redis-development/rules/observe-commands.md create mode 100644 api/.agents/skills/redis-development/rules/observe-metrics.md create mode 100644 api/.agents/skills/redis-development/rules/ram-limits.md create mode 100644 api/.agents/skills/redis-development/rules/ram-ttl.md create mode 100644 api/.agents/skills/redis-development/rules/rqe-dialect.md create mode 100644 api/.agents/skills/redis-development/rules/rqe-field-types.md create mode 100644 api/.agents/skills/redis-development/rules/rqe-index-creation.md create mode 100644 api/.agents/skills/redis-development/rules/rqe-index-management.md create mode 100644 api/.agents/skills/redis-development/rules/rqe-query-optimization.md create mode 100644 api/.agents/skills/redis-development/rules/rqe-skip-initial-scan.md create mode 100644 api/.agents/skills/redis-development/rules/security-acls.md create mode 100644 api/.agents/skills/redis-development/rules/security-auth.md create mode 100644 api/.agents/skills/redis-development/rules/security-network.md create mode 100644 api/.agents/skills/redis-development/rules/semantic-cache-best-practices.md create mode 100644 api/.agents/skills/redis-development/rules/semantic-cache-langcache-usage.md create mode 100644 api/.agents/skills/redis-development/rules/stream-choosing-pattern.md create mode 100644 api/.agents/skills/redis-development/rules/vector-algorithm-choice.md create mode 100644 api/.agents/skills/redis-development/rules/vector-hybrid-search.md create mode 100644 api/.agents/skills/redis-development/rules/vector-index-creation.md create mode 100644 api/.agents/skills/redis-development/rules/vector-rag-pattern.md create mode 100644 api/alembic/versions/0020_refresh_rt_lineage.py create mode 100644 api/alembic/versions/0021_memory_source_segment_id.py create mode 100644 api/app/core/app_config.py create mode 100644 api/app/core/app_config_loader.py create mode 100644 api/app/core/app_config_models.py create mode 100644 api/app/core/auth_deps.py create mode 100644 api/app/core/deps_types.py create mode 100644 api/app/core/error_codes.py create mode 100644 api/app/core/redis_sync.py create mode 100644 api/app/core/redis_urls.py create mode 100644 api/app/core/runtime_constants.py create mode 100644 api/app/features/auth/integrity.py create mode 100644 api/app/features/conversation/constants.py create mode 100644 api/app/features/conversation/ws/persist.py create mode 100644 api/app/features/evaluation/constants.py create mode 100644 api/app/features/memoir/constants.py create mode 100644 api/app/features/memory/constants.py create mode 100644 api/app/features/story/constants.py create mode 100644 api/config/default.toml create mode 100644 api/config/development.toml create mode 100644 api/config/production.toml create mode 100644 api/config/staging.toml create mode 100644 api/docs/configuration.md create mode 100644 api/skills-lock.json rename api/static/{home.html => home/index.html} (100%) create mode 100644 api/static/legal/privacy.html create mode 100644 api/static/legal/terms.html create mode 100644 api/tests/fixtures/config/merge/default.toml create mode 100644 api/tests/fixtures/config/merge/staging.toml create mode 100644 api/tests/fixtures/config/minimal/default.toml create mode 100644 api/tests/support/__init__.py create mode 100644 api/tests/support/auth_async_sqlite.py create mode 100644 api/tests/test_app_config_loader.py create mode 100644 api/tests/test_app_error_contract.py create mode 100644 api/tests/test_auth_refresh_http.py create mode 100644 api/tests/test_auth_refresh_rotation.py create mode 100644 api/tests/test_auth_sms_login_nested_transaction.py create mode 100644 api/tests/test_auth_sms_rate_limit.py create mode 100644 api/tests/test_auth_sms_verify_transactional.py create mode 100644 api/tests/test_chapter_cover_enqueue_redis.py create mode 100644 api/tests/test_content_static_http.py create mode 100644 api/tests/test_conversation_history_list.py create mode 100644 api/tests/test_cors_and_sms_http.py create mode 100644 api/tests/test_db_transactional.py create mode 100644 api/tests/test_default_toml_legacy_parity.py create mode 100644 api/tests/test_error_code_registry.py create mode 100644 api/tests/test_history_store_transactional.py create mode 100644 api/tests/test_http_router_error_contract.py create mode 100644 api/tests/test_main_app_smoke.py create mode 100644 api/tests/test_memoir_phase1_ingest_idempotency.py create mode 100644 api/tests/test_memory_compaction_sweep.py create mode 100644 api/tests/test_openapi_error_response.py create mode 100644 api/tests/test_redis_sync_client.py create mode 100644 api/tests/test_redis_urls.py create mode 100644 api/tests/test_settings_allowlist.py create mode 100644 api/tests/test_task_tracker_ttl.py create mode 100644 api/tests/test_ws_pipeline_transactional.py create mode 100644 app-eval-web/src/parseApiError.test.ts create mode 100644 app-eval-web/src/parseApiError.ts create mode 100644 app-expo/src/core/api/parseApiError.ts create mode 100644 app-expo/src/core/auth/refresh-lock.ts create mode 100644 app-expo/tests/core/api/parseApiError.test.ts create mode 100644 app-expo/tests/core/auth/refresh-lock.test.ts diff --git a/.cursor/rules/Backend-Develop-Guideline.mdc b/.cursor/rules/Backend-Develop-Guideline.mdc index 1dd3e39..6fae574 100644 --- a/.cursor/rules/Backend-Develop-Guideline.mdc +++ b/.cursor/rules/Backend-Develop-Guideline.mdc @@ -11,11 +11,39 @@ alwaysApply: true 4. **feature 间禁止 import router**:跨 feature 调用通过 service 注入 5. **所有 schema 变更走 Alembic**:禁止直接 DDL 6. **新增 provider 必须实现 port protocol**:不得在 feature 中直接调用 SDK -7. **事务边界:repo 不提交,service 管事务**:`repo.py` 只做 `add/delete/query`,`commit/rollback` 由 service 或 UoW 统一执行。`get_async_db()` 不自动 commit +7. **事务边界:repo 不提交,service 管事务**:`repo.py` 只做 `add/delete/query`,`commit/rollback` 由 service 或 Celery task 统一执行。`get_async_db()` 不自动 commit;禁止在 router/repo 中 `await db.commit()` / `session.commit()` 8. **port 边界不可打穿**:service 需要厂商增强能力时,必须扩充 port 或定义第二个窄 port,禁止直接引用 adapter 扩展方法 9. **quota 是独立 feature**:conversation、memoir、payment 如需配额检查,通过注入 `QuotaService`,不得直接 import quota 内部函数 10. **成功响应不强制包装**:`core/errors.py` 只统一错误响应格式 `{error_code, message, request_id}`,成功响应直接返回 Pydantic model / FileResponse / 原始结构 +## 配置 SSOT(TOML + .env) +- **Secrets / bootstrap 仅进 `.env`**:`DATABASE_URL`、`SECRET_KEY`、各厂商 API key、支付私钥等 → `app/core/config.py` 的 `Settings` +- **产品 / 部署调参仅进 TOML**:`api/config/default.toml` + `api/config/{APP_ENV}.toml`(`development` / `staging` / `production`) +- **禁止**在 `Settings` 新增 `chat_*` / `memoir_*` / `memory_*` 等产品字段;`tests/test_settings_allowlist.py` 会拦截 env 反弹 +- **禁止**业务代码散落 `os.getenv()`;读配置走既有入口: + - feature 常量:`from app.features..constants import chat`(等) + - 运行时默认:`from app.core.runtime_constants import llm_defaults, tts_defaults, ...` + - deploy 开关:`settings.enable_tts` 等(`SettingsFacade` 代理到 TOML `[deploy]`) +- 改默认值时 **`default.toml` 与 `app_config_models.py` 必须同步**;`tests/test_default_toml_legacy_parity.py` 锁定关键默认行为 +- 字段对照与运维说明见 `api/docs/configuration.md` + +## 错误处理 +- router / service **抛 `AppError` 子类**(`BadRequestError`、`AuthenticationError`、`NotFoundError`、feature 内 `AuthError` 等),**禁止** `HTTPException` +- 业务专用 `error_code` 登记在 `app/core/error_codes.py`,OpenAPI 通过 `ErrorResponse` 组件文档化 +- 429 语义:`QuotaExceededError` vs `RateLimitedError`;勿用裸 `HTTPException(429)`(legacy handler 无法区分 quota) +- 客户端(`app-expo` / `app-eval-web`)统一用 `parseApiError` 读 `message` / `error_code`(兼容旧 `detail`) + +## 事务 helpers(补充规则 7) +- 多步写、需原子提交 → `transactional()` / `transactional_sync()`(`app/core/db.py`) +- 外层事务 active 时、局部失败可独立回滚 → `transactional_nested()` / `transactional_nested_sync()`(savepoint) +- **同一 session 上连续两次 `transactional()` = 两次独立 commit**(WS 分段持久化等刻意为之);不要假设嵌套合并成一个事务 +- Celery sync 路径用 `transactional_sync()` + `get_sync_db()`;外部副作用(SMS、COS、LLM)放在 commit 成功之后 + +## Redis / Celery +- 业务 key:`settings.redis_url_resolved`(通常 DB/0) +- Celery broker/backend:`settings.celery_redis_url_resolved`(通常 DB/1;compose 显式 `CELERY_REDIS_URL`) +- `REDIS_URL` 使用 DB/15 时必须显式设置 `CELERY_REDIS_URL`(无法 auto +1) + ## 依赖管理(uv) 11. **依赖统一用 uv 管理**:禁止直接 `pip install`、禁止手动编辑 `pyproject.toml` 的 `[project.dependencies]` 或 `[dependency-groups]` 12. **新增依赖用 `uv add `**,dev 依赖用 `uv add --dev `,移除用 `uv remove ` @@ -23,4 +51,4 @@ alwaysApply: true 14. **安装环境统一用 `uv sync`**:开发环境 `uv sync --dev`,生产环境 `uv sync --no-dev` 15. **运行命令统一用 `uv run`**:如 `uv run pytest`、`uv run alembic upgrade head`、`uv run uvicorn ...` 16. 每次添加新代码时,一定要阅读已有部分代码,确保符合项目架构,pattern,loguru模式 -17. **日志**:业务代码使用 `app.core.logging.get_logger(__name__)`(loguru `bind`);禁止在 `app` 包内用 `import logging` 取业务 logger;仅第三方 SDK / 适配器层可对标准库 `logging` 做桥接(如 `InterceptHandler`) \ No newline at end of file +17. **日志**:业务代码使用 `app.core.logging.get_logger(__name__)`(loguru `bind`);禁止在 `app` 包内用 `import logging` 取业务 logger;仅第三方 SDK / 适配器层可对标准库 `logging` 做桥接(如 `InterceptHandler`) diff --git a/.cursor/rules/app-api-client.mdc b/.cursor/rules/app-api-client.mdc new file mode 100644 index 0000000..3eabc4a --- /dev/null +++ b/.cursor/rules/app-api-client.mdc @@ -0,0 +1,21 @@ +--- +description: Mobile and eval-web API client error handling and auth refresh +globs: app-expo/src/**/*.{ts,tsx},app-eval-web/src/**/*.{ts,tsx} +alwaysApply: false +--- + +# API Client(app-expo / app-eval-web) + +后端错误体为 `{ error_code, message, request_id }`(旧版可能仍有 `{ detail }`)。 + +## 错误解析 +- **统一使用 `parseApiError`**(`app-expo/src/core/api/parseApiError.ts`、`app-eval-web/src/parseApiError.ts`) +- 展示给用户 / toast 用返回的 `message`;分支逻辑用 `errorCode`(对应后端 `error_code`) +- **禁止**在新代码里直接读 `body.detail` 或假设 FastAPI 默认错误形状 + +## Token 刷新 +- `app-expo` 并发 401 refresh 必须走 `refresh-lock.ts`,避免双端同时 refresh 触发 `REFRESH_TOKEN_REUSE` +- refresh 失败应清 session 并引导重新登录,不要无限重试同一 refresh token + +## 类型 +- 共享错误 body 形状见各 app 的 `ApiErrorBody` / `types.ts`;与 OpenAPI `ErrorResponse` 对齐 diff --git a/.cursor/rules/backend-testing-strategy.mdc b/.cursor/rules/backend-testing-strategy.mdc index 261ed26..5d7f81c 100644 --- a/.cursor/rules/backend-testing-strategy.mdc +++ b/.cursor/rules/backend-testing-strategy.mdc @@ -19,10 +19,23 @@ alwaysApply: false - 用 `httpx.AsyncClient` + `ASGITransport` 进行异步 HTTP 测试。 - `api/tests/conftest.py` 和 `api/tests/factories.py` 只提供通用基础设施,不要把具体业务测试硬编码进去。 +## HTTP 错误契约 + +- 失败响应断言 **`error_code` + `message`**(及可选 `request_id`),不要断言 FastAPI 旧格式 `{ detail }`。 +- OpenAPI / router 级 smoke:参考 `tests/test_http_router_error_contract.py`、`tests/test_openapi_error_response.py`。 +- Auth / payment / SMS 等 feature 码:断言具体 `error_code`(如 `INVALID_SMS_CODE`、`REFRESH_TOKEN_REUSE`),不只测 status code。 + +## 配置 / TOML 测试 + +- TOML 加载测试用 `CONFIG_DIR` 指向 fixture 目录,或 `reload_app_config()`;见 `tests/test_app_config_loader.py`。 +- **不要**用 `os.environ["OTEL_ENABLED"]` 等已迁出 Settings 的 env 键驱动应用行为(运行时已读 TOML)。 +- 新增 TOML 默认值时,若影响产品行为,更新 `tests/test_default_toml_legacy_parity.py` 或明确 overlay 意图。 + ## Test Layering - HTTP 场景测试:注册、登录、刷新、登出、受保护资源访问、关键资源 CRUD、重要失败分支。 - 纯单元测试:纯函数、规则计算、序列化、适配器错误分支与格式转换。 +- Service / 事务边界测试:仅当保护明确业务承诺(如 refresh rotation、SMS 回滚、幂等 ingest)且 HTTP 层难以稳定复现时允许;参考 `tests/test_auth_refresh_rotation.py`、`tests/test_db_transactional.py`。 - 手工 / E2E:WebSocket 多轮对话、真实短信、Celery + Redis + LLM 编排、支付/对象存储/ASR/图像生成联调。 ## Decision Filter diff --git a/.github/workflows/android-release.yml b/.github/workflows/android-release.yml index c5bbed3..fb85a6f 100644 --- a/.github/workflows/android-release.yml +++ b/.github/workflows/android-release.yml @@ -14,7 +14,7 @@ on: type: string env: - APP_NAME: 岁月时书 + APP_NAME: 岁月留书 jobs: build-release-apk: diff --git a/README.md b/README.md index 2740fdb..0fa25db 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ -# 岁月时书 (Life Echo) +# 岁月留书 (Life Echo) > 一个基于实时 WebSocket 长连接的智能语音对话回忆录生成系统 ## 📖 项目简介 -**岁月时书 (Life Echo)** 是一个创新的回忆录生成平台,通过 AI 智能对话引导用户回顾人生历程,并将口语对话自动整理为结构化的回忆录章节,最终生成精美的 PDF 电子书。 +**岁月留书 (Life Echo)** 是一个创新的回忆录生成平台,通过 AI 智能对话引导用户回顾人生历程,并将口语对话自动整理为结构化的回忆录章节,最终生成精美的 PDF 电子书。 后端侧:会话轮次以 DB `conversation_messages` 为真源、Redis 为缓存;实时对话编排统一走 `ChatOrchestrator`;图像任务为 `generate_story_image`(正文)与 `generate_chapter_cover`(章节封面)。详见 [api/README.md](api/README.md)。 @@ -146,7 +146,7 @@ uv run celery -A tasks.celery_app worker --loglevel=info --pool=solo 1. **环境变量安全**:确保 `.env` 文件不被提交到版本控制 2. **SECRET_KEY 安全**:生产环境必须使用强随机字符串 -3. **CORS 配置**:生产环境应限制为特定域名 +3. **CORS 配置**:本地 `API_CORS_ORIGINS` 可留空;生产/staging 须设为前端域名(逗号分隔),否则浏览器无法携带 credentials 跨域 4. **API Key 安全**:妥善保管 LLM API Key 5. **密码安全**:密码使用 bcrypt 哈希存储 @@ -171,4 +171,4 @@ MIT License --- -**岁月时书** - 让每一段人生故事都被温柔记录 ✨ +**岁月留书** - 让每一段人生故事都被温柔记录 ✨ diff --git a/api/.agents/skills/celery-expert/SKILL.md b/api/.agents/skills/celery-expert/SKILL.md new file mode 100644 index 0000000..fb9ef9b --- /dev/null +++ b/api/.agents/skills/celery-expert/SKILL.md @@ -0,0 +1,630 @@ +--- +name: celery-expert +description: "Expert Celery distributed task queue engineer specializing in async task processing, workflow orchestration, broker configuration (Redis/RabbitMQ), Celery Beat scheduling, and production monitoring. Deep expertise in task patterns (chains, groups, chords), retries, rate limiting, Flower monitoring, and security best practices. Use when designing distributed task systems, implementing background job processing, building workflow orchestration, or optimizing task queue performance." +model: sonnet +--- + +# Celery Distributed Task Queue Expert + +## 1. Overview + +You are an elite Celery engineer with deep expertise in: + +- **Core Celery**: Task definition, async execution, result backends, task states, routing +- **Workflow Patterns**: Chains, groups, chords, canvas primitives, complex workflows +- **Brokers**: Redis vs RabbitMQ trade-offs, connection pools, broker failover +- **Result Backends**: Redis, database, memcached, result expiration, state tracking +- **Task Reliability**: Retries, exponential backoff, acks late, task rejection, idempotency +- **Scheduling**: Celery Beat, crontab schedules, interval tasks, solar schedules +- **Performance**: Prefetch multiplier, concurrency models (prefork, gevent, eventlet), autoscaling +- **Monitoring**: Flower, Prometheus metrics, task inspection, worker management +- **Security**: Task signature validation, secure serialization (no pickle), message signing +- **Error Handling**: Dead letter queues, task timeouts, exception handling, logging + +### Core Principles + +1. **TDD First** - Write tests before implementation; verify task behavior with pytest-celery +2. **Performance Aware** - Optimize for throughput with chunking, pooling, and proper prefetch +3. **Reliability** - Task retries, acknowledgment strategies, no task loss +4. **Scalability** - Distributed workers, routing, autoscaling, queue prioritization +5. **Security** - Signed tasks, safe serialization, broker authentication +6. **Observable** - Comprehensive monitoring, metrics, tracing, alerting + +**Risk Level**: MEDIUM +- Task processing failures can impact business operations +- Improper serialization (pickle) can lead to code execution vulnerabilities +- Missing retries/timeouts can cause task accumulation and system degradation +- Broker misconfigurations can lead to task loss or message exposure + +--- + +## 2. Implementation Workflow (TDD) + +### Step 1: Write Failing Test First + +```python +# tests/test_tasks.py +import pytest +from celery.contrib.testing.tasks import ping +from celery.result import EagerResult + +@pytest.fixture +def celery_config(): + return { + 'broker_url': 'memory://', + 'result_backend': 'cache+memory://', + 'task_always_eager': True, + 'task_eager_propagates': True, + } + +class TestProcessOrder: + def test_process_order_success(self, celery_app, celery_worker): + """Test order processing returns correct result""" + from myapp.tasks import process_order + + # Execute task + result = process_order.delay(order_id=123) + + # Assert expected behavior + assert result.get(timeout=10) == { + 'order_id': 123, + 'status': 'success' + } + + def test_process_order_idempotent(self, celery_app, celery_worker): + """Test task is idempotent - safe to retry""" + from myapp.tasks import process_order + + # Run twice + result1 = process_order.delay(order_id=123).get(timeout=10) + result2 = process_order.delay(order_id=123).get(timeout=10) + + # Should be safe to retry + assert result1['status'] in ['success', 'already_processed'] + assert result2['status'] in ['success', 'already_processed'] + + def test_process_order_retry_on_failure(self, celery_app, celery_worker, mocker): + """Test task retries on temporary failure""" + from myapp.tasks import process_order + + # Mock to fail first, succeed second + mock_process = mocker.patch('myapp.tasks.perform_order_processing') + mock_process.side_effect = [TemporaryError("Timeout"), {'result': 'ok'}] + + result = process_order.delay(order_id=123) + + assert result.get(timeout=10)['status'] == 'success' + assert mock_process.call_count == 2 +``` + +### Step 2: Implement Minimum to Pass + +```python +# myapp/tasks.py +from celery import Celery + +app = Celery('tasks', broker='redis://localhost:6379/0') + +@app.task(bind=True, max_retries=3) +def process_order(self, order_id: int): + try: + order = get_order(order_id) + if order.status == 'processed': + return {'order_id': order_id, 'status': 'already_processed'} + + result = perform_order_processing(order) + return {'order_id': order_id, 'status': 'success'} + except TemporaryError as exc: + raise self.retry(exc=exc, countdown=2 ** self.request.retries) +``` + +### Step 3: Refactor Following Patterns + +Add proper error handling, time limits, and observability. + +### Step 4: Run Full Verification + +```bash +# Run all Celery tests +pytest tests/test_tasks.py -v + +# Run with coverage +pytest tests/test_tasks.py --cov=myapp.tasks --cov-report=term-missing + +# Test workflow patterns +pytest tests/test_workflows.py -v + +# Integration test with real broker +pytest tests/integration/ --broker=redis://localhost:6379/0 +``` + +--- + +## 3. Performance Patterns + +### Pattern 1: Task Chunking + +```python +# Bad - Individual tasks for each item +for item_id in item_ids: # 10,000 items = 10,000 tasks + process_item.delay(item_id) + +# Good - Process in batches +@app.task +def process_batch(item_ids: list): + """Process items in chunks for efficiency""" + results = [] + for chunk in chunks(item_ids, size=100): + items = fetch_items_bulk(chunk) # Single DB query + results.extend([process(item) for item in items]) + return results + +# Dispatch in chunks +for chunk in chunks(item_ids, size=100): + process_batch.delay(chunk) # 100 tasks instead of 10,000 +``` + +### Pattern 2: Prefetch Tuning + +```python +# Bad - Default prefetch for I/O-bound tasks +app.conf.worker_prefetch_multiplier = 4 # Too many reserved + +# Good - Tune based on task type +# CPU-bound: Higher prefetch, fewer workers +app.conf.worker_prefetch_multiplier = 4 +# celery -A app worker --concurrency=4 + +# I/O-bound: Lower prefetch, more workers +app.conf.worker_prefetch_multiplier = 1 +# celery -A app worker --pool=gevent --concurrency=100 + +# Long tasks: Disable prefetch +app.conf.worker_prefetch_multiplier = 1 +app.conf.task_acks_late = True +``` + +### Pattern 3: Result Backend Optimization + +```python +# Bad - Storing results for fire-and-forget tasks +@app.task +def send_email(to, subject, body): + mailer.send(to, subject, body) + return {'sent': True} # Stored in Redis unnecessarily + +# Good - Ignore results when not needed +@app.task(ignore_result=True) +def send_email(to, subject, body): + mailer.send(to, subject, body) + +# Good - Set expiration for results you need +app.conf.result_expires = 3600 # 1 hour + +# Good - Store minimal data, reference external storage +@app.task +def process_large_file(file_id): + data = process(read_file(file_id)) + result_key = save_to_s3(data) # Store large result externally + return {'result_key': result_key} # Store only reference +``` + +### Pattern 4: Connection Pooling + +```python +# Bad - Creating new connections per task +@app.task +def query_database(query): + conn = psycopg2.connect(...) # New connection each time + result = conn.execute(query) + conn.close() + return result + +# Good - Use connection pools +from sqlalchemy import create_engine +from redis import ConnectionPool, Redis + +# Initialize once at module level +db_engine = create_engine( + 'postgresql://user:pass@localhost/db', + pool_size=20, + max_overflow=10, + pool_pre_ping=True +) +redis_pool = ConnectionPool(host='localhost', port=6379, max_connections=50) + +@app.task +def query_database(query): + with db_engine.connect() as conn: # Uses pool + return conn.execute(query).fetchall() + +@app.task +def cache_result(key, value): + redis = Redis(connection_pool=redis_pool) # Uses pool + redis.set(key, value) +``` + +### Pattern 5: Task Routing + +```python +# Bad - All tasks in single queue +@app.task +def critical_payment(): pass + +@app.task +def generate_report(): pass # Blocks payment processing + +# Good - Route to dedicated queues +from kombu import Queue, Exchange + +app.conf.task_queues = ( + Queue('critical', Exchange('critical'), routing_key='critical'), + Queue('default', Exchange('default'), routing_key='default'), + Queue('bulk', Exchange('bulk'), routing_key='bulk'), +) + +app.conf.task_routes = { + 'tasks.critical_payment': {'queue': 'critical'}, + 'tasks.generate_report': {'queue': 'bulk'}, +} + +# Run dedicated workers per queue +# celery -A app worker -Q critical --concurrency=4 +# celery -A app worker -Q bulk --concurrency=2 +``` + +--- + +## 4. Core Responsibilities + +### 1. Task Design & Workflow Orchestration +- Define tasks with proper decorators (`@app.task`, `@shared_task`) +- Implement idempotent tasks (safe to retry) +- Use chains for sequential execution, groups for parallel, chords for map-reduce +- Design task routing to specific queues/workers +- Avoid long-running tasks (break into subtasks) + +### 2. Broker Configuration & Management +- Choose Redis for simplicity, RabbitMQ for reliability +- Configure connection pools, heartbeats, and failover +- Enable broker authentication and encryption (TLS) +- Monitor broker health and connection states + +### 3. Task Reliability & Error Handling +- Implement retry logic with exponential backoff +- Use `acks_late=True` for critical tasks +- Set appropriate task time limits (soft/hard) +- Handle exceptions gracefully with error callbacks +- Implement dead letter queues for failed tasks +- Design idempotent tasks to handle retries safely + +### 4. Result Backends & State Management +- Choose appropriate result backend (Redis, database, RPC) +- Set result expiration to prevent memory leaks +- Use `ignore_result=True` for fire-and-forget tasks +- Store minimal data in results (use external storage) + +### 5. Celery Beat Scheduling +- Define crontab schedules for recurring tasks +- Use interval schedules for simple periodic tasks +- Configure Beat scheduler persistence (database backend) +- Avoid scheduling conflicts with task locks + +### 6. Monitoring & Observability +- Deploy Flower for real-time monitoring +- Export Prometheus metrics for alerting +- Track task success/failure rates and queue lengths +- Implement distributed tracing (correlation IDs) +- Log task execution with context + +--- + +## 5. Implementation Patterns + +### Pattern 1: Task Definition Best Practices + +```python +# COMPLETE TASK DEFINITION +from celery import Celery +from celery.exceptions import SoftTimeLimitExceeded +import logging + +app = Celery('tasks', broker='redis://localhost:6379/0') +logger = logging.getLogger(__name__) + +@app.task( + bind=True, + name='tasks.process_order', + max_retries=3, + default_retry_delay=60, + acks_late=True, + reject_on_worker_lost=True, + time_limit=300, + soft_time_limit=240, + rate_limit='100/m', +) +def process_order(self, order_id: int): + """Process order with proper error handling and retries""" + try: + logger.info(f"Processing order {order_id}", extra={'task_id': self.request.id}) + + order = get_order(order_id) + if order.status == 'processed': + return {'order_id': order_id, 'status': 'already_processed'} + + result = perform_order_processing(order) + return {'order_id': order_id, 'status': 'success', 'result': result} + + except SoftTimeLimitExceeded: + cleanup_processing(order_id) + raise + except TemporaryError as exc: + raise self.retry(exc=exc, countdown=2 ** self.request.retries) + except PermanentError as exc: + send_failure_notification(order_id, str(exc)) + raise +``` + +### Pattern 2: Workflow Patterns (Chains, Groups, Chords) + +```python +from celery import chain, group, chord + +# CHAIN: Sequential execution (A -> B -> C) +workflow = chain( + fetch_data.s('https://api.example.com/data'), + process_item.s(), + send_notification.s() +) + +# GROUP: Parallel execution +job = group(fetch_data.s(url) for url in urls) + +# CHORD: Map-Reduce (parallel + callback) +workflow = chord( + group(process_item.s(item) for item in items) +)(aggregate_results.s()) +``` + +### Pattern 3: Production Configuration + +```python +from kombu import Exchange, Queue + +app = Celery('myapp') +app.conf.update( + broker_url='redis://localhost:6379/0', + broker_connection_retry_on_startup=True, + broker_pool_limit=10, + + result_backend='redis://localhost:6379/1', + result_expires=3600, + + task_serializer='json', + result_serializer='json', + accept_content=['json'], + + task_acks_late=True, + task_reject_on_worker_lost=True, + task_time_limit=300, + task_soft_time_limit=240, + + worker_prefetch_multiplier=4, + worker_max_tasks_per_child=1000, +) +``` + +### Pattern 4: Retry Strategies & Error Handling + +```python +from celery.exceptions import Reject + +@app.task( + bind=True, + max_retries=5, + autoretry_for=(RequestException,), + retry_backoff=True, + retry_backoff_max=600, + retry_jitter=True, +) +def call_external_api(self, url: str): + """Auto-retry on RequestException with exponential backoff""" + response = requests.get(url, timeout=10) + response.raise_for_status() + return response.json() +``` + +### Pattern 5: Celery Beat Scheduling + +```python +from celery.schedules import crontab +from datetime import timedelta + +app.conf.beat_schedule = { + 'cleanup-temp-files': { + 'task': 'tasks.cleanup_temp_files', + 'schedule': timedelta(minutes=10), + }, + 'daily-report': { + 'task': 'tasks.generate_daily_report', + 'schedule': crontab(hour=3, minute=0), + }, +} +``` + +--- + +## 6. Security Standards + +### 6.1 Secure Serialization + +```python +# DANGEROUS: Pickle allows code execution +app.conf.task_serializer = 'pickle' # NEVER! + +# SECURE: Use JSON +app.conf.update( + task_serializer='json', + result_serializer='json', + accept_content=['json'], +) +``` + +### 6.2 Broker Authentication & TLS + +```python +# Redis with TLS +app.conf.broker_url = 'redis://:password@localhost:6379/0' +app.conf.broker_use_ssl = { + 'ssl_cert_reqs': 'required', + 'ssl_ca_certs': '/path/to/ca.pem', +} + +# RabbitMQ with TLS +app.conf.broker_url = 'amqps://user:password@localhost:5671/vhost' +``` + +### 6.3 Input Validation + +```python +from pydantic import BaseModel + +class OrderData(BaseModel): + order_id: int + amount: float + +@app.task +def process_order_validated(order_data: dict): + validated = OrderData(**order_data) + return process_order(validated.dict()) +``` + +--- + +## 7. Common Mistakes + +### Mistake 1: Using Pickle Serialization +```python +# DON'T +app.conf.task_serializer = 'pickle' +# DO +app.conf.task_serializer = 'json' +``` + +### Mistake 2: Not Making Tasks Idempotent +```python +# DON'T: Retries increment multiple times +@app.task +def increment_counter(user_id): + user.counter += 1 + user.save() + +# DO: Safe to retry +@app.task +def set_counter(user_id, value): + user.counter = value + user.save() +``` + +### Mistake 3: Missing Time Limits +```python +# DON'T +@app.task +def slow_task(): + external_api_call() + +# DO +@app.task(time_limit=30, soft_time_limit=25) +def safe_task(): + external_api_call() +``` + +### Mistake 4: Storing Large Results +```python +# DON'T +@app.task +def process_file(file_id): + return read_large_file(file_id) # Stored in Redis! + +# DO +@app.task +def process_file(file_id): + result_id = save_to_storage(read_large_file(file_id)) + return {'result_id': result_id} +``` + +--- + +## 8. Pre-Implementation Checklist + +### Phase 1: Before Writing Code + +- [ ] Write failing test for task behavior +- [ ] Define task idempotency strategy +- [ ] Choose queue routing for task priority +- [ ] Determine result storage needs (ignore_result?) +- [ ] Plan retry strategy and error handling +- [ ] Review security requirements (serialization, auth) + +### Phase 2: During Implementation + +- [ ] Task has time limits (soft and hard) +- [ ] Task uses `acks_late=True` for critical work +- [ ] Task validates inputs with Pydantic +- [ ] Task logs with correlation ID +- [ ] Connection pools configured for DB/Redis +- [ ] Results stored externally if large + +### Phase 3: Before Committing + +- [ ] All tests pass: `pytest tests/test_tasks.py -v` +- [ ] Coverage adequate: `pytest --cov=myapp.tasks` +- [ ] Serialization set to JSON (not pickle) +- [ ] Broker authentication configured +- [ ] Result expiration set +- [ ] Monitoring configured (Flower/Prometheus) +- [ ] Task routes documented +- [ ] Dead letter queue handling implemented + +--- + +## 9. Critical Reminders + +### NEVER + +- Use pickle serialization +- Run without time limits +- Store large data in results +- Create non-idempotent tasks +- Run without broker authentication +- Expose Flower without authentication + +### ALWAYS + +- Use JSON serialization +- Set time limits (soft and hard) +- Make tasks idempotent +- Use `acks_late=True` for critical tasks +- Set result expiration +- Implement retry logic with backoff +- Monitor with Flower/Prometheus +- Validate task inputs +- Log with correlation IDs + +--- + +## 10. Summary + +You are a Celery expert focused on: +1. **TDD First** - Write tests before implementation +2. **Performance** - Chunking, pooling, prefetch tuning, routing +3. **Reliability** - Retries, acks_late, idempotency +4. **Security** - JSON serialization, message signing, broker auth +5. **Observability** - Flower monitoring, Prometheus metrics, tracing + +**Key Principles**: +- Tasks must be idempotent - safe to retry without side effects +- TDD ensures task behavior is verified before deployment +- Performance tuning - prefetch, chunking, connection pooling, routing +- Security first - never use pickle, always authenticate +- Monitor everything - queue lengths, task latency, failure rates diff --git a/api/.agents/skills/redis-development/.cursor-plugin/plugin.json b/api/.agents/skills/redis-development/.cursor-plugin/plugin.json new file mode 100644 index 0000000..f09823e --- /dev/null +++ b/api/.agents/skills/redis-development/.cursor-plugin/plugin.json @@ -0,0 +1,14 @@ +{ + "name": "redis-development", + "version": "1.0.0", + "description": "Redis development best practices — data structures, query engine, vector search, caching, and performance optimization", + "author": { + "name": "Redis", + "email": "support@redis.com" + }, + "homepage": "https://redis.io", + "repository": "https://github.com/redis/agent-skills", + "license": "MIT", + "keywords": ["redis", "database", "caching", "vector-search", "performance", "best-practices"], + "logo": "assets/logo.png" +} diff --git a/api/.agents/skills/redis-development/AGENTS.md b/api/.agents/skills/redis-development/AGENTS.md new file mode 100644 index 0000000..a803302 --- /dev/null +++ b/api/.agents/skills/redis-development/AGENTS.md @@ -0,0 +1,2228 @@ +# Redis Development + +**Version 1.0.0** +Redis, Inc. +January 2026 + +> **Note:** +> This document is mainly for agents and LLMs to follow when maintaining, +> generating, or refactoring Redis applications. Humans +> may also find it useful, but guidance here is optimized for automation +> and consistency by AI-assisted workflows. + +--- + +## Abstract + +Best practices for Redis including data structures, memory management, Redis Query Engine (RQE), vector search with RedisVL, semantic caching with LangCache, and performance optimization. Optimized for AI agents and LLMs. + +--- + +## Table of Contents + +1. [Data Structures & Keys](#1-data-structures--keys) — **HIGH** + - 1.1 [Choose the Right Data Structure](#11-choose-the-right-data-structure) + - 1.2 [Use Consistent Key Naming Conventions](#12-use-consistent-key-naming-conventions) + - 1.3 [Use Hash Field Expiration for Per-Field TTL](#13-use-hash-field-expiration-for-per-field-ttl) + - 1.4 [Use INCR for Atomic Counters](#14-use-incr-for-atomic-counters) + - 1.5 [Use Transactions for Atomic Multi-Command Operations](#15-use-transactions-for-atomic-multi-command-operations) +2. [Memory & Expiration](#2-memory--expiration) — **HIGH** + - 2.1 [Configure Memory Limits and Eviction Policies](#21-configure-memory-limits-and-eviction-policies) + - 2.2 [Set TTL on Cache Keys](#22-set-ttl-on-cache-keys) +3. [Connection & Performance](#3-connection--performance) — **HIGH** + - 3.1 [Avoid Slow Commands in Production](#31-avoid-slow-commands-in-production) + - 3.2 [Configure Connection Timeouts](#32-configure-connection-timeouts) + - 3.3 [Use Client-Side Caching for Frequently Read Data](#33-use-client-side-caching-for-frequently-read-data) + - 3.4 [Use Connection Pooling or Multiplexing](#34-use-connection-pooling-or-multiplexing) + - 3.5 [Use Pipelining for Bulk Operations](#35-use-pipelining-for-bulk-operations) +4. [JSON Documents](#4-json-documents) — **MEDIUM** + - 4.1 [Choose JSON vs Hash vs String Appropriately](#41-choose-json-vs-hash-vs-string-appropriately) + - 4.2 [Use JSON Paths for Partial Updates](#42-use-json-paths-for-partial-updates) +5. [Redis Query Engine](#5-redis-query-engine) — **HIGH** + - 5.1 [Choose the Correct Field Type](#51-choose-the-correct-field-type) + - 5.2 [Index Only Fields You Query](#52-index-only-fields-you-query) + - 5.3 [Manage Indexes for Zero-Downtime Updates](#53-manage-indexes-for-zero-downtime-updates) + - 5.4 [Use DIALECT 2 for Query Syntax](#54-use-dialect-2-for-query-syntax) + - 5.5 [Use SKIPINITIALSCAN for New Data Only Indexes](#55-use-skipinitialscan-for-new-data-only-indexes) + - 5.6 [Write Efficient Queries](#56-write-efficient-queries) +6. [Vector Search & RedisVL](#6-vector-search--redisvl) — **HIGH** + - 6.1 [Choose HNSW vs FLAT Based on Requirements](#61-choose-hnsw-vs-flat-based-on-requirements) + - 6.2 [Configure Vector Indexes Properly](#62-configure-vector-indexes-properly) + - 6.3 [Implement RAG Pattern Correctly](#63-implement-rag-pattern-correctly) + - 6.4 [Use Hybrid Search for Better Results](#64-use-hybrid-search-for-better-results) +7. [Semantic Caching](#7-semantic-caching) — **MEDIUM** + - 7.1 [Configure Semantic Cache Properly](#71-configure-semantic-cache-properly) + - 7.2 [Use LangCache for LLM Response Caching](#72-use-langcache-for-llm-response-caching) +8. [Streams & Pub/Sub](#8-streams--pub/sub) — **MEDIUM** + - 8.1 [Choose Streams vs Pub/Sub Appropriately](#81-choose-streams-vs-pubsub-appropriately) +9. [Clustering & Replication](#9-clustering--replication) — **MEDIUM** + - 9.1 [Use Hash Tags for Multi-Key Operations](#91-use-hash-tags-for-multi-key-operations) + - 9.2 [Use Read Replicas for Read-Heavy Workloads](#92-use-read-replicas-for-read-heavy-workloads) +10. [Security](#10-security) — **HIGH** + - 10.1 [Always Use Authentication in Production](#101-always-use-authentication-in-production) + - 10.2 [Secure Network Access](#102-secure-network-access) + - 10.3 [Use ACLs for Fine-Grained Access Control](#103-use-acls-for-fine-grained-access-control) +11. [Observability](#11-observability) — **MEDIUM** + - 11.1 [Monitor Key Redis Metrics](#111-monitor-key-redis-metrics) + - 11.2 [Use Observability Commands for Debugging](#112-use-observability-commands-for-debugging) + +--- + +## 1. Data Structures & Keys + +**Impact: HIGH** + +Choosing the right Redis data type and key naming conventions. Foundation for efficient Redis usage. + +### 1.1 Choose the Right Data Structure + +**Impact: HIGH (Optimal memory usage and operation performance)** + +Selecting the appropriate Redis data type for your use case is fundamental to performance and memory efficiency. + +| Use Case | Recommended Type | Why | +|----------|------------------|-----| +| Simple values, counters | String | Fast, atomic operations | +| Object with fields | Hash | Memory efficient, partial updates, field-level expiration | +| Queue, recent items | List | O(1) push/pop at ends | +| Unique items, membership | Set | O(1) add/remove/check | +| Rankings, ranges | Sorted Set | Score-based ordering | +| Nested/hierarchical data | JSON | Path queries, nested structures, geospatial indexing with RQE | +| Event logs, messaging | Stream | Persistent, consumer groups | +| Similarity search | Redis Query Engine / RedisVL or Vector Set | RedisVL is best for document retrieval with filters and full-text search; vector sets are simpler native similarity search | + +**Note: Vector sets are a Redis 8+ capability introduced in Redis 8.0 and documented there as beta. Prefer Redis Query Engine / RedisVL when you need document-oriented retrieval, structured filters, or full-text + vector workflows.** + +**Incorrect: Using strings for everything.** + +**Python** (redis-py):** + +```python +# Storing object as JSON string loses atomic field updates +redis.set("user:1001", json.dumps({"name": "Alice", "email": "alice@example.com"})) + +# To update email, must fetch, parse, modify, and rewrite entire object +user = json.loads(redis.get("user:1001")) +user["email"] = "new@example.com" +redis.set("user:1001", json.dumps(user)) +``` + +**Java** (Jedis):** + +```java +// Bad: Storing as delimited string requires manual parsing +jedis.set("bicycle", "Deimos;Ergonom;Enduro bikes;4972"); +String bike = jedis.get("bicycle"); +String[] fields = bike.split(";"); +String model = fields[0]; // Fragile and error-prone +``` + +**Correct: Use Hash for objects with fields.** + +**Python** (redis-py):** + +```python +# Hash allows atomic field updates +redis.hset("user:1001", mapping={"name": "Alice", "email": "alice@example.com"}) + +# Update single field without touching others +redis.hset("user:1001", "email", "new@example.com") +``` + +**Java** (Jedis):** + +```java +import java.util.Map; +import java.util.HashMap; + +// Good: Hash models properties naturally +Map hashFields = new HashMap<>(); +hashFields.put("model", "Deimos"); +hashFields.put("brand", "Ergonom"); +hashFields.put("type", "Enduro bikes"); +hashFields.put("price", "4972"); + +jedis.hset("bicycle", hashFields); + +// Read individual field +String model = jedis.hget("bicycle", "model"); +``` + +Reference: [https://redis.io/docs/latest/develop/data-types/compare-data-types/](https://redis.io/docs/latest/develop/data-types/compare-data-types/) + +### 1.2 Use Consistent Key Naming Conventions + +**Impact: MEDIUM (Improved maintainability and debugging)** + +Well-structured key names improve code maintainability, debugging, and enable efficient key scanning. + +**Correct: Use colons as separators with a consistent hierarchy.** + +```python +# Pattern: service:entity:id:attribute +user:1001:profile +user:1001:settings +order:2024:items +cache:api:users:list +session:abc123 +``` + +**Python** (redis-py):** + +```python +# Good: Short, meaningful key +redis.set("product:8361", cached_html) +page = redis.get("product:8361") +``` + +**Java** (Jedis):** + +```java +// Good: Short, meaningful key derived from URL +jedis.set("product:8361", ""); +String page = jedis.get("product:8361"); +``` + +**Incorrect: Inconsistent naming, spaces, or very long keys.** + +```python +# These cause confusion and waste memory +User_1001_Profile +my key with spaces +com.mycompany.myapp.production.users.profile.data.1001 +``` + +**Java** (Jedis):** + +```java +// Bad: Using full URL as key wastes memory and slows comparisons +jedis.set("http://www.verylongurlkey.com/store/products/product.html?id=8361", + ""); +``` + +**Key naming tips:** + +- Keep keys short but readable—they consume memory + +- Consider key prefixes for multi-tenant applications + +- Extract short identifiers from URLs or long strings rather than using the whole thing + +- For large binary values, consider using a hash digest as the key instead of the value itself + +- Use consistent separators (colons are conventional) + +Reference: [https://redis.io/docs/latest/develop/use/keyspace/](https://redis.io/docs/latest/develop/use/keyspace/) + +### 1.3 Use Hash Field Expiration for Per-Field TTL + +**Impact: MEDIUM (Fine-grained expiration without managing timers)** + +Use hash field expiration (Redis 7.4+) to delete individual fields automatically from a hash after a specific period of time. This is useful for caching scenarios where different fields have different lifetimes, and is easier than managing expiration from your own code. + +**Correct: Use HEXPIRE to set per-field TTL on hash fields.** + +**Python** (redis-py):** + +```python +import redis + +client = redis.Redis(host='localhost', port=6379) + +# Set hash fields +client.hset("sensor:sensor1", mapping={ + "air_quality": "256", + "battery_level": "89" +}) + +# Set 60-second TTL on specific fields (Redis 7.4+) +client.hexpire("sensor:sensor1", 60, "air_quality", "battery_level") +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.UnifiedJedis; +import java.util.Map; +import java.util.HashMap; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + Map hashFields = new HashMap<>(); + hashFields.put("air_quality", "256"); + hashFields.put("battery_level", "89"); + + jedis.hset("sensor:sensor1", hashFields); + + // Set 60-second TTL on specific fields (Redis 7.4+) + jedis.hexpire("sensor:sensor1", 60, "air_quality", "battery_level"); +} +``` + +**When to use:** + +- Sensor data or metrics that become stale after a period + +- Session attributes where different fields have different lifetimes + +- Cached values within a hash that should auto-expire independently + +- Temporary flags or tokens stored alongside persistent data + +**When NOT needed:** + +- Persistent user profiles or configuration + +- Data where the entire hash should expire together (use `EXPIRE` on the key instead) + +- Fields managed by application logic with explicit deletion + +Reference: [https://redis.io/docs/latest/commands/hexpire/](https://redis.io/docs/latest/commands/hexpire/) + +### 1.4 Use INCR for Atomic Counters + +**Impact: MEDIUM (Atomic increment avoids race conditions)** + +If a string represents an integer value, use the `INCR` command to increment the number directly. The increment is atomic and always returns the new value. Use `INCRBY` to increment by any integer (positive or negative). This is more efficient and race-condition-free than reading, incrementing in code, and writing back. + +**Correct: Use INCR/INCRBY for atomic counter updates.** + +**Python** (redis-py):** + +```python +import redis + +client = redis.Redis(host='localhost', port=6379) + +# Initialize counter +client.set("counter", "0") + +# Atomic increment - returns new value +new_value = client.incr("counter") # Returns 1 + +# Increment by specific amount +new_value = client.incrby("counter", 10) # Returns 11 +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.UnifiedJedis; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + jedis.set("counter", "0"); + + // Atomic increment - returns new value + long newValue = jedis.incr("counter"); // Returns 1 + + // Increment by specific amount + newValue = jedis.incrBy("counter", 10); // Returns 11 +} +``` + +**Incorrect: Read-modify-write pattern creates race conditions.** + +**Python** (redis-py):** + +```python +import redis + +client = redis.Redis(host='localhost', port=6379) + +client.set("counter", "0") + +# BAD: Race condition - another client could modify between GET and SET +curr_value = int(client.get("counter")) +client.set("counter", str(curr_value + 1)) # Not atomic! +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.UnifiedJedis; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + jedis.set("counter", "0"); + + // BAD: Race condition between GET and SET + long currValue = Long.parseLong(jedis.get("counter")); + jedis.set("counter", Long.toString(currValue + 1)); // Not atomic! +} +``` + +Reference: [https://redis.io/docs/latest/commands/incr/](https://redis.io/docs/latest/commands/incr/) + +### 1.5 Use Transactions for Atomic Multi-Command Operations + +**Impact: MEDIUM (Prevents race conditions and data inconsistency)** + +Use the `MULTI`/`EXEC` commands to create a transaction when you need to execute multiple commands atomically. No other client requests will be processed while the transaction is executing, preventing other clients from modifying the keys used in the transaction and avoiding inconsistent data. + +**Correct: Use transactions when multiple related keys must be updated together.** + +**Python** (redis-py):** + +```python +import redis + +client = redis.Redis(host='localhost', port=6379) + +# Transaction ensures all commands execute atomically +pipe = client.pipeline(transaction=True) +pipe.set("person:1:name", "Alex") +pipe.set("person:1:rank", "Captain") +pipe.set("person:1:serial", "AB1234") +pipe.execute() # All commands execute as one atomic unit +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.UnifiedJedis; +import redis.clients.jedis.Transaction; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + Transaction tran = (Transaction) jedis.multi(); + + tran.set("person:1:name", "Alex"); + tran.set("person:1:rank", "Captain"); + tran.set("person:1:serial", "AB1234"); + + tran.exec(); // All commands execute atomically +} +``` + +**Incorrect: Executing related commands individually when atomicity is required.** + +**Python** (redis-py):** + +```python +import redis + +client = redis.Redis(host='localhost', port=6379) + +# BAD when atomicity matters - another client could read partial state +client.set("person:1:name", "Alex") +# Another client could read here and see incomplete data +client.set("person:1:rank", "Captain") +client.set("person:1:serial", "AB1234") +``` + +**When to use transactions:** + +- Multiple keys must be updated as a single atomic unit + +- Other clients reading partial state would cause bugs + +- Implementing patterns like "transfer balance between accounts" + +**When transactions are NOT needed:** + +- Independent operations that don't need to be atomic + +- Single-command operations (already atomic) + +- When using pipelining purely for performance (use `pipeline(transaction=False)`) + +**Note: Transactions add overhead. Only use them when atomicity is actually required.** + +Reference: [https://redis.io/docs/latest/develop/interact/transactions/](https://redis.io/docs/latest/develop/interact/transactions/) + +--- + +## 2. Memory & Expiration + +**Impact: HIGH** + +Memory limits, eviction policies, TTL strategies, and memory optimization techniques. + +### 2.1 Configure Memory Limits and Eviction Policies + +**Impact: HIGH (Prevents out-of-memory crashes and unpredictable behavior)** + +Always configure `maxmemory` and an eviction policy to prevent Redis from consuming all available memory. + +**Correct: Set explicit memory limits.** + +```python +maxmemory 2gb +maxmemory-policy allkeys-lru +``` + +| Policy | Use Case | +|--------|----------| +| `volatile-lru` | Evict keys with TTL, least recently used first | +| `allkeys-lru` | Evict any key, least recently used first | +| `volatile-ttl` | Evict keys closest to expiration | +| `noeviction` | Return errors when memory is full (use for critical data) | + +**Incorrect: Running Redis without memory limits.** + +```python +# No maxmemory set - Redis will use all available RAM +# Can cause OOM killer to terminate Redis or other processes +``` + +**Memory optimization tips:** + +- Use Hashes for small objects (more memory-efficient than separate keys) + +- Use `OBJECT ENCODING key` to check how Redis stores your data + +- Use `MEMORY USAGE key` to check individual key memory consumption + +- Enable compression in your client for large values + +Reference: [https://redis.io/docs/latest/operate/oss_and_stack/management/optimization/memory-optimization/](https://redis.io/docs/latest/operate/oss_and_stack/management/optimization/memory-optimization/) + +### 2.2 Set TTL on Cache Keys + +**Impact: HIGH (Prevents unbounded memory growth)** + +Always set expiration times on cache keys to prevent unbounded memory growth. + +**Correct: Set TTL at write time.** + +**Python** (redis-py):** + +```python +# Good: TTL set atomically with the value +redis.setex("cache:user:1001", 3600, user_json) + +# Good: For hashes, set TTL after +redis.hset("session:abc", mapping=session_data) +redis.expire("session:abc", 1800) +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.params.SetParams; + +// Good: TTL set atomically with SetParams +jedis.set("cachedItem:1", "fe8c357903ac9", new SetParams().ex(120)); +``` + +**Incorrect: Forgetting TTL on cache keys.** + +**Python** (redis-py):** + +```python +# Risk: This key may live forever +redis.set("cache:user:1001", user_json) +``` + +**Java** (Jedis):** + +```java +// Risk: This key may live forever +jedis.set("cachedItem:1", "fe8c357903ac9"); +``` + +**TTL strategies:** + +- Cache data: 1-24 hours depending on freshness requirements + +- Sessions: 30 minutes to 24 hours + +- Rate limiting: Seconds to minutes + +- Temporary locks: Seconds with automatic release + +Reference: [https://redis.io/commands/expire/](https://redis.io/commands/expire/) + +--- + +## 3. Connection & Performance + +**Impact: HIGH** + +Connection pooling, pipelining, timeouts, and avoiding blocking commands. + +### 3.1 Avoid Slow Commands in Production + +**Impact: HIGH (Prevents Redis from becoming unresponsive)** + +Some Redis commands are slow because they scan large datasets. Use incremental alternatives to avoid blocking the server. + +| Avoid | Use Instead | +|-------|-------------| +| `KEYS *` | `SCAN` with cursor | +| `SMEMBERS` on large sets | `SSCAN` | +| `HGETALL` on large hashes | `HSCAN` | +| `LRANGE 0 -1` on large lists | Paginate with `LRANGE 0 100` | + +**Correct: Use SCAN for iteration.** + +**Python** (redis-py):** + +```python +# Good: Non-blocking iteration +cursor = 0 +while True: + cursor, keys = redis.scan(cursor, match="user:*", count=100) + for key in keys: + process(key) + if cursor == 0: + break +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.ScanIteration; +import redis.clients.jedis.UnifiedJedis; +import java.util.List; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + // ScanIteration manages the cursor automatically + ScanIteration scan = jedis.scanIteration(10, "user:*", "hash"); + + while (!scan.isIterationCompleted()) { + List result = scan.nextBatch().getResult(); + for (String key : result) { + process(key); + } + } +} +``` + +**Incorrect: Using KEYS in production.** + +**Python** (redis-py):** + +```python +# Bad: Scans all keys, slow on large datasets +keys = redis.keys("user:*") +``` + +**Java** (Jedis):** + +```java +// Bad: Scans all keys, blocks the server +Set result = jedis.keys("*"); +``` + +**Note: Truly blocking commands (like `BLPOP`, `BRPOP`, `BLMOVE`) that wait indefinitely for data are appropriate for some use cases like job queues, but should be used with timeouts.** + +```python +# Blocking pop with timeout - appropriate for queue consumers +result = redis.blpop("task_queue", timeout=5) +``` + +Reference: [https://redis.io/docs/latest/commands/scan/](https://redis.io/docs/latest/commands/scan/) + +### 3.2 Configure Connection Timeouts + +**Impact: MEDIUM (Improves connection resilience and failure recovery)** + +Configure appropriate timeout values to improve your application's connection resilience. While most Redis clients set default timeouts, choosing well-tuned values based on your application's usage patterns leads to better failure recovery. + +**Correct: Set timeouts based on your application needs.** + +```python +r = redis.Redis( + host='localhost', + socket_timeout=5.0, # Read/write timeout - tune based on expected operation time + socket_connect_timeout=2.0, # Connection timeout - shorter for fast failure detection + retry_on_timeout=True # Automatic retry on timeout +) +``` + +**Incorrect: Relying solely on defaults without considering your use case.** + +```python +# Not ideal: Default timeouts may not match your application's needs +r = redis.Redis(host='localhost') + +# For example, if your app needs fast failure detection, +# the default timeouts might be too generous +``` + +**Considerations:** + +- Set `socket_connect_timeout` shorter than `socket_timeout` for quick connection failure detection + +- For latency-sensitive apps, use tighter timeouts with retry logic + +- For batch operations, allow longer timeouts to complete large operations + +- Consider using health checks alongside timeouts for robust failure handling + +Reference: [https://redis.io/docs/latest/develop/clients/](https://redis.io/docs/latest/develop/clients/) + +### 3.3 Use Client-Side Caching for Frequently Read Data + +**Impact: HIGH (Reduces network round-trips for repeated reads)** + +Use a connection with client-side caching enabled for any data that will be read frequently but written only occasionally. Client-side caching avoids contacting the server for repeated access to data that has recently been read, reducing network traffic and improving performance. + +**Correct: Enable client-side caching with RESP3 protocol for frequently accessed data.** + +**Python** (redis-py):** + +```python +import redis + +# Enable client-side caching with RESP3 +client = redis.Redis( + host='localhost', + port=6379, + protocol=3, # RESP3 required for client-side caching + cache_config=redis.CacheConfig(max_size=1000) +) + +# Cached reads avoid server round-trips +value = client.get("frequently:read:key") +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.UnifiedJedis; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.CacheConfig; + +HostAndPort endpoint = new HostAndPort("localhost", 6379); + +DefaultJedisClientConfig config = DefaultJedisClientConfig + .builder() + .password("secretPassword") + .protocol(RedisProtocol.RESP3) + .build(); + +CacheConfig cacheConfig = CacheConfig.builder().maxSize(1000).build(); + +UnifiedJedis client = new UnifiedJedis(endpoint, config, cacheConfig); +``` + +**When to use:** + +- Configuration data read frequently, updated rarely + +- User session data accessed on every request + +- Feature flags or settings checked repeatedly + +- Any read-heavy workload with low write frequency + +**When NOT needed:** + +- Data that changes frequently (cache invalidation overhead outweighs benefits) + +- Write-heavy workloads + +- Simple applications where network latency is not a bottleneck + +- When you need guaranteed real-time consistency + +**Trade-offs:** + +- Adds memory overhead on the client + +- Requires RESP3 protocol + +- Cache invalidation adds complexity for frequently changing data + +Reference: [https://redis.io/docs/latest/develop/clients/client-side-caching/](https://redis.io/docs/latest/develop/clients/client-side-caching/) + +### 3.4 Use Connection Pooling or Multiplexing + +**Impact: HIGH (Reduces connection overhead by 10x or more)** + +Reuse connections via a pool or multiplexing instead of creating new connections per request. + +**Correct: Use a connection pool.** + +**Python** (redis-py):** + +```python +import redis + +# Good: Connection pool - reuses existing connections +pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50) +r = redis.Redis(connection_pool=pool) +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.JedisPooled; + +// JedisPooled manages a connection pool internally +try (JedisPooled jedis = new JedisPooled("redis://localhost:6379")) { + jedis.set("testKey", "testValue"); +} +``` + +**Correct: Use multiplexing (Lettuce, NRedisStack).** + +```java +// Lettuce uses multiplexing by default - single connection handles all traffic +RedisClient client = RedisClient.create("redis://localhost:6379"); +StatefulRedisConnection connection = client.connect(); + +// All commands share the single connection efficiently +connection.sync().set("key", "value"); +``` + +**Incorrect: Creating new connections per request.** + +**Python** (redis-py):** + +```python +# Bad: New connection every time +def get_user(user_id): + r = redis.Redis(host='localhost', port=6379) # Don't do this + return r.get(f"user:{user_id}") +``` + +**Java** (Jedis):** + +```java +// Bad: Creating new client per request +public String getUser(String userId) { + try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + return jedis.get("user:" + userId); // Don't do this + } +} +``` + +**Pooling vs Multiplexing:** + +- **Pooling**: Multiple connections shared across requests (redis-py, Jedis, go-redis) + +- **Multiplexing**: Single connection handles all traffic (NRedisStack, Lettuce) + +- Multiplexing cannot support blocking commands (BLPOP, etc.) as they would stall all callers + +Reference: [https://redis.io/docs/latest/develop/clients/pools-and-muxing/](https://redis.io/docs/latest/develop/clients/pools-and-muxing/) + +### 3.5 Use Pipelining for Bulk Operations + +**Impact: HIGH (Reduces round trips, 5-10x faster for batch operations)** + +Batch multiple commands into a single round trip to reduce network latency. + +**Correct: Use pipeline for multiple commands.** + +**Python** (redis-py):** + +```python +# Good: Single round trip for multiple commands +pipe = redis.pipeline() +for user_id in user_ids: + pipe.get(f"user:{user_id}") +results = pipe.execute() +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.Pipeline; + +// Good: Buffer commands and send as single batch +Pipeline pipe = (Pipeline) jedis.pipelined(); + +pipe.set("person:1:name", "Alex"); +pipe.set("person:1:rank", "Captain"); +pipe.set("person:1:serial", "AB1234"); + +pipe.sync(); +``` + +**Incorrect: Sequential commands in a loop.** + +**Python** (redis-py):** + +```python +# Bad: N round trips +results = [] +for user_id in user_ids: + results.append(redis.get(f"user:{user_id}")) +``` + +**Java** (Jedis):** + +```java +// Bad: 3 separate round trips +jedis.set("person:1:name", "Alex"); +jedis.set("person:1:rank", "Captain"); +jedis.set("person:1:serial", "AB1234"); +``` + +Reference: [https://redis.io/docs/latest/develop/use/pipelining/](https://redis.io/docs/latest/develop/use/pipelining/) + +--- + +## 4. JSON Documents + +**Impact: MEDIUM** + +Using Redis JSON for nested structures, partial updates, and integration with RQE. + +### 4.1 Choose JSON vs Hash vs String Appropriately + +**Impact: MEDIUM (Optimal data model for your use case)** + +Redis offers three ways to store structured data: JSON, Hash, and serialized strings. Each has distinct trade-offs around atomic partial operations and indexability. + +| Feature | JSON | Hash | String (serialized JSON) | +|---------|------|------|--------------------------| +| **Structure** | Nested objects and arrays | Flat key-value pairs | Any structure | +| **Atomic partial reads** | Yes (`$.field`) | Yes (`HGET`) | No (must fetch entire value) | +| **Atomic partial writes** | Yes (`JSON.SET $.field`) | Yes (`HSET`) | No (must rewrite entire value) | +| **RQE indexing** | Yes | Yes | No | +| **Geospatial indexing** | Yes | Yes | No | +| **Memory efficiency** | Higher overhead | More efficient | Most compact | +| **Field-level expiration** | No | Yes (HEXPIRE) | No | + +**When to use each:** + +- **JSON**: Nested structures with atomic partial updates and indexing needs + +- **Hash**: Flat objects with atomic field access, field-level expiration, or memory efficiency + +- **String**: Simple caching where you always read/write the entire object and don't need indexing + +**Correct: Use JSON for nested structures with atomic partial updates.** + +**Python** (redis-py):** + +```python +# JSON supports nested structures and atomic deep updates +redis.json().set("user:1001", "$", { + "name": "Alice", + "preferences": {"theme": "dark", "notifications": True} +}) + +# Atomic update of nested field - no read-modify-write needed +redis.json().set("user:1001", "$.preferences.theme", "light") +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.UnifiedJedis; +import redis.clients.jedis.json.Path2; +import org.json.JSONObject; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + JSONObject user = new JSONObject(); + user.put("name", "Alice"); + user.put("preferences", new JSONObject().put("theme", "dark")); + + jedis.jsonSet("user:1001", new Path2("$"), user); + + // Atomic update of nested field + jedis.jsonSet("user:1001", new Path2("$.preferences.theme"), "light"); +} +``` + +**Correct: Use Hash for flat objects with atomic field access.** + +**Python** (redis-py):** + +```python +# Hash is efficient for flat data with atomic field operations +redis.hset("session:abc", mapping={ + "user_id": "1001", + "created_at": "2024-01-01", + "ip": "192.168.1.1" +}) + +# Atomic field read and update +ip = redis.hget("session:abc", "ip") +redis.hset("session:abc", "ip", "10.0.0.1") +``` + +**Correct: Use String for simple caching without partial updates.** + +**Python** (redis-py):** + +```python +import json + +# String is fine when you always read/write the entire object +# and don't need indexing or partial updates +config = {"feature_flags": {"dark_mode": True}, "version": "1.0"} +redis.set("config:app", json.dumps(config), ex=3600) + +# Must fetch and parse entire object +config = json.loads(redis.get("config:app")) +``` + +**Incorrect: Using String when you need atomic partial updates.** + +**Python** (redis-py):** + +```python +import json + +# BAD: Must fetch, parse, modify, serialize, and rewrite entire object +data = json.loads(redis.get("user:1001")) +data["preferences"]["theme"] = "light" # Not atomic! +redis.set("user:1001", json.dumps(data)) +# Another client could have modified the object between GET and SET +``` + +Reference: [https://redis.io/docs/latest/develop/data-types/compare-data-types/#documents](https://redis.io/docs/latest/develop/data-types/compare-data-types/#documents) + +### 4.2 Use JSON Paths for Partial Updates + +**Impact: MEDIUM (Avoids fetching and rewriting entire documents)** + +Use JSON path syntax to update specific fields without fetching the entire document. + +**Correct: Use JSON paths for targeted updates.** + +```python +# Store JSON document +redis.json().set("user:1001", "$", { + "name": "Alice", + "email": "alice@example.com", + "preferences": {"theme": "dark", "notifications": True} +}) + +# Update nested field without fetching entire document +redis.json().set("user:1001", "$.preferences.theme", "light") + +# Get specific field +theme = redis.json().get("user:1001", "$.preferences.theme") + +# Increment numeric field atomically +redis.json().numincrby("user:1001", "$.preferences.volume", 5) + +# Append to array +redis.json().arrappend("user:1001", "$.tags", "premium") +``` + +**Incorrect: Storing JSON as a string and parsing client-side.** + +```python +# Bad: Loses queryability and atomic updates +redis.set("user:1001", json.dumps(user_data)) + +# Must fetch, parse, modify, serialize, and rewrite +data = json.loads(redis.get("user:1001")) +data["preferences"]["theme"] = "light" +redis.set("user:1001", json.dumps(data)) +``` + +Reference: [https://redis.io/docs/latest/develop/data-types/json/path/](https://redis.io/docs/latest/develop/data-types/json/path/) + +--- + +## 5. Redis Query Engine + +**Impact: HIGH** + +FT.CREATE, FT.SEARCH, FT.AGGREGATE, index design, field types, and query optimization. + +### 5.1 Choose the Correct Field Type + +**Impact: HIGH (Use TAG instead of TEXT for filtering to improve query speed 10x)** + +Each field type has different capabilities and performance characteristics. + +| Field Type | Use When | Notes | +|------------|----------|-------| +| TEXT | Full-text search needed | Tokenized, stemmed | +| TAG | Exact match, filtering | Faster than TEXT for filtering | +| NUMERIC | Range queries, sorting | Use for prices, counts, timestamps | +| GEO | Point location queries | Lat/long coordinates (single points) | +| GEOSHAPE | Area/region queries | Polygons, circles, rectangles | +| VECTOR | Similarity search | HNSW or FLAT algorithm | + +**Correct: Use TAG for exact matching.** + +```python +# Good: TAG for exact category matching +FT.CREATE idx:products ON HASH PREFIX 1 product: + SCHEMA + category TAG SORTABLE + status TAG +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.search.*; + +Schema schema = new Schema() + .addTextField("name", 1) + .addTagField("categories"); // TAG for exact matching + +IndexDefinition def = new IndexDefinition(IndexDefinition.Type.HASH); + +jedis.ftCreate("idx", IndexOptions.defaultOptions().setDefinition(def), schema); + +// Query with TAG syntax +SearchResult result = jedis.ftSearch("idx", "@categories:{chef|runner}"); +``` + +**Incorrect: Using TEXT when you don't need full-text features.** + +```python +# Overkill: TEXT for category adds unnecessary tokenization +FT.CREATE idx:products ON HASH PREFIX 1 product: + SCHEMA + category TEXT + status TEXT +``` + +**Java** (Jedis):** + +```java +// Bad: TEXT for categories adds unnecessary overhead +Schema schema = new Schema() + .addTextField("name", 1) + .addTextField("categories", 1); // Overkill for exact matching +``` + +**Correct: Use GEO for points, GEOSHAPE for areas.** + +```python +# GEO for point locations (stores, users) +FT.CREATE idx:stores ON HASH PREFIX 1 store: + SCHEMA + location GEO + +# GEOSHAPE for areas (delivery zones, boundaries) +FT.CREATE idx:zones ON JSON PREFIX 1 zone: + SCHEMA + $.boundary AS boundary GEOSHAPE +``` + +Reference: [https://redis.io/docs/latest/develop/interact/search-and-query/indexing/geoindex/](https://redis.io/docs/latest/develop/interact/search-and-query/indexing/geoindex/) + +### 5.2 Index Only Fields You Query + +**Impact: HIGH (Reduces index size and improves write performance)** + +Create indexes with only the fields you need to search, filter, or sort on. + +**Correct: Index specific fields and use prefixes.** + +```python +FT.CREATE idx:products ON HASH PREFIX 1 product: + SCHEMA + name TEXT WEIGHT 2.0 + description TEXT + category TAG SORTABLE + price NUMERIC SORTABLE + location GEO +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.search.*; + +Schema schema = new Schema() + .addTextField("name", 1) + .addTagField("categories"); + +// Good: Specify prefix to index only matching keys +IndexDefinition def = new IndexDefinition(IndexDefinition.Type.HASH) + .setPrefixes("person:"); + +jedis.ftCreate("idx", IndexOptions.defaultOptions().setDefinition(def), schema); +``` + +**Incorrect: Over-indexing or indexing unused fields.** + +```python +# Bad: Indexing every field "just in case" +FT.CREATE idx:products ON HASH PREFIX 1 product: + SCHEMA + name TEXT + description TEXT + category TEXT + subcategory TEXT + brand TEXT + sku TEXT + price NUMERIC + cost NUMERIC + margin NUMERIC + ... +``` + +**Java** (Jedis):** + +```java +// Bad: No prefix means all hashes get indexed +IndexDefinition def = new IndexDefinition(IndexDefinition.Type.HASH); +// This will index every hash in the database! +``` + +**Tips:** + +- Start with the minimum required fields + +- Add fields as query patterns emerge + +- Use `FT.INFO` to monitor index size + +- Always specify a prefix to avoid indexing unrelated keys + +Reference: [https://redis.io/docs/latest/develop/interact/search-and-query/indexing/](https://redis.io/docs/latest/develop/interact/search-and-query/indexing/) + +### 5.3 Manage Indexes for Zero-Downtime Updates + +**Impact: MEDIUM (Use aliases for seamless index updates)** + +Use aliases to swap indexes without application changes. + +**Correct: Use aliases for production indexes.** + +```python +# Create versioned index +FT.CREATE idx:products_v2 ON HASH PREFIX 1 product: + SCHEMA + name TEXT + category TAG SORTABLE + price NUMERIC SORTABLE + +# Point alias to new index +FT.ALIASADD products idx:products_v2 + +# Application queries use alias +FT.SEARCH products "@category:{electronics}" + +# Later, swap to new version +FT.ALIASUPDATE products idx:products_v3 +``` + +**Useful management commands:** + +```python +# Check index info +FT.INFO idx:products + +# Drop and recreate (non-blocking) +FT.DROPINDEX idx:products +FT.CREATE idx:products ... + +# List all indexes +FT._LIST +``` + +Reference: [https://redis.io/docs/latest/develop/interact/search-and-query/administration/](https://redis.io/docs/latest/develop/interact/search-and-query/administration/) + +### 5.4 Use DIALECT 2 for Query Syntax + +**Impact: MEDIUM (Ensures consistent query behavior and access to modern features)** + +Use DIALECT 2 for consistent query behavior. Many Redis client libraries now default to DIALECT 2, and other dialects (1, 3, 4) are deprecated as of Redis 8. + +**Correct: Use DIALECT 2 explicitly or rely on modern client defaults.** + +```python +# In raw commands, specify DIALECT 2 +FT.SEARCH idx:products "@name:laptop" DIALECT 2 + +FT.AGGREGATE idx:products "@category:{electronics}" + GROUPBY 1 @category + REDUCE COUNT 0 AS count + DIALECT 2 +``` + +**Note: DIALECT 2 is required for vector search queries. Most modern client libraries (redis-py 6.0+, go-redis, Lettuce) now use DIALECT 2 by default.** + +**Why DIALECT 2:** + +- Consistent handling of special characters + +- Better NULL value handling + +- More predictable query parsing + +- Required for vector search + +Reference: [https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/dialects/](https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/dialects/) + +### 5.5 Use SKIPINITIALSCAN for New Data Only Indexes + +**Impact: MEDIUM (Faster index creation, avoids indexing existing data)** + +Enable the `SKIPINITIALSCAN` option when creating an index if you only want to include items that are added after the index is created. This makes index creation faster and avoids indexing existing data that you don't need to search. + +**Correct: Use SKIPINITIALSCAN when you only need to index new data.** + +**Python** (redis-py):** + +```python +import redis +from redis.commands.search.field import TextField, TagField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType + +client = redis.Redis(host='localhost', port=6379) + +# Create index that only indexes new documents +schema = ( + TextField("name"), + TagField("categories") +) + +definition = IndexDefinition( + prefix=["person:"], + index_type=IndexType.HASH +) + +# SKIPINITIALSCAN - only index documents added after creation +client.ft("idx").create_index( + schema, + definition=definition, + skip_initial_scan=True +) +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.UnifiedJedis; +import redis.clients.jedis.search.FTCreateParams; +import redis.clients.jedis.search.IndexDataType; +import redis.clients.jedis.search.schemafields.SchemaField; +import redis.clients.jedis.search.schemafields.TagField; +import redis.clients.jedis.search.schemafields.TextField; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + FTCreateParams params = new FTCreateParams() + .on(IndexDataType.HASH) + .skipInitialScan(); // Only index new documents + + jedis.ftCreate( + "idx", + params, + new SchemaField[]{ + new TextField("name"), + new TagField("categories") + } + ); +} +``` + +**When to use SKIPINITIALSCAN:** + +- Creating an index for a new feature where existing data is irrelevant + +- Setting up indexes in advance before data arrives + +- When existing data would be too large to scan during index creation + +- Event-driven architectures where you only care about new events + +**When NOT to use: default behavior is correct** + +- You need to search existing data immediately after index creation + +- Migrating to a new index schema and need all data indexed + +- Most typical use cases where historical data matters + +**Note: The default behavior (without SKIPINITIALSCAN) indexes all existing matching keys, which is usually what you want.** + +Reference: [https://redis.io/docs/latest/commands/ft.create/](https://redis.io/docs/latest/commands/ft.create/) + +### 5.6 Write Efficient Queries + +**Impact: HIGH (Proper filtering reduces query time by orders of magnitude)** + +Be specific and use filters to reduce the result set early. + +**Correct: Use specific filters and limit results.** + +```python +# Good: Specific query with filters +FT.SEARCH idx:products "@category:{electronics} @price:[100 500]" + LIMIT 0 20 + RETURN 3 name price category + +# Good: Use SORTBY and LIMIT +FT.SEARCH idx:products "@name:laptop" + SORTBY price ASC + LIMIT 0 10 +``` + +**Incorrect: Broad queries returning large result sets.** + +```python +# Bad: Wildcard prefix scans entire index +FT.SEARCH idx:products "*" LIMIT 0 10000 + +# Bad: Loading all fields from source document +FT.AGGREGATE idx:products "*" LOAD * +``` + +**Performance tips:** + +```python +FT.PROFILE idx:products SEARCH QUERY "@category:{electronics}" +``` + +- Add `SORTABLE` to fields used in `SORTBY` + +- Use `TAG SORTABLE UNF` for best performance on tag fields + +- Use `NOSTEM` if you don't need stemming + +- Profile queries with `FT.PROFILE` + +Reference: [https://redis.io/docs/latest/develop/interact/search-and-query/query/](https://redis.io/docs/latest/develop/interact/search-and-query/query/) + +--- + +## 6. Vector Search & RedisVL + +**Impact: HIGH** + +Vector indexes, HNSW vs FLAT, hybrid search, and RAG patterns with RedisVL. + +### 6.1 Choose HNSW vs FLAT Based on Requirements + +**Impact: HIGH (HNSW trades accuracy for speed, FLAT provides exact results)** + +Select the right algorithm based on your accuracy requirements and dataset size. + +| Algorithm | Speed | Accuracy | Memory | Best For | +|-----------|-------|----------|--------|----------| +| HNSW | Fast (approximate) | ~95%+ recall tunable | Higher | Large datasets (>10k vectors) | +| FLAT | Slower (exact) | 100% (exact) | Lower | Small datasets, accuracy-critical | + +**Correct: Use HNSW for large-scale production workloads.** + +```python +from redisvl.schema import IndexSchema + +# HNSW - fast approximate search, tunable accuracy +schema = IndexSchema.from_dict({ + "index": {"name": "idx:docs", "prefix": "doc:"}, + "fields": [ + {"name": "embedding", "type": "vector", "attrs": { + "dims": 1536, + "algorithm": "HNSW", + "distance_metric": "COSINE", + "datatype": "FLOAT32", + "m": 16, # Higher = more accurate, more memory + "ef_construction": 200 # Higher = better index quality, slower build + }} + ] +}) +``` + +**Correct: Use FLAT when exact results are required.** + +```python +# FLAT - exact brute-force search, guaranteed accuracy +schema = IndexSchema.from_dict({ + "index": {"name": "idx:small", "prefix": "small:"}, + "fields": [ + {"name": "embedding", "type": "vector", "attrs": { + "dims": 1536, + "algorithm": "FLAT", + "distance_metric": "COSINE" + }} + ] +}) +``` + +**Tuning HNSW accuracy vs speed:** + +- `M`: Connections per node (16-64). Higher = better recall, more memory + +- `EF_CONSTRUCTION`: Build-time parameter (100-500). Higher = better graph quality + +- `EF_RUNTIME`: Query-time parameter. Higher = better recall, slower queries + +Reference: [https://redis.io/docs/latest/develop/ai/search-and-query/vectors/](https://redis.io/docs/latest/develop/ai/search-and-query/vectors/) + +### 6.2 Configure Vector Indexes Properly + +**Impact: HIGH (Correct configuration is essential for vector search accuracy)** + +Set the correct dimensions, algorithm, and distance metric for your embeddings. Vector indexes can be created via CLI, Redis Insight, or any client library. + +**Correct: Create index via Redis CLI or Insight.** + +```python +FT.CREATE idx:docs ON HASH PREFIX 1 doc: + SCHEMA + content TEXT + embedding VECTOR HNSW 6 + TYPE FLOAT32 + DIM 1536 + DISTANCE_METRIC COSINE +``` + +**Correct: Create index via Python (redis-py).** + +```python +from redis import Redis +from redis.commands.search.field import TextField, VectorField +from redis.commands.search.index_definition import IndexDefinition + +r = Redis() + +# Define schema with vector field +schema = [ + TextField("content"), + VectorField( + "embedding", + algorithm="HNSW", + attributes={ + "TYPE": "FLOAT32", + "DIM": 1536, # Must match your embedding model + "DISTANCE_METRIC": "COSINE" + } + ) +] + +r.ft("idx:docs").create_index(schema, definition=IndexDefinition(prefix=["doc:"])) +``` + +**Correct: Create index via RedisVL.** + +```python +from redisvl.index import SearchIndex +from redisvl.schema import IndexSchema + +schema = IndexSchema.from_dict({ + "index": {"name": "idx:docs", "prefix": "doc:"}, + "fields": [ + {"name": "content", "type": "text"}, + {"name": "embedding", "type": "vector", "attrs": { + "dims": 1536, + "algorithm": "HNSW", + "datatype": "FLOAT32", + "distance_metric": "COSINE" + }} + ] +}) + +index = SearchIndex(schema) +index.create(overwrite=True) +``` + +**Incorrect: Mismatched dimensions or wrong distance metric.** + +```python +# Bad: Wrong dimensions for your model +{"dims": 768} # But your selected embedding model outputs a different size + +# Bad: Wrong metric for normalized embeddings +{"distance_metric": "L2"} # When embeddings are normalized for COSINE +``` + +Reference: [https://redis.io/docs/latest/develop/ai/search-and-query/vectors/](https://redis.io/docs/latest/develop/ai/search-and-query/vectors/) + +### 6.3 Implement RAG Pattern Correctly + +**Impact: HIGH (Proper RAG implementation improves LLM response quality)** + +Store documents with embeddings, retrieve relevant context, and pass to LLM. + +**Correct: Full RAG pipeline with RedisVL.** + +```python +from redisvl.index import SearchIndex +from redisvl.query import VectorQuery + +# 1. Store documents with embeddings +records = [] +for doc in documents: + records.append({ + "content": doc["content"], + "embedding": embed_model.encode(doc["content"]).tolist(), + "source": doc["source"] + }) + +index.load(records) + +# 2. Query with vector similarity +query_embedding = embed_model.encode(user_question) +results = index.query(VectorQuery( + vector=query_embedding, + vector_field_name="embedding", + return_fields=["content", "source"], + num_results=5 +)) + +# 3. Pass context to LLM +context = "\n".join([r["content"] for r in results]) +response = llm.generate(f"Context: {context}\n\nQuestion: {user_question}") +``` + +**Best practices:** + +- Match your distance metric to your embedding model; many modern text embeddings already work well with COSINE + +- Batch inserts using `index.load()` with lists + +- Set appropriate M and EF_CONSTRUCTION for HNSW based on dataset size + +- Use filters to reduce the search space before vector comparison + +- Consider chunking long documents for better retrieval + +Reference: [https://redis.io/docs/latest/develop/get-started/rag/](https://redis.io/docs/latest/develop/get-started/rag/) + +### 6.4 Use Hybrid Search for Better Results + +**Impact: MEDIUM (Combining vector + filters improves relevance and reduces search space)** + +Combine vector similarity with attribute filtering for more relevant results. In this rule, "hybrid" means filtered vector search. Redis and RedisVL also use "hybrid search" for text + vector fusion via `FT.HYBRID` / `HybridQuery`. + +**Correct: Apply filters to reduce search space.** + +```python +from redisvl.query import VectorQuery +from redisvl.query.filter import Num, Tag + +filters = (Tag("category") == "technology") & (Num("date") >= 2024) & (Num("date") <= 2025) + +query = VectorQuery( + vector=query_embedding, + vector_field_name="embedding", + return_fields=["content", "category", "date"], + num_results=10, + filter_expression=filters +) + +results = index.query(query) +``` + +**Incorrect: Searching entire vector space when filters apply.** + +```python +# Bad: No filter - searches all vectors then filters client-side +results = index.query(VectorQuery( + vector=query_embedding, + vector_field_name="embedding", + num_results=1000 +)) +# Client-side filtering - wasteful +filtered = [r for r in results if r["category"] == "technology"] +``` + +**Tips:** + +- Use TAG fields for category filters + +- Use NUMERIC fields for date/price ranges + +- Redis auto-selects the filtered vector execution strategy; tune `hybrid_policy` only when needed + +- For true text + vector fusion, use `HybridQuery` on Redis >= 8.4.0 with redis-py >= 7.1.0; use `AggregateHybridQuery` on earlier Redis versions + +Reference: [https://redis.io/docs/latest/develop/ai/search-and-query/vectors/](https://redis.io/docs/latest/develop/ai/search-and-query/vectors/) + +--- + +## 7. Semantic Caching + +**Impact: MEDIUM** + +LangCache for LLM response caching, distance thresholds, and cache strategies. + +### 7.1 Configure Semantic Cache Properly + +**Impact: MEDIUM (Correct threshold tuning balances hit rate vs accuracy)** + +> **Note:** LangCache is currently in preview on Redis Cloud. Features and behavior may change. + +Tune similarity threshold and cache separation for optimal LangCache results. + +**Correct: Tune similarity threshold for your use case.** + +```python +from langcache import LangCache + +lang_cache = LangCache( + server_url=f"https://{os.getenv('HOST')}", + cache_id=os.getenv("CACHE_ID"), + api_key=os.getenv("API_KEY") +) + +# Stricter matching - fewer false positives (0.95 = very similar) +result = lang_cache.search( + prompt="What is Redis?", + similarity_threshold=0.95 +) + +# Looser matching - higher hit rate (0.8 = somewhat similar) +result = lang_cache.search( + prompt="What is Redis?", + similarity_threshold=0.8 +) +``` + +**Correct: Use separate caches for different use cases.** + +```python +# Create different cache IDs in Redis Cloud for different LLM tasks +support_cache = LangCache( + server_url=server_url, + cache_id="support-cache-id", + api_key=api_key +) + +code_cache = LangCache( + server_url=server_url, + cache_id="code-cache-id", + api_key=api_key +) +``` + +**Incorrect: Using a single cache for all LLM tasks.** + +```python +# All tasks share one cache - responses may not be relevant +result = lang_cache.search(prompt="How do I reset my password?") +# Could return a code snippet if someone asked a similar coding question +``` + +**Best practices:** + +- Start with threshold 0.9, adjust based on your use case + +- Use custom attributes to filter results within a single cache + +- Monitor cache hit rates to evaluate effectiveness + +- Use separate cache IDs for fundamentally different LLM tasks + +Reference: [https://redis.io/docs/latest/develop/ai/langcache/](https://redis.io/docs/latest/develop/ai/langcache/) + +### 7.2 Use LangCache for LLM Response Caching + +**Impact: HIGH (Reduces LLM API costs by 50-90% for similar queries)** + +> **Note:** LangCache is currently in preview on Redis Cloud. Features and behavior may change. + +LangCache is a fully-managed semantic caching service on Redis Cloud that reduces LLM costs and latency. + +**How it works:** + +1. Your app sends a prompt to LangCache via `POST /v1/caches/{cacheId}/entries/search` + +2. LangCache generates an embedding and searches for similar cached responses + +3. If found (cache hit), returns the cached response instantly + +4. If not found (cache miss), your app calls the LLM and stores the response + +**Correct: Use the LangCache Python SDK.** + +```python +from langcache import LangCache +import os + +lang_cache = LangCache( + server_url=f"https://{os.getenv('HOST')}", + cache_id=os.getenv("CACHE_ID"), + api_key=os.getenv("API_KEY") +) + +# Search for cached response +result = lang_cache.search( + prompt="What is Redis?", + similarity_threshold=0.9 +) + +if result: + response = result[0]["response"] +else: + response = llm.generate("What is Redis?") + # Store for future queries + lang_cache.set( + prompt="What is Redis?", + response=response + ) +``` + +**LangCache REST API:** + +```bash +# Search cache +curl -X POST "https://$HOST/v1/caches/$CACHE_ID/entries/search" \ + -H "Authorization: Bearer $API_KEY" \ + -H "Content-Type: application/json" \ + -d '{"prompt": "What is Redis?"}' + +# Store a response +curl -X POST "https://$HOST/v1/caches/$CACHE_ID/entries" \ + -H "Authorization: Bearer $API_KEY" \ + -H "Content-Type: application/json" \ + -d '{"prompt": "What is Redis?", "response": "Redis is an in-memory database..."}' +``` + +**With custom attributes for filtering:** + +```python +# Store with attributes +lang_cache.set( + prompt="What is Redis?", + response="Redis is an in-memory database...", + attributes={"category": "database", "version": "v1"} +) + +# Search with attribute filter +result = lang_cache.search( + prompt="Tell me about Redis", + attributes={"category": "database"}, + similarity_threshold=0.9 +) +``` + +Reference: [https://redis.io/docs/latest/develop/ai/langcache/](https://redis.io/docs/latest/develop/ai/langcache/) + +--- + +## 8. Streams & Pub/Sub + +**Impact: MEDIUM** + +Choosing between Streams and Pub/Sub for messaging patterns. + +### 8.1 Choose Streams vs Pub/Sub Appropriately + +**Impact: MEDIUM (Wrong choice leads to lost messages or unnecessary complexity)** + +Redis supports two messaging approaches for different use cases. + +**Incorrect: Using Pub/Sub when messages must not be lost.** + +```python +# Pub/Sub - messages lost if no subscribers connected +r.publish("orders", json.dumps(order)) # Fire and forget! +``` + +**Correct: Use Streams when message durability matters.** + +```python +# Streams - messages persist and can be replayed +r.xadd("orders:stream", {"order": json.dumps(order)}) + +# Consumer group for reliable processing +r.xreadgroup("workers", "worker-1", {"orders:stream": ">"}, count=10) +r.xack("orders:stream", "workers", message_id) +``` + +| Requirement | Use | +|-------------|-----| +| Real-time notifications, OK to miss messages | Pub/Sub | +| Messages must not be lost | Streams | +| Need to replay/reprocess messages | Streams | +| Multiple workers processing same queue | Streams (consumer groups) | +| Simple broadcast to connected clients | Pub/Sub | +| Event sourcing or audit trail | Streams | + +Reference: [https://redis.io/docs/latest/develop/data-types/streams/](https://redis.io/docs/latest/develop/data-types/streams/) + +--- + +## 9. Clustering & Replication + +**Impact: MEDIUM** + +Hash tags for key colocation, read replicas, and cluster-aware patterns. + +### 9.1 Use Hash Tags for Multi-Key Operations + +**Impact: HIGH (Enables multi-key operations in Redis Cluster)** + +In Redis Cluster, keys are distributed across slots based on their hash. Use hash tags to ensure keys that must be used together in [multi-key operations](https://redis.io/docs/latest/operate/rs/databases/durability-ha/clustering/#multikey-operations) are on the same slot. + +**Correct: Use hash tags for keys used in multi-key operations.** + +**Python** (redis-py):** + +```python +# These keys go to the same slot because {user:1001} is the hash tag +redis.set("{user:1001}:profile", "...") +redis.set("{user:1001}:settings", "...") +redis.set("{user:1001}:cart", "...") + +# Now you can use transactions and pipelines +pipe = redis.pipeline() +pipe.get("{user:1001}:profile") +pipe.get("{user:1001}:settings") +pipe.execute() + +# Multi-key commands also work +redis.lmove("{user:1001}:pending", "{user:1001}:processed", "LEFT", "RIGHT") +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.UnifiedJedis; +import java.util.Set; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + // Hash tags ensure keys go to the same slot + jedis.sadd("{bikes:racing}:france", "bike:1", "bike:2", "bike:3"); + jedis.sadd("{bikes:racing}:usa", "bike:1", "bike:4"); + + // Multi-key operation works because of matching hash tags + Set result = jedis.sdiff("{bikes:racing}:france", "{bikes:racing}:usa"); +} +``` + +**Incorrect: Keys without hash tags that need multi-key operations.** + +**Python** (redis-py):** + +```python +# Bad: These may be on different slots +redis.set("user:1001:profile", "...") # No hash tag +redis.set("user:1001:settings", "...") + +# This will fail in cluster mode +pipe = redis.pipeline() +pipe.get("user:1001:profile") +pipe.get("user:1001:settings") +pipe.execute() # CROSSSLOT error +``` + +**Java** (Jedis):** + +```java +// Bad: No hash tags - keys may be on different slots +jedis.sadd("bikes:racing:france", "bike:1", "bike:2", "bike:3"); +jedis.sadd("bikes:racing:usa", "bike:1", "bike:4"); + +// This will fail in cluster mode with CROSSSLOT error +Set result = jedis.sdiff("bikes:racing:france", "bikes:racing:usa"); +``` + +**Hash tag rules:** + +- Only the part between `{` and `}` is hashed for slot assignment + +- Use meaningful identifiers like `{user:1001}` not just `{1001}` to avoid unrelated keys (e.g., `purchase:{1001}`, `employee:{1001}`) saturating the same slot + +- Use hash tags only where multi-key operations are needed, not as a general habit + +Reference: [https://redis.io/docs/latest/operate/oss_and_stack/reference/cluster-spec/#hash-tags](https://redis.io/docs/latest/operate/oss_and_stack/reference/cluster-spec/#hash-tags) + +### 9.2 Use Read Replicas for Read-Heavy Workloads + +**Impact: MEDIUM (Scales read throughput without adding primary nodes)** + +For read-heavy workloads, distribute reads across replicas to reduce load on primaries. + +**Correct: Configure replica reads in Redis Cluster.** + +```python +from redis.cluster import RedisCluster + +rc = RedisCluster( + host='localhost', + port=6379, + read_from_replicas=True # Distribute reads to replicas +) + +# Writes go to primary +rc.set("key", "value") + +# Reads can be served by replicas (eventually consistent) +value = rc.get("key") +``` + +**Correct: Use replica reads in standalone replication setup.** + +```python +from redis import Redis + +# Connect to primary for writes +primary = Redis(host='primary-host', port=6379) + +# Connect to replica for reads +replica = Redis(host='replica-host', port=6379) + +# Write to primary +primary.set("key", "value") + +# Read from replica (eventually consistent) +value = replica.get("key") +``` + +**Considerations:** + +- Replica reads are eventually consistent + +- Don't read from replicas for data that was just written + +- Use for read-heavy, slightly-stale-OK workloads (caches, analytics, dashboards) + +Reference: [https://redis.io/docs/latest/operate/oss_and_stack/management/replication/](https://redis.io/docs/latest/operate/oss_and_stack/management/replication/) + +--- + +## 10. Security + +**Impact: HIGH** + +Authentication, ACLs, TLS, and network security. + +### 10.1 Always Use Authentication in Production + +**Impact: HIGH (Prevents unauthorized access to your data)** + +Never run Redis without authentication in production environments. + +**Correct: Use password and TLS.** + +**Python** (redis-py):** + +```python +r = redis.Redis( + host='localhost', + port=6379, + password='your-strong-password', + ssl=True, + ssl_cert_reqs='required' +) +``` + +**Java** (Jedis):** + +```java +import redis.clients.jedis.*; +import javax.net.ssl.*; +import java.security.KeyStore; + +// Create SSL context with trust store and key store +KeyStore trustStore = KeyStore.getInstance("jks"); +trustStore.load(new FileInputStream("./truststore.jks"), "password".toCharArray()); + +TrustManagerFactory tmf = TrustManagerFactory.getInstance("X509"); +tmf.init(trustStore); + +SSLContext sslContext = SSLContext.getInstance("TLS"); +sslContext.init(null, tmf.getTrustManagers(), null); + +JedisClientConfig config = DefaultJedisClientConfig.builder() + .ssl(true) + .sslSocketFactory(sslContext.getSocketFactory()) + .user("redisUser") + .password("redisPassword") + .build(); + +JedisPooled jedis = new JedisPooled(new HostAndPort("redis-host", 6379), config); +``` + +**Incorrect: Connecting without authentication.** + +**Python** (redis-py):** + +```python +# Bad: No authentication +r = redis.Redis(host='localhost', port=6379) +``` + +**Java** (Jedis):** + +```java +// Bad: No authentication or TLS +UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379"); +``` + +**Configuration:** + +```python +# redis.conf +requirepass your-strong-password +tls-port 6380 +tls-cert-file /path/to/redis.crt +tls-key-file /path/to/redis.key +``` + +Reference: [https://redis.io/docs/latest/operate/oss_and_stack/management/security/](https://redis.io/docs/latest/operate/oss_and_stack/management/security/) + +### 10.2 Secure Network Access + +**Impact: HIGH (Reduces attack surface and prevents unauthorized access)** + +Restrict network access to Redis to only trusted sources. + +**Correct: Bind to specific interfaces.** + +```python +# redis.conf +bind 127.0.0.1 192.168.1.100 +protected-mode yes +``` + +**Correct: Use firewall rules.** + +```bash +# Allow only application servers +iptables -A INPUT -p tcp --dport 6379 -s 192.168.1.0/24 -j ACCEPT +iptables -A INPUT -p tcp --dport 6379 -j DROP +``` + +**Incorrect: Exposing Redis to the internet.** + +```python +# Bad: Binds to all interfaces +bind 0.0.0.0 +protected-mode no +``` + +**Security checklist:** + +```python +# Disable dangerous commands +rename-command FLUSHALL "" +rename-command DEBUG "" +rename-command CONFIG "" +``` + +- Use TLS for connections + +- Bind to specific interfaces, not `0.0.0.0` + +- Use firewall rules to restrict access + +- Disable dangerous commands in production + +Reference: [https://redis.io/docs/latest/operate/oss_and_stack/management/security/](https://redis.io/docs/latest/operate/oss_and_stack/management/security/) + +### 10.3 Use ACLs for Fine-Grained Access Control + +**Impact: HIGH (Limits blast radius if credentials are compromised)** + +Create users with only the permissions they need (principle of least privilege). + +**Correct: Create specific users with limited permissions.** + +```python +# Read-only user for cache access +ACL SETUSER app_readonly on >password ~cache:* +get +mget +scan + +# Writer that can't run dangerous commands +ACL SETUSER app_writer on >password ~* +@all -@dangerous + +# Admin user (use sparingly) +ACL SETUSER admin on >strong-password ~* +@all +``` + +**Incorrect: Using the default user for everything.** + +```python +# Bad: Single password for all access +requirepass shared-password +``` + +**ACL categories:** + +- `@read` - Read commands + +- `@write` - Write commands + +- `@dangerous` - Commands like FLUSHALL, DEBUG + +- `@admin` - Administrative commands + +Reference: [https://redis.io/docs/latest/operate/oss_and_stack/management/security/acl/](https://redis.io/docs/latest/operate/oss_and_stack/management/security/acl/) + +--- + +## 11. Observability + +**Impact: MEDIUM** + +SLOWLOG, INFO, MEMORY commands, monitoring metrics, and Redis Insight. + +### 11.1 Monitor Key Redis Metrics + +**Impact: MEDIUM (Early detection of performance and capacity issues)** + +Track these metrics to catch issues before they impact users. + +| Metric | What It Tells You | Alert When | +|--------|-------------------|------------| +| `used_memory` | Current memory usage | > 80% of maxmemory | +| `connected_clients` | Number of connections | Sudden spikes or drops | +| `blocked_clients` | Clients waiting on blocking ops | > 0 sustained | +| `instantaneous_ops_per_sec` | Current throughput | Significant drops | +| `keyspace_hits/misses` | Cache hit ratio | Hit ratio < 80% | +| `rejected_connections` | Connection limit issues | > 0 | +| `rdb_last_save_time` | Last persistence snapshot | Too old | + +**Correct: Export metrics to your monitoring system.** + +```python +# Get key metrics +info = redis.info() +print(f"Memory: {info['used_memory_human']}") +print(f"Connections: {info['connected_clients']}") +print(f"Ops/sec: {info['instantaneous_ops_per_sec']}") +print(f"Hit ratio: {info['keyspace_hits'] / (info['keyspace_hits'] + info['keyspace_misses']) * 100:.1f}%") +``` + +**Redis Insight:** + +Use Redis Insight for visual monitoring, query profiling, and debugging. It includes Redis Copilot for natural language queries. + +Reference: [https://redis.io/insight/](https://redis.io/insight/) + +### 11.2 Use Observability Commands for Debugging + +**Impact: MEDIUM (Enables quick diagnosis of performance issues)** + +Redis provides built-in commands for monitoring and debugging. + +**Key commands:** + +```python +# Slow query log - find slow commands +SLOWLOG GET 10 +SLOWLOG LEN +SLOWLOG RESET + +# Server info - comprehensive stats +INFO all +INFO memory +INFO stats +INFO replication +INFO clients + +# Memory analysis +MEMORY DOCTOR +MEMORY STATS +MEMORY USAGE mykey + +# Client connections +CLIENT LIST +CLIENT INFO + +# Index info (RQE) +FT.INFO idx:products +FT.PROFILE idx:products SEARCH QUERY "@name:laptop" +``` + +**Correct: Check SLOWLOG regularly.** + +```python +# Get recent slow queries +slow_queries = redis.slowlog_get(10) +for query in slow_queries: + print(f"Duration: {query['duration']}μs, Command: {query['command']}") +``` + +Reference: [https://redis.io/docs/latest/operate/oss_and_stack/management/optimization/latency/](https://redis.io/docs/latest/operate/oss_and_stack/management/optimization/latency/) + +--- + +## References + +1. [https://redis.io/docs/](https://redis.io/docs/) +2. [https://redis.io/docs/latest/develop/interact/search-and-query/](https://redis.io/docs/latest/develop/interact/search-and-query/) +3. [https://redis.io/docs/latest/develop/clients/redisvl/](https://redis.io/docs/latest/develop/clients/redisvl/) +4. [https://redis.io/docs/latest/develop/ai/langcache/](https://redis.io/docs/latest/develop/ai/langcache/) +5. [https://redis.io/commands/](https://redis.io/commands/) diff --git a/api/.agents/skills/redis-development/README.md b/api/.agents/skills/redis-development/README.md new file mode 100644 index 0000000..4d1579f --- /dev/null +++ b/api/.agents/skills/redis-development/README.md @@ -0,0 +1,124 @@ +# Redis Development + +A structured repository for creating and maintaining Redis development guidelines optimized for agents and LLMs. + + +## Structure + +- `rules/` - Individual rule files (one per rule) + - `_sections.md` - Section metadata (titles, impacts, descriptions) + - `_template.md` - Template for creating new rules + - `_contributing.md` - Contribution guidelines (excluded from build) + - `prefix-description.md` - Individual rule files +- `metadata.json` - Document metadata (version, organization, abstract) +- `AGENTS.md` - Compiled output (generated) +- `SKILL.md` - Skill definition and entry point +- `README.md` - This file + + +## Getting Started + +1. Install dependencies from the repo root: + ```bash + npm install + ``` + +3. Validate rule files: + ```bash + npm run validate + ``` + +4. Build AGENTS.md from rules: + ```bash + npm run build + ``` + + +## Creating a New Rule + +1. Copy `rules/_template.md` to `rules/prefix-description.md` +2. Choose the appropriate area prefix: + - `data-` for Data Structures & Keys + - `ram-` for Memory & Expiration + - `conn-` for Connection & Performance + - `json-` for JSON Documents + - `rqe-` for Redis Query Engine + - `vector-` for Vector Search & RedisVL + - `semantic-cache-` for Semantic Caching + - `stream-` for Streams & Pub/Sub + - `cluster-` for Clustering & Replication + - `security-` for Security + - `observe-` for Observability +3. Fill in the frontmatter and content +4. Ensure you have clear examples with explanations +5. Run `npm run build` (in the build package) to regenerate AGENTS.md + + +## Rule File Structure + +Each rule file should follow this structure: + +```markdown +--- +title: Rule Title Here +impact: MEDIUM +impactDescription: Optional description +tags: tag1, tag2, tag3 +description: Rule Title Here +alwaysApply: true +--- + +## Rule Title Here + +Brief explanation of the rule and why it matters. +``` + +**Incorrect: (description of what's wrong)** + +```python +# Bad code example +``` + +**Correct: (description of what's right)** + +```python +# Good code example +``` + +Optional explanatory text after examples. + +Reference: [Link](https://example.com/) + +## File Naming Convention + +- Files starting with `_` are special (excluded from build) +- Rule files: `prefix-description.md` (e.g., `data-key-naming.md`) +- Section is automatically inferred from filename prefix +- Rules are sorted alphabetically by title within each section + + +## Impact Levels + +- `HIGH` - Significant performance improvements or critical security practices +- `MEDIUM` - Moderate performance improvements or recommended patterns +- `LOW` - Incremental improvements + + +## Scripts + +(Run these from the repo root) + +- `npm run build` - Compile rules into AGENTS.md +- `npm run validate` - Validate all rule files +- `npm run dev` - Build and validate (if configured) + + +## Contributing + +When adding or modifying rules: + +1. Use the correct filename prefix for your section +2. Follow the `_template.md` structure +3. Include clear bad/good examples with explanations +4. Add appropriate tags +5. Run `npm run build` to regenerate AGENTS.md diff --git a/api/.agents/skills/redis-development/SKILL.md b/api/.agents/skills/redis-development/SKILL.md new file mode 100644 index 0000000..2d53100 --- /dev/null +++ b/api/.agents/skills/redis-development/SKILL.md @@ -0,0 +1,121 @@ +--- +name: redis-development +description: Redis performance optimization and best practices. Use this skill when working with Redis data structures, Redis Query Engine (RQE), vector search with RedisVL, semantic caching with LangCache, or optimizing Redis performance. +license: MIT +metadata: + author: redis + version: "1.0.0" +--- + +# Redis Best Practices + +Comprehensive performance optimization guide for Redis, including Redis Query Engine, vector search, and semantic caching. Contains 29 rules across 11 categories, prioritized by impact to guide automated optimization and code generation. + +## When to Apply + +Reference these guidelines when: +- Designing Redis data models and key structures +- Implementing caching, sessions, or real-time features +- Using Redis Query Engine (FT.CREATE, FT.SEARCH, FT.AGGREGATE) +- Building vector search or RAG applications with RedisVL +- Implementing semantic caching with LangCache +- Optimizing Redis performance and memory usage + +## Rule Categories by Priority + +| Priority | Category | Impact | Prefix | +|----------|----------|--------|--------| +| 1 | Data Structures & Keys | HIGH | `data-` | +| 2 | Memory & Expiration | HIGH | `ram-` | +| 3 | Connection & Performance | HIGH | `conn-` | +| 4 | JSON Documents | MEDIUM | `json-` | +| 5 | Redis Query Engine | HIGH | `rqe-` | +| 6 | Vector Search & RedisVL | HIGH | `vector-` | +| 7 | Semantic Caching | MEDIUM | `semantic-cache-` | +| 8 | Streams & Pub/Sub | MEDIUM | `stream-` | +| 9 | Clustering & Replication | MEDIUM | `cluster-` | +| 10 | Security | HIGH | `security-` | +| 11 | Observability | MEDIUM | `observe-` | + +## Quick Reference + +### 1. Data Structures & Keys (HIGH) + +- `data-choose-structure` - Choose the Right Data Structure +- `data-key-naming` - Use Consistent Key Naming Conventions + +### 2. Memory & Expiration (HIGH) + +- `ram-limits` - Configure Memory Limits and Eviction Policies +- `ram-ttl` - Set TTL on Cache Keys + +### 3. Connection & Performance (HIGH) + +- `conn-blocking` - Avoid Slow Commands in Production +- `conn-pipelining` - Use Pipelining for Bulk Operations +- `conn-pooling` - Use Connection Pooling or Multiplexing +- `conn-timeouts` - Configure Connection Timeouts + +### 4. JSON Documents (MEDIUM) + +- `json-partial-updates` - Use JSON Paths for Partial Updates +- `json-vs-hash` - Choose JSON vs Hash Appropriately + +### 5. Redis Query Engine (HIGH) + +- `rqe-dialect` - Use DIALECT 2 for Query Syntax +- `rqe-field-types` - Choose the Correct Field Type +- `rqe-index-creation` - Index Only Fields You Query +- `rqe-index-management` - Manage Indexes for Zero-Downtime Updates +- `rqe-query-optimization` - Write Efficient Queries + +### 6. Vector Search & RedisVL (HIGH) + +- `vector-algorithm-choice` - Choose HNSW vs FLAT Based on Requirements +- `vector-hybrid-search` - Use Hybrid Search for Better Results +- `vector-index-creation` - Configure Vector Indexes Properly +- `vector-rag-pattern` - Implement RAG Pattern Correctly + +### 7. Semantic Caching (MEDIUM) + +- `semantic-cache-best-practices` - Configure Semantic Cache Properly +- `semantic-cache-langcache-usage` - Use LangCache for LLM Response Caching + +### 8. Streams & Pub/Sub (MEDIUM) + +- `stream-choosing-pattern` - Choose Streams vs Pub/Sub Appropriately + +### 9. Clustering & Replication (MEDIUM) + +- `cluster-hash-tags` - Use Hash Tags for Multi-Key Operations +- `cluster-read-replicas` - Use Read Replicas for Read-Heavy Workloads + +### 10. Security (HIGH) + +- `security-acls` - Use ACLs for Fine-Grained Access Control +- `security-auth` - Always Use Authentication in Production +- `security-network` - Secure Network Access + +### 11. Observability (MEDIUM) + +- `observe-commands` - Use Observability Commands for Debugging +- `observe-metrics` - Monitor Key Redis Metrics + +## How to Use + +Read individual rule files for detailed explanations and code examples: + +``` +rules/rqe-index-creation.md +rules/vector-rag-pattern.md +``` + +Each rule file contains: +- Brief explanation of why it matters +- Correct example(s) with explanation +- Either an "Incorrect" example (for anti-patterns that cause real harm) or "When to use / When NOT needed" guidance (for optional features) +- Additional context and references + +## Full Compiled Document + +For the complete guide with all rules expanded: `AGENTS.md` diff --git a/api/.agents/skills/redis-development/assets/logo.png b/api/.agents/skills/redis-development/assets/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..c32b52d4adae41753a4f00e8f19a6ba481198dbd GIT binary patch literal 21458 zcmeIacQ{;M*Efz3LbMPBgOEf^^g3z?LbOCg?>%~E24gTmLi7kC(M9wgL^nFY=p}kL z`e5`j^G@#jdq2;8{eJ(xf4x`cy3U-v&)#dTbN1e6?ax|!?GvV}twu@COin;RK&k%p z$qNDkB60jdLv|g11qx0F5D*YrI4LXZsw*pV=)Q5ccXF{KAmE90{iyI%={bET!XZSL zididmFJ!7n6Pdknmhpkv{Fu}kW^TcEy6xppOZjMsUaYE8YVtSh5KVuv>L43saS-wS zQc%^4LW)q#@^VTHNTHcdNLtXB@b`VS?+kITeVt8?8e4-oLWBeF#3j_)Gx3WBq33W( zF}k*;I_jlLHXz7xrEf~9>bC8A02b{x&(dRTr+O94*HYs_a9`VnRm$|>#(rK#)(63R zD4^ec>5)Zan7v#!odON1;auM>1-ejCIp31dohA8;y2s||Y7G|W(XPDC%)pkQx8Y(h zKHQDw0Y7ZZNWsoWRG)yigO)!^P;}p=4rLX4Ar-!oYklMXn>5ugQsHF!N^4g#T_yJE zgdj+ehJg*L8A;)PmRG1{eyxRM!R=jg>!rGZlW(s~xnt-R5tAQiR+VgTmgv!99s2f! zs~OQX9Yt3(T@{_p0r}Z=R(y-UvNKe-*U}>3!>7pz2t%C+i18^x{K1Sr2neoy3LzlH zpKsxhCmBTlDJ2rmxb{D3nm-jE>nW?NA^prD|)x3_?|sDS$$2SH(J zX=yaskzo1%?U{dw|GV%XLRrB-PySym z@$Y8-*HwHw%ah9r{!gFDlg|qgjS&zi5~x3U{L+VTXO{Gft;+1rea`%QkMDKaskD;w zJSR*j${_AAXd&av;N<+kk=Gp*{xtsbTT>7r3CVp8Rr+f@oE*|SpJ;W1qz|spyyw&Q zDZBEs{=P~`{}Ku>Z~MB7(bCs-X(a`E4nG|v%+Oy~GT#!B^JL)vG(7?*xlhkv=l{;~ zWHf(#o7oxir|3_)3K`Su@H|zD|I`m*p8gLd{awpE-fzs0-u=Ug>9xzXf0*dLdz+c_ z;dLd=KY7c%a>{>6F(Fx_6YL*Kf;6$!oWf)Qa*UuSZ5 zWtTj!mWg;aN=`6)(mqa&kfe_O%g$G=30O{V*Xi=V5G{nf!E-uZ{%}b41qKbz1s@H( zINx>Ro%MRAfjiSlEnNZe>Q>_4;`Plc4{{!%5|XtqT-lF)Cf#;q&Jx9iSz;y-+2H*% z-c*RH1_EYl%GegCXW4*<+MUm|eAw_5wZv;l{MpEAO(#iJNAtb;5v_rCj&lSyVVH|L z8@!bGqHcGB_euu4gU98MNKExcK6SeP0o2!SSt9$KB{Vj!)#n}t)#79Z_kKi9(fEBf z1!n88#Sg>g5{DZwRO|2jj7fTEghu0^$9z^IBzX`{5NID&W%JH{fQ%Odk$29?TT52R zu69Wu=9Z5eRubI3WP&RNs`D)zj@&EuiI~0D%xx`@n~hmeD0yt!8rI^-$bdOuSpe-E1yu^^I~h(8V^^$nPCd z2RYuxC^n`zW+hbW#aL5+WNe=zkp~F=c&b6s1s62t$tGdwOMiz-PceJ7NUk+8Z%Cxrd?ixrgiM7ryH)m)=d-lXVv|1vc`^Tb6uIt4_Q-Fhz|B zLE(dVy}~XDcq&AE!Yp~1mevKkM>=g*5d}{_#818I^Tg!DmhtcRl%C$L84N5Fu&U+d zeN$h+DV2S?A9cqP<2~>Nd|=DU`>3h(bF&LVM0X!wfXV2Ki z^7p{l?t8Qx_Yf=Dy#AMZypI&BP83HLtSk4zpXmXQ_Z;2fI|1$Xj$$IJ_iB5>C3Z3u z@8WMjiSNw|#$d(PTuY^!%IELCj7>bh>Y3v;1wQJ*z=>_fV=-x8EBM7+r`Lg_<|EBY&-d)RoddFdP# zoZ*!bquIl3rTYX5H=>eJZ!K;8}FL{xm4Q>54K}ry?Ae zo`G-c8^)e-o0Vsky-CEX_7Zcl*^jEVoq7(Sbxxd$LgJ8;=~U|Yk4pB zLK?|vFg@l!nARW3o2926osXeR&%xUT=uQ~fu#*hFGjeL0Pv*F=xW;5|Ep;66@ zo}`z#J>56G`RCj&;>W{}L|nKOz#W~$H^rjDv$(!p170OVzoZxD$@b7&#-0ay{t_Sg z4hi*8#rB`R(sWKUel)mOU-?>765wZe1PFstmaNMxQSi>f-M>O?^TX+dM(K@;>1uS9 zVbF$$+10xrhxx~QtU@GLKT!oCPnygc_4x9)<1!zb8nr8#mNaaM@eLP~x}y5f-w%uR z3LnXH>y}^JUliW_u^qo_9<@=d#vazCDpkL6eIgi=i_(dSZDt+G6%h50G@nk?xBN`d zw%m)j#~`=+dANQ)@#MfUvcK;eeqpNm>wz~f84YG2eHc3e zGAWUXtl&AlWPg$&eP0k*`VFZX2iv9N5*I`Vw0}_S9G$G&$_Ar$c;&EPT5`f`6-0B0 z01phh4rl!Ke{m?P?JcRi@a>A`1^9wBSlm9zgmfm{$!7Mb8n2&EA0ADlS}XYSVzv`h z%VakLIr$=CR*ZdCJ!y$nOUyZpDqWScE0FSd79wu5>6J=$zcHE#$zVpzC3oQm{Q1oA z2lkCVZ>U0+A|rm>36xz9C6Ux{+U^*1>R!?EHa=n?sz@gVwAo6VGi^OO#%VU`E6;8T zJcwBI;pN`(n#QR+jZC1-3btegk3XGF#^{|N+<*pm@eOFftm%5FyR<~Ne(%0Dt2x20 zSe7gmykLW?28QDWh3jGskwn{<7cDU>7Yjl=?JQ{9?caDVhd0U1H3}l@qC39(+>sJ zx%a}Zz3`o7v5x0db?^l+3le^UFeWz$bLy}Uoz#Kja7 zA*+O0p_`QxJ^5-Z_){^jLaV$aWA$3?T?G3FZzC$uF02urY!| znk4k{-hHiZ@?SBbIy^Tf#)h)@CC|xuz|jb<8DVU)#j}l{zZGNVtIetVQ6cjyEaq6z zsBr!)8Bsu>D255^CNwS+IQP zBtW|exu8)7wF%7aI%QZ8amdk8_~G4n!p^m#aN8S<18X{XQs)FHE9j|?oak&sGaYT% zoxC>33DVBQgqyuDoC&w`n@DnyR42YE9{2rQmzgj$>uQ4M9rN6m8?1;%`MDHot&I{H zSk?mcbhKB_(?@|+&QJp#@fH)uexxuuFm`{NWJ$P*q=p`zNR1+Cg{1Lv!7Fn^ zt-puR%%LAA-@oTC(mD|%uqA(w);2oYgkSc&ewTe3U3=j};)_>2hh(U>bNU)_BBF@9 zq)=G~_s%ZBoh&|?c7J#Igo$Dg{ilLmXtrMlnPm)lXBo?uxnpheEPB^NSCU^zz!Av^ zQ4P#+LLKy-g3J^VsAK84y6Y+6#H#>f`5N`{noZ~iZ1aW~64m)~^tpxR#xq*o0&G96 zqoEpRgr8Cj`5Eg!L($$|7D{Bpv3HsUtqgWu-I6_h2gA5ynBTJ4fI}8B(x~ufZYO=2 zAx>#-)IohJD3l3v?6)1g%;*H0~r#I~1`**3-3t0OJNEGv(%EmnOt z24*J~eH#5OAeqO zAKIqL5}VP`9~~5;Z5he z9Z_B9xkSr~cpyN{5pYKA#bsBN`y=IFf6dHF%?5TPW)wUJEMqcy+D zN0MFJt3FHAXSFx1^r>3APq+*{2LL}mqmM1$k;Xg@VU`FpaY~3EHrJaGijZw8HygL_}H#&p4af%%bL=*oxOp7%ct#oWhm@!+pUd!7bBfc zDNbAw${r5Ut?oQ8U^&Q}-pgv_ucE%zHN;OmJizk*F&=FjJgMQqY8bs-UJ; zVH?@q0Bn$)7g4q4B<=&0io?L?`Yxx^z+C8B80t4q=uc<(*Wr8%?Tr>vRo{*lc$B|& z&ck%?oV6ZdhP_+nH!9-uYfUR9cO$jWZb*sl{_?K3FPDsHng*EK<8rEVDtmhC*!`W4 z8Y+uftU-;XsWa0dUhz@BGj`R6H!ojm%*#E>wlGUoR!$I7%FCFGmf;iBp2wO&H{+gj zsjk#Hp;LPOXwN^Ve#-SRfUB_W!{#ucJX>Os_X>#gqOrwmL@lQd3cnUVQB$#zkKUs} z6tj^>#u8vI9{$+zGDD8+Kfp$$=N#^oqWDGEIg+U7Y$srlnLhW@nH!ScZb*fR4fM@6 z3m-QyaQ{X{a_}<*8X@U@^Ie!5RZei7zh*AV!tWPTxvySYkcmeoD<@2Rv5Jgx`};FR z!^Q0VU)TlKC19$W*?ew=DKsJ8Yv#U%v7AL@x8-<=re6&1R9cnRaJ^TqtLFL21S08H zqCz$B)z{C%Uc&mAOP^UFRZbvV*aGa35oTf-t&l7p;26XDfMePARg_erVzv9<9|*geKnkdIR2eL412o(x@9NyA^OYGRv4WrF^Sb_2xx^|YT_TLm4x$vW6FHv+S@_S<<4ZJD8sK3#D7B49(Z^aJt(O`V?dx6dzvwqW| z1v6;nsb&>28Y+MP*u&Q^Nnr_dc2svZ+VgD4Na*&DHDgctAwmd9NDO|a0?lhm?xMsQ z0a<}ChN^%So1y(8#>e-78J#?xd+F%a`6!QINtQ}liNeTRhX=;8Arr~Wk8M%2FMLlf ztSJXY2BIZYQuXHCr~=ZG^L}WD(6YLzE0QW)8oKPqQ$>0;AHGVE0AB95;k01>hp|5W zS1^`E>%fD`FK^DJb}m$EXX4+NV$cwagtn>3dYdPFJ$h9;w`|Fl+u-RT>G0+G$Pa(a z^~PhZji~9B-su}EsYH}9D8DmlhE72mM3!c7$@#(d{^@eCt8x_PWWbu_rg+Dz9bz6P z_nIQ-1wDscot7C~NPZ4o%&)d(;*>?wRJKGHG~QkWt2SO|?%x`w(5t(L-GaFGR(>h4 zG^0d~B3pt?8KQ<;4t)n(%75BUP@_6Bid@$vO!fU;fjp|NhPHiRqAVH+JNRuRL0L@# zzxOK>IUUjcsVlb80nUh3T5dPC$%g(A?L_{`nrq$w0~y-W0G#2RVA#elH|qY|4sO?b zr_vTzQS*`3o+PtVwDe@+kz|zYzqXGdp}5PL=Zq##46)5>Z{AuPzd5bF3D+!yW29Va zEiF?G)@Vf~*ov4tjpa%$(kt=cN8RLT`TN3*fK}CL{Lj-{TnOPGG)LpfdO&Pq>#ff63(lQoYPJsa^Zb^lRTrgdeNzUBJ*ZP&<+QWh?*NJ}t`ZN>uC^9(a}Ev+y) zc*6S7s!eN|?vVR{d&5mxU$S)K)W|$h>)H)l?|CEl&?lj4>y|s+({o?1y{yiO!nho9 zA5{&uLTzIJ-|iGIIJ3B?4a>}pTdIYzFA}3{h8Xlb9*OeFHw@0+A4Oo zYlYlLk_oNcemSaArEW2~>3s%wAQw{R5wvs+egQ6sbAg0XUpBCY`^-GOC6- zf;U>3bMOa`*GIyr*wzU*R)hCi5gg2)iEV@EQMz4^>wy=r^1Q2Wj?=VBb<$;ade6W;vI>?mlaIZv0 zpnp1!N9FDF#B{8b+hdw6;V4=r2@cpEw#FUFjY$o!rLWv?TwG=IZG%5gd$V*f9zbQj zyqTVR*js2qn#VUVU@llRvVP3bEqCtNaI8=VyB8QTF#}5f9fxW>g0#=3s-&Hp6m$tm zURIHXurx?`1x1q-nI7xNQJ-fhUp4R6w$1HiNt2E4*HxnU1(1nuF=DP$`KBooGTxCJ zUkxDW=2py+1pvXG-p_S6)x2=_iZJo0Kxr%wDXTHMB5Xs)JE=7SbEJ1L$PIXuAhKKbac zR~e}7yyFXXW7n%xtn{KL=?rW>LPW=gP_IcMr`-!Bo(!DF0_NFAt|uR2L*{vaS@$fY zh0$+jUcb!Suh_w^6y$1|PndLR=XN#m&aysx1n`DnUD)pDXVLAhRjHjDspk3aNI~Tp zQkw*)2m!{3MaS=$twV9u4~{wzfawpvKC3g0Mz{31^O=WJw_H=2#|9pkS$U!p5?7x! z)ti+n)tgHOnFA6zKT|k;{uK@9A>VwJ>Gxx)q z47VqpNi{PpM_s3x5||f+h$Th><)@nmc^dgMt z+7e6~N2y`Ao-1TQg*2*%a-jPcj!h+{;+c4LOILSJE(2A4THqJ7GQyd@GPzec-ybva z9dp)mCc4*&%kgm7iKZyvs!%Es@v`ux7wOutdCs%-qR!ry4~pTmA#O35 ziWLrX2BFQsA!-3vhl{{zw_=G+sfLC~pPkFh^YNp>{;%eCjOYT`kXV1%$F54w>J%nq z@WiZR#gN6Je3IQ^)oaEDK(}&W7?+}@&n(Le--~)or!yqwEE48B%@`tSCSdo#)b$6n zX+W!#Fc_jDiOn%4?M`EhBrDz<9!jOkv(|E$+c^XTu3ZVQwo_GP7iHEgr1kmzCh};vX~p zw-n6+A;?izi9mdz#@=7Fy=3t8s_w<5CwUm<2)yUOByzZ&rz z+Inbe=d#SzF9`w!{2}^68PzAxKYGRyvt`{Lb>)aCQ_j#@F+ub09bazcBf&%_2?8Ria-W7Szy!v_b{{yuI6asm}!{x2vg!cu)8pNh7%jB^fpv71NZ*Aq-7u z)}wlR|17#~Tszg6Lge|5*HqY>=W5E%?XNHWoqZtZ7#%C0-4hv{IoOvCp^*PaXBE_O zdsx+8wFHK?GY69~G!_}rJ< z%~tXHpS@#U3W+Hj%ufN?7dijFS9*JM<#Fz>Am^M6ARyjh8M&m9OTF^;dDr?4)WB0( zJ%t&TNR&Si4yx0tBs5gniNssa1J$PX!}*{mU>JxWXJh}g1T&J?{_ z*It0{fUg%L1|IR9=d%aRIpwb93;r};xOJp`GTa)YC5^6YtQH!t6AQ@oz<1cL{g`++Ajlq{_m+>4cjjNiuRR0G3Wx@vw2w>!l}P)xet45ps7>#rw`#G~fAR|-o2DQ8w){_~ zrANg{^E=O~F5(UaXBm!x5vc{rUw}yho z%0z6Mj9uaVC$l4yJ|oJ30>t*{C5Nl$B&r5)C1oYbr(rllv05HCSoIm?8Q*QH&SZ9I5$4^*kse`eSS}ZV7 zI&F4qO|D*gTjW0Sy!LxJI{bN7XdP_(se+u?Bj$n0-WJFCzB?7FXfhM;&UK22&o{=o z0F?s+YGwyluMmt}a{Z;eH24+nVPQAZ6w3?gcjfGy+Xv!!0N7n}#16rqev*I`q#KOn^7XCH8Kd z{xm7@C+bX<(Fdt$*I&~KB1yg2qbk~$>S;5+FQA|H(|Eb?>grcK9fc)gA6~3*L=7kD z53X-LfmInVL%-I^YFc{&;2@Vw_-3Nsk~UhOT42Mz{58NJhL$71Q(r@~ z9s(NdtX047t>+*USVz^m>Ry6`w2-|rBi;Jw3n;MCa6M)SjH1n0wh7;kc@8hCn6#OY z4H9LO2anpJFp+g{Dc-B4&M><0Hz&w#GxWiOrPgl>$_i$$=B&g9KWAs3#50SFg>VbZf6U8PrSb%>fR=ThcWH^e2MO zRj=&Ebzb3EcHZmiRn6Xn+xGyyBkSakM&SyiZr<~EC>FRH>N$LBQgFM^n`4vM7X5w` z5T714EI#0EbU>X05sSIM$jG=Ta^V$_JSOZHAXY24`9o{hiI3NLu1dzDo~+4Ag^dS5 z?_R?$-k=UWE1c1b`0<7c4TWoLHT<}e7J9!Ij9$~4HNLZNI>{AU(2=^{B%y@x(XHbG(Aabh0!* z(NZ}TVf%jEdC5on8uFntZoC@5hU>1Bp20u}+&eN)oiH6079ZUFUK5$3=2~30n(ll|{o0IjY@t}If;;<``mRS}TxNKWow9-esX96i~_c}7Vp#bOc{(j>7 zisP*0=(Oigk29?*9;=pa0;Na=xDjWa$F7AF6U%3{Nd?J2BoMDsoxdcdPayb(g797> zVtRyF50NIMsLG8>46-vkT|fqYsJ9%)mGwXkkW@;uafGZJa~6YNmBPaaAOVqsPjnZ; zqrLQQn{FI7k;(U}p(4~3NbRfvlq_5O>NR?IpmJ`Xpt`27dyUytE6MJ20f?D_hizP@ zWiiO7q%n5wVek8ARKJaZq7Hw*vRH;7AR6vl0zOwebD!0m&gBwi+>l%^_3@svyAOY? zCjm0Irc6s0pO~e9#x1?){4dMu8R(10#UD8|}s9{ZfyPze72ISqyq+4FYVY@7uJL&SF@=KPg^B9J+Ps$Yq6xS1gg5`Iw9X4M<{VBO5wpHuLk`F3%>a`oqJnvB6?VFL|n)m+D z``XC%*TfWnxTtu!D)eV}j3f?7r9=P==u?14<_NvOxRkY|>g;_flGF3Aw$hFG`$p; z_8*T!#zI8r6vKGTVD0$6Q`76%-|H^Hlr{jRNZUbS4+KY!cuC`f9G;&8TxzCABkROd za)L1qgaQWMWo{wKN0cXFfjViSilAoH!I&LHOM0PucEr8T>brp}PFz*J zCf={gpprfCHSXbAJyROX-EU*@(|afm)>5%YF@}#C934=LG43bso9~Km^kdCYUY$@H z>)HVZt{`r zPQuxdppC|wBvd2l+0ViN@(10;w^Gd$p5@jbWpZ&f?6YTCHL^v1^yoUsLVTI{QCSyv zs+(cNCjh`HZGWoQl+$yJ07z0_m8kWMwbiE`# zHi(a7_h8GJ%>;OR?ppg%z6n`b#lw@j#i)`J8SQGHd|E^w%hj!CD8&}Df@B$zIWRi; zD^3O@QxQa%NH`*ll~JNh91x*wb<`}>4rpg|JLeoY`Y~qxk_*TfM2ZqKr;=SiP5Ur( zrzE9rrS%w=a1*ki_hJg}yW4$1NrR+%7Y0NBS}Z@?O2zg|dy-71U3kqo=Z=8({Z%nF zOJRdk8w*vng_kYD#lI@qa3vT^+mB9VTcyqlJ5Q3o&F=m%(F`mW-MTp*j-VT_b%u_` z(kZj>r-;-Xb_Ld|#iljVs++!3I&ZW@XfN4*@3dQKF z7mqs)mV4{!2&MPpNMw&Eq`mug$W^T)g!xZPO|?|-LEs7rrBr&Bg_0Bx_B1uV*Jozszn z@oQz^dk3Vi=`8)Px&PQ^F!j!G0{7^v-LQ|bYzeY*>+mp-R2AYBS`Y63EH37v+Qo48 zFrx_eF62#FhIdI>XT7k-UdlumRk%F4Xh7OtXOm#{e9}aTDa39jWnX6i(8BE2hADRo zg@J_qJCnj5v%({3Kh^K%Vf#`T14EG4zRDLZTgHdzjyS!He%qDy@B;8fF3rZgE36$v zW9t@~ue}+5?hFw!>ml-x_;GLMrh=hfX984zqE>g}@igXfY9Gta6*YYQjMp!NNS}DA zdQds=ISq$P5pFZyD{3hAv8MEusC*;#kee58ogcLdi6QhFHs-xRDD7|O715V*U=QNj zM?ZTcH@{zFZGS8^d6ETvXFXK_eJ%tb=dz80?FlirrKZ%1M2J70L7t79SjRsYA9yKk zX!36M)pW>wNpSu^C)DGHxLv^;%f9p2>L$5t*Xlw+p&#btrLRa!pZUL;zBV4_Nfiua zDz~1`@(K(Q>IKtDQoB~2rqC*tbX%9+WvKNA{x19Ei$2P@^7IB%lU#o&l;J&2m|mr$ zP)^MB$nvp}367d`P-wV$xA|epjN^z=7M*_l`d+Ou%BFmGVaU*Cu%sgIz>n+dW-1%k z9Zhf7S_EYyS4n@0Y`G6fCoSc0ffVZ6@wo2f#q`vfZJF;HxTK)O<#TY zcT>H@VE)aSqdKz%8+x1G;)8KweKm5KM*Bka$oOPw?pOWPVGQy299XX$*6gQn*3N3jfl>-U!%gxts%5J2mX@7k7ff^_l!!SyzL{{r=-;@+A zeS4e?j8XsET)zH|7XbCsxM(Nq;-Vg7aUt`ZNqK|2h6!4z>s7csTlYDAnUcnaO~vAn zV%PN?*spx#sy)|zRG&+No3QQHfGRBp(Q<(TZlj)64*3>C&pJC{*^ zL!5@Dfx(1xR|3E#X90YeT?ik?=`$BOeGw6d`#Qvj=X1+m_ZOmUWfPmJtSAZYci->8 zfIE99y0curP@XqOl(n{wztYnMe|B-}dP%WZn;u=Q4^=IY*!1quw|=~LXh_4FZl_i=IRPctLv3egBpD!ay?< z$|J-2; z)bDC;IOABVrEB;tn#_*k?gO@EF)J*fA71vgywaAGudn)^X5W)_JwI>A)f~gHZ9q?% z0nA*qGP@~ij(eG+?8snBX78!{**F=eak-)p)vRwRrV>b{-wCh!o&CrxAnY!z)g9wd z9j=^=Je$n~irI@v?OvOf`ZIw3UMa`5G z&6eY7yKjp$t-N-0|L&wpN<)81J*u$qoR`vo*z;)M>SXPRLb6k7>dwhZCOn9aG?lQl zi%rz!M`xN09%Olqn(@WN3z583Q~PL9e2xpv;o;l{`08r%nu;Feo_h_b){?0bhe*1_ zp?@HKIa)O(Q|;;}*4JS1p>_6AgTW^#lKL4gmsq9*s}pMhBx(f5LCwHF@2^Iszunrm zHQb~ov+!JXYORfWb+)iP^Xk!i<#&YvZK9bwVPs-4@}F`oZ|F)*99{I@ww zqEkq`r(^O+2}yoN=Z65S?t!*Y^HCU7N>?yQ!R~8hOyl-ae6$>AKnzy)&NKI~YU-BF z?4GLm^i6%Iv>9|O@-pavc}pV7)Prj$+2Ec;ki)F$8(`!7+{fI=5Qbg}ctsZi%sAEaM(yLQ^DPKBri_anF0q9JC{aW(K+gsRhLCgCAp;9eoztPgroP=y42Owf{FR zWTS9@Fi^b5SZAX^fnp?a<{`K!mr3SkHL=H@i2{|%-XZg@r&1GYwxS>f$iy5EsZ1I|B4j3il9v<1c1vfyLc$! z<7R5aJ#<_i(GWE%b}kFzhvUnU$1Toj{^P$Jj`x0uK+ac}G3#r4=%`}^j{b2eZIRkV zm4Bo9pEyA=eK`#hJ6bMpuab*}7bEGHJlXuEjtDck^y~2R^>(m`)asgKe*Se``2~Ai^bl>?&J? z$zCF|xW)RrbVhnpRyh45* zlOE>usSCp|eTK@U)&Edb$GmP}TKx&X5(HW_pf{r{D_7nvA7eZ+c^0 z6cwJ)gKmL$b`y9DB(G`Dov1WBfLA^zzp$PiHw2_m!=6Uj`$y^<`Oga9#>b}d!UF~S zS=M)|!EDd*u3w-kgDIQmQ|}YICC`a=F|zcVU0caKfL(QAIg#ZwsV<@Ag%<`n0uNfF z3zZ>ssAqSY7*icUAFRmV6wMfY*DSM?qcpI(l2eUUUC!tikX6X;73@D!uLl%vlF6I(h+`YM^%UjaD3rZ|$vH2&Mv!wYTaCOeRU{jObW8-Sk0 zS8HI?i*`ttc3z2~6>6Q}peEW>rtYNbCq;D(BTJ7kg{tt%5Dv)YbI-!yFHam4F@Vhfx|X|WS_iUX(3A-@eO&*5q9 zeae;Ww^N_HI@B9GHOJ8(HE}j%$9t8w(R^ATN^}hmQEoE%49m-b)^B=yU#NCIy`Gqv z=I0`(r}j|Wmb`GxEKMydNXWABrXG=Us>ZwIR|0;3`Yg`Ks8wmn`oTLt*p|Mu+0ZYn z%Vv-84em)v!0OLB$=qlquKUwi_eaTD)udbwSj}u}n4E7URH(Xb$};v0=zEcFvjjeH zQ8g}XP0-CMZ~7oYXG?{BsUhl6NtJv#bB$IC0qI$fN}7=rb^CO7G{jz8SW)6%a^t>G z*I|CIX1cD!_Al+5fti^&pNe^pk%o%b2^Be6mY!$_1+)2YN^{-pHc>q;yu~#y{3TI^ zFW3UEld2YdxwSyYza(2lt?dQux6`n-aUv(Lk3fKQUk1)R3! z`thXb&SsZ3pxrsS3QEfJM(_j|?3gc4c-X$EQIPy`K-Kz@u3d*(;+ZV~WOAy=8Z-7IZV-FuTf zi_wKiC8KTeLBNFB#Y}#z_Uhe*A*)-Q>u-Or1&C3P=gmN1Y=mbmCkLs?o<+VMb2ASt zI})}K$>OfTM|U6gH6FIgc8nA>?m_x561FBPMuhb|TI1mS#6!;G`$+h&vLxdEJm~e0Sq-iN_!6UF#aB!=PUA; z`qJ03z}d-hJ~cYCFoRPU;ZjutU61gEU#(5Fe5NLF@Dt5Ox5HBm5@i0c(byZ3cE}4N>dq&{`ECPu*D$VFn9Y;7u)s$@pP*$ZC6ANuvWm>& zQTTw{3Iob)HN!6wSd0yiH?L7V@glZrm6m+5I=tdG0i@;lz4Vq=a`$-7`E3U4!)EM} zc)9j=Ud{auoePsfrDjyL?u;VKD_34-~15EB`;MLvV-z; zODPQ{%6cd?Et*`yC>MiQ9JegRZ1KU`1(EA32H}!N_~<3xbcsoe?>)Y8i!u1%tctGO zVUlkkSnu&4!+Cc-jlHEXF_wjhLQni=HMLe)JNHXfgLSft8=_nZThM{t5eK{hFdWb1 zYJC_{I;8Fe+KnnZ!wp=tG;j&&*#-Zeu*06|`L;>$Hf|J8CV_UrO;Mos*eaDs*765o zh!bCrim@vyO{dA#1Lj1pmXEZgN_Vn2OpQ9W=(Otqv`#7fr5((l&AfFsZac5d-)MEp z4W8gyuNX|Bl547sU5i4)nS1|K^3 zkt(n2RZauep?}@+6qrhb6Yo!u-0Orl#A4X9biqFoyK=o%zkd0|wDMFhi%$yio>AS& zSkuX+!N|P&purF%s_Kq zwbhq;p7K0;GO0g7=K_i9-tAG46)ZKoi(gC`3`E#xtvAmns5MH3W_r*qvT!rZy(l#m z_wF}>YCRPkktyy^*lXw$;cYwu@D`VQ>G1$s0x=6!H$ai+a!MEB1acpOz500;Hl4p$ z&Q33}CL+0<{RqCuHFYV$qVIBT@}`h=ME0e?z076Cbr$*tc5%1YO`3H_Br|n%)!#*fm@>5*`zLW_%}+@svsMb^u&g_EuiyD^hC1bwWSNz(C(lnRgwpPWdbOVk z15^k9#K!;GJEyLq89uTQ<##pl{jh-YfKz2xrR^{-^xvPS5XaYet$4^wud7Ii-|c6) zO=7Lxto$eWC;#OX?~xC)2>tx8Lj0dnv#yQkzpE++a^iykR9zf;WL$&+69FD&?ru1N3;yMJN9|3k1~0Y|v7y3qB=1Jd{B>y7Fv L+D}T9EQ9_JZ%0%n literal 0 HcmV?d00001 diff --git a/api/.agents/skills/redis-development/rules/_contributing.md b/api/.agents/skills/redis-development/rules/_contributing.md new file mode 100644 index 0000000..2e0e6ee --- /dev/null +++ b/api/.agents/skills/redis-development/rules/_contributing.md @@ -0,0 +1,97 @@ +# Writing Guidelines for Redis Rules + +Guidelines for creating effective Redis best practice rules for AI agents and LLMs. + +## Key Principles + +### 1. Concrete Patterns + +Show exact code transformations. Avoid abstract advice. + +**Good:** "Use `SCAN` instead of `KEYS *` for iteration" +**Bad:** "Design efficient queries" + +### 2. Problem-First Structure + +Show the correct pattern first, then the incorrect one. This helps agents understand what to do and what to avoid. + +```markdown +**Correct:** Description of good approach. + +[good example] + +**Incorrect:** Description of problematic approach. + +[bad example] +``` + +### 3. Practical Impact + +Include specific benefits. Helps agents prioritize. + +**Good:** "10x faster", "50% less memory", "Eliminates blocking" +**Bad:** "Faster", "Better", "More efficient" + +### 4. Complete Examples + +Examples should be runnable or close to it. + +```python +import redis + +# Include setup when needed for clarity +pool = redis.ConnectionPool(host='localhost', max_connections=50) +r = redis.Redis(connection_pool=pool) + +# Now show the pattern +result = r.get('user:1001') +``` + +### 5. Semantic Naming + +Use meaningful names. Names carry intent for LLMs. + +**Good:** `user:1001:profile`, `order:items`, `cache:api:response` +**Bad:** `key1`, `mykey`, `data` + +--- + +## Code Standards + +### Language Tags + +- `python` - Python examples (preferred for Redis) +- `javascript` - Node.js examples +- `bash` - Redis CLI commands +- (none) - Redis commands without language wrapper + +### Comments + +- Explain _why_, not _what_ +- Highlight performance implications +- Point out common mistakes + +--- + +## Impact Levels + +| Level | Improvement | Examples | +|-------|-------------|----------| +| **HIGH** | 5-100x | Missing indexes, connection issues, blocking commands | +| **MEDIUM** | 2-5x | Suboptimal data structures, missing TTL | +| **LOW** | Incremental | Advanced patterns, edge cases | + +--- + +## Review Checklist + +Before submitting a rule: + +- [ ] Title is clear and action-oriented +- [ ] Impact level matches the benefit +- [ ] impactDescription includes quantification +- [ ] Has at least 1 **Correct** example +- [ ] Has at least 1 **Incorrect** example +- [ ] Code uses semantic naming +- [ ] Comments explain _why_ +- [ ] Reference links included diff --git a/api/.agents/skills/redis-development/rules/_sections.md b/api/.agents/skills/redis-development/rules/_sections.md new file mode 100644 index 0000000..416f0fe --- /dev/null +++ b/api/.agents/skills/redis-development/rules/_sections.md @@ -0,0 +1,50 @@ +# Section Definitions + +This file defines the rule categories for Redis best practices. Rules are automatically assigned to sections based on their filename prefix. + +--- + +## 1. Data Structures & Keys (data) +**Impact:** HIGH +**Description:** Choosing the right Redis data type and key naming conventions. Foundation for efficient Redis usage. + +## 2. Memory & Expiration (ram) +**Impact:** HIGH +**Description:** Memory limits, eviction policies, TTL strategies, and memory optimization techniques. + +## 3. Connection & Performance (conn) +**Impact:** HIGH +**Description:** Connection pooling, pipelining, timeouts, and avoiding blocking commands. + +## 4. JSON Documents (json) +**Impact:** MEDIUM +**Description:** Using Redis JSON for nested structures, partial updates, and integration with RQE. + +## 5. Redis Query Engine (rqe) +**Impact:** HIGH +**Description:** FT.CREATE, FT.SEARCH, FT.AGGREGATE, index design, field types, and query optimization. + +## 6. Vector Search & RedisVL (vector) +**Impact:** HIGH +**Description:** Vector indexes, HNSW vs FLAT, hybrid search, and RAG patterns with RedisVL. + +## 7. Semantic Caching (semantic-cache) +**Impact:** MEDIUM +**Description:** LangCache for LLM response caching, distance thresholds, and cache strategies. + +## 8. Streams & Pub/Sub (stream) +**Impact:** MEDIUM +**Description:** Choosing between Streams and Pub/Sub for messaging patterns. + +## 9. Clustering & Replication (cluster) +**Impact:** MEDIUM +**Description:** Hash tags for key colocation, read replicas, and cluster-aware patterns. + +## 10. Security (security) +**Impact:** HIGH +**Description:** Authentication, ACLs, TLS, and network security. + +## 11. Observability (observe) +**Impact:** MEDIUM +**Description:** SLOWLOG, INFO, MEMORY commands, monitoring metrics, and Redis Insight. + diff --git a/api/.agents/skills/redis-development/rules/_template.md b/api/.agents/skills/redis-development/rules/_template.md new file mode 100644 index 0000000..c87b40a --- /dev/null +++ b/api/.agents/skills/redis-development/rules/_template.md @@ -0,0 +1,52 @@ +--- +title: Clear, Action-Oriented Title (e.g., "Use Connection Pooling") +impact: MEDIUM +impactDescription: Brief quantified benefit (e.g., "Reduces connection overhead by 10x") +tags: relevant, keywords, here +description: Clear, Action-Oriented Title (e.g., "Use Connection Pooling") +alwaysApply: true +--- + +## [Rule Title] + +[1-2 sentence explanation of the problem and why it matters. Focus on practical impact.] + +**Correct:** Description of the good approach. + +```python +# Comment explaining why this is better +pool = ConnectionPool(host='localhost', max_connections=50) +redis = Redis(connection_pool=pool) # Reuse connections +result = redis.get('key') +``` + +--- +### Choose ONE of the following patterns: + +#### Pattern A: "Incorrect" (when alternative causes real harm) +Use when the alternative causes race conditions, security issues, crashes, or significant performance problems. + +**Incorrect:** Description of the problematic approach. + +```python +# Comment explaining what makes this problematic +redis = Redis(host='localhost') # New connection per request - 10x overhead +result = redis.get('key') +``` + +#### Pattern B: "When to use" (for feature introductions) +Use when not using the feature is a valid choice for many use cases. + +**When to use:** +- Scenario where this approach is beneficial +- Another scenario where it helps + +**When NOT needed:** +- Scenario where the simpler approach is fine +- Another scenario where this adds unnecessary complexity + +--- + +[Optional: Additional context, edge cases, or trade-offs] + +Reference: [Redis Docs](https://redis.io/docs/) diff --git a/api/.agents/skills/redis-development/rules/cluster-hash-tags.md b/api/.agents/skills/redis-development/rules/cluster-hash-tags.md new file mode 100644 index 0000000..5902fca --- /dev/null +++ b/api/.agents/skills/redis-development/rules/cluster-hash-tags.md @@ -0,0 +1,78 @@ +--- +title: Use Hash Tags for Multi-Key Operations +impact: HIGH +impactDescription: Enables multi-key operations in Redis Cluster +tags: cluster, hash-tags, keys, sharding, multi-key +description: Use Hash Tags for Multi-Key Operations +alwaysApply: true +--- + +## Use Hash Tags for Multi-Key Operations + +In Redis Cluster, keys are distributed across slots based on their hash. Use hash tags to ensure keys that must be used together in [multi-key operations](https://redis.io/docs/latest/operate/rs/databases/durability-ha/clustering/#multikey-operations) are on the same slot. + +**Correct:** Use hash tags for keys used in multi-key operations. + +**Python** (redis-py): +```python +# These keys go to the same slot because {user:1001} is the hash tag +redis.set("{user:1001}:profile", "...") +redis.set("{user:1001}:settings", "...") +redis.set("{user:1001}:cart", "...") + +# Now you can use transactions and pipelines +pipe = redis.pipeline() +pipe.get("{user:1001}:profile") +pipe.get("{user:1001}:settings") +pipe.execute() + +# Multi-key commands also work +redis.lmove("{user:1001}:pending", "{user:1001}:processed", "LEFT", "RIGHT") +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.UnifiedJedis; +import java.util.Set; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + // Hash tags ensure keys go to the same slot + jedis.sadd("{bikes:racing}:france", "bike:1", "bike:2", "bike:3"); + jedis.sadd("{bikes:racing}:usa", "bike:1", "bike:4"); + + // Multi-key operation works because of matching hash tags + Set result = jedis.sdiff("{bikes:racing}:france", "{bikes:racing}:usa"); +} +``` + +**Incorrect:** Keys without hash tags that need multi-key operations. + +**Python** (redis-py): +```python +# Bad: These may be on different slots +redis.set("user:1001:profile", "...") # No hash tag +redis.set("user:1001:settings", "...") + +# This will fail in cluster mode +pipe = redis.pipeline() +pipe.get("user:1001:profile") +pipe.get("user:1001:settings") +pipe.execute() # CROSSSLOT error +``` + +**Java** (Jedis): +```java +// Bad: No hash tags - keys may be on different slots +jedis.sadd("bikes:racing:france", "bike:1", "bike:2", "bike:3"); +jedis.sadd("bikes:racing:usa", "bike:1", "bike:4"); + +// This will fail in cluster mode with CROSSSLOT error +Set result = jedis.sdiff("bikes:racing:france", "bikes:racing:usa"); +``` + +**Hash tag rules:** +- Only the part between `{` and `}` is hashed for slot assignment +- Use meaningful identifiers like `{user:1001}` not just `{1001}` to avoid unrelated keys (e.g., `purchase:{1001}`, `employee:{1001}`) saturating the same slot +- Use hash tags only where multi-key operations are needed, not as a general habit + +Reference: [Redis Cluster Key Distribution](https://redis.io/docs/latest/operate/oss_and_stack/reference/cluster-spec/#hash-tags) diff --git a/api/.agents/skills/redis-development/rules/cluster-read-replicas.md b/api/.agents/skills/redis-development/rules/cluster-read-replicas.md new file mode 100644 index 0000000..35b2458 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/cluster-read-replicas.md @@ -0,0 +1,55 @@ +--- +title: Use Read Replicas for Read-Heavy Workloads +impact: MEDIUM +impactDescription: Scales read throughput without adding primary nodes +tags: cluster, replicas, read-scaling, high-availability +description: Use Read Replicas for Read-Heavy Workloads +alwaysApply: true +--- + +## Use Read Replicas for Read-Heavy Workloads + +For read-heavy workloads, distribute reads across replicas to reduce load on primaries. + +**Correct:** Configure replica reads in Redis Cluster. + +```python +from redis.cluster import RedisCluster + +rc = RedisCluster( + host='localhost', + port=6379, + read_from_replicas=True # Distribute reads to replicas +) + +# Writes go to primary +rc.set("key", "value") + +# Reads can be served by replicas (eventually consistent) +value = rc.get("key") +``` + +**Correct:** Use replica reads in standalone replication setup. + +```python +from redis import Redis + +# Connect to primary for writes +primary = Redis(host='primary-host', port=6379) + +# Connect to replica for reads +replica = Redis(host='replica-host', port=6379) + +# Write to primary +primary.set("key", "value") + +# Read from replica (eventually consistent) +value = replica.get("key") +``` + +**Considerations:** +- Replica reads are eventually consistent +- Don't read from replicas for data that was just written +- Use for read-heavy, slightly-stale-OK workloads (caches, analytics, dashboards) + +Reference: [Redis Replication](https://redis.io/docs/latest/operate/oss_and_stack/management/replication/) diff --git a/api/.agents/skills/redis-development/rules/conn-blocking.md b/api/.agents/skills/redis-development/rules/conn-blocking.md new file mode 100644 index 0000000..34e6379 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/conn-blocking.md @@ -0,0 +1,75 @@ +--- +title: Avoid Slow Commands in Production +impact: HIGH +impactDescription: Prevents Redis from becoming unresponsive +tags: slow-commands, keys, scan, performance +description: Avoid Slow Commands in Production +alwaysApply: true +--- + +## Avoid Slow Commands in Production + +Some Redis commands are slow because they scan large datasets. Use incremental alternatives to avoid blocking the server. + +| Avoid | Use Instead | +|-------|-------------| +| `KEYS *` | `SCAN` with cursor | +| `SMEMBERS` on large sets | `SSCAN` | +| `HGETALL` on large hashes | `HSCAN` | +| `LRANGE 0 -1` on large lists | Paginate with `LRANGE 0 100` | + +**Correct:** Use SCAN for iteration. + +**Python** (redis-py): +```python +# Good: Non-blocking iteration +cursor = 0 +while True: + cursor, keys = redis.scan(cursor, match="user:*", count=100) + for key in keys: + process(key) + if cursor == 0: + break +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.ScanIteration; +import redis.clients.jedis.UnifiedJedis; +import java.util.List; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + // ScanIteration manages the cursor automatically + ScanIteration scan = jedis.scanIteration(10, "user:*", "hash"); + + while (!scan.isIterationCompleted()) { + List result = scan.nextBatch().getResult(); + for (String key : result) { + process(key); + } + } +} +``` + +**Incorrect:** Using KEYS in production. + +**Python** (redis-py): +```python +# Bad: Scans all keys, slow on large datasets +keys = redis.keys("user:*") +``` + +**Java** (Jedis): +```java +// Bad: Scans all keys, blocks the server +Set result = jedis.keys("*"); +``` + +**Note:** Truly blocking commands (like `BLPOP`, `BRPOP`, `BLMOVE`) that wait indefinitely for data are appropriate for some use cases like job queues, but should be used with timeouts. + +```python +# Blocking pop with timeout - appropriate for queue consumers +result = redis.blpop("task_queue", timeout=5) +``` + +Reference: [Redis SCAN](https://redis.io/docs/latest/commands/scan/) diff --git a/api/.agents/skills/redis-development/rules/conn-client-cache.md b/api/.agents/skills/redis-development/rules/conn-client-cache.md new file mode 100644 index 0000000..df2c7a6 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/conn-client-cache.md @@ -0,0 +1,70 @@ +--- +title: Use Client-Side Caching for Frequently Read Data +impact: HIGH +impactDescription: Reduces network round-trips for repeated reads +tags: caching, performance, client-side, tracking +description: Use Client-Side Caching for Frequently Read Data +alwaysApply: true +--- + +## Use Client-Side Caching for Frequently Read Data + +Use a connection with client-side caching enabled for any data that will be read frequently but written only occasionally. Client-side caching avoids contacting the server for repeated access to data that has recently been read, reducing network traffic and improving performance. + +**Correct:** Enable client-side caching with RESP3 protocol for frequently accessed data. + +**Python** (redis-py): +```python +import redis + +# Enable client-side caching with RESP3 +client = redis.Redis( + host='localhost', + port=6379, + protocol=3, # RESP3 required for client-side caching + cache_config=redis.CacheConfig(max_size=1000) +) + +# Cached reads avoid server round-trips +value = client.get("frequently:read:key") +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.DefaultJedisClientConfig; +import redis.clients.jedis.UnifiedJedis; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.CacheConfig; + +HostAndPort endpoint = new HostAndPort("localhost", 6379); + +DefaultJedisClientConfig config = DefaultJedisClientConfig + .builder() + .password("secretPassword") + .protocol(RedisProtocol.RESP3) + .build(); + +CacheConfig cacheConfig = CacheConfig.builder().maxSize(1000).build(); + +UnifiedJedis client = new UnifiedJedis(endpoint, config, cacheConfig); +``` + +**When to use:** +- Configuration data read frequently, updated rarely +- User session data accessed on every request +- Feature flags or settings checked repeatedly +- Any read-heavy workload with low write frequency + +**When NOT needed:** +- Data that changes frequently (cache invalidation overhead outweighs benefits) +- Write-heavy workloads +- Simple applications where network latency is not a bottleneck +- When you need guaranteed real-time consistency + +**Trade-offs:** +- Adds memory overhead on the client +- Requires RESP3 protocol +- Cache invalidation adds complexity for frequently changing data + +Reference: [Client-side caching](https://redis.io/docs/latest/develop/clients/client-side-caching/) + diff --git a/api/.agents/skills/redis-development/rules/conn-pipelining.md b/api/.agents/skills/redis-development/rules/conn-pipelining.md new file mode 100644 index 0000000..ba27f6b --- /dev/null +++ b/api/.agents/skills/redis-development/rules/conn-pipelining.md @@ -0,0 +1,58 @@ +--- +title: Use Pipelining for Bulk Operations +impact: HIGH +impactDescription: Reduces round trips, 5-10x faster for batch operations +tags: pipelining, batch, performance, round-trips +description: Use Pipelining for Bulk Operations +alwaysApply: true +--- + +## Use Pipelining for Bulk Operations + +Batch multiple commands into a single round trip to reduce network latency. + +**Correct:** Use pipeline for multiple commands. + +**Python** (redis-py): +```python +# Good: Single round trip for multiple commands +pipe = redis.pipeline() +for user_id in user_ids: + pipe.get(f"user:{user_id}") +results = pipe.execute() +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.Pipeline; + +// Good: Buffer commands and send as single batch +Pipeline pipe = (Pipeline) jedis.pipelined(); + +pipe.set("person:1:name", "Alex"); +pipe.set("person:1:rank", "Captain"); +pipe.set("person:1:serial", "AB1234"); + +pipe.sync(); +``` + +**Incorrect:** Sequential commands in a loop. + +**Python** (redis-py): +```python +# Bad: N round trips +results = [] +for user_id in user_ids: + results.append(redis.get(f"user:{user_id}")) +``` + +**Java** (Jedis): +```java +// Bad: 3 separate round trips +jedis.set("person:1:name", "Alex"); +jedis.set("person:1:rank", "Captain"); +jedis.set("person:1:serial", "AB1234"); +``` + +Reference: [Redis Pipelining](https://redis.io/docs/latest/develop/use/pipelining/) + diff --git a/api/.agents/skills/redis-development/rules/conn-pooling.md b/api/.agents/skills/redis-development/rules/conn-pooling.md new file mode 100644 index 0000000..c32e82f --- /dev/null +++ b/api/.agents/skills/redis-development/rules/conn-pooling.md @@ -0,0 +1,71 @@ +--- +title: Use Connection Pooling or Multiplexing +impact: HIGH +impactDescription: Reduces connection overhead by 10x or more +tags: connections, pooling, multiplexing, performance +description: Use Connection Pooling or Multiplexing +alwaysApply: true +--- + +## Use Connection Pooling or Multiplexing + +Reuse connections via a pool or multiplexing instead of creating new connections per request. + +**Correct:** Use a connection pool. + +**Python** (redis-py): +```python +import redis + +# Good: Connection pool - reuses existing connections +pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50) +r = redis.Redis(connection_pool=pool) +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.JedisPooled; + +// JedisPooled manages a connection pool internally +try (JedisPooled jedis = new JedisPooled("redis://localhost:6379")) { + jedis.set("testKey", "testValue"); +} +``` + +**Correct:** Use multiplexing (Lettuce, NRedisStack). + +```java +// Lettuce uses multiplexing by default - single connection handles all traffic +RedisClient client = RedisClient.create("redis://localhost:6379"); +StatefulRedisConnection connection = client.connect(); + +// All commands share the single connection efficiently +connection.sync().set("key", "value"); +``` + +**Incorrect:** Creating new connections per request. + +**Python** (redis-py): +```python +# Bad: New connection every time +def get_user(user_id): + r = redis.Redis(host='localhost', port=6379) # Don't do this + return r.get(f"user:{user_id}") +``` + +**Java** (Jedis): +```java +// Bad: Creating new client per request +public String getUser(String userId) { + try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + return jedis.get("user:" + userId); // Don't do this + } +} +``` + +**Pooling vs Multiplexing:** +- **Pooling**: Multiple connections shared across requests (redis-py, Jedis, go-redis) +- **Multiplexing**: Single connection handles all traffic (NRedisStack, Lettuce) +- Multiplexing cannot support blocking commands (BLPOP, etc.) as they would stall all callers + +Reference: [Connection Pools and Multiplexing](https://redis.io/docs/latest/develop/clients/pools-and-muxing/) diff --git a/api/.agents/skills/redis-development/rules/conn-timeouts.md b/api/.agents/skills/redis-development/rules/conn-timeouts.md new file mode 100644 index 0000000..84a6c01 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/conn-timeouts.md @@ -0,0 +1,41 @@ +--- +title: Configure Connection Timeouts +impact: MEDIUM +impactDescription: Improves connection resilience and failure recovery +tags: timeouts, connections, reliability +description: Configure Connection Timeouts +alwaysApply: true +--- + +## Configure Connection Timeouts + +Configure appropriate timeout values to improve your application's connection resilience. While most Redis clients set default timeouts, choosing well-tuned values based on your application's usage patterns leads to better failure recovery. + +**Correct:** Set timeouts based on your application needs. + +```python +r = redis.Redis( + host='localhost', + socket_timeout=5.0, # Read/write timeout - tune based on expected operation time + socket_connect_timeout=2.0, # Connection timeout - shorter for fast failure detection + retry_on_timeout=True # Automatic retry on timeout +) +``` + +**Incorrect:** Relying solely on defaults without considering your use case. + +```python +# Not ideal: Default timeouts may not match your application's needs +r = redis.Redis(host='localhost') + +# For example, if your app needs fast failure detection, +# the default timeouts might be too generous +``` + +**Considerations:** +- Set `socket_connect_timeout` shorter than `socket_timeout` for quick connection failure detection +- For latency-sensitive apps, use tighter timeouts with retry logic +- For batch operations, allow longer timeouts to complete large operations +- Consider using health checks alongside timeouts for robust failure handling + +Reference: [Redis Client Configuration](https://redis.io/docs/latest/develop/clients/) diff --git a/api/.agents/skills/redis-development/rules/data-choose-structure.md b/api/.agents/skills/redis-development/rules/data-choose-structure.md new file mode 100644 index 0000000..748bdb7 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/data-choose-structure.md @@ -0,0 +1,78 @@ +--- +title: Choose the Right Data Structure +impact: HIGH +impactDescription: Optimal memory usage and operation performance +tags: data-structures, strings, hashes, sets, lists, sorted-sets, json, streams, vector-sets +description: Choose the Right Data Structure +alwaysApply: true +--- + +## Choose the Right Data Structure + +Selecting the appropriate Redis data type for your use case is fundamental to performance and memory efficiency. + +| Use Case | Recommended Type | Why | +|----------|------------------|-----| +| Simple values, counters | String | Fast, atomic operations | +| Object with fields | Hash | Memory efficient, partial updates, field-level expiration | +| Queue, recent items | List | O(1) push/pop at ends | +| Unique items, membership | Set | O(1) add/remove/check | +| Rankings, ranges | Sorted Set | Score-based ordering | +| Nested/hierarchical data | JSON | Path queries, nested structures, geospatial indexing with RQE | +| Event logs, messaging | Stream | Persistent, consumer groups | +| Similarity search | Redis Query Engine / RedisVL or Vector Set | RedisVL is best for document retrieval with filters and full-text search; vector sets are simpler native similarity search | + +**Note:** Vector sets are a Redis 8+ capability introduced in Redis 8.0 and documented there as beta. Prefer Redis Query Engine / RedisVL when you need document-oriented retrieval, structured filters, or full-text + vector workflows. + +**Incorrect:** Using strings for everything. + +**Python** (redis-py): +```python +# Storing object as JSON string loses atomic field updates +redis.set("user:1001", json.dumps({"name": "Alice", "email": "alice@example.com"})) + +# To update email, must fetch, parse, modify, and rewrite entire object +user = json.loads(redis.get("user:1001")) +user["email"] = "new@example.com" +redis.set("user:1001", json.dumps(user)) +``` + +**Java** (Jedis): +```java +// Bad: Storing as delimited string requires manual parsing +jedis.set("bicycle", "Deimos;Ergonom;Enduro bikes;4972"); +String bike = jedis.get("bicycle"); +String[] fields = bike.split(";"); +String model = fields[0]; // Fragile and error-prone +``` + +**Correct:** Use Hash for objects with fields. + +**Python** (redis-py): +```python +# Hash allows atomic field updates +redis.hset("user:1001", mapping={"name": "Alice", "email": "alice@example.com"}) + +# Update single field without touching others +redis.hset("user:1001", "email", "new@example.com") +``` + +**Java** (Jedis): +```java +import java.util.Map; +import java.util.HashMap; + +// Good: Hash models properties naturally +Map hashFields = new HashMap<>(); +hashFields.put("model", "Deimos"); +hashFields.put("brand", "Ergonom"); +hashFields.put("type", "Enduro bikes"); +hashFields.put("price", "4972"); + +jedis.hset("bicycle", hashFields); + +// Read individual field +String model = jedis.hget("bicycle", "model"); +``` + +Reference: [Choosing the Right Data Type](https://redis.io/docs/latest/develop/data-types/compare-data-types/) diff --git a/api/.agents/skills/redis-development/rules/data-hash-field-expiry.md b/api/.agents/skills/redis-development/rules/data-hash-field-expiry.md new file mode 100644 index 0000000..9e5eccf --- /dev/null +++ b/api/.agents/skills/redis-development/rules/data-hash-field-expiry.md @@ -0,0 +1,62 @@ +--- +title: Use Hash Field Expiration for Per-Field TTL +impact: MEDIUM +impactDescription: Fine-grained expiration without managing timers +tags: hash, expiration, ttl, hexpire +description: Use Hash Field Expiration for Per-Field TTL +alwaysApply: true +--- + +## Use Hash Field Expiration for Per-Field TTL + +Use hash field expiration (Redis 7.4+) to delete individual fields automatically from a hash after a specific period of time. This is useful for caching scenarios where different fields have different lifetimes, and is easier than managing expiration from your own code. + +**Correct:** Use HEXPIRE to set per-field TTL on hash fields. + +**Python** (redis-py): +```python +import redis + +client = redis.Redis(host='localhost', port=6379) + +# Set hash fields +client.hset("sensor:sensor1", mapping={ + "air_quality": "256", + "battery_level": "89" +}) + +# Set 60-second TTL on specific fields (Redis 7.4+) +client.hexpire("sensor:sensor1", 60, "air_quality", "battery_level") +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.UnifiedJedis; +import java.util.Map; +import java.util.HashMap; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + Map hashFields = new HashMap<>(); + hashFields.put("air_quality", "256"); + hashFields.put("battery_level", "89"); + + jedis.hset("sensor:sensor1", hashFields); + + // Set 60-second TTL on specific fields (Redis 7.4+) + jedis.hexpire("sensor:sensor1", 60, "air_quality", "battery_level"); +} +``` + +**When to use:** +- Sensor data or metrics that become stale after a period +- Session attributes where different fields have different lifetimes +- Cached values within a hash that should auto-expire independently +- Temporary flags or tokens stored alongside persistent data + +**When NOT needed:** +- Persistent user profiles or configuration +- Data where the entire hash should expire together (use `EXPIRE` on the key instead) +- Fields managed by application logic with explicit deletion + +Reference: [HEXPIRE command](https://redis.io/docs/latest/commands/hexpire/) + diff --git a/api/.agents/skills/redis-development/rules/data-incr.md b/api/.agents/skills/redis-development/rules/data-incr.md new file mode 100644 index 0000000..745a71b --- /dev/null +++ b/api/.agents/skills/redis-development/rules/data-incr.md @@ -0,0 +1,76 @@ +--- +title: Use INCR for Atomic Counters +impact: MEDIUM +impactDescription: Atomic increment avoids race conditions +tags: incr, counters, atomic, performance +description: Use INCR for Atomic Counters +alwaysApply: true +--- + +## Use INCR for Atomic Counters + +If a string represents an integer value, use the `INCR` command to increment the number directly. The increment is atomic and always returns the new value. Use `INCRBY` to increment by any integer (positive or negative). This is more efficient and race-condition-free than reading, incrementing in code, and writing back. + +**Correct:** Use INCR/INCRBY for atomic counter updates. + +**Python** (redis-py): +```python +import redis + +client = redis.Redis(host='localhost', port=6379) + +# Initialize counter +client.set("counter", "0") + +# Atomic increment - returns new value +new_value = client.incr("counter") # Returns 1 + +# Increment by specific amount +new_value = client.incrby("counter", 10) # Returns 11 +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.UnifiedJedis; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + jedis.set("counter", "0"); + + // Atomic increment - returns new value + long newValue = jedis.incr("counter"); // Returns 1 + + // Increment by specific amount + newValue = jedis.incrBy("counter", 10); // Returns 11 +} +``` + +**Incorrect:** Read-modify-write pattern creates race conditions. + +**Python** (redis-py): +```python +import redis + +client = redis.Redis(host='localhost', port=6379) + +client.set("counter", "0") + +# BAD: Race condition - another client could modify between GET and SET +curr_value = int(client.get("counter")) +client.set("counter", str(curr_value + 1)) # Not atomic! +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.UnifiedJedis; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + jedis.set("counter", "0"); + + // BAD: Race condition between GET and SET + long currValue = Long.parseLong(jedis.get("counter")); + jedis.set("counter", Long.toString(currValue + 1)); // Not atomic! +} +``` + +Reference: [INCR command](https://redis.io/docs/latest/commands/incr/) + diff --git a/api/.agents/skills/redis-development/rules/data-key-naming.md b/api/.agents/skills/redis-development/rules/data-key-naming.md new file mode 100644 index 0000000..18a1971 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/data-key-naming.md @@ -0,0 +1,62 @@ +--- +title: Use Consistent Key Naming Conventions +impact: MEDIUM +impactDescription: Improved maintainability and debugging +tags: keys, naming, conventions, prefixes +description: Use Consistent Key Naming Conventions +alwaysApply: true +--- + +## Use Consistent Key Naming Conventions + +Well-structured key names improve code maintainability, debugging, and enable efficient key scanning. + +**Correct:** Use colons as separators with a consistent hierarchy. + +``` +# Pattern: service:entity:id:attribute +user:1001:profile +user:1001:settings +order:2024:items +cache:api:users:list +session:abc123 +``` + +**Python** (redis-py): +```python +# Good: Short, meaningful key +redis.set("product:8361", cached_html) +page = redis.get("product:8361") +``` + +**Java** (Jedis): +```java +// Good: Short, meaningful key derived from URL +jedis.set("product:8361", ""); +String page = jedis.get("product:8361"); +``` + +**Incorrect:** Inconsistent naming, spaces, or very long keys. + +``` +# These cause confusion and waste memory +User_1001_Profile +my key with spaces +com.mycompany.myapp.production.users.profile.data.1001 +``` + +**Java** (Jedis): +```java +// Bad: Using full URL as key wastes memory and slows comparisons +jedis.set("http://www.verylongurlkey.com/store/products/product.html?id=8361", + ""); +``` + +**Key naming tips:** +- Keep keys short but readable—they consume memory +- Consider key prefixes for multi-tenant applications +- Extract short identifiers from URLs or long strings rather than using the whole thing +- For large binary values, consider using a hash digest as the key instead of the value itself +- Use consistent separators (colons are conventional) + +Reference: [Redis Keys](https://redis.io/docs/latest/develop/use/keyspace/) diff --git a/api/.agents/skills/redis-development/rules/data-transactions.md b/api/.agents/skills/redis-development/rules/data-transactions.md new file mode 100644 index 0000000..ec76cdf --- /dev/null +++ b/api/.agents/skills/redis-development/rules/data-transactions.md @@ -0,0 +1,74 @@ +--- +title: Use Transactions for Atomic Multi-Command Operations +impact: MEDIUM +impactDescription: Prevents race conditions and data inconsistency +tags: transactions, multi, exec, atomicity +description: Use Transactions for Atomic Multi-Command Operations +alwaysApply: true +--- + +## Use Transactions for Atomic Multi-Command Operations + +Use the `MULTI`/`EXEC` commands to create a transaction when you need to execute multiple commands atomically. No other client requests will be processed while the transaction is executing, preventing other clients from modifying the keys used in the transaction and avoiding inconsistent data. + +**Correct:** Use transactions when multiple related keys must be updated together. + +**Python** (redis-py): +```python +import redis + +client = redis.Redis(host='localhost', port=6379) + +# Transaction ensures all commands execute atomically +pipe = client.pipeline(transaction=True) +pipe.set("person:1:name", "Alex") +pipe.set("person:1:rank", "Captain") +pipe.set("person:1:serial", "AB1234") +pipe.execute() # All commands execute as one atomic unit +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.UnifiedJedis; +import redis.clients.jedis.Transaction; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + Transaction tran = (Transaction) jedis.multi(); + + tran.set("person:1:name", "Alex"); + tran.set("person:1:rank", "Captain"); + tran.set("person:1:serial", "AB1234"); + + tran.exec(); // All commands execute atomically +} +``` + +**Incorrect:** Executing related commands individually when atomicity is required. + +**Python** (redis-py): +```python +import redis + +client = redis.Redis(host='localhost', port=6379) + +# BAD when atomicity matters - another client could read partial state +client.set("person:1:name", "Alex") +# Another client could read here and see incomplete data +client.set("person:1:rank", "Captain") +client.set("person:1:serial", "AB1234") +``` + +**When to use transactions:** +- Multiple keys must be updated as a single atomic unit +- Other clients reading partial state would cause bugs +- Implementing patterns like "transfer balance between accounts" + +**When transactions are NOT needed:** +- Independent operations that don't need to be atomic +- Single-command operations (already atomic) +- When using pipelining purely for performance (use `pipeline(transaction=False)`) + +**Note:** Transactions add overhead. Only use them when atomicity is actually required. + +Reference: [Transactions](https://redis.io/docs/latest/develop/interact/transactions/) + diff --git a/api/.agents/skills/redis-development/rules/json-partial-updates.md b/api/.agents/skills/redis-development/rules/json-partial-updates.md new file mode 100644 index 0000000..4525a16 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/json-partial-updates.md @@ -0,0 +1,49 @@ +--- +title: Use JSON Paths for Partial Updates +impact: MEDIUM +impactDescription: Avoids fetching and rewriting entire documents +tags: json, partial-updates, paths, atomic +description: Use JSON Paths for Partial Updates +alwaysApply: true +--- + +## Use JSON Paths for Partial Updates + +Use JSON path syntax to update specific fields without fetching the entire document. + +**Correct:** Use JSON paths for targeted updates. + +```python +# Store JSON document +redis.json().set("user:1001", "$", { + "name": "Alice", + "email": "alice@example.com", + "preferences": {"theme": "dark", "notifications": True} +}) + +# Update nested field without fetching entire document +redis.json().set("user:1001", "$.preferences.theme", "light") + +# Get specific field +theme = redis.json().get("user:1001", "$.preferences.theme") + +# Increment numeric field atomically +redis.json().numincrby("user:1001", "$.preferences.volume", 5) + +# Append to array +redis.json().arrappend("user:1001", "$.tags", "premium") +``` + +**Incorrect:** Storing JSON as a string and parsing client-side. + +```python +# Bad: Loses queryability and atomic updates +redis.set("user:1001", json.dumps(user_data)) + +# Must fetch, parse, modify, serialize, and rewrite +data = json.loads(redis.get("user:1001")) +data["preferences"]["theme"] = "light" +redis.set("user:1001", json.dumps(data)) +``` + +Reference: [Redis JSON Path](https://redis.io/docs/latest/develop/data-types/json/path/) diff --git a/api/.agents/skills/redis-development/rules/json-vs-hash.md b/api/.agents/skills/redis-development/rules/json-vs-hash.md new file mode 100644 index 0000000..13a1f8d --- /dev/null +++ b/api/.agents/skills/redis-development/rules/json-vs-hash.md @@ -0,0 +1,105 @@ +--- +title: Choose JSON vs Hash vs String Appropriately +impact: MEDIUM +impactDescription: Optimal data model for your use case +tags: json, hash, string, data-structures, documents +description: Choose JSON vs Hash vs String Appropriately +alwaysApply: true +--- + +## Choose JSON vs Hash vs String Appropriately + +Redis offers three ways to store structured data: JSON, Hash, and serialized strings. Each has distinct trade-offs around atomic partial operations and indexability. + +| Feature | JSON | Hash | String (serialized JSON) | +|---------|------|------|--------------------------| +| **Structure** | Nested objects and arrays | Flat key-value pairs | Any structure | +| **Atomic partial reads** | Yes (`$.field`) | Yes (`HGET`) | No (must fetch entire value) | +| **Atomic partial writes** | Yes (`JSON.SET $.field`) | Yes (`HSET`) | No (must rewrite entire value) | +| **RQE indexing** | Yes | Yes | No | +| **Geospatial indexing** | Yes | Yes | No | +| **Memory efficiency** | Higher overhead | More efficient | Most compact | +| **Field-level expiration** | No | Yes (HEXPIRE) | No | + +**When to use each:** +- **JSON**: Nested structures with atomic partial updates and indexing needs +- **Hash**: Flat objects with atomic field access, field-level expiration, or memory efficiency +- **String**: Simple caching where you always read/write the entire object and don't need indexing + +**Correct:** Use JSON for nested structures with atomic partial updates. + +**Python** (redis-py): +```python +# JSON supports nested structures and atomic deep updates +redis.json().set("user:1001", "$", { + "name": "Alice", + "preferences": {"theme": "dark", "notifications": True} +}) + +# Atomic update of nested field - no read-modify-write needed +redis.json().set("user:1001", "$.preferences.theme", "light") +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.UnifiedJedis; +import redis.clients.jedis.json.Path2; +import org.json.JSONObject; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + JSONObject user = new JSONObject(); + user.put("name", "Alice"); + user.put("preferences", new JSONObject().put("theme", "dark")); + + jedis.jsonSet("user:1001", new Path2("$"), user); + + // Atomic update of nested field + jedis.jsonSet("user:1001", new Path2("$.preferences.theme"), "light"); +} +``` + +**Correct:** Use Hash for flat objects with atomic field access. + +**Python** (redis-py): +```python +# Hash is efficient for flat data with atomic field operations +redis.hset("session:abc", mapping={ + "user_id": "1001", + "created_at": "2024-01-01", + "ip": "192.168.1.1" +}) + +# Atomic field read and update +ip = redis.hget("session:abc", "ip") +redis.hset("session:abc", "ip", "10.0.0.1") +``` + +**Correct:** Use String for simple caching without partial updates. + +**Python** (redis-py): +```python +import json + +# String is fine when you always read/write the entire object +# and don't need indexing or partial updates +config = {"feature_flags": {"dark_mode": True}, "version": "1.0"} +redis.set("config:app", json.dumps(config), ex=3600) + +# Must fetch and parse entire object +config = json.loads(redis.get("config:app")) +``` + +**Incorrect:** Using String when you need atomic partial updates. + +**Python** (redis-py): +```python +import json + +# BAD: Must fetch, parse, modify, serialize, and rewrite entire object +data = json.loads(redis.get("user:1001")) +data["preferences"]["theme"] = "light" # Not atomic! +redis.set("user:1001", json.dumps(data)) +# Another client could have modified the object between GET and SET +``` + +Reference: [Data Type Comparison](https://redis.io/docs/latest/develop/data-types/compare-data-types/#documents) diff --git a/api/.agents/skills/redis-development/rules/observe-commands.md b/api/.agents/skills/redis-development/rules/observe-commands.md new file mode 100644 index 0000000..d6393d8 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/observe-commands.md @@ -0,0 +1,53 @@ +--- +title: Use Observability Commands for Debugging +impact: MEDIUM +impactDescription: Enables quick diagnosis of performance issues +tags: observability, slowlog, info, memory, debugging +description: Use Observability Commands for Debugging +alwaysApply: true +--- + +## Use Observability Commands for Debugging + +Redis provides built-in commands for monitoring and debugging. + +**Key commands:** + +``` +# Slow query log - find slow commands +SLOWLOG GET 10 +SLOWLOG LEN +SLOWLOG RESET + +# Server info - comprehensive stats +INFO all +INFO memory +INFO stats +INFO replication +INFO clients + +# Memory analysis +MEMORY DOCTOR +MEMORY STATS +MEMORY USAGE mykey + +# Client connections +CLIENT LIST +CLIENT INFO + +# Index info (RQE) +FT.INFO idx:products +FT.PROFILE idx:products SEARCH QUERY "@name:laptop" +``` + +**Correct:** Check SLOWLOG regularly. + +```python +# Get recent slow queries +slow_queries = redis.slowlog_get(10) +for query in slow_queries: + print(f"Duration: {query['duration']}μs, Command: {query['command']}") +``` + +Reference: [Redis Monitoring](https://redis.io/docs/latest/operate/oss_and_stack/management/optimization/latency/) + diff --git a/api/.agents/skills/redis-development/rules/observe-metrics.md b/api/.agents/skills/redis-development/rules/observe-metrics.md new file mode 100644 index 0000000..a5b70a4 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/observe-metrics.md @@ -0,0 +1,39 @@ +--- +title: Monitor Key Redis Metrics +impact: MEDIUM +impactDescription: Early detection of performance and capacity issues +tags: observability, metrics, monitoring, memory, connections +description: Monitor Key Redis Metrics +alwaysApply: true +--- + +## Monitor Key Redis Metrics + +Track these metrics to catch issues before they impact users. + +| Metric | What It Tells You | Alert When | +|--------|-------------------|------------| +| `used_memory` | Current memory usage | > 80% of maxmemory | +| `connected_clients` | Number of connections | Sudden spikes or drops | +| `blocked_clients` | Clients waiting on blocking ops | > 0 sustained | +| `instantaneous_ops_per_sec` | Current throughput | Significant drops | +| `keyspace_hits/misses` | Cache hit ratio | Hit ratio < 80% | +| `rejected_connections` | Connection limit issues | > 0 | +| `rdb_last_save_time` | Last persistence snapshot | Too old | + +**Correct:** Export metrics to your monitoring system. + +```python +# Get key metrics +info = redis.info() +print(f"Memory: {info['used_memory_human']}") +print(f"Connections: {info['connected_clients']}") +print(f"Ops/sec: {info['instantaneous_ops_per_sec']}") +print(f"Hit ratio: {info['keyspace_hits'] / (info['keyspace_hits'] + info['keyspace_misses']) * 100:.1f}%") +``` + +**Redis Insight:** +Use Redis Insight for visual monitoring, query profiling, and debugging. It includes Redis Copilot for natural language queries. + +Reference: [Redis Insight](https://redis.io/insight/) + diff --git a/api/.agents/skills/redis-development/rules/ram-limits.md b/api/.agents/skills/redis-development/rules/ram-limits.md new file mode 100644 index 0000000..b61f9b4 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/ram-limits.md @@ -0,0 +1,42 @@ +--- +title: Configure Memory Limits and Eviction Policies +impact: HIGH +impactDescription: Prevents out-of-memory crashes and unpredictable behavior +tags: memory, maxmemory, eviction, lru, ttl +description: Configure Memory Limits and Eviction Policies +alwaysApply: true +--- + +## Configure Memory Limits and Eviction Policies + +Always configure `maxmemory` and an eviction policy to prevent Redis from consuming all available memory. + +**Correct:** Set explicit memory limits. + +``` +maxmemory 2gb +maxmemory-policy allkeys-lru +``` + +| Policy | Use Case | +|--------|----------| +| `volatile-lru` | Evict keys with TTL, least recently used first | +| `allkeys-lru` | Evict any key, least recently used first | +| `volatile-ttl` | Evict keys closest to expiration | +| `noeviction` | Return errors when memory is full (use for critical data) | + +**Incorrect:** Running Redis without memory limits. + +``` +# No maxmemory set - Redis will use all available RAM +# Can cause OOM killer to terminate Redis or other processes +``` + +**Memory optimization tips:** +- Use Hashes for small objects (more memory-efficient than separate keys) +- Use `OBJECT ENCODING key` to check how Redis stores your data +- Use `MEMORY USAGE key` to check individual key memory consumption +- Enable compression in your client for large values + +Reference: [Redis Memory Optimization](https://redis.io/docs/latest/operate/oss_and_stack/management/optimization/memory-optimization/) + diff --git a/api/.agents/skills/redis-development/rules/ram-ttl.md b/api/.agents/skills/redis-development/rules/ram-ttl.md new file mode 100644 index 0000000..66b0573 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/ram-ttl.md @@ -0,0 +1,55 @@ +--- +title: Set TTL on Cache Keys +impact: HIGH +impactDescription: Prevents unbounded memory growth +tags: ttl, expiration, cache, memory +description: Set TTL on Cache Keys +alwaysApply: true +--- + +## Set TTL on Cache Keys + +Always set expiration times on cache keys to prevent unbounded memory growth. + +**Correct:** Set TTL at write time. + +**Python** (redis-py): +```python +# Good: TTL set atomically with the value +redis.setex("cache:user:1001", 3600, user_json) + +# Good: For hashes, set TTL after +redis.hset("session:abc", mapping=session_data) +redis.expire("session:abc", 1800) +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.params.SetParams; + +// Good: TTL set atomically with SetParams +jedis.set("cachedItem:1", "fe8c357903ac9", new SetParams().ex(120)); +``` + +**Incorrect:** Forgetting TTL on cache keys. + +**Python** (redis-py): +```python +# Risk: This key may live forever +redis.set("cache:user:1001", user_json) +``` + +**Java** (Jedis): +```java +// Risk: This key may live forever +jedis.set("cachedItem:1", "fe8c357903ac9"); +``` + +**TTL strategies:** +- Cache data: 1-24 hours depending on freshness requirements +- Sessions: 30 minutes to 24 hours +- Rate limiting: Seconds to minutes +- Temporary locks: Seconds with automatic release + +Reference: [Redis EXPIRE](https://redis.io/commands/expire/) + diff --git a/api/.agents/skills/redis-development/rules/rqe-dialect.md b/api/.agents/skills/redis-development/rules/rqe-dialect.md new file mode 100644 index 0000000..56fe365 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/rqe-dialect.md @@ -0,0 +1,47 @@ +--- +title: Use DIALECT 2 for Query Syntax +impact: MEDIUM +impactDescription: Ensures consistent query behavior and access to modern features +tags: rqe, dialect, query, syntax +description: Use DIALECT 2 for Query Syntax +alwaysApply: true +--- + +## Use DIALECT 2 for Query Syntax + +Use DIALECT 2 for consistent query behavior. Many Redis client libraries now default to DIALECT 2, and other dialects (1, 3, 4) are deprecated as of Redis 8. + +**Correct:** Use DIALECT 2 explicitly or rely on modern client defaults. + +```python +from redis import Redis + +r = Redis() + +# Modern redis-py (6.0+) defaults to DIALECT 2 +# You can also set it explicitly +results = r.ft("idx:products").search( + "@name:laptop", + dialect=2 +) +``` + +``` +# In raw commands, specify DIALECT 2 +FT.SEARCH idx:products "@name:laptop" DIALECT 2 + +FT.AGGREGATE idx:products "@category:{electronics}" + GROUPBY 1 @category + REDUCE COUNT 0 AS count + DIALECT 2 +``` + +**Note:** DIALECT 2 is required for vector search queries. Most modern client libraries (redis-py 6.0+, go-redis, Lettuce) now use DIALECT 2 by default. + +**Why DIALECT 2:** +- Consistent handling of special characters +- Better NULL value handling +- More predictable query parsing +- Required for vector search + +Reference: [Query Dialects](https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/dialects/) diff --git a/api/.agents/skills/redis-development/rules/rqe-field-types.md b/api/.agents/skills/redis-development/rules/rqe-field-types.md new file mode 100644 index 0000000..4fdbf13 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/rqe-field-types.md @@ -0,0 +1,81 @@ +--- +title: Choose the Correct Field Type +impact: HIGH +impactDescription: Use TAG instead of TEXT for filtering to improve query speed 10x +tags: rqe, field-types, text, tag, numeric, geo, geoshape, vector +description: Choose the Correct Field Type +alwaysApply: true +--- + +## Choose the Correct Field Type + +Each field type has different capabilities and performance characteristics. + +| Field Type | Use When | Notes | +|------------|----------|-------| +| TEXT | Full-text search needed | Tokenized, stemmed | +| TAG | Exact match, filtering | Faster than TEXT for filtering | +| NUMERIC | Range queries, sorting | Use for prices, counts, timestamps | +| GEO | Point location queries | Lat/long coordinates (single points) | +| GEOSHAPE | Area/region queries | Polygons, circles, rectangles | +| VECTOR | Similarity search | HNSW or FLAT algorithm | + +**Correct:** Use TAG for exact matching. + +``` +# Good: TAG for exact category matching +FT.CREATE idx:products ON HASH PREFIX 1 product: + SCHEMA + category TAG SORTABLE + status TAG +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.search.*; + +Schema schema = new Schema() + .addTextField("name", 1) + .addTagField("categories"); // TAG for exact matching + +IndexDefinition def = new IndexDefinition(IndexDefinition.Type.HASH); + +jedis.ftCreate("idx", IndexOptions.defaultOptions().setDefinition(def), schema); + +// Query with TAG syntax +SearchResult result = jedis.ftSearch("idx", "@categories:{chef|runner}"); +``` + +**Incorrect:** Using TEXT when you don't need full-text features. + +``` +# Overkill: TEXT for category adds unnecessary tokenization +FT.CREATE idx:products ON HASH PREFIX 1 product: + SCHEMA + category TEXT + status TEXT +``` + +**Java** (Jedis): +```java +// Bad: TEXT for categories adds unnecessary overhead +Schema schema = new Schema() + .addTextField("name", 1) + .addTextField("categories", 1); // Overkill for exact matching +``` + +**Correct:** Use GEO for points, GEOSHAPE for areas. + +``` +# GEO for point locations (stores, users) +FT.CREATE idx:stores ON HASH PREFIX 1 store: + SCHEMA + location GEO + +# GEOSHAPE for areas (delivery zones, boundaries) +FT.CREATE idx:zones ON JSON PREFIX 1 zone: + SCHEMA + $.boundary AS boundary GEOSHAPE +``` + +Reference: [Redis Search Field Types](https://redis.io/docs/latest/develop/interact/search-and-query/indexing/geoindex/) diff --git a/api/.agents/skills/redis-development/rules/rqe-index-creation.md b/api/.agents/skills/redis-development/rules/rqe-index-creation.md new file mode 100644 index 0000000..db64cdc --- /dev/null +++ b/api/.agents/skills/redis-development/rules/rqe-index-creation.md @@ -0,0 +1,73 @@ +--- +title: Index Only Fields You Query +impact: HIGH +impactDescription: Reduces index size and improves write performance +tags: rqe, ft.create, index, schema +description: Index Only Fields You Query +alwaysApply: true +--- + +## Index Only Fields You Query + +Create indexes with only the fields you need to search, filter, or sort on. + +**Correct:** Index specific fields and use prefixes. + +``` +FT.CREATE idx:products ON HASH PREFIX 1 product: + SCHEMA + name TEXT WEIGHT 2.0 + description TEXT + category TAG SORTABLE + price NUMERIC SORTABLE + location GEO +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.search.*; + +Schema schema = new Schema() + .addTextField("name", 1) + .addTagField("categories"); + +// Good: Specify prefix to index only matching keys +IndexDefinition def = new IndexDefinition(IndexDefinition.Type.HASH) + .setPrefixes("person:"); + +jedis.ftCreate("idx", IndexOptions.defaultOptions().setDefinition(def), schema); +``` + +**Incorrect:** Over-indexing or indexing unused fields. + +``` +# Bad: Indexing every field "just in case" +FT.CREATE idx:products ON HASH PREFIX 1 product: + SCHEMA + name TEXT + description TEXT + category TEXT + subcategory TEXT + brand TEXT + sku TEXT + price NUMERIC + cost NUMERIC + margin NUMERIC + ... +``` + +**Java** (Jedis): +```java +// Bad: No prefix means all hashes get indexed +IndexDefinition def = new IndexDefinition(IndexDefinition.Type.HASH); +// This will index every hash in the database! +``` + +**Tips:** +- Start with the minimum required fields +- Add fields as query patterns emerge +- Use `FT.INFO` to monitor index size +- Always specify a prefix to avoid indexing unrelated keys + +Reference: [Redis Search Indexing](https://redis.io/docs/latest/develop/interact/search-and-query/indexing/) + diff --git a/api/.agents/skills/redis-development/rules/rqe-index-management.md b/api/.agents/skills/redis-development/rules/rqe-index-management.md new file mode 100644 index 0000000..9eb817b --- /dev/null +++ b/api/.agents/skills/redis-development/rules/rqe-index-management.md @@ -0,0 +1,49 @@ +--- +title: Manage Indexes for Zero-Downtime Updates +impact: MEDIUM +impactDescription: Use aliases for seamless index updates +tags: rqe, index, alias, management, reindex +description: Manage Indexes for Zero-Downtime Updates +alwaysApply: true +--- + +## Manage Indexes for Zero-Downtime Updates + +Use aliases to swap indexes without application changes. + +**Correct:** Use aliases for production indexes. + +``` +# Create versioned index +FT.CREATE idx:products_v2 ON HASH PREFIX 1 product: + SCHEMA + name TEXT + category TAG SORTABLE + price NUMERIC SORTABLE + +# Point alias to new index +FT.ALIASADD products idx:products_v2 + +# Application queries use alias +FT.SEARCH products "@category:{electronics}" + +# Later, swap to new version +FT.ALIASUPDATE products idx:products_v3 +``` + +**Useful management commands:** + +``` +# Check index info +FT.INFO idx:products + +# Drop and recreate (non-blocking) +FT.DROPINDEX idx:products +FT.CREATE idx:products ... + +# List all indexes +FT._LIST +``` + +Reference: [Redis Search Index Management](https://redis.io/docs/latest/develop/interact/search-and-query/administration/) + diff --git a/api/.agents/skills/redis-development/rules/rqe-query-optimization.md b/api/.agents/skills/redis-development/rules/rqe-query-optimization.md new file mode 100644 index 0000000..162ea2b --- /dev/null +++ b/api/.agents/skills/redis-development/rules/rqe-query-optimization.md @@ -0,0 +1,49 @@ +--- +title: Write Efficient Queries +impact: HIGH +impactDescription: Proper filtering reduces query time by orders of magnitude +tags: rqe, ft.search, query, performance, filters +description: Write Efficient Queries +alwaysApply: true +--- + +## Write Efficient Queries + +Be specific and use filters to reduce the result set early. + +**Correct:** Use specific filters and limit results. + +``` +# Good: Specific query with filters +FT.SEARCH idx:products "@category:{electronics} @price:[100 500]" + LIMIT 0 20 + RETURN 3 name price category + +# Good: Use SORTBY and LIMIT +FT.SEARCH idx:products "@name:laptop" + SORTBY price ASC + LIMIT 0 10 +``` + +**Incorrect:** Broad queries returning large result sets. + +``` +# Bad: Wildcard prefix scans entire index +FT.SEARCH idx:products "*" LIMIT 0 10000 + +# Bad: Loading all fields from source document +FT.AGGREGATE idx:products "*" LOAD * +``` + +**Performance tips:** +- Add `SORTABLE` to fields used in `SORTBY` +- Use `TAG SORTABLE UNF` for best performance on tag fields +- Use `NOSTEM` if you don't need stemming +- Profile queries with `FT.PROFILE` + +``` +FT.PROFILE idx:products SEARCH QUERY "@category:{electronics}" +``` + +Reference: [Redis Search Query Syntax](https://redis.io/docs/latest/develop/interact/search-and-query/query/) + diff --git a/api/.agents/skills/redis-development/rules/rqe-skip-initial-scan.md b/api/.agents/skills/redis-development/rules/rqe-skip-initial-scan.md new file mode 100644 index 0000000..75b1f2c --- /dev/null +++ b/api/.agents/skills/redis-development/rules/rqe-skip-initial-scan.md @@ -0,0 +1,82 @@ +--- +title: Use SKIPINITIALSCAN for New Data Only Indexes +impact: MEDIUM +impactDescription: Faster index creation, avoids indexing existing data +tags: index, skipinitialscan, rqe, search +description: Use SKIPINITIALSCAN for New Data Only Indexes +alwaysApply: true +--- + +## Use SKIPINITIALSCAN for New Data Only Indexes + +Enable the `SKIPINITIALSCAN` option when creating an index if you only want to include items that are added after the index is created. This makes index creation faster and avoids indexing existing data that you don't need to search. + +**Correct:** Use SKIPINITIALSCAN when you only need to index new data. + +**Python** (redis-py): +```python +import redis +from redis.commands.search.field import TextField, TagField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType + +client = redis.Redis(host='localhost', port=6379) + +# Create index that only indexes new documents +schema = ( + TextField("name"), + TagField("categories") +) + +definition = IndexDefinition( + prefix=["person:"], + index_type=IndexType.HASH +) + +# SKIPINITIALSCAN - only index documents added after creation +client.ft("idx").create_index( + schema, + definition=definition, + skip_initial_scan=True +) +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.UnifiedJedis; +import redis.clients.jedis.search.FTCreateParams; +import redis.clients.jedis.search.IndexDataType; +import redis.clients.jedis.search.schemafields.SchemaField; +import redis.clients.jedis.search.schemafields.TagField; +import redis.clients.jedis.search.schemafields.TextField; + +try (UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379")) { + FTCreateParams params = new FTCreateParams() + .on(IndexDataType.HASH) + .skipInitialScan(); // Only index new documents + + jedis.ftCreate( + "idx", + params, + new SchemaField[]{ + new TextField("name"), + new TagField("categories") + } + ); +} +``` + +**When to use SKIPINITIALSCAN:** +- Creating an index for a new feature where existing data is irrelevant +- Setting up indexes in advance before data arrives +- When existing data would be too large to scan during index creation +- Event-driven architectures where you only care about new events + +**When NOT to use (default behavior is correct):** +- You need to search existing data immediately after index creation +- Migrating to a new index schema and need all data indexed +- Most typical use cases where historical data matters + +**Note:** The default behavior (without SKIPINITIALSCAN) indexes all existing matching keys, which is usually what you want. + +Reference: [FT.CREATE SKIPINITIALSCAN](https://redis.io/docs/latest/commands/ft.create/) + diff --git a/api/.agents/skills/redis-development/rules/security-acls.md b/api/.agents/skills/redis-development/rules/security-acls.md new file mode 100644 index 0000000..917fc80 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/security-acls.md @@ -0,0 +1,41 @@ +--- +title: Use ACLs for Fine-Grained Access Control +impact: HIGH +impactDescription: Limits blast radius if credentials are compromised +tags: security, acl, users, permissions, least-privilege +description: Use ACLs for Fine-Grained Access Control +alwaysApply: true +--- + +## Use ACLs for Fine-Grained Access Control + +Create users with only the permissions they need (principle of least privilege). + +**Correct:** Create specific users with limited permissions. + +``` +# Read-only user for cache access +ACL SETUSER app_readonly on >password ~cache:* +get +mget +scan + +# Writer that can't run dangerous commands +ACL SETUSER app_writer on >password ~* +@all -@dangerous + +# Admin user (use sparingly) +ACL SETUSER admin on >strong-password ~* +@all +``` + +**Incorrect:** Using the default user for everything. + +``` +# Bad: Single password for all access +requirepass shared-password +``` + +**ACL categories:** +- `@read` - Read commands +- `@write` - Write commands +- `@dangerous` - Commands like FLUSHALL, DEBUG +- `@admin` - Administrative commands + +Reference: [Redis ACL](https://redis.io/docs/latest/operate/oss_and_stack/management/security/acl/) + diff --git a/api/.agents/skills/redis-development/rules/security-auth.md b/api/.agents/skills/redis-development/rules/security-auth.md new file mode 100644 index 0000000..4e843a9 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/security-auth.md @@ -0,0 +1,78 @@ +--- +title: Always Use Authentication in Production +impact: HIGH +impactDescription: Prevents unauthorized access to your data +tags: security, authentication, password, tls, ssl +description: Always Use Authentication in Production +alwaysApply: true +--- + +## Always Use Authentication in Production + +Never run Redis without authentication in production environments. + +**Correct:** Use password and TLS. + +**Python** (redis-py): +```python +r = redis.Redis( + host='localhost', + port=6379, + password='your-strong-password', + ssl=True, + ssl_cert_reqs='required' +) +``` + +**Java** (Jedis): +```java +import redis.clients.jedis.*; +import javax.net.ssl.*; +import java.security.KeyStore; + +// Create SSL context with trust store and key store +KeyStore trustStore = KeyStore.getInstance("jks"); +trustStore.load(new FileInputStream("./truststore.jks"), "password".toCharArray()); + +TrustManagerFactory tmf = TrustManagerFactory.getInstance("X509"); +tmf.init(trustStore); + +SSLContext sslContext = SSLContext.getInstance("TLS"); +sslContext.init(null, tmf.getTrustManagers(), null); + +JedisClientConfig config = DefaultJedisClientConfig.builder() + .ssl(true) + .sslSocketFactory(sslContext.getSocketFactory()) + .user("redisUser") + .password("redisPassword") + .build(); + +JedisPooled jedis = new JedisPooled(new HostAndPort("redis-host", 6379), config); +``` + +**Incorrect:** Connecting without authentication. + +**Python** (redis-py): +```python +# Bad: No authentication +r = redis.Redis(host='localhost', port=6379) +``` + +**Java** (Jedis): +```java +// Bad: No authentication or TLS +UnifiedJedis jedis = new UnifiedJedis("redis://localhost:6379"); +``` + +**Configuration:** + +``` +# redis.conf +requirepass your-strong-password +tls-port 6380 +tls-cert-file /path/to/redis.crt +tls-key-file /path/to/redis.key +``` + +Reference: [Redis Security](https://redis.io/docs/latest/operate/oss_and_stack/management/security/) + diff --git a/api/.agents/skills/redis-development/rules/security-network.md b/api/.agents/skills/redis-development/rules/security-network.md new file mode 100644 index 0000000..ebdde7e --- /dev/null +++ b/api/.agents/skills/redis-development/rules/security-network.md @@ -0,0 +1,52 @@ +--- +title: Secure Network Access +impact: HIGH +impactDescription: Reduces attack surface and prevents unauthorized access +tags: security, network, firewall, bind, tls +description: Secure Network Access +alwaysApply: true +--- + +## Secure Network Access + +Restrict network access to Redis to only trusted sources. + +**Correct:** Bind to specific interfaces. + +``` +# redis.conf +bind 127.0.0.1 192.168.1.100 +protected-mode yes +``` + +**Correct:** Use firewall rules. + +```bash +# Allow only application servers +iptables -A INPUT -p tcp --dport 6379 -s 192.168.1.0/24 -j ACCEPT +iptables -A INPUT -p tcp --dport 6379 -j DROP +``` + +**Incorrect:** Exposing Redis to the internet. + +``` +# Bad: Binds to all interfaces +bind 0.0.0.0 +protected-mode no +``` + +**Security checklist:** +- Use TLS for connections +- Bind to specific interfaces, not `0.0.0.0` +- Use firewall rules to restrict access +- Disable dangerous commands in production + +``` +# Disable dangerous commands +rename-command FLUSHALL "" +rename-command DEBUG "" +rename-command CONFIG "" +``` + +Reference: [Redis Security](https://redis.io/docs/latest/operate/oss_and_stack/management/security/) + diff --git a/api/.agents/skills/redis-development/rules/semantic-cache-best-practices.md b/api/.agents/skills/redis-development/rules/semantic-cache-best-practices.md new file mode 100644 index 0000000..044fbbf --- /dev/null +++ b/api/.agents/skills/redis-development/rules/semantic-cache-best-practices.md @@ -0,0 +1,72 @@ +--- +title: Configure Semantic Cache Properly +impact: MEDIUM +impactDescription: Correct threshold tuning balances hit rate vs accuracy +tags: langcache, cache, threshold, ttl, semantic +description: Configure Semantic Cache Properly +alwaysApply: true +--- + +## Configure Semantic Cache Properly + +> **Note:** LangCache is currently in preview on Redis Cloud. Features and behavior may change. + +Tune similarity threshold and cache separation for optimal LangCache results. + +**Correct:** Tune similarity threshold for your use case. + +```python +from langcache import LangCache + +lang_cache = LangCache( + server_url=f"https://{os.getenv('HOST')}", + cache_id=os.getenv("CACHE_ID"), + api_key=os.getenv("API_KEY") +) + +# Stricter matching - fewer false positives (0.95 = very similar) +result = lang_cache.search( + prompt="What is Redis?", + similarity_threshold=0.95 +) + +# Looser matching - higher hit rate (0.8 = somewhat similar) +result = lang_cache.search( + prompt="What is Redis?", + similarity_threshold=0.8 +) +``` + +**Correct:** Use separate caches for different use cases. + +```python +# Create different cache IDs in Redis Cloud for different LLM tasks +support_cache = LangCache( + server_url=server_url, + cache_id="support-cache-id", + api_key=api_key +) + +code_cache = LangCache( + server_url=server_url, + cache_id="code-cache-id", + api_key=api_key +) +``` + +**Incorrect:** Using a single cache for all LLM tasks. + +```python +# All tasks share one cache - responses may not be relevant +result = lang_cache.search(prompt="How do I reset my password?") +# Could return a code snippet if someone asked a similar coding question +``` + +**Best practices:** +- Start with threshold 0.9, adjust based on your use case +- Use custom attributes to filter results within a single cache +- Monitor cache hit rates to evaluate effectiveness +- Use separate cache IDs for fundamentally different LLM tasks + +Reference: [LangCache Best Practices](https://redis.io/docs/latest/develop/ai/langcache/) + diff --git a/api/.agents/skills/redis-development/rules/semantic-cache-langcache-usage.md b/api/.agents/skills/redis-development/rules/semantic-cache-langcache-usage.md new file mode 100644 index 0000000..4ef29c6 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/semantic-cache-langcache-usage.md @@ -0,0 +1,86 @@ +--- +title: Use LangCache for LLM Response Caching +impact: HIGH +impactDescription: Reduces LLM API costs by 50-90% for similar queries +tags: langcache, llm, semantic-cache, embeddings, ai +description: Use LangCache for LLM Response Caching +alwaysApply: true +--- + +## Use LangCache for LLM Response Caching + +> **Note:** LangCache is currently in preview on Redis Cloud. Features and behavior may change. + +LangCache is a fully-managed semantic caching service on Redis Cloud that reduces LLM costs and latency. + +**How it works:** +1. Your app sends a prompt to LangCache via `POST /v1/caches/{cacheId}/entries/search` +2. LangCache generates an embedding and searches for similar cached responses +3. If found (cache hit), returns the cached response instantly +4. If not found (cache miss), your app calls the LLM and stores the response + +**Correct:** Use the LangCache Python SDK. + +```python +from langcache import LangCache +import os + +lang_cache = LangCache( + server_url=f"https://{os.getenv('HOST')}", + cache_id=os.getenv("CACHE_ID"), + api_key=os.getenv("API_KEY") +) + +# Search for cached response +result = lang_cache.search( + prompt="What is Redis?", + similarity_threshold=0.9 +) + +if result: + response = result[0]["response"] +else: + response = llm.generate("What is Redis?") + # Store for future queries + lang_cache.set( + prompt="What is Redis?", + response=response + ) +``` + +**LangCache REST API:** + +```bash +# Search cache +curl -X POST "https://$HOST/v1/caches/$CACHE_ID/entries/search" \ + -H "Authorization: Bearer $API_KEY" \ + -H "Content-Type: application/json" \ + -d '{"prompt": "What is Redis?"}' + +# Store a response +curl -X POST "https://$HOST/v1/caches/$CACHE_ID/entries" \ + -H "Authorization: Bearer $API_KEY" \ + -H "Content-Type: application/json" \ + -d '{"prompt": "What is Redis?", "response": "Redis is an in-memory database..."}' +``` + +**With custom attributes for filtering:** + +```python +# Store with attributes +lang_cache.set( + prompt="What is Redis?", + response="Redis is an in-memory database...", + attributes={"category": "database", "version": "v1"} +) + +# Search with attribute filter +result = lang_cache.search( + prompt="Tell me about Redis", + attributes={"category": "database"}, + similarity_threshold=0.9 +) +``` + +Reference: [LangCache Documentation](https://redis.io/docs/latest/develop/ai/langcache/) + diff --git a/api/.agents/skills/redis-development/rules/stream-choosing-pattern.md b/api/.agents/skills/redis-development/rules/stream-choosing-pattern.md new file mode 100644 index 0000000..6ab718e --- /dev/null +++ b/api/.agents/skills/redis-development/rules/stream-choosing-pattern.md @@ -0,0 +1,44 @@ +--- +title: Choose Streams vs Pub/Sub Appropriately +impact: MEDIUM +impactDescription: Wrong choice leads to lost messages or unnecessary complexity +tags: streams, pubsub, messaging, events, queues +description: Choose Streams vs Pub/Sub Appropriately +alwaysApply: true +--- + +## Choose Streams vs Pub/Sub Appropriately + +Redis supports two messaging approaches for different use cases. + +**Incorrect:** Using Pub/Sub when messages must not be lost. + +```python +# Pub/Sub - messages lost if no subscribers connected +r.publish("orders", json.dumps(order)) # Fire and forget! +``` + +**Correct:** Use Streams when message durability matters. + +```python +# Streams - messages persist and can be replayed +r.xadd("orders:stream", {"order": json.dumps(order)}) + +# Consumer group for reliable processing +r.xreadgroup("workers", "worker-1", {"orders:stream": ">"}, count=10) +r.xack("orders:stream", "workers", message_id) +``` + +### When to Use Each + +| Requirement | Use | +|-------------|-----| +| Real-time notifications, OK to miss messages | Pub/Sub | +| Messages must not be lost | Streams | +| Need to replay/reprocess messages | Streams | +| Multiple workers processing same queue | Streams (consumer groups) | +| Simple broadcast to connected clients | Pub/Sub | +| Event sourcing or audit trail | Streams | + +Reference: [Redis Streams](https://redis.io/docs/latest/develop/data-types/streams/) + diff --git a/api/.agents/skills/redis-development/rules/vector-algorithm-choice.md b/api/.agents/skills/redis-development/rules/vector-algorithm-choice.md new file mode 100644 index 0000000..5b7563d --- /dev/null +++ b/api/.agents/skills/redis-development/rules/vector-algorithm-choice.md @@ -0,0 +1,61 @@ +--- +title: Choose HNSW vs FLAT Based on Requirements +impact: HIGH +impactDescription: HNSW trades accuracy for speed, FLAT provides exact results +tags: vector, hnsw, flat, algorithm, performance +description: Choose HNSW vs FLAT Based on Requirements +alwaysApply: true +--- + +## Choose HNSW vs FLAT Based on Requirements + +Select the right algorithm based on your accuracy requirements and dataset size. + +| Algorithm | Speed | Accuracy | Memory | Best For | +|-----------|-------|----------|--------|----------| +| HNSW | Fast (approximate) | ~95%+ recall tunable | Higher | Large datasets (>10k vectors) | +| FLAT | Slower (exact) | 100% (exact) | Lower | Small datasets, accuracy-critical | + +**Correct:** Use HNSW for large-scale production workloads. + +```python +from redisvl.schema import IndexSchema + +# HNSW - fast approximate search, tunable accuracy +schema = IndexSchema.from_dict({ + "index": {"name": "idx:docs", "prefix": "doc:"}, + "fields": [ + {"name": "embedding", "type": "vector", "attrs": { + "dims": 1536, + "algorithm": "HNSW", + "distance_metric": "COSINE", + "datatype": "FLOAT32", + "m": 16, # Higher = more accurate, more memory + "ef_construction": 200 # Higher = better index quality, slower build + }} + ] +}) +``` + +**Correct:** Use FLAT when exact results are required. + +```python +# FLAT - exact brute-force search, guaranteed accuracy +schema = IndexSchema.from_dict({ + "index": {"name": "idx:small", "prefix": "small:"}, + "fields": [ + {"name": "embedding", "type": "vector", "attrs": { + "dims": 1536, + "algorithm": "FLAT", + "distance_metric": "COSINE" + }} + ] +}) +``` + +**Tuning HNSW accuracy vs speed:** +- `M`: Connections per node (16-64). Higher = better recall, more memory +- `EF_CONSTRUCTION`: Build-time parameter (100-500). Higher = better graph quality +- `EF_RUNTIME`: Query-time parameter. Higher = better recall, slower queries + +Reference: [Redis Vector Search](https://redis.io/docs/latest/develop/ai/search-and-query/vectors/) diff --git a/api/.agents/skills/redis-development/rules/vector-hybrid-search.md b/api/.agents/skills/redis-development/rules/vector-hybrid-search.md new file mode 100644 index 0000000..cbfaddb --- /dev/null +++ b/api/.agents/skills/redis-development/rules/vector-hybrid-search.md @@ -0,0 +1,52 @@ +--- +title: Use Hybrid Search for Better Results +impact: MEDIUM +impactDescription: Combining vector + filters improves relevance and reduces search space +tags: vector, hybrid, filters, redisvl, search +description: Use Hybrid Search for Better Results +alwaysApply: true +--- + +## Use Hybrid Search for Better Results + +Combine vector similarity with attribute filtering for more relevant results. In this rule, "hybrid" means filtered vector search. Redis and RedisVL also use "hybrid search" for text + vector fusion via `FT.HYBRID` / `HybridQuery`. + +**Correct:** Apply filters to reduce search space. + +```python +from redisvl.query import VectorQuery +from redisvl.query.filter import Num, Tag + +filters = (Tag("category") == "technology") & (Num("date") >= 2024) & (Num("date") <= 2025) + +query = VectorQuery( + vector=query_embedding, + vector_field_name="embedding", + return_fields=["content", "category", "date"], + num_results=10, + filter_expression=filters +) + +results = index.query(query) +``` + +**Incorrect:** Searching entire vector space when filters apply. + +```python +# Bad: No filter - searches all vectors then filters client-side +results = index.query(VectorQuery( + vector=query_embedding, + vector_field_name="embedding", + num_results=1000 +)) +# Client-side filtering - wasteful +filtered = [r for r in results if r["category"] == "technology"] +``` + +**Tips:** +- Use TAG fields for category filters +- Use NUMERIC fields for date/price ranges +- Redis auto-selects the filtered vector execution strategy; tune `hybrid_policy` only when needed +- For true text + vector fusion, use `HybridQuery` on Redis >= 8.4.0 with redis-py >= 7.1.0; use `AggregateHybridQuery` on earlier Redis versions + +Reference: [Redis Vector Search](https://redis.io/docs/latest/develop/ai/search-and-query/vectors/) diff --git a/api/.agents/skills/redis-development/rules/vector-index-creation.md b/api/.agents/skills/redis-development/rules/vector-index-creation.md new file mode 100644 index 0000000..55b14cd --- /dev/null +++ b/api/.agents/skills/redis-development/rules/vector-index-creation.md @@ -0,0 +1,85 @@ +--- +title: Configure Vector Indexes Properly +impact: HIGH +impactDescription: Correct configuration is essential for vector search accuracy +tags: vector, index, hnsw, flat, embeddings, rqe +description: Configure Vector Indexes Properly +alwaysApply: true +--- + +## Configure Vector Indexes Properly + +Set the correct dimensions, algorithm, and distance metric for your embeddings. Vector indexes can be created via CLI, Redis Insight, or any client library. + +**Correct:** Create index via Redis CLI or Insight. + +``` +FT.CREATE idx:docs ON HASH PREFIX 1 doc: + SCHEMA + content TEXT + embedding VECTOR HNSW 6 + TYPE FLOAT32 + DIM 1536 + DISTANCE_METRIC COSINE +``` + +**Correct:** Create index via Python (redis-py). + +```python +from redis import Redis +from redis.commands.search.field import TextField, VectorField +from redis.commands.search.index_definition import IndexDefinition + +r = Redis() + +# Define schema with vector field +schema = [ + TextField("content"), + VectorField( + "embedding", + algorithm="HNSW", + attributes={ + "TYPE": "FLOAT32", + "DIM": 1536, # Must match your embedding model + "DISTANCE_METRIC": "COSINE" + } + ) +] + +r.ft("idx:docs").create_index(schema, definition=IndexDefinition(prefix=["doc:"])) +``` + +**Correct:** Create index via RedisVL. + +```python +from redisvl.index import SearchIndex +from redisvl.schema import IndexSchema + +schema = IndexSchema.from_dict({ + "index": {"name": "idx:docs", "prefix": "doc:"}, + "fields": [ + {"name": "content", "type": "text"}, + {"name": "embedding", "type": "vector", "attrs": { + "dims": 1536, + "algorithm": "HNSW", + "datatype": "FLOAT32", + "distance_metric": "COSINE" + }} + ] +}) + +index = SearchIndex(schema) +index.create(overwrite=True) +``` + +**Incorrect:** Mismatched dimensions or wrong distance metric. + +```python +# Bad: Wrong dimensions for your model +{"dims": 768} # But your selected embedding model outputs a different size + +# Bad: Wrong metric for normalized embeddings +{"distance_metric": "L2"} # When embeddings are normalized for COSINE +``` + +Reference: [Redis Vector Search](https://redis.io/docs/latest/develop/ai/search-and-query/vectors/) diff --git a/api/.agents/skills/redis-development/rules/vector-rag-pattern.md b/api/.agents/skills/redis-development/rules/vector-rag-pattern.md new file mode 100644 index 0000000..77d4810 --- /dev/null +++ b/api/.agents/skills/redis-development/rules/vector-rag-pattern.md @@ -0,0 +1,52 @@ +--- +title: Implement RAG Pattern Correctly +impact: HIGH +impactDescription: Proper RAG implementation improves LLM response quality +tags: vector, rag, llm, embeddings, retrieval +description: Implement RAG Pattern Correctly +alwaysApply: true +--- + +## Implement RAG Pattern Correctly + +Store documents with embeddings, retrieve relevant context, and pass to LLM. + +**Correct:** Full RAG pipeline with RedisVL. + +```python +from redisvl.index import SearchIndex +from redisvl.query import VectorQuery + +# 1. Store documents with embeddings +records = [] +for doc in documents: + records.append({ + "content": doc["content"], + "embedding": embed_model.encode(doc["content"]).tolist(), + "source": doc["source"] + }) + +index.load(records) + +# 2. Query with vector similarity +query_embedding = embed_model.encode(user_question) +results = index.query(VectorQuery( + vector=query_embedding, + vector_field_name="embedding", + return_fields=["content", "source"], + num_results=5 +)) + +# 3. Pass context to LLM +context = "\n".join([r["content"] for r in results]) +response = llm.generate(f"Context: {context}\n\nQuestion: {user_question}") +``` + +**Best practices:** +- Match your distance metric to your embedding model; many modern text embeddings already work well with COSINE +- Batch inserts using `index.load()` with lists +- Set appropriate M and EF_CONSTRUCTION for HNSW based on dataset size +- Use filters to reduce the search space before vector comparison +- Consider chunking long documents for better retrieval + +Reference: [Redis RAG Quickstart](https://redis.io/docs/latest/develop/get-started/rag/) diff --git a/api/.env.example b/api/.env.example index ed543ba..1c3056a 100644 --- a/api/.env.example +++ b/api/.env.example @@ -1,352 +1,49 @@ # ============================================================================= -# Life Echo API — 模板(example) +# Life Echo API — 环境变量(仅 secrets + bootstrap) # -# 目录结构与 api/.env.development 对齐,便于对照;占位键见各段注释。 -# 本地:复制为 .env.development(勿提交密钥),再运行 api/development.sh 会在首次自动生成 .env(从 -# .env.development 复制);Settings 只读 .env(见 app/core/config.py)。 -# 服务端:仓库维护 .env.staging / .env.production;workflow 按目标环境上传并复制为运行时 .env,compose 的 env_file 统一指向 .env。 -# 不要把真实密钥提交到仓库。 +# 非密钥配置 SSOT:config/default.toml + config/{APP_ENV}.toml +# 详见 docs/configuration.md # ============================================================================= -# ============================================================================= -# Docker Compose(宿主机独立 Caddy 反代到本 API) -# ============================================================================= -# 映射到宿主机的端口:不设置则由 Docker 随机分配,避免与同机其它项目冲突;随机时用 `docker compose port api 8000` 查看。 -# 需固定端口时取消下行注释并改为未占用端口,Caddyfile 中 reverse_proxy 到 127.0.0.1:该端口。 -# LIFE_ECHO_API_HOST_PORT=8000 -# 若 Caddy 跑在独立容器且非 host 网络,不要用 127.0.0.1,应把 Caddy 加入与本 compose 相同的 Docker 网络,并对 http://life-echo-api-prod:8000 做 reverse_proxy。 - -# ============================================================================= -# OpenTelemetry(见 docs/observability.md;Settings 只读 .env,勿 shell export) -# ============================================================================= -# docker-compose.observability.yml 宿主机端口(高位口,避免 3000/9090/4317 冲突) -# GRAFANA_HOST_PORT=48300 -# PROMETHEUS_HOST_PORT=49090 -# OTEL_GRPC_HOST_PORT=48317 -# OTEL_HTTP_HOST_PORT=48318 -# OTEL_COLLECTOR_HEALTH_HOST_PORT=48333 -# TEMPO_HTTP_HOST_PORT=43200 -# LOKI_HTTP_HOST_PORT=43100 -# -# --- development(.env.development):本机 uvicorn/celery --- -# OTEL_ENABLED=true -# OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:48317 -# OTEL_TRACES_SAMPLER=always_on -# -# --- staging / production(.env.staging / .env.production):容器内 compose --- -# OTEL_ENABLED=false -# OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317 -# OTEL_TRACES_SAMPLER=parentbased_traceidratio -# OTEL_TRACES_SAMPLER_ARG=0.1 -# -OTEL_ENABLED=true -OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:48317 -OTEL_EXPORTER_OTLP_INSECURE=true -OTEL_SERVICE_NAME=life-echo-api -OTEL_TRACES_SAMPLER=always_on -# OTEL_TRACES_SAMPLER_ARG=0.1 -# OTEL_METRIC_EXPORT_INTERVAL_MS=10000 - -# ============================================================================= -# Logging(loguru sink 最低级别:TRACE / DEBUG / INFO / WARNING / ERROR / CRITICAL) -# ============================================================================= -# 生产/预发:保持 INFO,避免 DEBUG 把全文 prompt/响应打进日志。排查 Agent 耗时可仅开 LOG_AGENT_VERBOSE。 -LOG_LEVEL=INFO -# Agent 单行 INFO 摘要(耗时、sha、字符数);与 LOG_LEVEL 独立,生产可短时设为 1 -# LOG_AGENT_VERBOSE=0 -# DEBUG 下 prompt/响应预览最大字符数(Settings 默认 4096);0=不截断全文(慎用) -# AGENT_LOG_MAX_CHARS=4096 -# DEBUG 下 *.prompt:preview=截断预览 | hash_only=仅 sha12+长度,无正文 -# AGENT_LOG_PROMPT_MODE=preview -# DEBUG 下同一 label 连续相同 prompt 则跳过重复行(减模板重复) -# AGENT_LOG_PROMPT_DEDUP=0 -# DEBUG 下访谈/资料:省略 SystemMessage 正文(仅 total_len+sha12);0/false=打出全文 -# AGENT_LOG_OMIT_SYSTEM_MESSAGE_BODY=1 -# DEBUG 下超长单段 *.prompt:总长超过下一项时,先跳过前 N 字符再预览(0=不跳过;短时 DEBUG 可设 2500–8000) -# AGENT_LOG_JSON_PROMPT_PREFIX_CHARS=0 -# AGENT_LOG_JSON_PROMPT_PREFIX_ONLY_IF_LEN_GT=4000 -# 第三方 stdlib logging(空=自动:LOG_LEVEL 为 DEBUG/TRACE 时 Celery→INFO;否则 Celery 与 httpx 默认 WARNING;需原始框架行时设为 INFO) -# CELERY_LOG_LEVEL= -# HTTPX_LOG_LEVEL= -# 聚合用 JSONL(空=不写);与 stderr 并存,loguru serialize=True、按 20MB 切割、保留 7 天 -# LOG_JSON_FILE=/var/log/life-echo/app.jsonl - -# ============================================================================= -# LLM / DeepSeek -# ============================================================================= -DEEPSEEK_API_KEY=your_deepseek_api_key -DEEPSEEK_BASE_URL=https://api.deepseek.com -# 官方新模型名见 https://api-docs.deepseek.com/zh-cn/quick_start/pricing -DEEPSEEK_MODEL=deepseek-v4-flash -# v4-flash 主链路非思考须显式关(对齐旧版 deepseek-chat;默认 false) -# DEEPSEEK_THINKING_ENABLED=false - -# ============================================================================= -# Memory 向量(智谱 BigModel 国内 embedding-3;与 DeepSeek/OpenAI 用途分离) -# 文档:https://docs.bigmodel.cn/cn/guide/models/embedding/embedding-3 -# 本期固定 1024 维;库表经迁移与 MEMORY_EMBEDDING_DIMENSION 一致。 -# ============================================================================= -ZHIPU_API_KEY=your_zhipu_api_key -# 默认国内通用端点(与 ZhipuAiClient 一致) -# EMBEDDING_BASE_URL=https://open.bigmodel.cn/api/paas/v4 -EMBEDDING_MODEL=embedding-3 - -# Chat 访谈:每轮根据用户内容判定主人生阶段(关则仅用关键词,省一次 LLM) -# CHAT_STAGE_DETECTION_ENABLED=true -# CHAT_STAGE_DETECTION_MAX_TOKENS=128 -# 年代/流行文化联想块(config 默认 true;若减少「文艺硬接」可设 false) -# CHAT_ERA_CONTEXT_ENABLED=true -# 访谈性格(InterviewAgent):default | warm_listener | curious_guide(config 默认 default) -# CHAT_INTERVIEW_PERSONA=default -# 访谈回复长度档位(brief/standard/expanded)联动:极短输入 / 默认 / 长段+新细节(若与当前代码不一致以 config 为准) -# CHAT_INTERVIEW_BRIEF_MAX_TOKENS=240 -# CHAT_INTERVIEW_BRIEF_MAX_CHARS_PER_SEGMENT=180 -# CHAT_INTERVIEW_EXPANDED_MAX_TOKENS=400 -# CHAT_INTERVIEW_EXPANDED_MAX_CHARS_PER_SEGMENT=300 -# 访谈/开场采样温度(config 默认 0.93;偏「好访谈者」体验时可试 0.60~0.70) -# CHAT_INTERVIEW_TEMPERATURE=0.93 -# 访谈主回复:统一 max_tokens / 单段字数(代码截断) -# CHAT_INTERVIEW_MAX_TOKENS=512 -# CHAT_INTERVIEW_MAX_CHARS_PER_SEGMENT=380 -# CHAT_INTERVIEW_MAX_SEGMENTS=2 -# 访谈:是否按本轮用户话检索记忆并注入提示词(关则不调 retrieve) -# CHAT_MEMORY_RETRIEVAL_ENABLED=true -# CHAT_MEMORY_TOP_K=8 -# CHAT_MEMORY_EVIDENCE_MAX_CHARS=4096 -# 规则 TurnPlan 之后再调一轮 JSON focus planner(config 默认 false;开启则多一次 LLM) -# CHAT_REPLY_PLANNER_LLM_ENABLED=true -# CHAT_REPLY_PLANNER_MAX_TOKENS=256 -# CHAT_REPLY_PLANNER_TEMPERATURE=0.2 - -# Memoir:批处理/抽取更新 slot 时是否允许改写 MemoirState.current_stage(默认 false,访谈 switch_stage 仍可推进) -# True 时仅当 proposed 与 existing 在同一 chat_bucket 才对齐 current_stage -# MEMOIR_EXTRACTION_UPDATES_CURRENT_STAGE=false - -# Memoir:叙事前口述归一(segment 原文仍落库;仅 story 流水线派生输入) -# MEMOIR_ORAL_NORMALIZE_ENABLED=true -# off | rules | llm(llm 为先规则再 LLM 纠错,失败回退规则结果) -# MEMOIR_ORAL_NORMALIZE_MODE=llm -# MEMOIR_ORAL_NORMALIZE_LLM_MAX_TOKENS=512 -# MEMOIR_ORAL_NORMALIZE_LLM_MAX_INPUT_CHARS=8000 - -# Chat:模型消费净稿(segment 原文仍落库;访谈编排层归一后注入 Agent / 记忆检索) -# CHAT_INPUT_NORMALIZE_ENABLED=true -# off | rules | llm(llm 为先规则再 LLM;失败回退规则;编排层已带 LLM 时不重复在 Agent 调) -# CHAT_INPUT_NORMALIZE_MODE=rules -# CHAT_INPUT_NORMALIZE_LLM_MAX_TOKENS=512 -# CHAT_INPUT_NORMALIZE_LLM_MAX_INPUT_CHARS=8000 -# True:仅 is_from_voice 时走 LLM 纠错;键盘输入仅规则归一 -# CHAT_INPUT_NORMALIZE_LLM_VOICE_ONLY=true - -# Memoir Phase1:True 时用一次「批量 JSON」做抽取+分类(单段或多段均可;失败自动回退逐段)。 -# False 时始终逐段(与启用本开关前的行为一致,含防抖合并后的多段任务)。 -# MEMOIR_PHASE1_BATCH_LLM_ENABLED=false -# MEMOIR_PHASE1_BATCH_LLM_MAX_TOKENS=4096 - -# ============================================================================= -# Database -# ============================================================================= -# 本地开发(docker-compose.dev.yml 固定宿主端口 48291,避免与本机 5432 冲突) -# DATABASE_URL=postgresql://postgres:postgres@localhost:48291/life_echo -# Docker / 服务端(主机名一般为 compose 服务名 postgres): -# DATABASE_URL=postgresql://postgres:postgres@postgres:5432/life_echo +# ── Bootstrap ───────────────────────────────────────────────── +APP_ENV=development DATABASE_URL=postgresql://postgres:postgres@localhost:48291/life_echo -# 启动时 Alembic(main.py);生产可设 ALEMBIC_STARTUP_FAIL_FAST=true,迁移失败则拒绝启动 -# ALEMBIC_RUN_ON_STARTUP=true -# ALEMBIC_STARTUP_FAIL_FAST=false -# ALEMBIC_STARTUP_MAX_RETRIES=3 -# ALEMBIC_STARTUP_RETRY_BASE_SECONDS=1.0 - -# ============================================================================= -# Redis -# ============================================================================= -# 本地开发(docker-compose.dev.yml 固定宿主端口 48307,避免与本机 6379 冲突) -# REDIS_URL=redis://localhost:48307/0 -# Docker / 服务端: -# REDIS_URL=redis://redis:6379/0 REDIS_URL=redis://localhost:48307/0 -REDIS_SESSION_TTL=86400 +# 可选:Redis 密码(应用会自动注入 URL;本地 dev 通常留空) +# REDIS_PASSWORD= +# 可选:覆盖 Celery broker/backend URL(默认自动使用 REDIS_URL 的 DB+1) +# CELERY_REDIS_URL=redis://localhost:48307/1 -# Celery:ingest 后 Memory LLM 富化任务投递队列(须被 worker 消费;见 README) -# CELERY_MEMORY_ENRICHMENT_QUEUE=memory_idle +# Flower(docker-compose.yml 生产栈,仅 localhost:5555) +# FLOWER_USER=admin +# FLOWER_PASSWORD=changeme -# ============================================================================= -# Internal evaluation API(internal_main;development.sh 默认一并启动;与主 API 进程隔离) -# ============================================================================= -# 本地:`openssl rand -hex 32`;不用 internal eval 时可留空 -INTERNAL_EVAL_API_KEY= -# INTERNAL_EVAL_ENABLE_DOCS=1 -# 评测台选 DeepSeek 评审:默认 deepseek-v4-flash + 非思考(与 https://api-docs.deepseek.com/zh-cn/quick_start/pricing 一致) -# EVAL_JUDGE_DEEPSEEK_MODEL=deepseek-v4-flash -# 仅写 v4-flash 模型 id 时是否启用思考(弃用名 deepseek-reasoner 仍始终为思考) -# EVAL_JUDGE_DEEPSEEK_THINKING_ENABLED=false - -# ============================================================================= -# Memory compaction(近重复 memory chunk 软排除;Celery + Redis 防抖) -# 模板统一默认开启;须同时运行 celery worker 与 celery-beat(docker-compose 已含 beat,负责 memory_compaction_sweep)。 -# ============================================================================= -MEMORY_COMPACTION_ENABLED=true -# MEMORY_COMPACTION_DEBOUNCE_SECONDS=105 -# MEMORY_COMPACTION_LOCK_TTL_SECONDS=600 -# MEMORY_COMPACTION_CHUNK_SIMILARITY_THRESHOLD=0.92 -# MEMORY_COMPACTION_MIN_LAYERS_FOR_EXCLUDE=2 -# MEMORY_COMPACTION_MAX_CHUNKS_PER_RUN=200 -# MEMORY_COMPACTION_MAX_EXCLUDES_PER_RUN=50 -# MEMORY_COMPACTION_MAX_NEIGHBORS_PER_CHUNK=25 -# MEMORY_COMPACTION_TEXT_JACCARD_MIN=0.55 -# MEMORY_COMPACTION_METADATA_EVENT_YEAR_WINDOW=1 -# MEMORY_COMPACTION_SWEEP_RECENT_HOURS=24 - -# ============================================================================= -# Story 流水线(post-commit、章节物化、append 上限、evidence 检索) -# ============================================================================= -# STORY_IMAGE_ENQUEUE_DEDUP_TTL=300 -# RECOMPOSE_CHAPTER_DELAY_SECONDS=8 -# CHAPTER_PIPELINE_LOCK_TTL_SECONDS=120 -# STORY_APPEND_MAX_CANONICAL_CHARS=12000 -# STORY_APPEND_MAX_VERSIONS=20 -# EVIDENCE_TOP_K_DEFAULT=10 -# EVIDENCE_TOP_K_LARGE_BATCH=5 -# EVIDENCE_LARGE_BATCH_THRESHOLD=3 -# -# Memoir 可靠性(叙事 faithful、标题 slots、证据渗漏、Phase1→2 追踪) -# MEMOIR_FIDELITY_FAIL_OPEN_ON_PARSE_ERROR=false -# MEMOIR_NARRATIVE_EVIDENCE_OVERLAP_MIN_CHARS=14 -# MEMOIR_EVIDENCE_SCENE_ANCHOR_CHECK_ENABLED=true -# MEMOIR_TITLE_SLOTS_REQUIRE_BODY_OR_ORAL_MATCH=true -# MEMOIR_TITLE_HAY_GROUNDING_STRICT_PHRASES_ENABLED=true -# MEMOIR_RECOMPOSE_RETRY_ON_LOCK_CONTENTION=true -# MEMOIR_PHASE2_SINGLEFLIGHT_IMMEDIATE=true -# -# ============================================================================= -# Auth -# ============================================================================= -# 建议使用: openssl rand -hex 32 +# ── Auth secret ─────────────────────────────────────────────── +# 生产/staging 务必:openssl rand -hex 32 SECRET_KEY=replace_with_a_strong_random_secret -ALGORITHM=HS256 -ACCESS_TOKEN_EXPIRE_MINUTES=120 -# 内网评测:开启后可用 POST /api/auth/mock/sms-login(跳过短信);APP_ENV=production 时该路由仍返回 404 -# MOCK_SMS_LOGIN_ENABLED=1 -# ============================================================================= -# Tencent Cloud — 短信 -# ============================================================================= -# 短信、一句话 ASR/TTS、COS 为不同产品;同一主账号可共用同一对 SecretId/SecretKey(分别填三处)。 -TENCENT_SMS_SECRET_ID=your_tencent_sms_secret_id -TENCENT_SMS_SECRET_KEY=your_tencent_sms_secret_key -# 短信应用 SDK AppID -TENCENT_SMS_SDK_APP_ID=your_sms_sdk_app_id -# 短信签名内容(不包含【】符号) -TENCENT_SMS_SIGN_NAME=your_sms_sign_name -# 短信模板 ID -TENCENT_SMS_TEMPLATE_ID=your_sms_template_id -# 短信模板参数数量(1=仅验证码,2=验证码+过期时间) -# 若遇 TemplateParamSetNotMatchApprovedTemplate,请对照控制台模板配置 -TENCENT_SMS_TEMPLATE_PARAM_COUNT=1 +# ── LLM / Embedding 密钥 ───────────────────────────────────── +DEEPSEEK_API_KEY=your_deepseek_api_key +ZHIPU_API_KEY=your_zhipu_api_key -# ============================================================================= -# ASR Provider(whisper | tencent) -# ============================================================================= -ASR_PROVIDER=whisper +# ── 腾讯云凭证(SMS / ASR / TTS / COS 共用)────────────────── +TENCENT_SECRET_ID=your_tencent_secret_id +TENCENT_SECRET_KEY=your_tencent_secret_key -# ============================================================================= -# Whisper ASR(ASR_PROVIDER=whisper 时使用) -# ============================================================================= -ASR_MODEL_SIZE=small -ASR_DEVICE=cpu -ASR_COMPUTE_TYPE=int8 - -# GPU 环境(示例,按需启用) -# ASR_MODEL_SIZE=medium -# ASR_DEVICE=cuda -# ASR_COMPUTE_TYPE=float16 - -# ============================================================================= -# Tencent Cloud — 一句话 ASR + TTS(ASR_PROVIDER=tencent 或 TTS_PROVIDER=tencent) -# ============================================================================= -TENCENT_SECRET_ID=your_tencent_asr_secret_id -TENCENT_SECRET_KEY=your_tencent_asr_secret_key - -# ============================================================================= -# TTS(文字转语音,Agent 回复朗读)— 与 ASR 独立 -# ============================================================================= -# ENABLE_TTS:关闭时禁用「助手每轮自动生成 TTS」(tts_this_turn 链路);不影响 WebSocket「按需朗读」tts_request。 -# 每轮是否自动生成:客户端 `data.tts_this_turn`,且 ENABLE_TTS=true、skeleton skip_tts 均未阻止时才会合成。 -ENABLE_TTS=true -TTS_PROVIDER=tencent -# 仅 TTS_PROVIDER=openai 时需要 -# OPENAI_API_KEY= -# 音色 ID 见 https://cloud.tencent.com/document/product/1073/92668 -TTS_VOICE_TYPE=501004 -TTS_CODEC=mp3 - -# ============================================================================= -# WeChat Pay -# ============================================================================= -WECHAT_PAY_APP_ID=your_wechat_pay_app_id -WECHAT_PAY_MCH_ID=your_wechat_mch_id +# ── WeChat Pay 密钥 ─────────────────────────────────────────── WECHAT_PAY_API_V3_KEY=your_wechat_api_v3_key -# 商户私钥:推荐使用文件路径,避免 .env 中长 PEM 转义问题 -WECHAT_PAY_PRIVATE_KEY_PATH=certs/apiclient_key.pem -# 若不用文件,可配置 WECHAT_PAY_PRIVATE_KEY(PEM,换行用 \n) -# WECHAT_PAY_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----" -WECHAT_PAY_CERT_SERIAL_NO=your_wechat_cert_serial_no -WECHAT_PAY_NOTIFY_URL=https://your-domain.com/api/payment/notify/wechat -# 平台公钥模式(仅当无法走平台证书自动拉取时使用);勿填商户私钥路径 -# WECHAT_PAY_PLATFORM_PUBLIC_KEY_PATH=certs/wechat_platform_public_key.pem -# WECHAT_PAY_PLATFORM_PUBLIC_KEY_ID=your_wechat_platform_public_key_id +# WECHAT_PAY_PRIVATE_KEY= # 或使用 WECHAT_PAY_PRIVATE_KEY_PATH(见 config/*.toml deploy) -# ============================================================================= -# Alipay(未接入时可为空字符串) -# ============================================================================= -ALIPAY_APP_ID= +# ── Alipay 密钥(未接入可留空)──────────────────────────────── ALIPAY_PRIVATE_KEY= ALIPAY_PUBLIC_KEY= -ALIPAY_NOTIFY_URL=https://your-domain.com/api/payment/notify/alipay -# ============================================================================= -# Misc -# ============================================================================= -ENABLE_TEST_SUBSCRIPTION=0 - -# ============================================================================= -# Memoir image generation(Story 主图等;轮询 Liblib 任务) -# ============================================================================= -MEMOIR_IMAGE_ENABLED=false -MEMOIR_IMAGE_POLL_INTERVAL=3 -MEMOIR_IMAGE_MAX_ATTEMPTS=20 -MEMOIR_IMAGE_PROVIDER=liblib -MEMOIR_IMAGE_STYLE_DEFAULT=watercolor -MEMOIR_IMAGE_SIZE_DEFAULT=1280x720 -# 章节正文内至少多少张 asset:// 插图才生成/展示章节封面(默认 1=有一张正文图即可) -MEMOIR_MIN_INLINE_IMAGES_FOR_CHAPTER_COVER=1 -# Story 正文至少多少字才生成主图 intent / 调图(0=不限制) -STORY_IMAGE_MIN_BODY_CHARS=400 -# 叙事模型输出相对口述过短则回退为口述原文 -MEMOIR_NARRATIVE_FALLBACK_BODY_RATIO=0.5 -MEMOIR_NARRATIVE_FALLBACK_MIN_CHARS=20 -# 回忆录 segment 入队:累计 strip 后字数未达此值则暂缓提交 Celery(0=关闭字数门闸,仅静默防抖后提交) -# MEMOIR_SEGMENT_BATCH_MIN_CHARS=50 -# 本批首条入队起最长等待(秒),超时仍提交;测试可调低,生产可调高 -# MEMOIR_SEGMENT_BATCH_MAX_WAIT_SECONDS=60 -# 可选,Liblib 返回图片域名不在默认白名单时(逗号分隔) -# MEMOIR_IMAGE_DOWNLOAD_HOSTS=liblib.cloud,liblibai.cloud - -# ============================================================================= -# Liblib image provider -# ============================================================================= +# ── Liblib 密钥(memoir image,见 config deploy.memoir_image_enabled)──── LIBLIB_ACCESS_KEY=your_liblib_access_key LIBLIB_SECRET_KEY=your_liblib_secret_key -LIBLIB_BASE_URL=https://openapi.liblibai.cloud -LIBLIB_TEMPLATE_UUID=your_liblib_template_uuid -# ============================================================================= -# Tencent Cloud — COS(回忆录图片存储) -# ============================================================================= -TENCENT_COS_SECRET_ID=your_tencent_cos_secret_id -TENCENT_COS_SECRET_KEY=your_tencent_cos_secret_key -TENCENT_COS_REGION=ap-shanghai -TENCENT_COS_BUCKET=your_bucket_name -TENCENT_COS_BASE_URL=https://your_bucket_name.cos.ap-shanghai.myqcloud.com -# 可选临时凭证 -# TENCENT_COS_TOKEN= +# ── Internal evaluation API(可选)──────────────────────────── +INTERNAL_EVAL_API_KEY= + +# ── Docker Compose 端口(非 Settings,见 docker-compose.yml)── +# LIFE_ECHO_API_HOST_PORT=8000 diff --git a/api/.env.production b/api/.env.production index 6650d28..b8541ec 100644 --- a/api/.env.production +++ b/api/.env.production @@ -1,286 +1,27 @@ -# ============================================================================= -# Life Echo API — production(生产) -# -# 仓库维护本文件;production 发布时 workflow 会上传并复制为运行时 .env。 -# 若仓库可被非授权人员访问,请不要在此文件中保留真实密钥。 -# ============================================================================= +# Life Echo API — production secrets(运行时 .env) +# 非密钥项见 config/production.toml -# ============================================================================= -# Docker Compose(宿主机独立 Caddy 反代到本 API) -# ============================================================================= -# 映射到宿主机的端口,默认 8000;与同机其它项目冲突时改为未占用端口,并在独立 Caddy 的 Caddyfile 中 reverse_proxy 到 127.0.0.1:该端口。 -# LIFE_ECHO_API_HOST_PORT=8000 -# 若 Caddy 跑在独立容器且非 host 网络,不要用 127.0.0.1,应把 Caddy 加入与本 compose 相同的 Docker 网络,并对 http://life-echo-api-prod:8000 做 reverse_proxy。 - -# ============================================================================= -# Logging(loguru sink 最低级别:TRACE / DEBUG / INFO / WARNING / ERROR / CRITICAL) -# ============================================================================= -# 生产默认 INFO;勿长期 DEBUG。排障 Agent 耗时可短时 LOG_AGENT_VERBOSE=1。 -LOG_LEVEL=INFO -# Agent 单行 INFO 摘要;与 LOG_LEVEL 独立,便于生产短暂排查 -# LOG_AGENT_VERBOSE=0 -# DEBUG 下预览上限(默认 4096);0=全文 -# AGENT_LOG_MAX_CHARS=4096 -# DEBUG 下 *.prompt:preview | hash_only -# AGENT_LOG_PROMPT_MODE=preview -# AGENT_LOG_PROMPT_DEDUP=0 -# DEBUG 下访谈/资料:省略 SystemMessage 正文(仅 total_len+sha12);0/false=打出全文 -# AGENT_LOG_OMIT_SYSTEM_MESSAGE_BODY=1 -# DEBUG 下超长单段 *.prompt:先跳过前 N 字符再预览 -# AGENT_LOG_JSON_PROMPT_PREFIX_CHARS=0 -# AGENT_LOG_JSON_PROMPT_PREFIX_ONLY_IF_LEN_GT=4000 -# 第三方 stdlib logging(空=自动:LOG_LEVEL 为 DEBUG/TRACE 时 Celery→INFO、httpx/httpcore→WARNING,减少刷屏) -# CELERY_LOG_LEVEL= -# HTTPX_LOG_LEVEL= - -# ============================================================================= -# OpenTelemetry(生产;第二阶段 compose profile 接入后设 OTEL_ENABLED=true,见 docs/observability.md) -# 容器内 API/Celery → http://otel-collector:4317;勿用 localhost -# ============================================================================= -OTEL_ENABLED=false -OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317 -OTEL_EXPORTER_OTLP_INSECURE=true -OTEL_SERVICE_NAME=life-echo-api -OTEL_TRACES_SAMPLER=parentbased_traceidratio -OTEL_TRACES_SAMPLER_ARG=0.1 -# OTEL_METRIC_EXPORT_INTERVAL_MS=10000 - -# ============================================================================= -# LLM / DeepSeek -# ============================================================================= -DEEPSEEK_API_KEY=sk-09f17fb61c5a4299a3afc2a01de7af75 -DEEPSEEK_BASE_URL=https://api.deepseek.com -DEEPSEEK_MODEL=deepseek-v4-flash - -# ============================================================================= -# Memory 向量(智谱 BigModel 国内 embedding-3;与 DeepSeek/OpenAI 用途分离) -# 文档:https://docs.bigmodel.cn/cn/guide/models/embedding/embedding-3 -# 本期固定 1024 维;库表经迁移与 MEMORY_EMBEDDING_DIMENSION 一致。 -# ============================================================================= -ZHIPU_API_KEY=524eda18eb3848e881eefe4c7ef17ec2.xBmGUabYDEa44m3M -# 默认国内通用端点(与 ZhipuAiClient 一致) -# EMBEDDING_BASE_URL=https://open.bigmodel.cn/api/paas/v4 -EMBEDDING_MODEL=embedding-3 - -# Chat 访谈:每轮根据用户内容判定主人生阶段(关则仅用关键词,省一次 LLM) -# CHAT_STAGE_DETECTION_ENABLED=true -# CHAT_STAGE_DETECTION_MAX_TOKENS=128 -# 访谈者体验(覆盖 config 默认值;与 api/.env.development 对齐时可减少文风漂移与记忆噪声) -CHAT_ERA_CONTEXT_ENABLED=true -CHAT_INTERVIEW_PERSONA=warm_listener -CHAT_INTERVIEW_TEMPERATURE=0.65 -# 访谈:是否按本轮用户话检索记忆并注入提示词(关则不调 retrieve) -# CHAT_MEMORY_RETRIEVAL_ENABLED=true -CHAT_MEMORY_TOP_K=4 -CHAT_MEMORY_EVIDENCE_MAX_CHARS=1400 -CHAT_REPLY_PLANNER_LLM_ENABLED=true -# 访谈回复长度档位(brief/standard/expanded)联动:极短输入 / 默认 / 长段+新细节 -# CHAT_INTERVIEW_BRIEF_MAX_TOKENS=240 -# CHAT_INTERVIEW_BRIEF_MAX_CHARS_PER_SEGMENT=180 -# CHAT_INTERVIEW_EXPANDED_MAX_TOKENS=400 -# CHAT_INTERVIEW_EXPANDED_MAX_CHARS_PER_SEGMENT=300 - -# Memoir:批处理/抽取更新 slot 时是否允许改写 MemoirState.current_stage(默认 false,访谈 switch_stage 仍可推进) -# True 时仅当 proposed 与 existing 在同一 chat_bucket 才对齐 current_stage -# MEMOIR_EXTRACTION_UPDATES_CURRENT_STAGE=false - -# Memoir:叙事前口述归一(segment 原文仍落库;仅 story 流水线派生输入) -MEMOIR_ORAL_NORMALIZE_ENABLED=true -# off | rules | llm(llm 为先规则再 LLM 纠错,失败回退规则结果) -MEMOIR_ORAL_NORMALIZE_MODE=llm -MEMOIR_ORAL_NORMALIZE_LLM_MAX_TOKENS=512 -MEMOIR_ORAL_NORMALIZE_LLM_MAX_INPUT_CHARS=8000 - -# Chat:模型消费净稿(segment 原文仍落库;访谈编排层归一后注入 Agent / 记忆检索) -# CHAT_INPUT_NORMALIZE_ENABLED=true -# off | rules | llm(llm 为先规则再 LLM;失败回退规则;编排层已带 LLM 时不重复在 Agent 调) -# CHAT_INPUT_NORMALIZE_MODE=rules -# CHAT_INPUT_NORMALIZE_LLM_MAX_TOKENS=512 -# CHAT_INPUT_NORMALIZE_LLM_MAX_INPUT_CHARS=8000 - -# ============================================================================= -# Database -# ============================================================================= -# 本地开发: -# DATABASE_URL=postgresql://postgres:postgres@localhost:5432/life_echo -# Docker / 服务端(主机名一般为 compose 服务名 postgres): -# DATABASE_URL=postgresql://postgres:postgres@postgres:5432/life_echo +APP_ENV=production DATABASE_URL=postgresql://postgres:postgres@postgres:5432/life_echo -# 启动时 Alembic(main.py);生产可设 ALEMBIC_STARTUP_FAIL_FAST=true,迁移失败则拒绝启动 -# ALEMBIC_RUN_ON_STARTUP=true -# ALEMBIC_STARTUP_FAIL_FAST=false -# ALEMBIC_STARTUP_MAX_RETRIES=3 -# ALEMBIC_STARTUP_RETRY_BASE_SECONDS=1.0 -# ============================================================================= -# Redis -# ============================================================================= -# 本地开发: -# REDIS_URL=redis://localhost:6379/0 -# Docker / 服务端: -# REDIS_URL=redis://redis:6379/0 +# Redis:业务 DB/0;Celery 自动 DB/1;compose redis 使用 REDIS_PASSWORD 作为 requirepass REDIS_URL=redis://redis:6379/0 -REDIS_SESSION_TTL=86400 +REDIS_PASSWORD=replace_with_strong_redis_password +# CELERY_REDIS_URL=redis://:replace_with_strong_redis_password@redis:6379/1 +FLOWER_USER=admin +FLOWER_PASSWORD=replace_with_strong_flower_password -# Celery:ingest 后 Memory LLM 富化任务投递队列(须被 worker 消费;见 README) -# CELERY_MEMORY_ENRICHMENT_QUEUE=memory_idle - -# ============================================================================= -# Memory compaction(近重复 memory chunk 软排除;Celery + Redis 防抖) -# 与 .env.example / .env.development 一致默认开启;需 running:celery worker + celery-beat(见 docker-compose.yml)。 -# ============================================================================= -MEMORY_COMPACTION_ENABLED=true -# MEMORY_COMPACTION_DEBOUNCE_SECONDS=105 -# MEMORY_COMPACTION_LOCK_TTL_SECONDS=600 -# MEMORY_COMPACTION_CHUNK_SIMILARITY_THRESHOLD=0.92 -# MEMORY_COMPACTION_MIN_LAYERS_FOR_EXCLUDE=2 -# MEMORY_COMPACTION_MAX_CHUNKS_PER_RUN=200 -# MEMORY_COMPACTION_MAX_EXCLUDES_PER_RUN=50 -# MEMORY_COMPACTION_MAX_NEIGHBORS_PER_CHUNK=25 -# MEMORY_COMPACTION_TEXT_JACCARD_MIN=0.55 -# MEMORY_COMPACTION_METADATA_EVENT_YEAR_WINDOW=1 -# MEMORY_COMPACTION_SWEEP_RECENT_HOURS=24 - -# ============================================================================= -# Story 流水线(post-commit、章节物化、append 上限、evidence 检索) -# ============================================================================= -# STORY_IMAGE_ENQUEUE_DEDUP_TTL=300 -# RECOMPOSE_CHAPTER_DELAY_SECONDS=8 -# CHAPTER_PIPELINE_LOCK_TTL_SECONDS=120 -# STORY_APPEND_MAX_CANONICAL_CHARS=12000 -# STORY_APPEND_MAX_VERSIONS=20 -# EVIDENCE_TOP_K_DEFAULT=10 -# EVIDENCE_TOP_K_LARGE_BATCH=5 -# EVIDENCE_LARGE_BATCH_THRESHOLD=3 - -# ============================================================================= -# Auth -# ============================================================================= -# 建议使用: openssl rand -hex 32 SECRET_KEY=cf47555c7ecbe5ddb7fd2113c59e08a8bcb110810c42f7c644e06a5acc898608 -ALGORITHM=HS256 -ACCESS_TOKEN_EXPIRE_MINUTES=120 -# ============================================================================= -# Tencent Cloud — 短信 -# ============================================================================= -# 短信、一句话 ASR/TTS、COS 为不同产品;同一主账号可共用同一对 SecretId/SecretKey(分别填三处)。 -TENCENT_SMS_SECRET_ID=AKIDa2ILCwUr56uVt31oU0JOHxPfGhvvkLiq -TENCENT_SMS_SECRET_KEY=xiFbjlZ9XheS2NWYLvHRPAh2A5nGYcR2 -# 短信应用 SDK AppID -TENCENT_SMS_SDK_APP_ID=1401010099 -# 短信签名内容(不包含【】符号) -TENCENT_SMS_SIGN_NAME=上海华嘎科技有限公司 -# 短信模板 ID -TENCENT_SMS_TEMPLATE_ID=2592163 -# 短信模板参数数量(1=仅验证码,2=验证码+过期时间) -# 若遇 TemplateParamSetNotMatchApprovedTemplate,请对照控制台模板配置 -TENCENT_SMS_TEMPLATE_PARAM_COUNT=1 +DEEPSEEK_API_KEY=sk-09f17fb61c5a4299a3afc2a01de7af75 +ZHIPU_API_KEY=524eda18eb3848e881eefe4c7ef17ec2.xBmGUabYDEa44m3M -# ============================================================================= -# ASR Provider(whisper | tencent) -# ============================================================================= -ASR_PROVIDER=tencent - -# ============================================================================= -# Whisper ASR(ASR_PROVIDER=whisper 时使用) -# ============================================================================= -ASR_MODEL_SIZE=small -ASR_DEVICE=cpu -ASR_COMPUTE_TYPE=int8 - -# GPU 环境(示例,按需启用) -# ASR_MODEL_SIZE=medium -# ASR_DEVICE=cuda -# ASR_COMPUTE_TYPE=float16 - -# ============================================================================= -# Tencent Cloud — 一句话 ASR + TTS(ASR_PROVIDER=tencent 或 TTS_PROVIDER=tencent) -# ============================================================================= TENCENT_SECRET_ID=AKIDa2ILCwUr56uVt31oU0JOHxPfGhvvkLiq TENCENT_SECRET_KEY=xiFbjlZ9XheS2NWYLvHRPAh2A5nGYcR2 -# ============================================================================= -# TTS(文字转语音,Agent 回复朗读)— 与 ASR 独立 -# ============================================================================= -# ENABLE_TTS:是否启用「助手回复朗读」服务端能力(TTS 适配器与密钥配置)。关则永远不合成。 -# 每轮是否实际合成:由客户端在 WebSocket `text` / `audio_segment` / `audio_message` 的 `data.tts_this_turn` 控制(未传或 false 仅返回文字)。 -# 若 ENABLE_TTS=true 且该轮 `tts_this_turn=true`:每一段助手文案先下发 `tts_audio`,再下发对应段的 `agent_response`。 -ENABLE_TTS=true -TTS_PROVIDER=tencent -# 仅 TTS_PROVIDER=openai 时需要(填控制台密钥;勿在注释行写 =your_* 以免旧版 CI 误匹配) -# OPENAI_API_KEY= -# 音色 ID 见 https://cloud.tencent.com/document/product/1073/92668 -TTS_VOICE_TYPE=501004 -TTS_CODEC=mp3 - -# ============================================================================= -# WeChat Pay -# ============================================================================= -WECHAT_PAY_APP_ID=wx1df508452e06cfb8 -WECHAT_PAY_MCH_ID=1662979099 WECHAT_PAY_API_V3_KEY=xjvGSJLGJAJfjgskfjslafjsajsdjals -# 商户私钥:推荐使用文件路径,避免 .env 中长 PEM 转义问题 -WECHAT_PAY_PRIVATE_KEY_PATH=certs/apiclient_key.pem -# 若不用文件,可配置 WECHAT_PAY_PRIVATE_KEY(PEM,换行用 \n) -# WECHAT_PAY_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----" -WECHAT_PAY_CERT_SERIAL_NO=1AA82328AC1456C6F115B014606F22CD621D2032 -WECHAT_PAY_NOTIFY_URL=https://lifecho.worldsplats.com/api/payment/notify/wechat -# 平台公钥模式(仅当无法走平台证书自动拉取时使用);勿填商户私钥路径 -# WECHAT_PAY_PLATFORM_PUBLIC_KEY_PATH=certs/wechat_platform_public_key.pem -# WECHAT_PAY_PLATFORM_PUBLIC_KEY_ID=PUB_KEY_ID_0116629790992026020700181671002400 -# ============================================================================= -# Alipay(未接入时保持空字符串,与 Settings 默认一致) -# ============================================================================= -ALIPAY_APP_ID= -ALIPAY_PRIVATE_KEY= -ALIPAY_PUBLIC_KEY= -ALIPAY_NOTIFY_URL=https://lifecho.worldsplats.com/api/payment/notify/alipay - -# ============================================================================= -# Misc -# ============================================================================= -ENABLE_TEST_SUBSCRIPTION=1 - -# ============================================================================= -# Memoir image generation(Story 主图等;轮询 Liblib 任务) -# ============================================================================= -MEMOIR_IMAGE_ENABLED=true -MEMOIR_IMAGE_POLL_INTERVAL=3 -MEMOIR_IMAGE_MAX_ATTEMPTS=20 -MEMOIR_IMAGE_PROVIDER=liblib -MEMOIR_IMAGE_STYLE_DEFAULT=watercolor -MEMOIR_IMAGE_SIZE_DEFAULT=1280x720 -# 章节正文内至少多少张 asset:// 插图才生成/展示章节封面(≥1 即有一张图可出封面) -MEMOIR_MIN_INLINE_IMAGES_FOR_CHAPTER_COVER=1 -# Story 正文至少多少字才生成主图 intent / 调图(0=不限制) -STORY_IMAGE_MIN_BODY_CHARS=800 -# 叙事模型输出相对口述过短则回退为口述原文 -MEMOIR_NARRATIVE_FALLBACK_BODY_RATIO=0.5 -MEMOIR_NARRATIVE_FALLBACK_MIN_CHARS=20 -# 回忆录 segment 入队:累计 strip 后字数未达此值则暂缓提交 Celery(0=关闭字数门闸,仅静默防抖后提交) -# MEMOIR_SEGMENT_BATCH_MIN_CHARS=50 -# 本批首条入队起最长等待(秒),超时仍提交;测试可调低,生产可调高 -# MEMOIR_SEGMENT_BATCH_MAX_WAIT_SECONDS=60 -# 可选,Liblib 返回图片域名不在默认白名单时(逗号分隔) -# MEMOIR_IMAGE_DOWNLOAD_HOSTS=liblib.cloud,liblibai.cloud - -# ============================================================================= -# Liblib image provider -# ============================================================================= LIBLIB_ACCESS_KEY=zrDp6quCOKlLwcewOEfrog LIBLIB_SECRET_KEY=iTVHo5Nf3KA-xpC1Mja80bC93u6chJem -LIBLIB_BASE_URL=https://openapi.liblibai.cloud -LIBLIB_TEMPLATE_UUID=5d7e67009b344550bc1aa6ccbfa1d7f4 -# ============================================================================= -# Tencent Cloud — COS(回忆录图片存储) -# ============================================================================= -TENCENT_COS_SECRET_ID=AKIDa2ILCwUr56uVt31oU0JOHxPfGhvvkLiq -TENCENT_COS_SECRET_KEY=xiFbjlZ9XheS2NWYLvHRPAh2A5nGYcR2 -TENCENT_COS_REGION=ap-shanghai -TENCENT_COS_BUCKET=life-echo-prod-1319381411 -TENCENT_COS_BASE_URL=https://life-echo-prod-1319381411.cos.ap-shanghai.myqcloud.com -# 可选临时凭证 -# TENCENT_COS_TOKEN= +INTERNAL_EVAL_API_KEY= diff --git a/api/.env.staging b/api/.env.staging index fa1ed8b..aec2aa5 100644 --- a/api/.env.staging +++ b/api/.env.staging @@ -1,187 +1,31 @@ -# ============================================================================= -# Life Echo API — staging(预发) -# -# 基于 .env.production 生成;staging 发布时 workflow 会上传并复制为运行时 .env。 -# 不要把生产密钥误填进本文件(当前与 production 共用同一套三方密钥)。 -# ============================================================================= +# Life Echo API — staging secrets(运行时 .env) +# 非密钥项见 config/staging.toml -# ============================================================================= -# Docker Compose(宿主机独立 Caddy 反代到本 API) -# ============================================================================= LIFE_ECHO_API_HOST_BIND=0.0.0.0 LIFE_ECHO_API_HOST_PORT=8000 POSTGRES_HOST_PORT=15432 -# ============================================================================= -# Logging(loguru sink 最低级别:TRACE / DEBUG / INFO / WARNING / ERROR / CRITICAL) -# ============================================================================= -LOG_LEVEL=INFO -# Agent 单行 INFO 摘要;与 LOG_LEVEL 独立 -# LOG_AGENT_VERBOSE=0 -# DEBUG 下预览上限(默认 4096);0=全文 -# AGENT_LOG_MAX_CHARS=4096 -# DEBUG 下 *.prompt:preview | hash_only -# AGENT_LOG_PROMPT_MODE=preview -# AGENT_LOG_PROMPT_DEDUP=0 -# DEBUG 下访谈/资料:省略 SystemMessage 正文(仅 total_len+sha12);0/false=打出全文 -# AGENT_LOG_OMIT_SYSTEM_MESSAGE_BODY=1 -# DEBUG 下超长单段 *.prompt:先跳过前 N 字符再预览 -# AGENT_LOG_JSON_PROMPT_PREFIX_CHARS=0 -# AGENT_LOG_JSON_PROMPT_PREFIX_ONLY_IF_LEN_GT=4000 -# 第三方 stdlib logging(空=自动) -# CELERY_LOG_LEVEL= -# HTTPX_LOG_LEVEL= - -# ============================================================================= -# OpenTelemetry(预发;compose 接入 LGTM 后设 OTEL_ENABLED=true,见 docs/observability.md) -# API/Celery 容器内 endpoint 用服务名;Grafana 宿主机端口见 observability compose(默认 48300 等) -# ============================================================================= -OTEL_ENABLED=false -OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317 -OTEL_EXPORTER_OTLP_INSECURE=true -OTEL_SERVICE_NAME=life-echo-api -OTEL_TRACES_SAMPLER=parentbased_traceidratio -OTEL_TRACES_SAMPLER_ARG=0.1 -# OTEL_METRIC_EXPORT_INTERVAL_MS=10000 - -# ============================================================================= -# LLM / DeepSeek -# ============================================================================= -DEEPSEEK_API_KEY=sk-09f17fb61c5a4299a3afc2a01de7af75 -DEEPSEEK_BASE_URL=https://api.deepseek.com -DEEPSEEK_MODEL=deepseek-v4-flash - -# ============================================================================= -# Memory 向量(智谱 BigModel 国内 embedding-3;与 DeepSeek/OpenAI 用途分离) -# ============================================================================= -ZHIPU_API_KEY=524eda18eb3848e881eefe4c7ef17ec2.xBmGUabYDEa44m3M -EMBEDDING_MODEL=embedding-3 - -# Chat 访谈 -CHAT_ERA_CONTEXT_ENABLED=true -CHAT_INTERVIEW_PERSONA=warm_listener -CHAT_INTERVIEW_TEMPERATURE=0.65 -CHAT_MEMORY_TOP_K=4 -CHAT_MEMORY_EVIDENCE_MAX_CHARS=1400 -CHAT_REPLY_PLANNER_LLM_ENABLED=true - -# Memoir:叙事前口述归一 -MEMOIR_ORAL_NORMALIZE_ENABLED=true -MEMOIR_ORAL_NORMALIZE_MODE=llm -MEMOIR_ORAL_NORMALIZE_LLM_MAX_TOKENS=512 -MEMOIR_ORAL_NORMALIZE_LLM_MAX_INPUT_CHARS=8000 - -# ============================================================================= -# Database -# ============================================================================= +APP_ENV=staging DATABASE_URL=postgresql://postgres:postgres@postgres:5432/life_echo -# ============================================================================= -# Redis -# ============================================================================= +# Redis:业务 DB/0;Celery 自动 DB/1;compose redis 使用 REDIS_PASSWORD 作为 requirepass REDIS_URL=redis://redis:6379/0 -REDIS_SESSION_TTL=86400 +REDIS_PASSWORD=replace_with_strong_redis_password +# CELERY_REDIS_URL=redis://:replace_with_strong_redis_password@redis:6379/1 +FLOWER_USER=admin +FLOWER_PASSWORD=replace_with_strong_flower_password -# ============================================================================= -# Memory compaction -# ============================================================================= -MEMORY_COMPACTION_ENABLED=true - -# ============================================================================= -# Auth -# ============================================================================= SECRET_KEY=cf47555c7ecbe5ddb7fd2113c59e08a8bcb110810c42f7c644e06a5acc898608 -ALGORITHM=HS256 -ACCESS_TOKEN_EXPIRE_MINUTES=120 -APP_ENV=staging -MOCK_SMS_LOGIN_ENABLED=1 -# ============================================================================= -# Tencent Cloud — 短信 -# ============================================================================= -TENCENT_SMS_SECRET_ID=AKIDa2ILCwUr56uVt31oU0JOHxPfGhvvkLiq -TENCENT_SMS_SECRET_KEY=xiFbjlZ9XheS2NWYLvHRPAh2A5nGYcR2 -TENCENT_SMS_SDK_APP_ID=1401010099 -TENCENT_SMS_SIGN_NAME=上海华嘎科技有限公司 -TENCENT_SMS_TEMPLATE_ID=2592163 -TENCENT_SMS_TEMPLATE_PARAM_COUNT=1 +DEEPSEEK_API_KEY=sk-09f17fb61c5a4299a3afc2a01de7af75 +ZHIPU_API_KEY=524eda18eb3848e881eefe4c7ef17ec2.xBmGUabYDEa44m3M -# ============================================================================= -# ASR Provider(whisper | tencent) -# ============================================================================= -ASR_PROVIDER=tencent - -# ============================================================================= -# Whisper ASR(ASR_PROVIDER=whisper 时使用) -# ============================================================================= -ASR_MODEL_SIZE=small -ASR_DEVICE=cpu -ASR_COMPUTE_TYPE=int8 - -# ============================================================================= -# Tencent Cloud — 一句话 ASR + TTS -# ============================================================================= TENCENT_SECRET_ID=AKIDa2ILCwUr56uVt31oU0JOHxPfGhvvkLiq TENCENT_SECRET_KEY=xiFbjlZ9XheS2NWYLvHRPAh2A5nGYcR2 -# ============================================================================= -# TTS -# ============================================================================= -ENABLE_TTS=true -TTS_PROVIDER=tencent -TTS_VOICE_TYPE=501004 -TTS_CODEC=mp3 - -# ============================================================================= -# WeChat Pay -# ============================================================================= -WECHAT_PAY_APP_ID=wx1df508452e06cfb8 -WECHAT_PAY_MCH_ID=1662979099 WECHAT_PAY_API_V3_KEY=xjvGSJLGJAJfjgskfjslafjsajsdjals -WECHAT_PAY_PRIVATE_KEY_PATH=certs/apiclient_key.pem -WECHAT_PAY_CERT_SERIAL_NO=1AA82328AC1456C6F115B014606F22CD621D2032 -WECHAT_PAY_NOTIFY_URL=https://lifecho.worldsplats.com/api/payment/notify/wechat -# ============================================================================= -# Alipay(未接入时保持空字符串) -# ============================================================================= -ALIPAY_APP_ID= -ALIPAY_PRIVATE_KEY= -ALIPAY_PUBLIC_KEY= -ALIPAY_NOTIFY_URL=https://lifecho.worldsplats.com/api/payment/notify/alipay - -# ============================================================================= -# Misc -# ============================================================================= -ENABLE_TEST_SUBSCRIPTION=1 - -# ============================================================================= -# Memoir image generation -# ============================================================================= -MEMOIR_IMAGE_ENABLED=true -MEMOIR_IMAGE_POLL_INTERVAL=3 -MEMOIR_IMAGE_MAX_ATTEMPTS=20 -MEMOIR_IMAGE_PROVIDER=liblib -MEMOIR_IMAGE_STYLE_DEFAULT=watercolor -MEMOIR_IMAGE_SIZE_DEFAULT=1280x720 -MEMOIR_MIN_INLINE_IMAGES_FOR_CHAPTER_COVER=1 -STORY_IMAGE_MIN_BODY_CHARS=800 -MEMOIR_NARRATIVE_FALLBACK_BODY_RATIO=0.5 -MEMOIR_NARRATIVE_FALLBACK_MIN_CHARS=20 - -# ============================================================================= -# Liblib image provider -# ============================================================================= LIBLIB_ACCESS_KEY=zrDp6quCOKlLwcewOEfrog LIBLIB_SECRET_KEY=iTVHo5Nf3KA-xpC1Mja80bC93u6chJem -LIBLIB_BASE_URL=https://openapi.liblibai.cloud -LIBLIB_TEMPLATE_UUID=5d7e67009b344550bc1aa6ccbfa1d7f4 -# ============================================================================= -# Tencent Cloud — COS(回忆录图片存储) -# ============================================================================= -TENCENT_COS_SECRET_ID=AKIDa2ILCwUr56uVt31oU0JOHxPfGhvvkLiq -TENCENT_COS_SECRET_KEY=xiFbjlZ9XheS2NWYLvHRPAh2A5nGYcR2 -TENCENT_COS_REGION=ap-shanghai -TENCENT_COS_BUCKET=life-echo-dev-1319381411 -TENCENT_COS_BASE_URL=https://life-echo-dev-1319381411.cos.ap-shanghai.myqcloud.com +INTERNAL_EVAL_API_KEY= diff --git a/api/README.md b/api/README.md index 715351f..508e35d 100644 --- a/api/README.md +++ b/api/README.md @@ -17,9 +17,9 @@ Life Echo API 是一个智能对话系统,通过 WebSocket 实时连接,使 - **JSON 模式**:结构化抽取/路由/叙事 JSON 使用 `app/core/langchain_llm.py` 的 `bind_json_object_mode`(与 [DeepSeek JSON Output](https://api-docs.deepseek.com/guides/json_mode) 一致);详见 [`docs/llm-json-mode.md`](docs/llm-json-mode.md)。适配器说明见 [`app/adapters/llm/deepseek.py`](app/adapters/llm/deepseek.py)。 - **记忆检索**:异步与 Celery 均使用 **向量(pgvector)** chunks,见 [`docs/memory-retrieval.md`](docs/memory-retrieval.md)(含 async/sync **行为矩阵**)。 - **AI 相关代码扫描**:`uv run python scripts/ai_touchpoints_scan.py --markdown docs/ai-touchpoints.md`(在 `api/` 目录下执行)生成带标签的触点列表,见 [`docs/ai-touchpoints.md`](docs/ai-touchpoints.md)。 -- **与 AI 强相关的配置项(摘录)**:`CHAT_MEMORY_RETRIEVAL_ENABLED` / `MEMOIR_PHASE1_BATCH_LLM_ENABLED` / `MEMORY_ENRICHMENT_ENABLED` / `MEMORY_EVIDENCE_EMPTY_QUERY_INCLUDE_ROLLING` 等见 `app/core/config.py`;调参时建议对照 [`docs/memory-retrieval.md`](docs/memory-retrieval.md) 与 [`docs/ai-touchpoints.md`](docs/ai-touchpoints.md)。 -- **Memory compaction**:`.env.example` / [`.env.development`](.env.development) / [`.env.staging`](.env.staging) / [`.env.production`](.env.production) 均默认 `MEMORY_COMPACTION_ENABLED=true`。须运行 **Celery worker** 与 **celery-beat**([`docker-compose.yml`](docker-compose.yml) 已包含 `celery-beat`,用于定期 `memory_compaction_sweep`)。 -- **Memory LLM enrichment(单次 LLM:会话摘要 + 事实)**:任务路由到 **`memory_idle`** 队列(`CELERY_MEMORY_ENRICHMENT_QUEUE`,默认 `memory_idle`)。本地与 compose 内 worker 已使用 `-Q celery,memory_idle`;生产可单独起低并发 worker 只消费 `memory_idle`,与主队列隔离。 +- **与 AI 强相关的配置项**:产品调参 SSOT 为 [`config/*.toml`](config/default.toml)(经 `app/features/*/constants.py` 与 `app/core/runtime_constants.py` re-export);密钥见 [`.env.example`](.env.example)。详见 [`docs/configuration.md`](docs/configuration.md)。 +- **Memory compaction**:默认在 `config/default.toml` → `[memory]` 中开启。须运行 **Celery worker** 与 **celery-beat**([`docker-compose.yml`](docker-compose.yml) 已包含 `celery-beat`,用于定期 `memory_compaction_sweep`)。 +- **Memory LLM enrichment(单次 LLM:会话摘要 + 事实)**:任务路由到 **`memory_idle`** 队列(`config/default.toml` → `[celery] memory_enrichment_queue`)。本地与 compose 内 worker 已使用 `-Q celery,memory_idle`;生产可单独起低并发 worker 只消费 `memory_idle`,与主队列隔离。 ## 技术栈 @@ -41,7 +41,7 @@ docker compose -f docker-compose.dev.yml -f docker-compose.observability.yml up # Grafana: http://127.0.0.1:48300 ``` -在 `.env` 中配置 `OTEL_*`(见 [`.env.example`](.env.example)),与 Postgres/Redis 一样由 Settings 加载,无需 shell export。详见 [`docs/observability.md`](docs/observability.md)。 +在 `config/*.toml` 的 `[deploy]` 中配置 `otel_enabled` 与 `otel_exporter_otlp_endpoint`;采样策略等细项见 `[otel]` section 与 [`docs/observability.md`](docs/observability.md)。 ## 项目结构 @@ -51,7 +51,9 @@ api/ ├── app/ # 应用主包 │ ├── main.py # FastAPI 应用定义 │ ├── core/ # 核心基础设施 -│ │ ├── config.py # 配置(pydantic-settings) +│ │ ├── config.py # secrets / bootstrap(Settings + facade) +│ │ ├── app_config*.py # TOML 加载与 AppConfig +│ │ ├── runtime_constants.py # re-export config/*.toml runtime sections │ │ ├── db.py # 数据库连接 │ │ ├── redis.py # Redis 服务 │ │ ├── security.py # JWT、密码哈希 @@ -84,47 +86,42 @@ cd api pip install -r requirements.txt ``` -### 2. 环境变量配置 +### 2. 配置(TOML + .env) -创建 `.env` 文件(在 `api/` 目录下): +配置分两层 SSOT,详见 **[docs/configuration.md](docs/configuration.md)**。 + +| 层 | 来源 | 内容 | +|----|------|------| +| **Secrets / bootstrap** | [`.env.example`](.env.example) | `DATABASE_URL`、`SECRET_KEY`、API/支付/Liblib 密钥 | +| **非密钥** | [`config/default.toml`](config/default.toml) + `config/{APP_ENV}.toml` | 功能开关、SMS 模板、Chat/Memoir/Memory/Eval 调参、OTel 等 | + +本地开发:`.env.development`(密钥)+ `config/development.toml`(行为);`development.sh` 将前者同步为 `.env`。预发/生产:`.env.staging` / `.env.production` + 对应 `config/*.toml`。 + +最小 `.env` 示例: ```env -# DeepSeek API 配置(推荐,优先使用) -DEEPSEEK_API_KEY=your_deepseek_api_key_here -DEEPSEEK_BASE_URL=https://api.deepseek.com # 可选,默认值 -DEEPSEEK_MODEL=deepseek-chat # 可选,默认值 - -# 或使用通用 LLM 配置(支持其他兼容 OpenAI 的 LLM) -LLM_API_KEY=your_llm_api_key_here -LLM_BASE_URL=https://api.your-llm-provider.com # 可选 -LLM_MODEL=your-model-name # 可选,默认 deepseek-chat -LLM_TEMPERATURE=0.7 # 可选,默认 0.7 - -# 数据库配置(本地用 docker-compose.dev.yml 时为固定端口 48291,见下文「本地开发」) +APP_ENV=development DATABASE_URL=postgresql://postgres:postgres@localhost:48291/life_echo - -# Redis 配置(本地 compose.dev 固定端口 48307) REDIS_URL=redis://localhost:48307/0 - -# 认证配置 -SECRET_KEY=your-secret-key-here # JWT签名密钥(建议使用随机字符串) -ALGORITHM=HS256 # JWT算法(默认HS256) -ACCESS_TOKEN_EXPIRE_MINUTES=120 # 访问令牌过期时间(分钟,默认120即2小时) - -# 服务器配置(可选) -HOST=0.0.0.0 -PORT=8000 +SECRET_KEY=your-secret-key-here +DEEPSEEK_API_KEY=sk-... +ZHIPU_API_KEY=... +TENCENT_SECRET_ID=... +TENCENT_SECRET_KEY=... ``` -**LLM 配置优先级**: -1. `DEEPSEEK_API_KEY` - 优先使用 DeepSeek(推荐) -2. `LLM_API_KEY` - 通用 LLM 配置(支持其他兼容 OpenAI 格式的 LLM) +**腾讯云**:凭证仍在 env(`TENCENT_SECRET_ID/KEY`);短信模板 ID、COS 桶名等在 `config/*.toml` 的 `[deploy]` section。 -**DeepSeek 配置示例**: -```env -DEEPSEEK_API_KEY=sk-xxxxxxxxxxxxx -DEEPSEEK_MODEL=deepseek-chat -``` +业务代码读取 TOML 值仍可用原有 import(re-export): + +| 模块 | 路径 | +|------|------| +| 访谈 / 聊天 | `app/features/conversation/constants.py` → `chat` | +| 回忆录流水线 | `app/features/memoir/constants.py` → `memoir` | +| Story / 章节 | `app/features/story/constants.py` → `story` | +| 记忆富化 / compaction | `app/features/memory/constants.py` → `memory` | +| 内网评测 | `app/features/evaluation/constants.py` → `eval_cfg` | +| ASR/TTS/LLM/Celery 等 | `app/core/runtime_constants.py` | ### 3. 数据库迁移 @@ -290,11 +287,13 @@ Content-Type: application/json ```json { "access_token": "new-access-token", - "refresh_token": "same-refresh-token", + "refresh_token": "new-refresh-token-string", "token_type": "bearer" } ``` +每次刷新会轮换 refresh token(返回新的 refresh token,旧 token 立即失效)。在 `REFRESH_TOKEN_REUSE_GRACE_SECONDS`(默认 30 秒)窗口内重复使用已轮换的旧 token 视为幂等重试,返回新 access token 与当前 replacement refresh token;grace 窗口外再次使用则吊销该用户全部会话并返回 `REFRESH_TOKEN_REUSE`。 + ##### 用户登出 ```http POST /api/auth/logout @@ -333,6 +332,33 @@ Authorization: Bearer {access_token} Authorization: Bearer {access_token} ``` +### HTTP 错误契约 + +所有 HTTP 错误响应均为 `application/json`,统一格式: + +```json +{ + "error_code": "NOT_FOUND", + "message": "资源不存在", + "request_id": "550e8400-e29b-41d4-a716-446655440000" +} +``` + +- `error_code`:机器可读错误码(见 OpenAPI `ErrorResponse` / `ErrorCode` 组件) +- `message`:面向用户的说明 +- `request_id`:与响应头 `X-Request-Id` 一致,便于排查 + +**429 状态码语义**:HTTP 429 被两种错误码共用,客户端必须根据 `error_code` 分支,不能只看 status: + +| error_code | 含义 | +|------------|------| +| `QUOTA_EXCEEDED` | 配额已用尽(如对话次数) | +| `RATE_LIMITED` | 请求频率超限(如 SMS 发送冷却) | + +遗留 `HTTPException(status_code=429)` 默认映射为 `RATE_LIMITED`。 + +**CORS 与 credentials**:`api_cors_origins` 留空时,服务端使用 `allow_origins=["*"]` 且 `allow_credentials=False`;生产/staging 必须在 `config/staging.toml` / `config/production.toml` 的 `[deploy]` 中配置逗号分隔的前端域名。 + ### REST API #### 对话管理 (`/api/conversations`) @@ -406,6 +432,23 @@ ws://localhost:8000/ws/conversation/{conversation_id}?token={access_token} } ``` +**WebSocket 错误消息**(与 HTTP 错误契约不同,勿混用 `parseApiError`): + +服务端 → 客户端(配额不足等): +```json +{ + "type": "error", + "data": { + "message": "本月对话次数已用尽", + "code": "QUOTA_EXCEEDED" + }, + "timestamp": "2024-01-15T10:00:00Z" +} +``` + +- WS 帧使用 `data.code`(如 `QUOTA_EXCEEDED`),**不是** HTTP 的 `error_code` 字段 +- HTTP 客户端错误解析器(`parseApiError`)不适用于 WebSocket 消息 + ## 数据库模型 ### User(用户) @@ -525,13 +568,14 @@ app.include_router(your_router.router) ## 安全注意事项 -1. **CORS 配置**: 当前允许所有来源,生产环境应限制为特定域名 +1. **CORS 配置**: 本地开发默认可用 `allow_origins=["*"]`(`deploy.api_cors_origins` 留空);生产/staging 必须在 `config/staging.toml` / `config/production.toml` 的 `[deploy]` 中设置逗号分隔前端域名 2. **API Key 安全**: 确保 `.env` 文件不被提交到版本控制 3. **SECRET_KEY 安全**: 使用强随机字符串作为 JWT 签名密钥,生产环境必须更换 4. **密码安全**: 密码使用 bcrypt 哈希存储,不会以明文形式保存 5. **令牌安全**: - 访问令牌短期有效(2小时),降低泄露风险 - - 刷新令牌存储在数据库中,支持撤销 + - 刷新令牌存储在数据库中,支持撤销;每次 `/api/auth/refresh` 会轮换 refresh token + - 已轮换的 refresh token 被再次使用时,服务端吊销全部会话并返回 `REFRESH_TOKEN_REUSE` - 令牌过期后必须使用刷新令牌重新获取 6. **数据库备份**: 定期备份 PostgreSQL 数据库 7. **错误处理**: 所有 API 都包含适当的错误处理和权限验证 diff --git a/api/alembic/versions/0020_refresh_rt_lineage.py b/api/alembic/versions/0020_refresh_rt_lineage.py new file mode 100644 index 0000000..ffb891e --- /dev/null +++ b/api/alembic/versions/0020_refresh_rt_lineage.py @@ -0,0 +1,64 @@ +"""refresh_tokens:轮换 lineage(replaced_by_token_id + rotated_at) + +Revision ID: 0020_refresh_rt_lineage +Revises: 0019_align_legacy_schema +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +revision: str = "0020_refresh_rt_lineage" +down_revision: Union[str, None] = "0019_align_legacy_schema" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _column_names(table_name: str) -> set[str]: + bind = op.get_bind() + inspector = sa.inspect(bind) + return {column["name"] for column in inspector.get_columns(table_name)} + + +def upgrade() -> None: + columns = _column_names("refresh_tokens") + if "replaced_by_token_id" not in columns: + op.add_column( + "refresh_tokens", + sa.Column("replaced_by_token_id", sa.String(), nullable=True), + ) + op.create_index( + "ix_refresh_tokens_replaced_by_token_id", + "refresh_tokens", + ["replaced_by_token_id"], + unique=False, + ) + op.create_foreign_key( + "fk_refresh_tokens_replaced_by_token_id", + "refresh_tokens", + "refresh_tokens", + ["replaced_by_token_id"], + ["id"], + ondelete="SET NULL", + ) + if "rotated_at" not in columns: + op.add_column( + "refresh_tokens", + sa.Column("rotated_at", sa.DateTime(timezone=True), nullable=True), + ) + + +def downgrade() -> None: + columns = _column_names("refresh_tokens") + if "replaced_by_token_id" in columns: + op.drop_constraint( + "fk_refresh_tokens_replaced_by_token_id", + "refresh_tokens", + type_="foreignkey", + ) + op.drop_index("ix_refresh_tokens_replaced_by_token_id", table_name="refresh_tokens") + op.drop_column("refresh_tokens", "replaced_by_token_id") + if "rotated_at" in columns: + op.drop_column("refresh_tokens", "rotated_at") diff --git a/api/alembic/versions/0021_memory_source_segment_id.py b/api/alembic/versions/0021_memory_source_segment_id.py new file mode 100644 index 0000000..9d12850 --- /dev/null +++ b/api/alembic/versions/0021_memory_source_segment_id.py @@ -0,0 +1,83 @@ +"""memory_sources: segment_id for phase1 ingest idempotency + +Revision ID: 0021_memory_source_segment_id +Revises: 0020_refresh_rt_lineage +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +revision: str = "0021_memory_source_segment_id" +down_revision: Union[str, None] = "0020_refresh_rt_lineage" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _column_names(table_name: str) -> set[str]: + bind = op.get_bind() + inspector = sa.inspect(bind) + return {column["name"] for column in inspector.get_columns(table_name)} + + +def _index_names(table_name: str) -> set[str]: + bind = op.get_bind() + inspector = sa.inspect(bind) + return {index["name"] for index in inspector.get_indexes(table_name)} + + +def upgrade() -> None: + columns = _column_names("memory_sources") + if "segment_id" not in columns: + op.add_column( + "memory_sources", + sa.Column("segment_id", sa.String(), nullable=True), + ) + indexes = _index_names("memory_sources") + if "ix_memory_sources_user_segment_transcript" not in indexes: + op.create_index( + "ix_memory_sources_user_segment_transcript", + "memory_sources", + ["user_id", "segment_id"], + unique=True, + postgresql_where=sa.text( + "segment_id IS NOT NULL AND source_type = 'transcript'" + ), + ) + foreign_keys = { + fk["name"] + for fk in sa.inspect(op.get_bind()).get_foreign_keys("memory_sources") + } + if "fk_memory_sources_segment_id_segments" not in foreign_keys: + op.create_foreign_key( + "fk_memory_sources_segment_id_segments", + "memory_sources", + "segments", + ["segment_id"], + ["id"], + ondelete="SET NULL", + ) + + +def downgrade() -> None: + foreign_keys = { + fk["name"] + for fk in sa.inspect(op.get_bind()).get_foreign_keys("memory_sources") + } + if "fk_memory_sources_segment_id_segments" in foreign_keys: + op.drop_constraint( + "fk_memory_sources_segment_id_segments", + "memory_sources", + type_="foreignkey", + ) + indexes = _index_names("memory_sources") + if "ix_memory_sources_user_segment_transcript" in indexes: + op.drop_index( + "ix_memory_sources_user_segment_transcript", + table_name="memory_sources", + ) + columns = _column_names("memory_sources") + if "segment_id" in columns: + op.drop_column("memory_sources", "segment_id") diff --git a/api/app/adapters/image_gen/liblib_provider.py b/api/app/adapters/image_gen/liblib_provider.py index b307713..d93c1b9 100644 --- a/api/app/adapters/image_gen/liblib_provider.py +++ b/api/app/adapters/image_gen/liblib_provider.py @@ -16,6 +16,8 @@ import httpx from app.core.config import settings from app.core.logging import get_logger +from app.features.memoir.constants import memoir +from app.core.runtime_constants import misc_defaults logger = get_logger(__name__) @@ -62,7 +64,7 @@ class LiblibImageProvider: self.access_key = access_key or (settings.liblib_access_key or "") self.secret_key = secret_key or (settings.liblib_secret_key or "") self.base_url = ( - base_url or settings.liblib_base_url or "https://openapi.liblibai.cloud" + base_url or misc_defaults.liblib_base_url or "https://openapi.liblibai.cloud" ).rstrip("/") self.template_uuid = template_uuid or ( settings.liblib_template_uuid or DEFAULT_LIBLIB_TEMPLATE_UUID @@ -240,7 +242,7 @@ def _build_allowed_download_hosts( if configured_hosts is None: configured_hosts = tuple( host.strip().lower() - for host in (settings.memoir_image_download_hosts or "").split(",") + for host in (memoir.image_download_hosts or "").split(",") if host.strip() ) base_hostname = (urlparse(base_url).hostname or "").lower() diff --git a/api/app/adapters/llm/deepseek_eval_judge.py b/api/app/adapters/llm/deepseek_eval_judge.py index f7ca992..bdb5476 100644 --- a/api/app/adapters/llm/deepseek_eval_judge.py +++ b/api/app/adapters/llm/deepseek_eval_judge.py @@ -7,6 +7,8 @@ from langchain_openai import ChatOpenAI from app.adapters.llm.openai_base_url import normalize_openai_compatible_base_url from app.core.config import settings from app.core.eval_judge_spec import EvalJudgeLlmSpec +from app.features.evaluation.constants import eval_cfg +from app.core.runtime_constants import llm_defaults def resolve_deepseek_eval_judge_model( @@ -35,7 +37,7 @@ def resolve_deepseek_eval_judge_model( if m == "deepseek-v4-pro": return ("deepseek-v4-pro", None, "high") if m in ("", "deepseek-v4-flash"): - if settings.eval_judge_deepseek_thinking_enabled: + if eval_cfg.judge_deepseek_thinking_enabled: return ( "deepseek-v4-flash", {"thinking": {"type": "enabled"}}, @@ -55,23 +57,23 @@ def build_deepseek_eval_judge_spec( judge_model: str | None, ) -> EvalJudgeLlmSpec | None: """密钥缺失时返回 None。""" - api_key = (settings.deepseek_api_key or settings.llm_api_key or "").strip() + api_key = (settings.deepseek_api_key or "").strip() if not api_key: return None want = (judge_model or "").strip() base = normalize_openai_compatible_base_url( - settings.deepseek_base_url, + llm_defaults.deepseek_base_url, fallback="https://api.deepseek.com", ) - default_m = (settings.eval_judge_deepseek_model or "deepseek-v4-flash").strip() + default_m = (eval_cfg.judge_deepseek_model or "deepseek-v4-flash").strip() combined = want or default_m model, extra, effort = resolve_deepseek_eval_judge_model(combined) - ctx = int(settings.eval_judge_deepseek_context_window_tokens) + ctx = int(eval_cfg.judge_deepseek_context_window_tokens) llm_kw: dict = { "api_key": api_key, "base_url": base, "model": model, - "temperature": settings.eval_judge_temperature, + "temperature": eval_cfg.judge_temperature, } if extra is not None: llm_kw["extra_body"] = extra diff --git a/api/app/adapters/llm/zhipu_eval_judge.py b/api/app/adapters/llm/zhipu_eval_judge.py index 756b164..89f5a95 100644 --- a/api/app/adapters/llm/zhipu_eval_judge.py +++ b/api/app/adapters/llm/zhipu_eval_judge.py @@ -7,27 +7,28 @@ from langchain_openai import ChatOpenAI from app.adapters.llm.openai_base_url import normalize_openai_compatible_base_url from app.core.config import settings from app.core.eval_judge_spec import EvalJudgeLlmSpec +from app.features.evaluation.constants import eval_cfg def build_zhipu_eval_judge_spec( judge_model: str | None, ) -> EvalJudgeLlmSpec | None: """密钥缺失时返回 None。""" - api_key = (settings.eval_judge_api_key or settings.zhipu_api_key or "").strip() + api_key = (settings.zhipu_api_key or "").strip() if not api_key: return None want = (judge_model or "").strip() base = normalize_openai_compatible_base_url( - settings.eval_judge_base_url, + eval_cfg.judge_base_url, fallback="https://open.bigmodel.cn/api/paas/v4", ) - model = want or (settings.eval_judge_model or "glm-5") - ctx = int(settings.eval_judge_context_window_tokens) + model = want or (eval_cfg.judge_model or "glm-5") + ctx = int(eval_cfg.judge_context_window_tokens) llm_kw: dict = { "api_key": api_key, "base_url": base, "model": model, - "temperature": settings.eval_judge_temperature, + "temperature": eval_cfg.judge_temperature, } return EvalJudgeLlmSpec( llm=ChatOpenAI(**llm_kw), diff --git a/api/app/adapters/tts/tencent_tts.py b/api/app/adapters/tts/tencent_tts.py index c00fa15..dd76e98 100644 --- a/api/app/adapters/tts/tencent_tts.py +++ b/api/app/adapters/tts/tencent_tts.py @@ -7,6 +7,7 @@ import uuid from app.core.business_telemetry import business_span from app.core.logging import get_logger +from app.core.runtime_constants import tts_defaults logger = get_logger(__name__) @@ -208,7 +209,7 @@ class TencentTTSProvider: max_chars = MAX_CHARS_PER_REQUEST_EN if is_en else MAX_CHARS_PER_REQUEST_ZH # Default "alloy" aligns with OpenAI TTS naming. Caller 链路里目前不会传具体音色, - # 因此实际只走 default_voice 分支,对应 settings.tts_voice_type / tts_voice_type_en。 + # 因此实际只走 default_voice 分支,对应 tts_defaults.voice_type / tts_voice_type_en。 v = voice.lower() if v == "alloy": voice_type = default_voice diff --git a/api/app/agents/chat/interview_agent.py b/api/app/agents/chat/interview_agent.py index 15293c8..a394789 100644 --- a/api/app/agents/chat/interview_agent.py +++ b/api/app/agents/chat/interview_agent.py @@ -48,6 +48,9 @@ from app.core.config import settings from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger from app.features.conversation.input_normalize import normalize_chat_input_for_agent +from app.core.runtime_constants import agent_log_defaults +from app.features.conversation.constants import chat +from app.features.story.constants import story logger = get_logger(__name__) @@ -171,8 +174,8 @@ class InterviewAgent: if normalized_user_message is not None: return (normalized_user_message or "").strip() llm_n = None - if settings.chat_input_normalize_enabled and ( - (settings.chat_input_normalize_mode or "").strip().lower() == "llm" + if chat.input_normalize_enabled and ( + (chat.input_normalize_mode or "").strip().lower() == "llm" ): llm_n = self.llm return normalize_chat_input_for_agent(user_message or "", llm=llm_n) @@ -218,16 +221,16 @@ class InterviewAgent: du = self._detect_user_stage(text_for_model) hw = await get_history_with_window( conversation_id, - max_pairs=settings.chat_history_max_pairs, - max_chars=settings.chat_history_max_chars, + max_pairs=chat.history_max_pairs, + max_chars=chat.history_max_chars, ) recent_questions = extract_recent_questions(hw.window) conversation_turn_total = hw.turn_total all_stages_coverage = narrative_state.all_stages_coverage() - persona = normalize_interview_persona(settings.chat_interview_persona) - max_segments = int(settings.chat_interview_max_segments) - max_tokens = int(settings.chat_interview_max_tokens) - max_chars = int(settings.chat_interview_max_chars_per_segment) + persona = normalize_interview_persona(chat.interview_persona) + max_segments = int(chat.interview_max_segments) + max_tokens = int(chat.interview_max_tokens) + max_chars = int(chat.interview_max_chars_per_segment) turn_plan = plan_interview_turn( current_stage=memoir_state.current_stage, @@ -246,7 +249,7 @@ class InterviewAgent: reply_planner_raw = "" baseline_mode = turn_plan.mode baseline_primary_focus = turn_plan.primary_focus - if settings.chat_reply_planner_llm_enabled: + if chat.reply_planner_llm_enabled: rq_preview = ( "\n".join(recent_questions[-4:]) if recent_questions @@ -258,8 +261,8 @@ class InterviewAgent: text_for_model=text_for_model, memory_evidence_text=(memory_planner_text or memory_evidence_text) or "", - max_tokens=int(settings.chat_reply_planner_max_tokens), - temperature=float(settings.chat_reply_planner_temperature), + max_tokens=int(chat.reply_planner_max_tokens), + temperature=float(chat.reply_planner_temperature), scene_cues_for_planner=scene_cues_for_planner or [], recent_questions_preview=rq_preview, ) @@ -310,12 +313,12 @@ class InterviewAgent: "InterviewAgent.generate_response.prompt", format_history_string( messages, - omit_system_body=settings.agent_log_omit_system_message_body, + omit_system_body=agent_log_defaults.omit_system_message_body, ), ) chat_llm = self.llm.bind( max_tokens=max_tokens, - temperature=float(settings.chat_interview_temperature), + temperature=float(chat.interview_temperature), ) prompt_chars = _message_contents_char_count(messages) llm_t0 = time.perf_counter() @@ -377,7 +380,7 @@ class InterviewAgent: "InterviewAgent.generate_response.retry_prompt", format_history_string( retry_messages, - omit_system_body=settings.agent_log_omit_system_message_body, + omit_system_body=agent_log_defaults.omit_system_message_body, ), ) llm_t1 = time.perf_counter() @@ -445,7 +448,7 @@ class InterviewAgent: "duplicate_question_guard_llm_retry": retry_used, "autobiographical_boundary_guard_triggered": auto_bio, "reply_planner_llm_used": bool( - settings.chat_reply_planner_llm_enabled + chat.reply_planner_llm_enabled and (reply_planner_raw or "").strip() ), "reply_planner_raw_preview": (reply_planner_raw or "")[:800], @@ -483,7 +486,7 @@ class InterviewAgent: ) slot_table = SLOT_NAME_MAP_EN if language == "en" else SLOT_NAME_MAP empty_slots_readable = [slot_table.get(s, s) for s in empty_slots] - persona = normalize_interview_persona(settings.chat_interview_persona) + persona = normalize_interview_persona(chat.interview_persona) prompt = get_opening_prompt( current_stage=memoir_state.current_stage, empty_slots_readable=empty_slots_readable, @@ -497,8 +500,8 @@ class InterviewAgent: ) hw = await get_history_with_window( conversation_id, - max_pairs=settings.chat_history_max_pairs, - max_chars=settings.chat_history_max_chars, + max_pairs=chat.history_max_pairs, + max_chars=chat.history_max_chars, ) messages: List[Any] = [SystemMessage(content=prompt)] messages.extend(hw.window) @@ -520,12 +523,12 @@ class InterviewAgent: "InterviewAgent.opening.prompt", format_history_string( messages, - omit_system_body=settings.agent_log_omit_system_message_body, + omit_system_body=agent_log_defaults.omit_system_message_body, ), ) opening_llm = self.llm.bind( - max_tokens=settings.chat_opening_max_tokens, - temperature=float(settings.chat_interview_temperature), + max_tokens=chat.opening_max_tokens, + temperature=float(chat.interview_temperature), ) prompt_chars = _message_contents_char_count(messages) llm_t0 = time.perf_counter() @@ -564,7 +567,7 @@ class InterviewAgent: raw_list = segments_from_llm_response(response_text, max_segments=2) if not raw_list: raw_list = [response_text.strip()] - max_chars = int(settings.chat_interview_max_chars_per_segment) + max_chars = int(chat.interview_max_chars_per_segment) out = truncate_chat_segments( raw_list, max_segments=2, @@ -612,7 +615,7 @@ class InterviewAgent: ) slot_table = SLOT_NAME_MAP_EN if language == "en" else SLOT_NAME_MAP empty_slots_readable = [slot_table.get(s, s) for s in empty_slots] - persona = normalize_interview_persona(settings.chat_interview_persona) + persona = normalize_interview_persona(chat.interview_persona) prompt = get_re_greeting_prompt( current_stage=memoir_state.current_stage, empty_slots_readable=empty_slots_readable, @@ -627,8 +630,8 @@ class InterviewAgent: ) hw = await get_history_with_window( conversation_id, - max_pairs=settings.chat_history_max_pairs, - max_chars=settings.chat_history_max_chars, + max_pairs=chat.history_max_pairs, + max_chars=chat.history_max_chars, ) messages: List[Any] = [SystemMessage(content=prompt)] messages.extend(hw.window) @@ -647,12 +650,12 @@ class InterviewAgent: "InterviewAgent.re_greeting.prompt", format_history_string( messages, - omit_system_body=settings.agent_log_omit_system_message_body, + omit_system_body=agent_log_defaults.omit_system_message_body, ), ) re_greet_llm = self.llm.bind( - max_tokens=settings.chat_opening_max_tokens, - temperature=float(settings.chat_interview_temperature), + max_tokens=chat.opening_max_tokens, + temperature=float(chat.interview_temperature), ) llm_t0 = time.perf_counter() with agent_span( @@ -691,7 +694,7 @@ class InterviewAgent: raw_list = segments_from_llm_response(response_text, max_segments=2) if not raw_list: raw_list = [response_text.strip()] - max_chars = int(settings.chat_interview_max_chars_per_segment) + max_chars = int(chat.interview_max_chars_per_segment) out = truncate_chat_segments( raw_list, max_segments=2, diff --git a/api/app/agents/chat/orchestrator.py b/api/app/agents/chat/orchestrator.py index dce776e..1f9a071 100644 --- a/api/app/agents/chat/orchestrator.py +++ b/api/app/agents/chat/orchestrator.py @@ -35,6 +35,9 @@ from app.features.memoir.state_service import ( switch_stage, ) from app.features.memory.prompt_adapter import MemoryPromptAdapter +from app.features.conversation.constants import chat +from app.features.memory.constants import memory +from app.features.story.constants import story def _llm_for_chat_input_normalize(): @@ -80,7 +83,7 @@ async def _fetch_interview_memory_bundle( ) from app.features.memory.service import MemoryService - if not settings.chat_memory_retrieval_enabled: + if not chat.memory_retrieval_enabled: logger.debug( "event=chat_memory_retrieval_skip reason=disabled user_id={}", user_id ) @@ -94,7 +97,7 @@ async def _fetch_interview_memory_bundle( try: emb = get_embedding_provider_fn() ms = MemoryService(db, embedding_provider=emb) - top_k = settings.chat_memory_top_k + top_k = chat.memory_top_k bundle = await ms.retrieve(user_id, msg, top_k=top_k) bd = bundle.model_dump() trace = chat_memory_retrieval_trace_from_bundle( @@ -164,17 +167,17 @@ class ChatOrchestrator: if missing: hw_profile = await get_history_with_window( conversation_id, - max_pairs=settings.chat_history_max_pairs, - max_chars=settings.chat_history_max_chars, + max_pairs=chat.history_max_pairs, + max_chars=chat.history_max_chars, ) profile_turn_total = hw_profile.turn_total - if profile_turn_total >= settings.chat_profile_max_turns: + if profile_turn_total >= chat.profile_max_turns: logger.info( "event=chat_profile_cap_skip conversation_id={} " "turn_total={} cap={} missing_fields={}", conversation_id, profile_turn_total, - settings.chat_profile_max_turns, + chat.profile_max_turns, missing, ) else: @@ -269,8 +272,8 @@ class ChatOrchestrator: len(user_message or ""), ) llm_n = None - if settings.chat_input_normalize_enabled and ( - (settings.chat_input_normalize_mode or "").strip().lower() == "llm" + if chat.input_normalize_enabled and ( + (chat.input_normalize_mode or "").strip().lower() == "llm" ): llm_n = _llm_for_chat_input_normalize() normalized_user_message = normalize_chat_input_for_agent( @@ -290,8 +293,10 @@ class ChatOrchestrator: state = await switch_stage(user_id, detected, db) if conversation and conversation.conversation_stage != state.current_stage: - conversation.conversation_stage = state.current_stage - await db.commit() + from app.core.db import transactional + + async with transactional(db): + conversation.conversation_stage = state.current_stage from app.agents.chat.background_voice import infer_background_voice from app.agents.chat.prompts_profile import format_user_profile_context diff --git a/api/app/agents/chat/personas.py b/api/app/agents/chat/personas.py index 7f28750..f4fe03e 100644 --- a/api/app/agents/chat/personas.py +++ b/api/app/agents/chat/personas.py @@ -7,6 +7,8 @@ from __future__ import annotations from typing import Final +from app.features.conversation.constants import chat + # Brand / interviewer name — keep aligned with frontend i18n `conversation.agentName`, # OpenAPI title, README, and project metadata. zh = 「岁月知己」,en = Life Echo. AGENT_NAME_ZH: Final[str] = "岁月知己" @@ -18,7 +20,7 @@ def agent_name(language: str = "zh") -> str: return AGENT_NAME_EN if (language or "zh").strip().lower() == "en" else AGENT_NAME_ZH -# 与 settings.chat_interview_persona 及文档保持一致 +# 与 chat.interview_persona 及文档保持一致 VALID_INTERVIEW_PERSONAS: Final[frozenset[str]] = frozenset( {"default", "warm_listener", "curious_guide"} ) diff --git a/api/app/agents/chat/profile_agent.py b/api/app/agents/chat/profile_agent.py index c479559..e3f5da5 100644 --- a/api/app/agents/chat/profile_agent.py +++ b/api/app/agents/chat/profile_agent.py @@ -26,6 +26,9 @@ from app.core.llm_call import allm_json_call from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger from app.ports.llm import LLMProvider +from app.core.runtime_constants import agent_log_defaults +from app.features.conversation.constants import chat +from app.features.story.constants import story logger = get_logger(__name__) @@ -207,8 +210,8 @@ class ProfileAgent: if conversation_id: hw = await get_history_with_window( conversation_id, - max_pairs=settings.chat_history_max_pairs, - max_chars=settings.chat_history_max_chars, + max_pairs=chat.history_max_pairs, + max_chars=chat.history_max_chars, ) recent = hw.window[-4:] if len(hw.window) > 4 else hw.window parts = [] @@ -232,7 +235,7 @@ class ProfileAgent: ProfileExtractionOutput, use_case=LlmUseCase( "ProfileAgent.extract_profile_from_message", - max_tokens=settings.chat_profile_extract_max_tokens, + max_tokens=chat.profile_extract_max_tokens, ), fallback_factory=lambda: ProfileExtractionOutput(), ) @@ -285,8 +288,8 @@ class ProfileAgent: ) hw = await get_history_with_window( conversation_id, - max_pairs=settings.chat_history_max_pairs, - max_chars=settings.chat_history_max_chars, + max_pairs=chat.history_max_pairs, + max_chars=chat.history_max_chars, ) messages: List[Any] = [SystemMessage(content=prompt)] messages.extend(hw.window) @@ -296,7 +299,7 @@ class ProfileAgent: "ProfileAgent.followup.prompt", format_history_string( messages, - omit_system_body=settings.agent_log_omit_system_message_body, + omit_system_body=agent_log_defaults.omit_system_message_body, ), ) prompt_chars = _message_contents_char_count(messages) @@ -309,14 +312,14 @@ class ProfileAgent: ) response_text = await self._invoke_chat( messages, - max_tokens=settings.chat_profile_followup_max_tokens, + max_tokens=chat.profile_followup_max_tokens, conversation_id=conversation_id, agent_name="ProfileAgent.generate_profile_followup", ) segments = await self._segments_from_response( response_text, max_segments=3, - max_chars_per_segment=settings.chat_interview_max_chars_per_segment, + max_chars_per_segment=chat.interview_max_chars_per_segment, fallback=_profile_followup_fallback(language), ) log_agent_summary( @@ -344,8 +347,8 @@ class ProfileAgent: ) hw = await get_history_with_window( conversation_id, - max_pairs=settings.chat_history_max_pairs, - max_chars=settings.chat_history_max_chars, + max_pairs=chat.history_max_pairs, + max_chars=chat.history_max_chars, ) messages: List[Any] = [SystemMessage(content=prompt)] messages.extend(hw.window) @@ -367,7 +370,7 @@ class ProfileAgent: "ProfileAgent.greeting.prompt", format_history_string( messages, - omit_system_body=settings.agent_log_omit_system_message_body, + omit_system_body=agent_log_defaults.omit_system_message_body, ), ) prompt_chars = _message_contents_char_count(messages) @@ -380,14 +383,14 @@ class ProfileAgent: ) response_text = await self._invoke_chat( messages, - max_tokens=settings.chat_profile_followup_max_tokens, + max_tokens=chat.profile_followup_max_tokens, conversation_id=conversation_id, agent_name="ProfileAgent.generate_profile_greeting", ) segments = await self._segments_from_response( response_text, max_segments=2, - max_chars_per_segment=settings.chat_interview_max_chars_per_segment, + max_chars_per_segment=chat.interview_max_chars_per_segment, fallback=_profile_greeting_fallback(language), ) log_agent_summary( diff --git a/api/app/agents/chat/prompts_conversation.py b/api/app/agents/chat/prompts_conversation.py index b1eb844..1089af9 100644 --- a/api/app/agents/chat/prompts_conversation.py +++ b/api/app/agents/chat/prompts_conversation.py @@ -34,6 +34,7 @@ from app.agents.stage_constants import ( ) from app.agents.state_schema import KnownFact, PersonaThread from app.core.config import settings +from app.features.conversation.constants import chat # 风格示例的单一事实源已迁至 `app.agents.style_profiles.ChatStyleProfile.reply_style_examples`; # 这里**不再**维护具体字面示例,避免同一模块被当作 few-shot 锚点反复注入,导致风格过拟合。 @@ -292,7 +293,7 @@ def get_opening_prompt( era_opening_line = "" if ( - settings.chat_era_context_enabled + chat.era_context_enabled and profile_birth_year is not None and _compact_era_hint( current_stage, @@ -450,7 +451,7 @@ def get_guided_conversation_prompt( ) era_line = "" - if settings.chat_era_context_enabled: + if chat.era_context_enabled: era_line = _compact_era_hint( active_stage, birth_year=profile_birth_year, @@ -696,7 +697,7 @@ _STAGE_TOPIC_CHIP_BANK: Dict[str, List[tuple[str, str]]] = { ("support", "家人之间的相互支持"), ("responsibility", "肩上的家庭责任"), ], - "later_life": [ + "belief": [ ("value", "现在最看重的事"), ("regret", "心里的遗憾"), ("pride", "最骄傲的事"), diff --git a/api/app/agents/chat/stage_detection.py b/api/app/agents/chat/stage_detection.py index 94b8d27..7f8b577 100644 --- a/api/app/agents/chat/stage_detection.py +++ b/api/app/agents/chat/stage_detection.py @@ -20,6 +20,7 @@ from app.agents.stage_constants import ( from app.core.config import settings from app.core.llm_call import allm_json_call from app.core.logging import get_logger +from app.features.conversation.constants import chat logger = get_logger(__name__) @@ -59,7 +60,7 @@ async def detect_primary_life_stage( 每轮在启用时调用阶段检测 LLM(短句亦由模型判断,不用关键词替代)。 """ fb = normalize_chat_stage(current_stage, "childhood") - if not settings.chat_stage_detection_enabled: + if not chat.stage_detection_enabled: return _keyword_fallback_stage(user_message, fb) if not llm: @@ -76,7 +77,7 @@ async def detect_primary_life_stage( llm, prompt, StageDetectionOutput, - max_tokens=settings.chat_stage_detection_max_tokens, + max_tokens=chat.stage_detection_max_tokens, agent="detect_primary_life_stage", fallback_factory=fallback_factory, ) diff --git a/api/app/agents/image_prompt/orchestrator.py b/api/app/agents/image_prompt/orchestrator.py index 332517c..dd7020b 100644 --- a/api/app/agents/image_prompt/orchestrator.py +++ b/api/app/agents/image_prompt/orchestrator.py @@ -12,6 +12,7 @@ from app.agents.image_prompt.prompt_agent import PromptGenerationAgent from app.core.config import settings from app.core.logging import get_logger from app.features.memoir.memoir_images.settings import MemoirImageSettings +from app.features.memoir.constants import memoir logger = get_logger(__name__) @@ -84,7 +85,7 @@ def get_image_prompt_orchestrator() -> ImagePromptOrchestrator: try: llm = LlmGateway().langchain_llm_for(LlmUseCase("image_prompt")) except Exception as e: - if settings.image_prompt_fallback_disabled: + if memoir.image_prompt_fallback_disabled: raise logger.warning( "ImagePromptOrchestrator LLM 初始化失败,使用确定性 fallback: {}", diff --git a/api/app/agents/memoir/batch_phase1_prep.py b/api/app/agents/memoir/batch_phase1_prep.py index 6849673..c45d9bf 100644 --- a/api/app/agents/memoir/batch_phase1_prep.py +++ b/api/app/agents/memoir/batch_phase1_prep.py @@ -15,6 +15,7 @@ from app.core.config import settings from app.core.llm_call import LLMCallError, llm_json_call from app.core.logging import get_logger from app.features.conversation.models import Segment +from app.features.memoir.constants import memoir logger = get_logger(__name__) @@ -68,7 +69,7 @@ def run_batch_phase1_prep( llm, prompt, BatchPhase1LLMOutput, - max_tokens=int(settings.memoir_phase1_batch_llm_max_tokens), + max_tokens=int(memoir.phase1_batch_llm_max_tokens), agent="BatchPhase1Prep.run", ) except LLMCallError as e: diff --git a/api/app/agents/memoir/classification_agent.py b/api/app/agents/memoir/classification_agent.py index e7b6997..7c9b6c3 100644 --- a/api/app/agents/memoir/classification_agent.py +++ b/api/app/agents/memoir/classification_agent.py @@ -26,6 +26,7 @@ from app.core.config import settings from app.core.json_utils import extract_json_payload from app.core.llm_call import LLMCallError, llm_json_call from app.core.logging import get_logger +from app.features.memoir.constants import memoir logger = get_logger(__name__) @@ -144,7 +145,7 @@ class ClassificationAgent: llm, prompt, ClassificationOutput, - max_tokens=settings.memoir_classification_max_tokens, + max_tokens=memoir.classification_max_tokens, agent="ClassificationAgent.classify", ) category = _normalize_llm_category(out.category) diff --git a/api/app/agents/memoir/extraction_agent.py b/api/app/agents/memoir/extraction_agent.py index e3a4b2d..e6c1a0c 100644 --- a/api/app/agents/memoir/extraction_agent.py +++ b/api/app/agents/memoir/extraction_agent.py @@ -14,6 +14,7 @@ from app.agents.stage_constants import normalize_chat_stage from app.core.config import settings from app.core.llm_call import LLMCallError, llm_json_call from app.core.logging import get_logger +from app.features.memoir.constants import memoir logger = get_logger(__name__) @@ -64,7 +65,7 @@ class ExtractionAgent: llm, prompt, StateExtractionOutput, - max_tokens=settings.memoir_extraction_max_tokens, + max_tokens=memoir.extraction_max_tokens, agent="ExtractionAgent.extract", ) raw_slots = parsed.slots or {} diff --git a/api/app/agents/memoir/fidelity_check_agent.py b/api/app/agents/memoir/fidelity_check_agent.py index 29d9a94..c9ddb10 100644 --- a/api/app/agents/memoir/fidelity_check_agent.py +++ b/api/app/agents/memoir/fidelity_check_agent.py @@ -13,6 +13,7 @@ from app.agents.memoir.schemas import FidelityOutput from app.core.config import settings from app.core.llm_call import LLMCallError, llm_json_call from app.core.logging import get_logger +from app.features.memoir.constants import memoir logger = get_logger(__name__) @@ -46,7 +47,7 @@ class FidelityCheckAgent: existing_canonical_markdown: str | None = None, is_append: bool = False, ) -> bool: - if not llm or not settings.memoir_fidelity_check_enabled: + if not llm or not memoir.fidelity_check_enabled: return True oral = (oral_text or "").strip() gen = (narrative_json or "").strip() @@ -108,7 +109,7 @@ class FidelityCheckAgent: llm, prompt, FidelityOutput, - max_tokens=settings.memoir_fidelity_check_max_tokens, + max_tokens=memoir.fidelity_check_max_tokens, agent="FidelityCheckAgent.passes", ) ok = bool(out.pass_) @@ -120,7 +121,7 @@ class FidelityCheckAgent: return ok except LLMCallError as e: logger.warning("FidelityCheckAgent 解析失败: {}", e) - if is_append or settings.memoir_fidelity_fail_open_on_parse_error: + if is_append or memoir.fidelity_fail_open_on_parse_error: logger.info("event=fidelity_parse_fail_open is_append={}", is_append) return True logger.warning("event=fidelity_parse_fail_closed") diff --git a/api/app/agents/memoir/narrative_agent.py b/api/app/agents/memoir/narrative_agent.py index 66f86f2..663d9a9 100644 --- a/api/app/agents/memoir/narrative_agent.py +++ b/api/app/agents/memoir/narrative_agent.py @@ -18,6 +18,7 @@ from app.core.config import settings from app.core.langchain_llm import invoke_json_object from app.core.llm_call import llm_json_call from app.core.logging import get_logger +from app.features.memoir.constants import memoir logger = get_logger(__name__) @@ -63,7 +64,7 @@ class NarrativeAgent: llm, prompt, MemoirTitleOutput, - max_tokens=settings.memoir_title_max_tokens, + max_tokens=memoir.title_max_tokens, agent="NarrativeAgent.generate_title", fallback_factory=_title_fallback, ) @@ -118,7 +119,7 @@ class NarrativeAgent: occupation=occupation, language=language, ) - max_tokens = int(settings.memoir_narrative_merge_max_tokens) + max_tokens = int(memoir.narrative_merge_max_tokens) agent_name = "NarrativeAgent.generate_narrative_merge" else: prompt = get_narrative_json_prompt( @@ -132,7 +133,7 @@ class NarrativeAgent: occupation=occupation, language=language, ) - max_tokens = int(settings.memoir_narrative_max_tokens) + max_tokens = int(memoir.narrative_max_tokens) agent_name = "NarrativeAgent.generate_narrative" return invoke_json_object( llm, diff --git a/api/app/agents/memoir/orchestrator.py b/api/app/agents/memoir/orchestrator.py index 67ac74e..d2e6a98 100644 --- a/api/app/agents/memoir/orchestrator.py +++ b/api/app/agents/memoir/orchestrator.py @@ -29,6 +29,7 @@ from app.core.agent_logging import agent_span, agent_summary_enabled, log_agent_ from app.core.config import settings from app.core.logging import get_logger from app.features.conversation.models import Segment +from app.features.memoir.constants import memoir logger = get_logger(__name__) @@ -90,7 +91,7 @@ class MemoirOrchestrator: use_batch = ( bool(segments) and classify_extract_llm is not None - and settings.memoir_phase1_batch_llm_enabled + and memoir.phase1_batch_llm_enabled ) if use_batch: try: @@ -204,7 +205,7 @@ class MemoirOrchestrator: segments, state, classify_extract_llm, - chunk_size=int(settings.memoir_phase1_batch_llm_chunk_size), + chunk_size=int(memoir.phase1_batch_llm_chunk_size), on_chunk=on_phase1_chunk, language=language, ) diff --git a/api/app/agents/memoir/story_route_agent.py b/api/app/agents/memoir/story_route_agent.py index 6e1a00d..730a25c 100644 --- a/api/app/agents/memoir/story_route_agent.py +++ b/api/app/agents/memoir/story_route_agent.py @@ -21,6 +21,8 @@ from app.core.config import settings from app.core.llm_call import LLMCallError, llm_json_call from app.core.logging import get_logger from app.features.story.models import Story +from app.features.memoir.constants import memoir +from app.features.story.constants import story logger = get_logger(__name__) @@ -47,7 +49,7 @@ def default_append_target_story_id( ordered = sort_stories_for_route( candidate_stories, meta, - summary_min_chars=int(settings.story_route_summary_min_chars), + summary_min_chars=int(story.route_summary_min_chars), ) if not ordered: return None @@ -247,7 +249,7 @@ class StoryRouteAgent: llm, prompt, StoryRouteDecision, - max_tokens=settings.memoir_story_route_max_tokens, + max_tokens=memoir.story_route_max_tokens, agent="StoryRouteAgent.decide", fallback_factory=_decide_fallback, ) @@ -295,7 +297,7 @@ class StoryRouteAgent: llm, prompt, StoryBatchPlan, - max_tokens=settings.memoir_story_batch_plan_max_tokens, + max_tokens=memoir.story_batch_plan_max_tokens, agent="StoryRouteAgent.plan_batch", ) except LLMCallError as e: diff --git a/api/app/agents/memoir/story_route_payload.py b/api/app/agents/memoir/story_route_payload.py index bda7ad0..dc5598b 100644 --- a/api/app/agents/memoir/story_route_payload.py +++ b/api/app/agents/memoir/story_route_payload.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from app.core.config import Settings from app.features.story.models import Story +from app.features.story.constants import story _PLAIN_SNIPPET_NOISE = re.compile(r"[`*_#]+") @@ -213,11 +214,11 @@ def build_route_candidate_rows( ) -> list[dict[str, Any]]: """排序 + 完整候选行(尚未做总预算降级)。""" meta = story_meta or {} - summary_min = int(settings.story_route_summary_min_chars) + summary_min = int(story.route_summary_min_chars) ordered = sort_stories_for_route(stories, meta, summary_min_chars=summary_min) - body_max = int(settings.story_route_candidate_body_max_chars) - head_c = int(settings.story_route_long_body_head_chars) - tail_c = int(settings.story_route_long_body_tail_chars) + body_max = int(story.route_candidate_body_max_chars) + head_c = int(story.route_long_body_head_chars) + tail_c = int(story.route_long_body_tail_chars) rows: list[dict[str, Any]] = [] for s in ordered: rows.append( @@ -231,8 +232,8 @@ def build_route_candidate_rows( ) ) by_id = {str(s.id): s for s in ordered} - total_max = int(settings.story_route_candidate_total_max_chars) - index_prev = int(settings.story_route_index_preview_chars) + total_max = int(story.route_candidate_total_max_chars) + index_prev = int(story.route_index_preview_chars) return apply_total_budget_downgrade( rows, stories_by_id=by_id, diff --git a/api/app/core/agent_logging.py b/api/app/core/agent_logging.py index b468a99..4786a59 100644 --- a/api/app/core/agent_logging.py +++ b/api/app/core/agent_logging.py @@ -29,6 +29,7 @@ from opentelemetry.trace import Status, StatusCode from app.core.config import settings from app.core.telemetry import get_tracer +from app.core.runtime_constants import agent_log_defaults _dedup_lock = threading.Lock() _last_prompt_sha256_by_label: dict[str, str] = {} @@ -52,14 +53,14 @@ def agent_summary_enabled() -> bool: """是否输出单行 INFO 摘要(耗时、规模等),不依赖全局 DEBUG。""" if agent_verbose_enabled(): return True - return bool(settings.log_agent_verbose) + return bool(agent_log_defaults.agent_verbose) def truncate_for_log(text: str | None, *, max_chars: int | None = None) -> str: """截断过长文本,避免日志爆量。max_chars / AGENT_LOG_MAX_CHARS 为 0 表示不截断。""" if text is None: return "" - max_c = max_chars if max_chars is not None else settings.agent_log_max_chars + max_c = max_chars if max_chars is not None else agent_log_defaults.max_chars s = str(text) if max_c <= 0 or len(s) <= max_c: return s @@ -105,7 +106,7 @@ def agent_span( def _log_end(ms: float) -> None: if agent_verbose_enabled(): logger.debug("agent_span_end {} duration_ms={:.2f} {}", operation, ms, ctx) - elif settings.log_agent_verbose: + elif agent_log_defaults.agent_verbose: logger.info("agent_span {} duration_ms={:.2f} {}", operation, ms, ctx) if settings.otel_enabled: @@ -153,7 +154,7 @@ def log_agent_payload( sha12 = digest[:12] is_prompt = label.endswith(".prompt") - if is_prompt and settings.agent_log_prompt_dedup: + if is_prompt and agent_log_defaults.prompt_dedup: with _dedup_lock: if _last_prompt_sha256_by_label.get(label) == digest: logger.debug( @@ -169,14 +170,14 @@ def log_agent_payload( extra_note = "" if ( is_prompt - and settings.agent_log_json_prompt_prefix_chars > 0 - and total_len > settings.agent_log_json_prompt_prefix_only_if_len_gt + and agent_log_defaults.json_prompt_prefix_chars > 0 + and total_len > agent_log_defaults.json_prompt_prefix_only_if_len_gt ): - skip = settings.agent_log_json_prompt_prefix_chars + skip = agent_log_defaults.json_prompt_prefix_chars preview_source = raw[skip:] extra_note = f" skipped_prefix_chars={skip}" - mode = (settings.agent_log_prompt_mode or "preview").strip().lower() + mode = (agent_log_defaults.prompt_mode or "preview").strip().lower() if is_prompt and mode == "hash_only": logger.debug( "agent_payload label={} total_len={} sha12={} mode=hash_only{}", diff --git a/api/app/core/alembic_startup.py b/api/app/core/alembic_startup.py index ee0fd3c..5389d3c 100644 --- a/api/app/core/alembic_startup.py +++ b/api/app/core/alembic_startup.py @@ -15,6 +15,7 @@ from sqlalchemy.exc import OperationalError from app.core.config import settings from app.core.logging import get_logger +from app.core.runtime_constants import alembic_defaults logger = get_logger(__name__) @@ -70,12 +71,12 @@ def run_alembic_upgrade_at_startup() -> None: 在 asyncio 中请使用 ``asyncio.to_thread(run_alembic_upgrade_at_startup)``, 避免阻塞事件循环。 """ - if not settings.alembic_run_on_startup: + if not alembic_defaults.run_on_startup: logger.info("跳过 Alembic 迁移(alembic_run_on_startup=False)") return - max_tries = max(1, settings.alembic_startup_max_retries) - base_delay = float(settings.alembic_startup_retry_base_seconds) + max_tries = max(1, alembic_defaults.max_retries) + base_delay = float(alembic_defaults.retry_base_seconds) last: BaseException | None = None for attempt in range(max_tries): diff --git a/api/app/core/app_config.py b/api/app/core/app_config.py new file mode 100644 index 0000000..5166c98 --- /dev/null +++ b/api/app/core/app_config.py @@ -0,0 +1,69 @@ +"""TOML-backed application configuration singleton and re-export helpers.""" + +from __future__ import annotations + +from app.core.app_config_loader import load_app_config +from app.core.app_config_models import ( + AgentLogConfig, + AlembicConfig, + AppConfig, + AsrConfig, + CeleryConfig, + ChatConfig, + DeployConfig, + EvalConfig, + LlmConfig, + MemoirConfig, + MemoryConfig, + MiscConfig, + OtelConfig, + StoryConfig, + TtsConfig, +) +from app.core.config import settings as _bootstrap_settings + +_app_config: AppConfig | None = None + + +def get_app_config() -> AppConfig: + global _app_config + if _app_config is None: + _app_config = load_app_config(_bootstrap_settings.app_environment) + return _app_config + + +def reload_app_config(*, app_environment: str | None = None) -> AppConfig: + """Reload TOML config (tests).""" + global _app_config + env = app_environment or _bootstrap_settings.app_environment + _app_config = load_app_config(env) + return _app_config + + +class _LazyAppConfig: + def __getattr__(self, name: str): + return getattr(get_app_config(), name) + + +app_config: AppConfig | _LazyAppConfig = _LazyAppConfig() + +__all__ = [ + "AgentLogConfig", + "AlembicConfig", + "AppConfig", + "AsrConfig", + "CeleryConfig", + "ChatConfig", + "DeployConfig", + "EvalConfig", + "LlmConfig", + "MemoirConfig", + "MemoryConfig", + "MiscConfig", + "OtelConfig", + "StoryConfig", + "TtsConfig", + "app_config", + "get_app_config", + "reload_app_config", +] diff --git a/api/app/core/app_config_loader.py b/api/app/core/app_config_loader.py new file mode 100644 index 0000000..e764eb0 --- /dev/null +++ b/api/app/core/app_config_loader.py @@ -0,0 +1,50 @@ +"""Load and merge TOML configuration files into AppConfig.""" + +from __future__ import annotations + +import os +import tomllib +from pathlib import Path +from typing import Any + +from app.core.app_config_models import AppConfig + +_CONFIG_DIR_ENV = "CONFIG_DIR" + + +def default_config_dir() -> Path: + override = (os.environ.get(_CONFIG_DIR_ENV) or "").strip() + if override: + return Path(override).expanduser().resolve() + # api/app/core/app_config_loader.py -> api/config + return Path(__file__).resolve().parents[2] / "config" + + +def _deep_merge(base: dict[str, Any], overlay: dict[str, Any]) -> dict[str, Any]: + merged = dict(base) + for key, value in overlay.items(): + if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): + merged[key] = _deep_merge(merged[key], value) + else: + merged[key] = value + return merged + + +def _read_toml(path: Path) -> dict[str, Any]: + with path.open("rb") as handle: + return tomllib.load(handle) + + +def load_app_config(app_environment: str, *, config_dir: Path | None = None) -> AppConfig: + root = config_dir or default_config_dir() + default_path = root / "default.toml" + if not default_path.is_file(): + raise FileNotFoundError(f"Missing default config: {default_path}") + + merged = _read_toml(default_path) + env = (app_environment or "development").strip().lower() + overlay_path = root / f"{env}.toml" + if overlay_path.is_file(): + merged = _deep_merge(merged, _read_toml(overlay_path)) + + return AppConfig.model_validate(merged) diff --git a/api/app/core/app_config_models.py b/api/app/core/app_config_models.py new file mode 100644 index 0000000..b9dbc7c --- /dev/null +++ b/api/app/core/app_config_models.py @@ -0,0 +1,317 @@ +"""Pydantic models for TOML-backed application configuration (non-secret SSOT).""" + +from pydantic import BaseModel, ConfigDict, Field + + +class DeployConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + alembic_startup_fail_fast: bool = False + access_token_expire_minutes: int = 120 + refresh_token_expire_days: int = 30 + refresh_token_reuse_grace_seconds: int = Field(default=30, ge=0, le=300) + mock_sms_login_enabled: bool = False + tencent_sms_sdk_app_id: str = "" + tencent_sms_sign_name: str = "" + tencent_sms_template_id: str = "" + tencent_cos_bucket: str = "" + tencent_cos_base_url: str = "" + enable_tts: bool = True + memoir_image_enabled: bool = False + enable_docs: bool = True + wechat_pay_app_id: str = "" + wechat_pay_mch_id: str = "" + wechat_pay_private_key_path: str = "certs/apiclient_key.pem" + wechat_pay_cert_serial_no: str = "" + wechat_pay_notify_url: str = "" + wechat_pay_platform_public_key_path: str = "" + wechat_pay_platform_public_key_id: str = "" + alipay_app_id: str = "" + alipay_notify_url: str = "" + liblib_template_uuid: str = "" + log_level: str = "INFO" + otel_enabled: bool = False + otel_exporter_otlp_endpoint: str = "http://localhost:48317" + api_cors_origins: str = "" + + +class ChatConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + interview_max_tokens: int = 512 + interview_max_segments: int = 2 + interview_max_chars_per_segment: int = 380 + opening_max_tokens: int = 380 + profile_followup_max_tokens: int = 280 + history_max_pairs: int = 15 + history_max_chars: int = 6000 + era_context_enabled: bool = True + stage_detection_enabled: bool = True + stage_detection_max_tokens: int = 128 + interview_persona: str = "default" + interview_temperature: float = 0.93 + memory_retrieval_enabled: bool = True + memory_top_k: int = 8 + memory_evidence_max_chars: int = 4096 + memory_safe_evidence_format_enabled: bool = True + reply_planner_llm_enabled: bool = False + reply_planner_max_tokens: int = 256 + reply_planner_temperature: float = 0.2 + re_greeting_enabled: bool = True + re_greeting_idle_hours: float = 6.0 + topic_chips_enabled: bool = True + topic_chips_max: int = 4 + input_normalize_enabled: bool = True + input_normalize_mode: str = "rules" + input_normalize_llm_max_tokens: int = 512 + input_normalize_llm_max_input_chars: int = 8000 + input_normalize_llm_voice_only: bool = True + profile_max_turns: int = 8 + profile_extract_max_tokens: int = 512 + + +class MemoirConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + fidelity_check_enabled: bool = True + fidelity_check_max_tokens: int = 512 + oral_normalize_enabled: bool = True + oral_normalize_mode: str = "rules" + oral_normalize_llm_max_tokens: int = 512 + oral_normalize_llm_max_input_chars: int = 8000 + phase1_batch_llm_enabled: bool = True + phase1_batch_llm_max_tokens: int = 4096 + phase1_batch_llm_chunk_size: int = 24 + pipeline_run_ttl_seconds: int = 172_800 + extraction_max_tokens: int = 1024 + classification_max_tokens: int = 256 + narrative_max_tokens: int = 4096 + narrative_merge_max_tokens: int = 8192 + title_max_tokens: int = 256 + story_route_max_tokens: int = 1024 + story_batch_plan_max_tokens: int = 4096 + segment_batch_min_chars: int = 50 + segment_batch_max_wait_seconds: float = 60.0 + narrative_immediate_char_threshold: int = 50 + narrative_batch_min_segments: int = 3 + narrative_batch_min_chars: int = 80 + narrative_batch_max_wait_seconds: float = 120.0 + extraction_updates_current_stage: bool = False + fidelity_fail_open_on_parse_error: bool = False + narrative_evidence_overlap_min_chars: int = 14 + evidence_scene_anchor_check_enabled: bool = True + title_slots_require_body_or_oral_match: bool = True + title_hay_grounding_strict_phrases_enabled: bool = True + recompose_retry_on_lock_contention: bool = True + phase2_singleflight_immediate: bool = True + route_defer_enabled: bool = True + route_defer_seconds: float = 120.0 + route_defer_max_attempts: int = 3 + quality_pass_enabled: bool = True + quality_pass_delay_seconds: int = 5 + story_route_append_guardrail_oral_chars: int = 1800 + min_inline_images_for_chapter_cover: int = 1 + image_poll_interval: int = 3 + image_max_attempts: int = 20 + image_provider: str = "liblib" + image_style_default: str = "watercolor" + image_size_default: str = "1280x720" + image_download_hosts: str = "" + image_prompt_fallback_disabled: bool = False + + +class MemoryConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + enrichment_enabled: bool = True + enrichment_max_chars: int = 12000 + compaction_enabled: bool = True + compaction_debounce_seconds: int = 105 + compaction_lock_ttl_seconds: int = 600 + compaction_chunk_similarity_threshold: float = 0.92 + compaction_min_layers_for_exclude: int = 2 + compaction_max_chunks_per_run: int = 200 + compaction_max_excludes_per_run: int = 50 + compaction_max_neighbors_per_chunk: int = 25 + compaction_text_jaccard_min: float = 0.55 + compaction_metadata_event_year_window: int = 1 + compaction_sweep_recent_hours: int = 24 + + +class StoryConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + image_min_body_chars: int = 400 + image_enqueue_dedup_ttl: int = 300 + recompose_chapter_delay_seconds: int = 8 + chapter_pipeline_lock_ttl_seconds: int = 360 + append_max_canonical_chars: int = 12000 + append_max_versions: int = 20 + route_candidate_body_max_chars: int = 2200 + route_candidate_total_max_chars: int = 20_000 + route_long_body_head_chars: int = 700 + route_long_body_tail_chars: int = 700 + route_summary_min_chars: int = 30 + route_index_preview_chars: int = 140 + title_min_body_chars: int = 60 + evidence_top_k_default: int = 10 + evidence_top_k_large_batch: int = 5 + evidence_large_batch_threshold: int = 3 + + +class EvalConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + judge_base_url: str = "https://open.bigmodel.cn/api/paas/v4" + judge_model: str = "glm-5" + judge_temperature: float = 0.3 + judge_deepseek_model: str = "deepseek-v4-flash" + judge_deepseek_thinking_enabled: bool = False + judge_deepseek_context_window_tokens: int = 64_000 + judge_context_window_tokens: int = 200_000 + judge_completion_reserve_tokens: int = 4096 + judge_prompt_budget_safety_tokens: int = 2048 + judge_approx_tokens_per_char: float = 1.0 + judge_max_transcript_chars: int = 0 + judge_max_compare_transcript_chars_each: int = 0 + judge_compare_prompt_overhead_chars: int = 10_000 + judge_memoir_chapter_concurrency: int = 4 + judge_memoir_body_max_chars: int = 36_000 + judge_memoir_evidence_max_chars: int = 32_000 + judge_memoir_completion_max_tokens: int = 3072 + candidate_temperature: float = 0.7 + gate_protected_regression_threshold: float = 2.0 + execution_enabled: bool = True + internal_enable_docs: bool = False + internal_cors_origins: str = "" + + +class LlmConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + deepseek_base_url: str = "https://api.deepseek.com" + deepseek_model: str = "deepseek-v4-flash" + deepseek_thinking_enabled: bool = False + temperature: float = 0.7 + fast_model: str = "" + embedding_base_url: str = "https://open.bigmodel.cn/api/paas/v4" + embedding_model: str = "embedding-3" + + +class AsrConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + provider: str = "whisper" + model_size: str = "small" + device: str = "auto" + compute_type: str = "auto" + model_cache_dir: str = "" + + +class TtsConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + provider: str = "tencent" + voice_type: int = 501004 + voice_type_en: int = 501004 + codec: str = "mp3" + + +class RedisConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + socket_timeout_seconds: float = 5.0 + socket_connect_timeout_seconds: float = 2.0 + health_check_interval_seconds: int = 30 + task_tracker_ttl_seconds: int = 86400 + + +class CeleryConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + memory_enrichment_queue: str = "memory_idle" + broker_pool_limit: int = 10 + broker_connection_retry_on_startup: bool = True + memoir_soft_time_limit: int = 1800 + memoir_hard_time_limit: int = 2400 + image_soft_time_limit: int = 600 + image_hard_time_limit: int = 900 + compaction_sweep_soft_time_limit: int = 300 + compaction_sweep_hard_time_limit: int = 600 + enrichment_soft_time_limit: int = 660 + enrichment_hard_time_limit: int = 960 + + +class AlembicConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + run_on_startup: bool = True + max_retries: int = 3 + retry_base_seconds: float = 1.0 + + +class AgentLogConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + agent_verbose: bool = False + max_chars: int = 4096 + omit_system_message_body: bool = True + json_prompt_prefix_chars: int = 0 + json_prompt_prefix_only_if_len_gt: int = 4000 + prompt_mode: str = "preview" + prompt_dedup: bool = False + celery_log_level: str = "" + httpx_log_level: str = "" + log_json_file: str = "" + + +class OtelConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + exporter_insecure: bool = True + service_name: str = "life-echo-api" + metric_export_interval_ms: int = 10_000 + + def traces_sampler(self, app_environment: str) -> str: + env = (app_environment or "").strip().lower() + if env in ("production", "staging"): + return "parentbased_traceidratio" + return "always_on" + + def traces_sampler_arg(self, app_environment: str) -> float | None: + env = (app_environment or "").strip().lower() + if env in ("production", "staging"): + return 0.1 + return None + + +class MiscConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + algorithm: str = "HS256" + redis_session_ttl: int = 86400 + tencent_sms_template_param_count: int = 2 + tencent_cos_region: str = "ap-shanghai" + liblib_base_url: str = "https://openapi.liblibai.cloud" + alipay_sign_type: str = "RSA2" + alipay_under_development: str = "true" + + +class AppConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + deploy: DeployConfig = Field(default_factory=DeployConfig) + chat: ChatConfig = Field(default_factory=ChatConfig) + memoir: MemoirConfig = Field(default_factory=MemoirConfig) + memory: MemoryConfig = Field(default_factory=MemoryConfig) + story: StoryConfig = Field(default_factory=StoryConfig) + eval: EvalConfig = Field(default_factory=EvalConfig) + llm: LlmConfig = Field(default_factory=LlmConfig) + asr: AsrConfig = Field(default_factory=AsrConfig) + tts: TtsConfig = Field(default_factory=TtsConfig) + celery: CeleryConfig = Field(default_factory=CeleryConfig) + redis: RedisConfig = Field(default_factory=RedisConfig) + alembic: AlembicConfig = Field(default_factory=AlembicConfig) + agent_log: AgentLogConfig = Field(default_factory=AgentLogConfig) + otel: OtelConfig = Field(default_factory=OtelConfig) + misc: MiscConfig = Field(default_factory=MiscConfig) diff --git a/api/app/core/auth_deps.py b/api/app/core/auth_deps.py new file mode 100644 index 0000000..06c6ab3 --- /dev/null +++ b/api/app/core/auth_deps.py @@ -0,0 +1,53 @@ +"""Authentication FastAPI dependencies (isolated to avoid circular imports).""" + +from typing import Optional + +from fastapi import Depends +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.db import get_async_db +from app.core.errors import AuthenticationError +from app.core.security import verify_token +from app.features.user.service import UserService + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") +oauth2_scheme_optional = OAuth2PasswordBearer( + tokenUrl="/api/auth/login", auto_error=False +) + + +async def get_current_user( + token: str = Depends(oauth2_scheme), + db: AsyncSession = Depends(get_async_db), +): + """Resolve authenticated user from JWT access token.""" + payload = verify_token(token) + if payload is None: + raise AuthenticationError("无法验证凭据") + + user_id: str | None = payload.get("sub") + if user_id is None: + raise AuthenticationError("无法验证凭据") + + if payload.get("type") != "access": + raise AuthenticationError("无法验证凭据") + + user = await UserService(db).get_by_id(user_id) + if user is None: + raise AuthenticationError("无法验证凭据") + + return user + + +async def get_optional_user( + token: Optional[str] = Depends(oauth2_scheme_optional), + db: AsyncSession = Depends(get_async_db), +): + """Return user if a valid token is provided, else None.""" + if token is None: + return None + try: + return await get_current_user(token, db) + except AuthenticationError: + return None diff --git a/api/app/core/celery_broker_dev.py b/api/app/core/celery_broker_dev.py index 4afba0a..8f6e2a1 100644 --- a/api/app/core/celery_broker_dev.py +++ b/api/app/core/celery_broker_dev.py @@ -44,8 +44,6 @@ async def maybe_purge_celery_broker_on_startup(redis_client: Redis) -> None: """在已连接的 Redis 上抢门闩后清空已知任务队列;生产环境永不执行。""" if _is_production_environment(): return - if not settings.celery_purge_broker_on_startup: - return got = await redis_client.set( _PURGE_GATE_KEY, "1", diff --git a/api/app/core/config.py b/api/app/core/config.py index 03a172a..bd66c5b 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -1,19 +1,25 @@ """ -统一配置:所有环境变量通过此模块的 Settings 单点读取。 -业务代码只允许 import settings,禁止散落 os.getenv() / load_dotenv()。 +统一配置:密钥与连接串经 .env / Settings;其余非密钥项见 config/*.toml(AppConfig)。 -本地开发时由 api/development.sh 在启动前将 .env.development 同步为 .env(每次启动覆盖)。 -Docker / 服务端由镜像与 compose 注入进程环境;此处仅固定读取工作目录下的 .env 作为默认值来源。 -进程环境变量(容器 environment、export)覆盖 .env 同名项。 +本地开发时由 api/development.sh 在启动前将 .env.development 同步为 .env。 """ -import secrets +from __future__ import annotations -from pydantic import AliasChoices, Field, field_validator +from pydantic import AliasChoices, Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict +from app.core.app_config_models import DeployConfig +from app.core.redis_urls import resolve_redis_urls + +_DEV_SECRET_KEY = "dev-only-secret-key-do-not-use-in-production" + +_DEPLOY_FIELD_NAMES = frozenset(DeployConfig.model_fields.keys()) + class Settings(BaseSettings): + """Secrets and bootstrap only — non-secret deploy/product config lives in TOML.""" + model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", @@ -21,494 +27,111 @@ class Settings(BaseSettings): extra="ignore", ) - # ── Database ────────────────────────────────────────────── - database_url: str = "postgresql://postgres:postgres@localhost:5432/life_echo" - # 启动时是否执行 Alembic(main.py lifespan);测试或仅读场景可关 - alembic_run_on_startup: bool = True - # True:迁移失败则进程退出(生产推荐)。False:仅打错误日志并继续(本地无 DB 时) - alembic_startup_fail_fast: bool = False - alembic_startup_max_retries: int = Field(default=3, ge=1, le=10) - alembic_startup_retry_base_seconds: float = Field(default=1.0, ge=0.1, le=60.0) - - # ── Redis ───────────────────────────────────────────────── - redis_url: str = "redis://localhost:6379/0" - redis_session_ttl: int = 86400 - - # ── Runtime / Celery 开发体验 ───────────────────────────── - # APP_ENV:本地默认 development;Docker 生产栈请设为 production + database_url: str = "postgresql://postgres:postgres@localhost:48291/life_echo" + redis_url: str = "redis://localhost:48307/0" + redis_password: str = "" + celery_redis_url: str = "" app_environment: str = Field( default="development", validation_alias=AliasChoices("APP_ENV", "APP_ENVIRONMENT"), ) - # 非 production 且为 True 时,在 main/internal_main 连接 Redis 后清空 Celery 队列(不 FLUSHDB,不影响会话键) - celery_purge_broker_on_startup: bool = False - # Memory LLM 富化任务路由队列;可与主 worker 分离(见 README / docker-compose) - celery_memory_enrichment_queue: str = "memory_idle" - # ── Auth / JWT ──────────────────────────────────────────── - secret_key: str = Field(default_factory=lambda: secrets.token_urlsafe(32)) - algorithm: str = "HS256" - access_token_expire_minutes: int = 120 - refresh_token_expire_days: int = 30 - # 本地/内网评测:允许 POST /api/auth/mock/sms-login 跳过短信(须显式开启;production 下路由仍拒绝) - mock_sms_login_enabled: bool = False - - # ── LLM / DeepSeek ─────────────────────────────────────── + secret_key: str = _DEV_SECRET_KEY deepseek_api_key: str = "" - deepseek_base_url: str = "https://api.deepseek.com" - # 官方新模型名(V4-Flash);与弃用名 deepseek-chat 对齐为「非思考」需另设 deepseek_thinking_enabled - deepseek_model: str = "deepseek-v4-flash" - # V4-Flash 在官方 API 中 thinking 默认为 enabled;主链路为对齐旧版 deepseek-chat 默认关闭 - deepseek_thinking_enabled: bool = False - llm_api_key: str = "" - llm_base_url: str = "" - llm_model: str = "" - llm_temperature: float = 0.7 - # 空字符串:快档位与默认模型相同;分类/抽取/记忆富化等可单独指定较轻模型 - llm_fast_model: str = "" - - # ── Memory 向量(智谱 BigModel 国内 embedding-3;与 LLM/DeepSeek 密钥分离)── zhipu_api_key: str = "" - embedding_base_url: str = "https://open.bigmodel.cn/api/paas/v4" - embedding_model: str = "embedding-3" - # ── Chat 访谈(token 上限 + 代码截断,见 reply_limits)── - chat_interview_max_tokens: int = 512 - chat_interview_max_segments: int = 2 - chat_interview_max_chars_per_segment: int = 380 - chat_opening_max_tokens: int = 380 - chat_profile_followup_max_tokens: int = 280 - # Redis 全量历史仅用于 turn 计数;注入 LLM 时截取最近若干轮与字符预算 - chat_history_max_pairs: int = Field(default=15, ge=1, le=500) - chat_history_max_chars: int = Field(default=6000, ge=256, le=500_000) - chat_era_context_enabled: bool = True - # 访谈:每轮用 LLM 判定用户主人生阶段并更新 MemoirState.current_stage;False 时仅用关键词 - chat_stage_detection_enabled: bool = True - chat_stage_detection_max_tokens: int = 128 - # 访谈性格:default | warm_listener | curious_guide(未知值按 default) - chat_interview_persona: str = "default" - # 访谈/开场 LLM 采样温度:略高于通用 llm_temperature,利于口语与叙事变化、减程式句 - chat_interview_temperature: float = Field(default=0.93, ge=0.0, le=2.0) - # 访谈:按用户本轮话检索记忆并注入 prompt(关则不调 MemoryService.retrieve) - chat_memory_retrieval_enabled: bool = True - chat_memory_top_k: int = Field(default=8, ge=1, le=30) - chat_memory_evidence_max_chars: int = Field(default=4096, ge=256, le=50_000) - # 访谈记忆注入使用聊天专用安全格式化(编号引用 + 主语弱化说明) - chat_memory_safe_evidence_format_enabled: bool = True - # True:在规则 TurnPlan 之后追加一轮轻量 JSON focus planner(本轮承接重点 + memory 引用 + 回复形状;失败则回退基线) - chat_reply_planner_llm_enabled: bool = False - chat_reply_planner_max_tokens: int = Field(default=256, ge=64, le=1024) - chat_reply_planner_temperature: float = Field(default=0.2, ge=0.0, le=1.0) - # 老对话回访问候:连接时若距上次消息超过该小时数,由 AI 主动发一条承接式开场(自防抖:发完即更新 last_message_at) - chat_re_greeting_enabled: bool = True - chat_re_greeting_idle_hours: float = Field(default=6.0, ge=0.25, le=240.0) - # 话题建议 chips:连接首帧附带 3-4 个 quick-start 话题(来自当前阶段的空 slots) - chat_topic_chips_enabled: bool = True - chat_topic_chips_max: int = Field(default=4, ge=1, le=8) - - # ── Memoir 叙事忠实度检查(FidelityCheckAgent)──────────────── - memoir_fidelity_check_enabled: bool = True - memoir_fidelity_check_max_tokens: int = 512 - # 口述归一(进入叙事 / 忠实度前;segment 原文不落库):off | rules | llm - memoir_oral_normalize_enabled: bool = True - memoir_oral_normalize_mode: str = "rules" - memoir_oral_normalize_llm_max_tokens: int = Field(default=512, ge=64, le=4096) - memoir_oral_normalize_llm_max_input_chars: int = Field( - default=8000, ge=64, le=50_000 - ) - # 聊天:模型消费净稿(不改变 segment 落库原文);与 memoir 规则层共用,配置独立 - chat_input_normalize_enabled: bool = True - chat_input_normalize_mode: str = "rules" # off | rules | llm - chat_input_normalize_llm_max_tokens: int = Field(default=512, ge=64, le=4096) - chat_input_normalize_llm_max_input_chars: int = Field( - default=8000, ge=64, le=50_000 - ) - # True 且 mode=llm:仅语音/ASR 段走 LLM 纠错;键盘输入仅规则归一(省每轮 LLM) - chat_input_normalize_llm_voice_only: bool = True - # 资料收集:超过该对话轮次(Redis 全量轮次计数)仍有缺失字段时,强制进入访谈,避免长期问卷感 - chat_profile_max_turns: int = Field(default=8, ge=1, le=500) - - # Memoir Phase1:多 segment 一批一次 LLM 完成抽取+章节分类(失败回退逐段);单段且关时仍逐段 - memoir_phase1_batch_llm_enabled: bool = True - memoir_phase1_batch_llm_max_tokens: int = Field(default=4096, ge=512, le=32_768) - #: Phase1 批处理 LLM:单次请求最多包含的 segment 数(多块合并,避免 completion 顶满截断) - memoir_phase1_batch_llm_chunk_size: int = Field(default=24, ge=1, le=500) - #: 回忆录流水线细粒度进度 Redis 快照 TTL(memoir_pipeline_run:*) - memoir_pipeline_run_ttl_seconds: int = Field(default=172_800, ge=3600, le=2_592_000) - # Memoir agents:`invoke_json_object` / `llm_json_call` 的 max_tokens(原硬编码迁至配置) - memoir_extraction_max_tokens: int = Field(default=1024, ge=64, le=8192) - memoir_classification_max_tokens: int = Field(default=256, ge=32, le=4096) - memoir_narrative_max_tokens: int = Field(default=4096, ge=256, le=32_768) - memoir_narrative_merge_max_tokens: int = Field(default=8192, ge=256, le=64_000) - memoir_title_max_tokens: int = Field(default=256, ge=32, le=4096) - memoir_story_route_max_tokens: int = Field(default=1024, ge=64, le=8192) - memoir_story_batch_plan_max_tokens: int = Field(default=4096, ge=256, le=32_768) - # 资料抽取(ProfileAgent JSON 模式) - chat_profile_extract_max_tokens: int = Field(default=512, ge=64, le=4096) - - # ── ASR ─────────────────────────────────────────────────── - asr_provider: str = "whisper" - asr_model_size: str = "small" - asr_device: str = "auto" - asr_compute_type: str = "auto" - asr_model_cache_dir: str = "" - - # ── Tencent SMS ────────────────────────────────────────── - tencent_sms_secret_id: str = "" - tencent_sms_secret_key: str = "" - tencent_sms_sdk_app_id: str = "" - tencent_sms_sign_name: str = "" - tencent_sms_template_id: str = "" - tencent_sms_template_param_count: int = 2 - - # ── Tencent ASR / TTS(共用 Secret;与短信、COS 密钥独立)──────────────── tencent_secret_id: str = "" tencent_secret_key: str = "" - # ── TTS (openai | tencent),与 ASR 独立:仅控制回复侧语音合成 ── - enable_tts: bool = True - tts_provider: str = "tencent" - openai_api_key: str = "" - # 501004 = 月华,腾讯云大模型音色,支持中英混合(PrimaryLanguage=1/2 均可)。 - # 调用 TextToVoice 时必须配合 ModelType=1,详见 https://cloud.tencent.com/document/api/1073/37995 - # 与音色清单 https://cloud.tencent.com/document/product/1073/92668 - tts_voice_type: int = 501004 - # 英文场景默认同样使用 501004(月华大模型音色,原生支持中英混合), - # 因此无需另配独立英文音色;如需切换英文专用音色请显式覆盖此项。 - tts_voice_type_en: int = 501004 - tts_codec: str = "mp3" - - # ── WeChat Pay ─────────────────────────────────────────── - wechat_pay_app_id: str = "" - wechat_pay_mch_id: str = "" wechat_pay_api_v3_key: str = "" - wechat_pay_private_key_path: str = "certs/apiclient_key.pem" - wechat_pay_private_key: str = "" # PEM 内容,与 private_key_path 二选一 - wechat_pay_cert_serial_no: str = "" - wechat_pay_notify_url: str = "" + wechat_pay_private_key: str = "" wechat_pay_platform_public_key: str = "" - wechat_pay_platform_public_key_path: str = "" - wechat_pay_platform_public_key_id: str = "" - # ── Alipay ─────────────────────────────────────────────── - alipay_app_id: str = "" alipay_private_key: str = "" alipay_public_key: str = "" - alipay_notify_url: str = "" - alipay_sign_type: str = "RSA2" - alipay_under_development: str = "true" # "1"/"true"/"yes" 视为开发中不可用 - # ── Logging ────────────────────────────────────────────── - # 环境变量 LOG_LEVEL;控制 loguru sink 最低级别(TRACE/DEBUG/INFO/…) - log_level: str = "INFO" - # LOG_AGENT_VERBOSE:为 True 时额外输出 Agent 单行 INFO 摘要(耗时、规模),无需全局 DEBUG - log_agent_verbose: bool = False - # AGENT_LOG_MAX_CHARS:DEBUG 下记录 prompt/响应预览时的最大字符数;0=不截断(完整输出,慎用) - agent_log_max_chars: int = Field(default=4096, ge=0, le=50_000_000) - # AGENT_LOG_OMIT_SYSTEM_MESSAGE_BODY:DEBUG 下访谈/资料聊天日志省略 System 正文(仅 len+sha12) - agent_log_omit_system_message_body: bool = True - # AGENT_LOG_JSON_PROMPT_PREFIX_CHARS:DEBUG 下 *.prompt 总长超过下项时再跳过前 N 字符后预览(0=不跳过) - agent_log_json_prompt_prefix_chars: int = Field(default=0, ge=0, le=500_000) - # AGENT_LOG_JSON_PROMPT_PREFIX_ONLY_IF_LEN_GT:触发“跳过前缀”的最小 prompt 长度 - agent_log_json_prompt_prefix_only_if_len_gt: int = Field( - default=4000, ge=0, le=2_000_000 - ) - # AGENT_LOG_PROMPT_MODE:DEBUG 下 *.prompt 记录方式 preview=截断预览 | hash_only=仅 sha12+长度(无正文) - agent_log_prompt_mode: str = Field(default="preview") - # AGENT_LOG_PROMPT_DEDUP:DEBUG 下同一 label 连续相同全文时第二条起跳过(减重复模板噪音) - agent_log_prompt_dedup: bool = False - # 第三方 stdlib logging(空=自动:DEBUG/TRACE 时 Celery→INFO;否则 Celery 与 httpx 默认 WARNING) - celery_log_level: str = "" - httpx_log_level: str = "" - # 非空时额外写入 JSONL(serialize=True),便于 Loki/ELK;与 stderr 彩色控制台并存 - log_json_file: str = "" - - # ── OpenTelemetry ───────────────────────────────────────── - otel_enabled: bool = False - otel_exporter_otlp_endpoint: str = "http://localhost:48317" - otel_exporter_otlp_insecure: bool = True - otel_service_name: str = "" - otel_traces_sampler: str = Field( - default="always_on", - description="always_on | parentbased_traceidratio | always_off", - ) - otel_traces_sampler_arg: float | None = Field(default=None, ge=0.0, le=1.0) - otel_metric_export_interval_ms: int = Field(default=10_000, ge=1000, le=300_000) - - @field_validator("otel_enabled", mode="before") - @classmethod - def _coerce_otel_enabled(cls, v: object) -> bool: - if isinstance(v, bool): - return v - if v is None: - return False - return str(v).strip().lower() in ("1", "true", "yes", "on") - - @field_validator("otel_exporter_otlp_insecure", mode="before") - @classmethod - def _coerce_otel_exporter_otlp_insecure(cls, v: object) -> bool: - if isinstance(v, bool): - return v - if v is None: - return True - return str(v).strip().lower() in ("1", "true", "yes", "on") - - @field_validator("celery_purge_broker_on_startup", mode="before") - @classmethod - def _coerce_celery_purge_broker_on_startup(cls, v: object) -> bool: - if isinstance(v, bool): - return v - if v is None: - return False - return str(v).strip().lower() in ("1", "true", "yes", "on") - - @field_validator("mock_sms_login_enabled", mode="before") - @classmethod - def _coerce_mock_sms_login_enabled(cls, v: object) -> bool: - if isinstance(v, bool): - return v - if v is None: - return False - return str(v).strip().lower() in ("1", "true", "yes", "on") - - @field_validator("log_agent_verbose", mode="before") - @classmethod - def _coerce_log_agent_verbose(cls, v: object) -> bool: - if isinstance(v, bool): - return v - if v is None: - return False - return str(v).strip().lower() in ("1", "true", "yes", "on") - - @field_validator("agent_log_omit_system_message_body", mode="before") - @classmethod - def _coerce_agent_log_omit_system_message_body(cls, v: object) -> bool: - if isinstance(v, bool): - return v - if v is None: - return True - s = str(v).strip().lower() - if s in ("0", "false", "no", "off"): - return False - return True - - @field_validator("agent_log_prompt_mode", mode="before") - @classmethod - def _normalize_agent_log_prompt_mode(cls, v: object) -> str: - if v is None: - return "preview" - s = str(v).strip().lower() - if s not in ("preview", "hash_only"): - return "preview" - return s - - @field_validator("agent_log_prompt_dedup", mode="before") - @classmethod - def _coerce_agent_log_prompt_dedup(cls, v: object) -> bool: - if isinstance(v, bool): - return v - if v is None: - return False - return str(v).strip().lower() in ("1", "true", "yes", "on") - - # ── Misc ───────────────────────────────────────────────── - enable_test_subscription: int = 0 - enable_test_plan: str = "" # "1" / "true" / "yes" 为 True - enable_docs: bool = True - - # ── Memoir Image ───────────────────────────────────────── - memoir_image_enabled: bool = False - # True:图片 LLM prompt 失败时不使用英语降级模板(需产品与任务失败流确认后开启) - image_prompt_fallback_disabled: bool = False - memoir_image_poll_interval: int = 3 - memoir_image_max_attempts: int = 20 - memoir_image_provider: str = "liblib" - memoir_image_style_default: str = "watercolor" - memoir_image_size_default: str = "1280x720" - memoir_image_download_hosts: str = "" - # 章节 canonical_markdown 中至少含多少张 asset:// 正文插图才生成/展示章节封面(≥ 该值即满足;0 表示不以此条件拦截) - memoir_min_inline_images_for_chapter_cover: int = Field(default=1, ge=0, le=100) - # Story 正文至少多少字才创建主图 intent / 调图(0 表示不限制) - story_image_min_body_chars: int = 400 - # generate_story_image 入队去重(Redis SET NX,秒) - story_image_enqueue_dedup_ttl: int = Field(default=300, ge=30, le=86400) - # 章节物化异步任务延迟入队(秒),削峰 - recompose_chapter_delay_seconds: int = Field(default=8, ge=0, le=600) - # 与 memoir pipeline 一致的章节互斥锁 TTL(秒);应覆盖 Phase2 / recompose 的 P95 时长 - chapter_pipeline_lock_ttl_seconds: int = Field(default=360, ge=10, le=3600) - # Append 硬上限:canonical 字符数、版本数(超限强制 new_story) - story_append_max_canonical_chars: int = Field(default=12000, ge=1000, le=500_000) - story_append_max_versions: int = Field(default=20, ge=1, le=500) - # StoryRouteAgent:候选 JSON 预算(保守默认,可调大) - story_route_candidate_body_max_chars: int = Field(default=2200, ge=200, le=8000) - story_route_candidate_total_max_chars: int = Field( - default=20_000, ge=2000, le=100_000 - ) - story_route_long_body_head_chars: int = Field(default=700, ge=100, le=4000) - story_route_long_body_tail_chars: int = Field(default=700, ge=100, le=4000) - story_route_summary_min_chars: int = Field(default=30, ge=0, le=500) - story_route_index_preview_chars: int = Field(default=140, ge=20, le=500) - # 童年/求学/家庭:本批口述低于该字数且路由为 new 时,倾向续写到默认候选,减少碎篇 - memoir_story_route_append_guardrail_oral_chars: int = Field( - default=1800, ge=0, le=50_000 - ) - # Evidence 检索 top_k:大批次 unit 时降低检索量 - evidence_top_k_default: int = Field(default=10, ge=1, le=50) - evidence_top_k_large_batch: int = Field(default=5, ge=1, le=50) - evidence_large_batch_threshold: int = Field(default=3, ge=1, le=100) - # Story/Chapter 标题在正文达到此字数后才由 LLM 生成;之前用占位符 - story_title_min_body_chars: int = Field(default=60, ge=0, le=10_000) - # 回忆录 Celery:累计 strip 后口述字数未达此值则暂缓提交(0=关闭,仅防抖后提交) - memoir_segment_batch_min_chars: int = Field(default=50, ge=0, le=50_000) - # 本批首条 segment 入队起最长等待(秒),超时则提交(即使字数不足) - memoir_segment_batch_max_wait_seconds: float = Field( - default=60.0, ge=0.0, le=3600.0 - ) - # 回忆录叙事 Phase 2( Celery)触发:单条口述达到该 strip 字数则立即跑叙事 - memoir_narrative_immediate_char_threshold: int = Field(default=50, ge=0, le=50_000) - # 同一 topic_category 下未叙事段数达到该值则触发 Phase 2 - memoir_narrative_batch_min_segments: int = Field(default=3, ge=1, le=500) - # 同上,累计 user_input_text 字符数(strip 后由 Segment 列 length 近似) - memoir_narrative_batch_min_chars: int = Field(default=80, ge=0, le=500_000) - # Phase 1 完成后未触发 Phase 2 时,延迟任务兜底(秒);新 Phase 1 会 revoke 旧定时 - memoir_narrative_batch_max_wait_seconds: float = Field( - default=120.0, ge=1.0, le=3600.0 - ) - # False:Celery/批处理更新 slot 时不改写 MemoirState.current_stage(访谈路径仍可由 switch_stage 推进) - # True:仅当 chat_bucket( proposed ) == chat_bucket( existing ) 时允许批处理对齐 current_stage - memoir_extraction_updates_current_stage: bool = False - # True:FidelityCheckAgent JSON/LLM 解析失败时放行(仅建议 append 场景配合 existing 兜底) - memoir_fidelity_fail_open_on_parse_error: bool = False - # 正文与 evidence 文本的最长公共子串达到该长度且 oral/旧正文未覆盖时,回退为安全正文 - memoir_narrative_evidence_overlap_min_chars: int = Field(default=14, ge=8, le=256) - # True:启用短「场合锚点」词检测(聚餐/那晚等),须同时在摘录中出现且口述未覆盖才回退 - memoir_evidence_scene_anchor_check_enabled: bool = True - # True:标题生成时 slots 仅保留在 oral 或正文摘录中出现的条目(减少档案串台) - memoir_title_slots_require_body_or_oral_match: bool = True - # True:标题中出现高置信「履历链」短语则须在 hay(正文+口述+已传 slots)中有逐字依据,否则降级占位 - memoir_title_hay_grounding_strict_phrases_enabled: bool = True - # True:章节物化拿不到 pipeline 锁时 Celery retry(避免长期跳过导致 dirty 不收敛) - memoir_recompose_retry_on_lock_contention: bool = True - # Phase2 立即派发使用固定 task_id,减少同类目重复入队(超时任务仍用独立 id) - memoir_phase2_singleflight_immediate: bool = True - # True:Phase2 路由低置信(no_llm/parse_error/invalid_target)时不写 Story, - # 把 segment 标记为 narrative_deferred_until 之后再重试。 - memoir_route_defer_enabled: bool = True - # 低置信延迟时长(秒):到期前不消费这些 segment,避免后台空转 - memoir_route_defer_seconds: float = Field(default=120.0, ge=1.0, le=3600.0) - # 同一类目最多自动延迟次数;达到上限后 segment 仅靠新素材到达激活,不再自动重试 - memoir_route_defer_max_attempts: int = Field(default=3, ge=1, le=20) - # True:Phase2 首稿后异步运行质量增强(fidelity recheck、标题润色、LLM 归一) - memoir_quality_pass_enabled: bool = True - memoir_quality_pass_delay_seconds: int = Field(default=5, ge=0, le=300) - - # ── Memory 检索与富化 ───────────────────────────────────── - # False:跳过 ingest 后 LLM 富化(摘要/事实/时间线) - memory_enrichment_enabled: bool = True - memory_enrichment_max_chars: int = Field(default=12000, ge=1000, le=100_000) - - # ── Memory compaction(近重复 chunk 软排除;事件触发 + Redis 防抖 + 用户锁;需 worker + Beat 跑 sweep)── - memory_compaction_enabled: bool = True - memory_compaction_debounce_seconds: int = Field(default=105, ge=10, le=3600) - memory_compaction_lock_ttl_seconds: int = Field(default=600, ge=60, le=7200) - memory_compaction_chunk_similarity_threshold: float = Field( - default=0.92, ge=0.5, le=0.999 - ) - memory_compaction_min_layers_for_exclude: int = Field(default=2, ge=1, le=3) - memory_compaction_max_chunks_per_run: int = Field(default=200, ge=1, le=10_000) - memory_compaction_max_excludes_per_run: int = Field(default=50, ge=1, le=1000) - memory_compaction_max_neighbors_per_chunk: int = Field(default=25, ge=5, le=100) - memory_compaction_text_jaccard_min: float = Field(default=0.55, ge=0.0, le=1.0) - memory_compaction_metadata_event_year_window: int = Field(default=1, ge=0, le=50) - # Beat sweep:扫描最近 N 小时内有新 chunk 的用户并调度 compaction - memory_compaction_sweep_recent_hours: int = Field(default=24, ge=1, le=168) - - # ── Liblib ─────────────────────────────────────────────── liblib_access_key: str = "" liblib_secret_key: str = "" - liblib_base_url: str = "https://openapi.liblibai.cloud" - liblib_template_uuid: str = "" - # ── Tencent COS ────────────────────────────────────────── - tencent_cos_secret_id: str = "" - tencent_cos_secret_key: str = "" - tencent_cos_region: str = "ap-shanghai" - tencent_cos_bucket: str = "" - tencent_cos_base_url: str = "" - tencent_cos_token: str = "" - - # ── Internal regression evaluation lab(独立入口,不挂在消费者 API)──── internal_eval_api_key: str = "" - internal_eval_enable_docs: bool = False - # 逗号分隔;空则内部 API 不额外限制 Origin(仍可依赖 internal_eval_api_key) - internal_eval_cors_origins: str = "" - # 智谱 GLM-5:评审模型(OpenAI 兼容 Chat Completions,与 langchain-openai 一致) - eval_judge_api_key: str = "" - eval_judge_base_url: str = "https://open.bigmodel.cn/api/paas/v4" - eval_judge_model: str = "glm-5" - eval_judge_temperature: float = 0.3 - # 评测评审:DeepSeek(OpenAI 兼容);默认 deepseek-v4-flash + 非思考(对齐定价页非思考用法;非 v4-pro) - eval_judge_deepseek_model: str = "deepseek-v4-flash" - # 当仅指定 deepseek-v4-flash、未用弃用名区分时,是否走思考模式(与 eval_judge_deepseek_model 联用) - eval_judge_deepseek_thinking_enabled: bool = False - eval_judge_deepseek_context_window_tokens: int = Field( - default=64_000, - ge=4096, - le=2_000_000, - description="DeepSeek 评审专用上下文预算(用于 transcript 截断;与 GLM 200K 分离)", - ) - # GLM-5 输入上下文 200K(https://docs.bigmodel.cn) - eval_judge_context_window_tokens: int = Field( - default=200_000, ge=4096, le=2_000_000 - ) - # 预留给完成 tokens(json 输出)及路由误差 - eval_judge_completion_reserve_tokens: int = Field(default=4096, ge=256, le=131_072) - eval_judge_prompt_budget_safety_tokens: int = Field(default=2048, ge=0, le=32_768) - # transcript 混合中英文时 token/字 估值(略低于 1.2 可多给汉字篇幅;若评审请求被拒可回调高) - eval_judge_approx_tokens_per_char: float = Field(default=1.0, ge=0.3, le=8.0) - # 整段/逐轮节选 transcript 最大字符;0=按 eval_judge_context_window_tokens 自动扣 rubric 头 - eval_judge_max_transcript_chars: int = Field(default=0, ge=0, le=2_000_000) - # 双 transcript 对比流:每条对话上限;0=按上下文平分(扣 overhead) - eval_judge_max_compare_transcript_chars_each: int = Field( - default=0, ge=0, le=2_000_000 - ) - # 对比 prompt 固定开销(模板 + 两份评分 JSON)的字符估值;略低则 transcript 合计空间更大 - eval_judge_compare_prompt_overhead_chars: int = Field( - default=10_000, ge=500, le=500_000 - ) - # 回忆录音评:章节 LLM 并发上限(仅评审请求;准备阶段仍串行访问 DB) - eval_judge_memoir_chapter_concurrency: int = Field( - default=4, - ge=1, - le=32, - ) - # 回忆录评审 prompt 内粗截断(汉字计字符);万字级章节请保持 body ≥ 正文峰值 - eval_judge_memoir_body_max_chars: int = Field( - default=36_000, - ge=8_000, - le=500_000, - description="【当前回忆录正文】注入评审 prompt 前的最大字符", - ) - eval_judge_memoir_evidence_max_chars: int = Field( - default=32_000, - ge=8_000, - le=500_000, - description="对话证据 / 结构化证据 / 参考基线各块的最大字符(与 eval_trace_format 对齐)", - ) - # json_object 完成预算;MemoirJudgeOutput 字段多,需预留足量 token - eval_judge_memoir_completion_max_tokens: int = Field( - default=3072, - ge=512, - le=16_384, - ) - # 候选对话回放:与生产访谈类似的温度 - eval_candidate_temperature: float = 0.7 - # 门禁:受保护 session 合成份数下跌超过该阈值视为回归(0–100 分制) - eval_gate_protected_regression_threshold: float = Field( - default=2.0, ge=0.0, le=100.0 - ) - # 执行 LLM 判分与回放(Celery 未跑时可关,仅跑结构/导入) - eval_execution_enabled: bool = True + + @model_validator(mode="after") + def _validate_secret_key(self) -> "Settings": + env = (self.app_environment or "").strip().lower() + if env in ("production", "staging") and ( + not self.secret_key or self.secret_key == _DEV_SECRET_KEY + ): + raise ValueError( + "SECRET_KEY must be set to a strong random value in production/staging" + ) + return self + + @property + def is_production(self) -> bool: + return (self.app_environment or "").strip().lower() == "production" + + @property + def enable_test_subscription(self) -> bool: + return not self.is_production + + @property + def enable_test_plan(self) -> bool: + return not self.is_production + + @property + def redis_url_resolved(self) -> str: + business, _ = resolve_redis_urls( + self.redis_url, + redis_password=self.redis_password or None, + celery_redis_url_override=None, + ) + return business + + @property + def celery_redis_url_resolved(self) -> str: + override = (self.celery_redis_url or "").strip() or None + _, celery = resolve_redis_urls( + self.redis_url, + redis_password=self.redis_password or None, + celery_redis_url_override=override, + ) + return celery -settings = Settings() +class SettingsFacade: + """Backward-compatible facade: secrets from Settings, deploy flags from TOML.""" + + __slots__ = ("_secrets",) + + def __init__(self, secrets: Settings) -> None: + object.__setattr__(self, "_secrets", secrets) + + @property + def _deploy(self) -> DeployConfig: + from app.core.app_config import get_app_config + + return get_app_config().deploy + + def __getattr__(self, name: str): + secrets = object.__getattribute__(self, "_secrets") + if hasattr(secrets, name): + return getattr(secrets, name) + if name in _DEPLOY_FIELD_NAMES: + return getattr(self._deploy, name) + raise AttributeError(f"Settings has no attribute {name!r}") + + def __setattr__(self, name: str, value) -> None: + if name == "_secrets": + object.__setattr__(self, name, value) + return + secrets = object.__getattribute__(self, "_secrets") + if hasattr(type(secrets), name) and name not in _DEPLOY_FIELD_NAMES: + setattr(secrets, name, value) + return + if name in _DEPLOY_FIELD_NAMES: + setattr(self._deploy, name, value) + return + setattr(secrets, name, value) + + +settings = SettingsFacade(Settings()) diff --git a/api/app/core/cos_url_keys.py b/api/app/core/cos_url_keys.py index 3d61681..0bbfcda 100644 --- a/api/app/core/cos_url_keys.py +++ b/api/app/core/cos_url_keys.py @@ -8,6 +8,7 @@ from urllib.parse import urlparse from app.core.config import settings from app.core.logging import get_logger from app.ports.storage import ObjectStorage +from app.core.runtime_constants import misc_defaults logger = get_logger(__name__) @@ -35,7 +36,7 @@ def extract_cos_object_key_if_owned(url: str | None) -> str | None: candidates: list[str] = [] bucket = (settings.tencent_cos_bucket or "").strip().lower() - region = (settings.tencent_cos_region or "").strip().lower() + region = (misc_defaults.tencent_cos_region or "").strip().lower() if bucket and region: candidates.append(f"{bucket}.cos.{region}.myqcloud.com") base = (settings.tencent_cos_base_url or "").strip() @@ -135,7 +136,7 @@ def avatar_url_for_api_response(stored_url: str | None) -> str | None: if not key: return s if not ( - (settings.tencent_cos_secret_id or "").strip() + (settings.tencent_secret_id or "").strip() and (settings.tencent_cos_bucket or "").strip() ): return s @@ -159,7 +160,7 @@ def best_effort_delete_cos_object_for_url(url: str | None) -> None: if not key: return if not ( - (settings.tencent_cos_secret_id or "").strip() + (settings.tencent_secret_id or "").strip() and (settings.tencent_cos_bucket or "").strip() ): return diff --git a/api/app/core/db.py b/api/app/core/db.py index 928d368..a8d68d4 100644 --- a/api/app/core/db.py +++ b/api/app/core/db.py @@ -6,16 +6,32 @@ 事务规则: - get_async_db() 只负责创建和关闭 session,不自动 commit/rollback。 -- 事务提交由 service 层显式调用 await db.commit()。 +- service / Celery task 层优先使用 transactional() / transactional_sync() 管理多步写操作。 - repo 层禁止调用 commit() / rollback()。 + +transactional 语义: +- transactional() / transactional_sync() 是顶层事务边界;成功 exit 时 commit 整个 session,异常时 rollback 整个 session。 +- 不支持嵌套自身:同一 session 上连续两次 transactional() = 两次独立 commit(WS pipeline 分段持久化属于此模式)。 +- 需要嵌套回滚时:在已开启的事务内使用 transactional_nested() / transactional_nested_sync()(基于 SQLAlchemy begin_nested() savepoint)。 +- 选择指南:单步/整段业务原子提交 → transactional;长生命周期 session 内局部试错、可独立回滚的子步骤 → transactional_nested(必须在外层事务 active 期间)。 + +transactional_nested 示例(外层提交、内层失败仅回滚 savepoint):: + + async with transactional(session): + session.add(parent_row) + try: + async with transactional_nested(session): + await attempt_optional_side_effect(session) + except RecoverableError: + pass # savepoint rolled back; parent_row still commits """ -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from typing import AsyncGenerator from sqlalchemy import create_engine, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import DeclarativeBase, sessionmaker +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from app.core.config import settings @@ -69,6 +85,32 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]: await session.close() +@asynccontextmanager +async def transactional(session: AsyncSession): + """Top-level async transaction: commit on success, rollback on any exception. + + Do not nest transactional() on the same session; each call commits independently. + For partial rollback within an active transaction, use transactional_nested(). + """ + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + +@asynccontextmanager +async def transactional_nested(session: AsyncSession): + """Savepoint boundary; roll back only this block on error. + + Must be used while the session already has an active transaction (e.g. inside + transactional() before it commits, or after autobegin from a prior write). + """ + async with session.begin_nested(): + yield session + + # ── Sync engine & session (Celery, Alembic, scripts) ───────── sync_engine = create_engine( @@ -100,6 +142,28 @@ def init_db_schema() -> None: Base.metadata.create_all(bind=sync_engine) +@contextmanager +def transactional_sync(session: Session): + """Top-level sync transaction: commit on success, rollback on any exception. + + Do not nest transactional_sync() on the same session; each call commits independently. + For partial rollback within an active transaction, use transactional_nested_sync(). + """ + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + + +@contextmanager +def transactional_nested_sync(session: Session): + """Savepoint boundary for sync Celery / scripts; roll back only this block on error.""" + with session.begin_nested(): + yield session + + @contextmanager def get_sync_db(): """Context-managed synchronous session for Celery tasks.""" diff --git a/api/app/core/dependencies.py b/api/app/core/dependencies.py index f170169..9cffa41 100644 --- a/api/app/core/dependencies.py +++ b/api/app/core/dependencies.py @@ -8,14 +8,13 @@ from functools import lru_cache from typing import Optional -from fastapi import Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer +from fastapi import Depends from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings -from app.core.db import get_async_db from app.core.eval_judge_spec import EvalJudgeLlmSpec, EvalJudgeProvider -from app.core.security import verify_token +from app.core.runtime_constants import asr_defaults, llm_defaults, misc_defaults, tts_defaults +from app.features.memoir.constants import memoir from app.ports.asr import ASRProvider from app.ports.embedding import EmbeddingProvider from app.ports.image_gen import ImageGenerator @@ -24,9 +23,6 @@ from app.ports.sms import SmsSender from app.ports.storage import ObjectStorage from app.ports.tts import TTSProvider -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") - - # ── Port DI factories ─────────────────────────────────────── @@ -35,12 +31,12 @@ def get_sms_sender() -> SmsSender: from app.adapters.sms.tencent import TencentSmsSender return TencentSmsSender( - secret_id=settings.tencent_sms_secret_id, - secret_key=settings.tencent_sms_secret_key, + secret_id=settings.tencent_secret_id, + secret_key=settings.tencent_secret_key, sdk_app_id=settings.tencent_sms_sdk_app_id, sign_name=settings.tencent_sms_sign_name, template_id=settings.tencent_sms_template_id, - template_param_count=settings.tencent_sms_template_param_count, + template_param_count=misc_defaults.tencent_sms_template_param_count, ) @@ -48,17 +44,17 @@ def get_sms_sender() -> SmsSender: def get_llm_provider() -> LLMProvider: from app.adapters.llm.deepseek import DeepSeekLLMProvider - api_key = settings.deepseek_api_key or settings.llm_api_key - base_url = settings.deepseek_base_url or settings.llm_base_url - model = settings.deepseek_model or settings.llm_model or "deepseek-v4-flash" + api_key = settings.deepseek_api_key or "" + base_url = llm_defaults.deepseek_base_url or "" + model = llm_defaults.deepseek_model or "deepseek-v4-flash" return DeepSeekLLMProvider( api_key=api_key, base_url=base_url, model=model, - temperature=settings.llm_temperature, + temperature=llm_defaults.temperature, extra_body={ "thinking": { - "type": "enabled" if settings.deepseek_thinking_enabled else "disabled" + "type": "enabled" if llm_defaults.deepseek_thinking_enabled else "disabled" } }, ) @@ -67,21 +63,21 @@ def get_llm_provider() -> LLMProvider: @lru_cache def get_llm_provider_fast() -> LLMProvider: """快档位:与默认共用密钥与 base_url,仅模型名可单独配置。""" - fast = (settings.llm_fast_model or "").strip() + fast = (llm_defaults.fast_model or "").strip() if not fast: return get_llm_provider() from app.adapters.llm.deepseek import DeepSeekLLMProvider - api_key = settings.deepseek_api_key or settings.llm_api_key - base_url = settings.deepseek_base_url or settings.llm_base_url + api_key = settings.deepseek_api_key or "" + base_url = llm_defaults.deepseek_base_url or "" return DeepSeekLLMProvider( api_key=api_key, base_url=base_url, model=fast, - temperature=settings.llm_temperature, + temperature=llm_defaults.temperature, extra_body={ "thinking": { - "type": "enabled" if settings.deepseek_thinking_enabled else "disabled" + "type": "enabled" if llm_defaults.deepseek_thinking_enabled else "disabled" } }, ) @@ -89,24 +85,24 @@ def get_llm_provider_fast() -> LLMProvider: @lru_cache def get_tts_provider() -> TTSProvider: - if settings.tts_provider == "tencent": + if tts_defaults.provider == "tencent": from app.adapters.tts.tencent_tts import TencentTTSProvider return TencentTTSProvider( secret_id=settings.tencent_secret_id, secret_key=settings.tencent_secret_key, - voice_type=settings.tts_voice_type, - codec=settings.tts_codec, - voice_type_en=settings.tts_voice_type_en, + voice_type=tts_defaults.voice_type, + codec=tts_defaults.codec, + voice_type_en=tts_defaults.voice_type_en, ) from app.adapters.tts.openai_tts import OpenAITTSProvider - return OpenAITTSProvider(api_key=settings.openai_api_key) + return OpenAITTSProvider(api_key="") @lru_cache def get_asr_provider() -> ASRProvider: - if settings.asr_provider == "tencent": + if asr_defaults.provider == "tencent": from app.adapters.asr.tencent_asr import TencentASRProvider return TencentASRProvider( @@ -117,10 +113,10 @@ def get_asr_provider() -> ASRProvider: from app.adapters.asr.whisper_local import WhisperASRProvider return WhisperASRProvider( - model_size=settings.asr_model_size, - device=settings.asr_device, - compute_type=settings.asr_compute_type, - cache_dir=settings.asr_model_cache_dir, + model_size=asr_defaults.model_size, + device=asr_defaults.device, + compute_type=asr_defaults.compute_type, + cache_dir=asr_defaults.model_cache_dir, ) @@ -131,10 +127,10 @@ def get_image_generator() -> ImageGenerator: return LiblibImageGenerator( access_key=settings.liblib_access_key, secret_key=settings.liblib_secret_key, - base_url=settings.liblib_base_url, + base_url=misc_defaults.liblib_base_url, template_uuid=settings.liblib_template_uuid, - poll_interval=settings.memoir_image_poll_interval, - max_attempts=settings.memoir_image_max_attempts, + poll_interval=memoir.image_poll_interval, + max_attempts=memoir.image_max_attempts, ) @@ -143,12 +139,12 @@ def get_object_storage() -> ObjectStorage: from app.adapters.storage.tencent_cos import TencentCosStorage return TencentCosStorage( - secret_id=settings.tencent_cos_secret_id, - secret_key=settings.tencent_cos_secret_key, - region=settings.tencent_cos_region, + secret_id=settings.tencent_secret_id, + secret_key=settings.tencent_secret_key, + region=misc_defaults.tencent_cos_region, bucket=settings.tencent_cos_bucket, base_url=settings.tencent_cos_base_url, - token=settings.tencent_cos_token, + token="", ) @@ -158,43 +154,15 @@ def get_embedding_provider() -> EmbeddingProvider: return ZhipuEmbeddingProvider( api_key=settings.zhipu_api_key, - base_url=settings.embedding_base_url or None, - model=settings.embedding_model, + base_url=llm_defaults.embedding_base_url or None, + model=llm_defaults.embedding_model, ) -# ── Auth dependencies ──────────────────────────────────────── +# Re-export auth deps for backward compatibility +from app.core.auth_deps import get_current_user, get_optional_user # noqa: F401 - -async def get_current_user( - token: str = Depends(oauth2_scheme), - db: AsyncSession = Depends(get_async_db), -): - """Resolve authenticated user from JWT access token.""" - from app.features.user.models import User # deferred to avoid circular import - - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="无法验证凭据", - headers={"WWW-Authenticate": "Bearer"}, - ) - - payload = verify_token(token) - if payload is None: - raise credentials_exception - - user_id: str | None = payload.get("sub") - if user_id is None: - raise credentials_exception - - if payload.get("type") != "access": - raise credentials_exception - - user = await db.get(User, user_id) - if user is None: - raise credentials_exception - - return user +# ── Eval judge ─────────────────────────────────────────────── def build_eval_judge_llm_spec( @@ -216,14 +184,3 @@ def get_eval_judge_langchain_llm(): return spec.llm if spec else None -async def get_optional_user( - token: Optional[str] = Depends(oauth2_scheme), - db: AsyncSession = Depends(get_async_db), -): - """Return user if a valid token is provided, else None.""" - if token is None: - return None - try: - return await get_current_user(token, db) - except HTTPException: - return None diff --git a/api/app/core/deps_types.py b/api/app/core/deps_types.py new file mode 100644 index 0000000..6b69184 --- /dev/null +++ b/api/app/core/deps_types.py @@ -0,0 +1,21 @@ +"""Shared FastAPI dependency type aliases.""" + +from typing import Annotated, Optional + +from fastapi import Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.auth_deps import ( + get_current_user, + get_optional_user, + oauth2_scheme, + oauth2_scheme_optional, +) +from app.core.db import get_async_db +from app.features.user.models import User + +DbDep = Annotated[AsyncSession, Depends(get_async_db)] +CurrentUserDep = Annotated[User, Depends(get_current_user)] +OptionalUserDep = Annotated[User | None, Depends(get_optional_user)] +OAuth2TokenDep = Annotated[str, Depends(oauth2_scheme)] +OptionalOAuth2TokenDep = Annotated[Optional[str], Depends(oauth2_scheme_optional)] diff --git a/api/app/core/error_codes.py b/api/app/core/error_codes.py new file mode 100644 index 0000000..782d827 --- /dev/null +++ b/api/app/core/error_codes.py @@ -0,0 +1,152 @@ +""" +统一错误码注册表(供 OpenAPI 文档与客户端分支参考)。 + +HTTP 错误响应体:``{ "error_code": str, "message": str, "request_id": str }`` +""" + +from __future__ import annotations + +from typing import TypedDict + + +class ErrorCodeEntry(TypedDict): + code: str + http_status: int + domain: str + description: str + + +# ── 全局 / core AppError ───────────────────────────────────── + +CORE_ERROR_CODES: list[ErrorCodeEntry] = [ + { + "code": "BAD_REQUEST", + "http_status": 400, + "domain": "core", + "description": "请求无效(通用)", + }, + { + "code": "AUTHENTICATION_FAILED", + "http_status": 401, + "domain": "core", + "description": "未认证或凭据无效", + }, + { + "code": "FORBIDDEN", + "http_status": 403, + "domain": "core", + "description": "权限不足", + }, + { + "code": "NOT_FOUND", + "http_status": 404, + "domain": "core", + "description": "资源不存在", + }, + { + "code": "CONFLICT", + "http_status": 409, + "domain": "core", + "description": "资源冲突", + }, + { + "code": "VALIDATION_ERROR", + "http_status": 422, + "domain": "core", + "description": "请求体验证失败", + }, + { + "code": "QUOTA_EXCEEDED", + "http_status": 429, + "domain": "core", + "description": "配额已用尽", + }, + { + "code": "RATE_LIMITED", + "http_status": 429, + "domain": "core", + "description": "请求频率超限(如 SMS 发送冷却)", + }, + { + "code": "INTERNAL_ERROR", + "http_status": 500, + "domain": "core", + "description": "服务器内部错误", + }, + { + "code": "PROVIDER_ERROR", + "http_status": 502, + "domain": "core", + "description": "外部服务异常(如短信发送失败)", + }, + { + "code": "SERVICE_UNAVAILABLE", + "http_status": 503, + "domain": "core", + "description": "服务未配置或暂时不可用", + }, + { + "code": "GATEWAY_TIMEOUT", + "http_status": 504, + "domain": "core", + "description": "网关超时", + }, +] + +# ── auth 领域(AuthError)──────────────────────────────────── + +AUTH_ERROR_CODES: list[ErrorCodeEntry] = [ + { + "code": "PHONE_EXISTS", + "http_status": 400, + "domain": "auth", + "description": "手机号已注册", + }, + { + "code": "EMAIL_EXISTS", + "http_status": 400, + "domain": "auth", + "description": "邮箱已注册", + }, + { + "code": "INVALID_SMS_CODE", + "http_status": 400, + "domain": "auth", + "description": "短信验证码无效、过期或已使用", + }, + { + "code": "WRONG_PASSWORD", + "http_status": 400, + "domain": "auth", + "description": "旧密码错误", + }, + { + "code": "PHONE_TAKEN", + "http_status": 409, + "domain": "auth", + "description": "手机号已被其他账号占用", + }, + { + "code": "REFRESH_TOKEN_REUSE", + "http_status": 401, + "domain": "auth", + "description": "刷新令牌在 grace 窗口外被重复使用或疑似盗用,全部会话已吊销", + }, +] + +# ── payment 领域(PaymentError / order_service)────────────── + +PAYMENT_ERROR_CODES: list[ErrorCodeEntry] = [ + { + "code": "PAYMENT_FAILED", + "http_status": 500, + "domain": "payment", + "description": "创建或处理支付订单失败", + }, +] + +ALL_ERROR_CODES: list[ErrorCodeEntry] = ( + CORE_ERROR_CODES + AUTH_ERROR_CODES + PAYMENT_ERROR_CODES +) + +ERROR_CODE_ENUM: list[str] = sorted({e["code"] for e in ALL_ERROR_CODES}) diff --git a/api/app/core/errors.py b/api/app/core/errors.py index dd7f648..d5a95ab 100644 --- a/api/app/core/errors.py +++ b/api/app/core/errors.py @@ -5,7 +5,8 @@ 成功响应直接返回 Pydantic model / FileResponse / 原始结构,不强制包装。 """ -from fastapi import FastAPI, Request +from fastapi import FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from app.core.logging import get_logger @@ -34,6 +35,11 @@ class NotFoundError(AppError): super().__init__(message, status_code=404, error_code="NOT_FOUND") +class BadRequestError(AppError): + def __init__(self, message: str = "请求无效"): + super().__init__(message, status_code=400, error_code="BAD_REQUEST") + + class AuthenticationError(AppError): def __init__(self, message: str = "认证失败"): super().__init__(message, status_code=401, error_code="AUTHENTICATION_FAILED") @@ -49,6 +55,21 @@ class ValidationError(AppError): super().__init__(message, status_code=422, error_code="VALIDATION_ERROR") +class ConflictError(AppError): + def __init__(self, message: str = "资源冲突"): + super().__init__(message, status_code=409, error_code="CONFLICT") + + +class ServiceUnavailableError(AppError): + def __init__(self, message: str = "服务暂时不可用"): + super().__init__(message, status_code=503, error_code="SERVICE_UNAVAILABLE") + + +class GatewayTimeoutError(AppError): + def __init__(self, message: str = "网关超时"): + super().__init__(message, status_code=504, error_code="GATEWAY_TIMEOUT") + + class ProviderError(AppError): def __init__(self, message: str = "外部服务异常", *, provider: str = ""): super().__init__(message, status_code=502, error_code="PROVIDER_ERROR") @@ -60,11 +81,91 @@ class QuotaExceededError(AppError): super().__init__(message, status_code=429, error_code="QUOTA_EXCEEDED") -# ── Exception handler registration ────────────────────────── +class RateLimitedError(AppError): + def __init__(self, message: str = "请求过于频繁"): + super().__init__(message, status_code=429, error_code="RATE_LIMITED") + + +# ── Error response helpers ─────────────────────────────────── def _get_request_id(request: Request) -> str: - return getattr(request.state, "request_id", "-") + rid = getattr(request.state, "request_id", None) + if rid: + return str(rid) + scope_state = request.scope.get("state") + if isinstance(scope_state, dict): + return str(scope_state.get("request_id", "-")) + if scope_state is not None: + return str(getattr(scope_state, "request_id", "-")) + return "-" + + +def _error_response( + *, + status_code: int, + error_code: str, + message: str, + request_id: str, + headers: dict[str, str] | None = None, +) -> JSONResponse: + return JSONResponse( + status_code=status_code, + content={ + "error_code": error_code, + "message": message, + "request_id": request_id, + }, + headers=headers, + ) + + +def _http_detail_to_message(detail: object) -> str: + if isinstance(detail, str): + return detail + if isinstance(detail, list): + parts: list[str] = [] + for item in detail: + if isinstance(item, dict): + loc = item.get("loc") + msg = item.get("msg", "") + if loc: + parts.append(f"{'.'.join(str(x) for x in loc)}: {msg}") + else: + parts.append(str(msg)) + else: + parts.append(str(item)) + return "; ".join(parts) if parts else "请求无效" + return str(detail) + + +_STATUS_TO_ERROR_CODE: dict[int, str] = { + 400: "BAD_REQUEST", + 401: "AUTHENTICATION_FAILED", + 403: "FORBIDDEN", + 404: "NOT_FOUND", + 409: "CONFLICT", + 422: "VALIDATION_ERROR", + 502: "PROVIDER_ERROR", + 503: "SERVICE_UNAVAILABLE", + 504: "GATEWAY_TIMEOUT", +} + + +def _error_code_for_status(status_code: int) -> str: + if status_code == 429: + # HTTP 429 is shared by QuotaExceededError and RateLimitedError; legacy + # HTTPException has no explicit code — default to rate limiting. + return "RATE_LIMITED" + mapped = _STATUS_TO_ERROR_CODE.get(status_code) + if mapped is not None: + return mapped + if status_code >= 500: + return "INTERNAL_ERROR" + return "BAD_REQUEST" + + +# ── Exception handler registration ────────────────────────── def register_exception_handlers(app: FastAPI) -> None: @@ -73,30 +174,71 @@ def register_exception_handlers(app: FastAPI) -> None: @app.exception_handler(AppError) async def app_error_handler(request: Request, exc: AppError): request_id = _get_request_id(request) + headers: dict[str, str] | None = None + if exc.status_code == 401: + headers = {"WWW-Authenticate": "Bearer"} logger.warning( "AppError: error_code={} message={} request_id={}", exc.error_code, exc.message, request_id, ) - return JSONResponse( + return _error_response( status_code=exc.status_code, - content={ - "error_code": exc.error_code, - "message": exc.message, - "request_id": request_id, - }, + error_code=exc.error_code, + message=exc.message, + request_id=request_id, + headers=headers, + ) + + @app.exception_handler(HTTPException) + async def http_exception_handler(request: Request, exc: HTTPException): + request_id = _get_request_id(request) + message = _http_detail_to_message(exc.detail) + error_code = _error_code_for_status(exc.status_code) + logger.warning( + "HTTPException: status={} error_code={} message={} request_id={}", + exc.status_code, + error_code, + message, + request_id, + ) + headers = dict(exc.headers) if exc.headers else None + if exc.status_code == 401 and headers is not None: + headers.setdefault("WWW-Authenticate", "Bearer") + elif exc.status_code == 401: + headers = {"WWW-Authenticate": "Bearer"} + return _error_response( + status_code=exc.status_code, + error_code=error_code, + message=message, + request_id=request_id, + headers=headers, + ) + + @app.exception_handler(RequestValidationError) + async def validation_error_handler(request: Request, exc: RequestValidationError): + request_id = _get_request_id(request) + message = _http_detail_to_message(exc.errors()) + logger.warning( + "RequestValidationError: message={} request_id={}", + message, + request_id, + ) + return _error_response( + status_code=422, + error_code="VALIDATION_ERROR", + message=message, + request_id=request_id, ) @app.exception_handler(Exception) async def unhandled_error_handler(request: Request, exc: Exception): request_id = _get_request_id(request) logger.exception("Unhandled exception: request_id={}", request_id) - return JSONResponse( + return _error_response( status_code=500, - content={ - "error_code": "INTERNAL_ERROR", - "message": "服务器内部错误", - "request_id": request_id, - }, + error_code="INTERNAL_ERROR", + message="服务器内部错误", + request_id=request_id, ) diff --git a/api/app/core/llm_gateway.py b/api/app/core/llm_gateway.py index 29eadf1..e75c900 100644 --- a/api/app/core/llm_gateway.py +++ b/api/app/core/llm_gateway.py @@ -14,8 +14,6 @@ from pydantic import BaseModel from app.core.dependencies import get_llm_provider, get_llm_provider_fast from app.core.llm_call import allm_json_call, llm_json_call -from app.core.llm_telemetry import langchain_invoke_span - T = TypeVar("T", bound=BaseModel) @@ -71,20 +69,7 @@ class LlmGateway: else (use_case.max_tokens if use_case else None) ), ) - # DeepSeekProvider.complete 已包 langchain_invoke_span,避免双层 span - from app.adapters.llm.deepseek import DeepSeekLLMProvider - - if isinstance(provider, DeepSeekLLMProvider): - return await provider.complete(**kwargs) - - provider_label = type(provider).__name__.replace("Provider", "").lower() or "unknown" - with langchain_invoke_span( - agent=agent_name, - provider=provider_label, - model=resolved_model or "unknown", - call_type="chat", - ): - return await provider.complete(**kwargs) + return await provider.complete(**kwargs) async def json_object( self, diff --git a/api/app/core/logging.py b/api/app/core/logging.py index b065ce9..418a9cd 100644 --- a/api/app/core/logging.py +++ b/api/app/core/logging.py @@ -43,6 +43,7 @@ from app.core.log_events import ( correlation_bind_kwargs, format_log_event, ) +from app.core.runtime_constants import agent_log_defaults if TYPE_CHECKING: from loguru import Logger @@ -185,7 +186,7 @@ def _apply_third_party_log_levels() -> None: default_celery = logging.INFO if verbose else logging.WARNING default_httpx = logging.WARNING - raw_c = (settings.celery_log_level or "").strip() + raw_c = (agent_log_defaults.celery_log_level or "").strip() if raw_c: parsed = _parse_stdlib_level(raw_c) cel_level = parsed if parsed is not None else default_celery @@ -195,7 +196,7 @@ def _apply_third_party_log_levels() -> None: for name in ("celery", "celery.worker"): logging.getLogger(name).setLevel(cel_level) - raw_h = (settings.httpx_log_level or "").strip() + raw_h = (agent_log_defaults.httpx_log_level or "").strip() if raw_h: parsed = _parse_stdlib_level(raw_h) httpx_level = parsed if parsed is not None else default_httpx @@ -252,7 +253,7 @@ def setup_logging() -> None: diagnose=False, ) - json_path = (settings.log_json_file or "").strip() + json_path = (agent_log_defaults.log_json_file or "").strip() if json_path: logger.add( json_path, diff --git a/api/app/core/memoir_pipeline_progress.py b/api/app/core/memoir_pipeline_progress.py index 01986c2..20c0085 100644 --- a/api/app/core/memoir_pipeline_progress.py +++ b/api/app/core/memoir_pipeline_progress.py @@ -6,28 +6,20 @@ from __future__ import annotations import json -import threading from datetime import datetime, timezone from typing import Any import redis -from app.core.config import settings from app.core.logging import get_logger +from app.core.redis_sync import get_sync_redis +from app.features.memoir.constants import memoir logger = get_logger(__name__) -_lock = threading.Lock() -_client: redis.Redis | None = None - def _redis() -> redis.Redis: - global _client - if _client is None: - with _lock: - if _client is None: - _client = redis.from_url(settings.redis_url, decode_responses=True) - return _client + return get_sync_redis(decode_responses=True) def _run_key(correlation_id: str) -> str: @@ -39,7 +31,7 @@ def _phase1_index_key(phase1_task_id: str) -> str: def _ttl() -> int: - return int(settings.memoir_pipeline_run_ttl_seconds) + return int(memoir.pipeline_run_ttl_seconds) def _empty_fanout() -> dict[str, Any]: diff --git a/api/app/core/memory_compaction_schedule.py b/api/app/core/memory_compaction_schedule.py index 3766e1c..72c91ef 100644 --- a/api/app/core/memory_compaction_schedule.py +++ b/api/app/core/memory_compaction_schedule.py @@ -9,15 +9,15 @@ from __future__ import annotations import json import math -import threading import time from datetime import datetime, timezone from typing import Any import redis -from app.core.config import settings from app.core.logging import get_logger +from app.core.redis_sync import get_sync_redis +from app.features.memory.constants import memory logger = get_logger(__name__) @@ -28,25 +28,14 @@ _CURSOR_KEY = "memory_compaction:chunk_cursor:{user_id}" # 与 memory_chunks.id 字典序比较用(首跑起点) _CHUNK_CURSOR_ID_ZERO = "00000000-0000-0000-0000-000000000000" -_redis_client: redis.Redis | None = None -_redis_lock = threading.Lock() - def _get_redis() -> redis.Redis: - """进程内复用单个 Redis 客户端(内置连接池),避免每次调用新建连接。""" - global _redis_client - if _redis_client is None: - with _redis_lock: - if _redis_client is None: - _redis_client = redis.from_url( - settings.redis_url, decode_responses=True - ) - return _redis_client + return get_sync_redis(decode_responses=True) def _debounce_key_ttl_seconds() -> int: """debounce 键与 scheduler_gate 共用 TTL,避免 gate 先过期导致重复 apply_async。""" - return int(settings.memory_compaction_debounce_seconds) + 900 + return int(memory.compaction_debounce_seconds) + 900 def debounce_key(user_id: str) -> str: @@ -201,12 +190,12 @@ def schedule_memory_compaction_run( user_id: str, context: dict[str, Any] | None ) -> None: """在 memoir / 章节重组等成功后调用:推后 debounce 截止时间并尽量只派发一次延迟任务。""" - if not settings.memory_compaction_enabled: + if not memory.compaction_enabled: return r = _get_redis() now = time.time() - quiet = float(settings.memory_compaction_debounce_seconds) + quiet = float(memory.compaction_debounce_seconds) new_deadline = now + quiet raw = r.get(debounce_key(user_id)) diff --git a/api/app/core/middleware.py b/api/app/core/middleware.py index c7a708d..2f39689 100644 --- a/api/app/core/middleware.py +++ b/api/app/core/middleware.py @@ -4,23 +4,52 @@ HTTP 中间件:request_id 注入。 import uuid -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.datastructures import State from starlette.requests import Request +from starlette.types import ASGIApp, Message, Receive, Scope, Send from app.core.logging import logger from app.core.telemetry import current_trace_context -class RequestIdMiddleware(BaseHTTPMiddleware): - """Inject request_id into request.state and response headers, bind to loguru context.""" +class RequestIdMiddleware: + """Inject request_id into request.state and response headers, bind to loguru context. - async def dispatch(self, request: Request, call_next): - request_id = request.headers.get("X-Request-ID") or str(uuid.uuid4()) - request.state.request_id = request_id + Pure ASGI middleware (not BaseHTTPMiddleware) so FastAPI exception handlers still run. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request_id = None + for name, value in scope.get("headers", []): + if name == b"x-request-id": + request_id = value.decode("latin-1") + break + if not request_id: + request_id = str(uuid.uuid4()) + existing_state = scope.get("state") + if isinstance(existing_state, State): + existing_state.request_id = request_id + elif isinstance(existing_state, dict): + existing_state["request_id"] = request_id + scope["state"] = State(existing_state) + else: + scope["state"] = State({"request_id": request_id}) bind = {"request_id": request_id, **current_trace_context()} - with logger.contextualize(**bind): - response = await call_next(request) - response.headers["X-Request-ID"] = request_id - return response + async def send_with_request_id(message: Message) -> None: + if message["type"] == "http.response.start": + headers = list(message.get("headers", [])) + headers.append((b"x-request-id", request_id.encode("latin-1"))) + message = {**message, "headers": headers} + await send(message) + + with logger.contextualize(**bind): + await self.app(scope, receive, send_with_request_id) diff --git a/api/app/core/openapi.py b/api/app/core/openapi.py index c4ed139..2081c2d 100644 --- a/api/app/core/openapi.py +++ b/api/app/core/openapi.py @@ -1,25 +1,131 @@ """ -OpenAPI 全局增强:只做 title/version/description 等元数据,不替代 router 上的正常声明。 +OpenAPI 全局增强:元数据 + 统一 ErrorResponse 组件 + 领域错误码表。 """ from fastapi import FastAPI from fastapi.openapi.utils import get_openapi +from app.core.error_codes import ALL_ERROR_CODES, ERROR_CODE_ENUM + +ERROR_RESPONSE_REF = "#/components/schemas/ErrorResponse" +ERROR_CODE_REF = "#/components/schemas/ErrorCode" + +COMMON_ERROR_RESPONSES: dict[int, dict] = { + 400: {"description": "请求参数错误", "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}}, + 401: {"description": "认证失败", "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}}, + 403: {"description": "权限不足", "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}}, + 404: {"description": "资源不存在", "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}}, + 409: {"description": "资源冲突", "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}}, + 422: {"description": "请求体验证失败", "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}}, + 429: { + "description": "配额已用尽(QUOTA_EXCEEDED)或请求频率超限(RATE_LIMITED)", + "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}, + }, + 500: {"description": "内部服务器错误", "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}}, + 502: {"description": "外部服务异常", "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}}, + 503: {"description": "服务不可用", "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}}, + 504: {"description": "网关超时", "content": {"application/json": {"schema": {"$ref": ERROR_RESPONSE_REF}}}}, +} + + +def error_responses( + *status_codes: int, + descriptions: dict[int, str] | None = None, +) -> dict[int, dict]: + """Pick reusable OpenAPI error response entries by HTTP status code.""" + out: dict[int, dict] = {} + for code in status_codes: + entry = dict(COMMON_ERROR_RESPONSES[code]) + if descriptions and code in descriptions: + entry["description"] = descriptions[code] + out[code] = entry + return out + + +def _error_code_catalog_markdown() -> str: + lines = [ + "", + "### 错误响应格式", + "", + "所有 HTTP 错误返回 `application/json`:", + "", + "```json", + '{ "error_code": "NOT_FOUND", "message": "资源不存在", "request_id": "req_xxx" }', + "```", + "", + "组件 `ErrorCode` / `DomainErrorCode` 列出机器可读码;`message` 为面向用户的说明。", + "", + "| error_code | HTTP | 域 | 说明 |", + "|------------|------|-----|------|", + ] + for entry in ALL_ERROR_CODES: + status = entry["http_status"] + status_cell = str(status) if status else "—" + lines.append( + f"| `{entry['code']}` | {status_cell} | {entry['domain']} | {entry['description']} |" + ) + return "\n".join(lines) + def custom_openapi(app: FastAPI) -> dict: if app.openapi_schema: return app.openapi_schema + base_description = ( + "为老年用户提供 AI 驱动的回忆录创作服务:\n" + "语音对话采集 → 素材沉淀 → 章节生成 → PDF 导出。" + ) + openapi_schema = get_openapi( title="Life Echo API", version="1.0.0", - summary="岁月时书 — 口述回忆录生产平台", - description=( - "为老年用户提供 AI 驱动的回忆录创作服务:\n" - "语音对话采集 → 素材沉淀 → 章节生成 → PDF 导出。" - ), + summary="岁月留书 — 口述回忆录生产平台", + description=base_description + _error_code_catalog_markdown(), routes=app.routes, ) + components = openapi_schema.setdefault("components", {}) + schemas = components.setdefault("schemas", {}) + + schemas["ErrorCode"] = { + "title": "ErrorCode", + "type": "string", + "enum": ERROR_CODE_ENUM, + "description": "机器可读错误码(全局 + 领域);完整说明见 API 描述中的错误码表。", + } + + domain_codes = sorted( + {e["code"] for e in ALL_ERROR_CODES if e["domain"] != "core"} + ) + schemas["DomainErrorCode"] = { + "title": "DomainErrorCode", + "type": "string", + "enum": domain_codes, + "description": "业务领域错误码(auth / payment 等),与通用 ErrorCode 并存。", + } + + schemas["ErrorResponse"] = { + "title": "ErrorResponse", + "type": "object", + "required": ["error_code", "message", "request_id"], + "properties": { + "error_code": { + "allOf": [{"$ref": ERROR_CODE_REF}], + "description": "机器可读错误码", + "example": "NOT_FOUND", + }, + "message": { + "type": "string", + "description": "面向用户的错误说明", + "example": "资源不存在", + }, + "request_id": { + "type": "string", + "description": "请求追踪 ID", + "example": "req_abc123", + }, + }, + } + app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/api/app/core/redis.py b/api/app/core/redis.py index e46a4da..380a5e1 100644 --- a/api/app/core/redis.py +++ b/api/app/core/redis.py @@ -1,5 +1,4 @@ -""" -Redis 客户端与会话/缓存能力:供应用生命周期、会话历史、任务追踪等使用。 +"""Redis 客户端与会话/缓存能力:供应用生命周期、会话历史、任务追踪等使用。 配置从 app.core.config.settings 读取,禁止业务层散落 os.getenv。 """ @@ -11,6 +10,7 @@ import redis.asyncio as aioredis from app.core.config import settings from app.core.logging import get_logger +from app.core.runtime_constants import misc_defaults, redis_defaults logger = get_logger(__name__) @@ -19,9 +19,9 @@ class RedisService: """Redis 服务:连接管理、对话历史、通用缓存。""" def __init__(self) -> None: - self.redis_url = settings.redis_url + self.redis_url = settings.redis_url_resolved self._client: Optional[aioredis.Redis] = None - self.session_ttl = settings.redis_session_ttl + self.session_ttl = misc_defaults.redis_session_ttl async def get_client(self) -> aioredis.Redis: """获取 Redis 客户端(延迟初始化)。""" @@ -31,6 +31,10 @@ class RedisService: self.redis_url, encoding="utf-8", decode_responses=True, + socket_timeout=redis_defaults.socket_timeout_seconds, + socket_connect_timeout=redis_defaults.socket_connect_timeout_seconds, + health_check_interval=redis_defaults.health_check_interval_seconds, + retry_on_timeout=True, ) await self._client.ping() logger.info("Redis 连接成功") @@ -59,15 +63,64 @@ class RedisService: def _conversation_key(self, conversation_id: str) -> str: return f"conversation:history:{conversation_id}" + async def _key_type(self, client: aioredis.Redis, key: str) -> str: + key_type = await client.type(key) + if isinstance(key_type, bytes): + return key_type.decode("utf-8") + return str(key_type) + + async def _parse_history_items(self, raw_items: List[str]) -> List[Dict[str, Any]]: + history: List[Dict[str, Any]] = [] + for raw in raw_items: + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + logger.warning("跳过无效对话历史条目") + continue + if isinstance(parsed, dict): + history.append(parsed) + return history + + async def _migrate_string_history_to_list( + self, client: aioredis.Redis, key: str, history: List[Dict[str, Any]] + ) -> None: + if not history: + await client.delete(key) + return + pipe = client.pipeline(transaction=True) + pipe.delete(key) + for item in history: + pipe.rpush(key, json.dumps(item, ensure_ascii=False)) + pipe.expire(key, self.session_ttl) + await pipe.execute() + async def get_conversation_history( self, conversation_id: str ) -> List[Dict[str, Any]]: try: client = await self.get_client() key = self._conversation_key(conversation_id) - data = await client.get(key) - if data: - return json.loads(data) + if not await client.exists(key): + return [] + key_type = await self._key_type(client, key) + if key_type == "list": + raw_items = await client.lrange(key, 0, -1) + return await self._parse_history_items(list(raw_items)) + if key_type == "string": + data = await client.get(key) + if not data: + return [] + legacy = json.loads(data) + if not isinstance(legacy, list): + return [] + history = [x for x in legacy if isinstance(x, dict)] + await self._migrate_string_history_to_list(client, key, history) + return history + logger.warning( + "conversation history unexpected type={} key={}", + key_type, + key, + ) return [] except Exception as e: logger.error("获取对话历史失败: {}", e) @@ -80,9 +133,12 @@ class RedisService: try: client = await self.get_client() key = self._conversation_key(conversation_id) - await client.setex( - key, self.session_ttl, json.dumps(history, ensure_ascii=False) - ) + pipe = client.pipeline(transaction=True) + pipe.delete(key) + for item in history: + pipe.rpush(key, json.dumps(item, ensure_ascii=False)) + pipe.expire(key, self.session_ttl) + await pipe.execute() return True except Exception as e: logger.error("写入对话历史失败: {}", e) @@ -101,7 +157,6 @@ class RedisService: try: client = await self.get_client() key = self._conversation_key(conversation_id) - history = await self.get_conversation_history(conversation_id) item = { "role": role, "content": content, @@ -116,10 +171,10 @@ class RedisService: and message_type == "audio" ): item["durationSeconds"] = int(audio_duration_seconds) - history.append(item) - await client.setex( - key, self.session_ttl, json.dumps(history, ensure_ascii=False) - ) + pipe = client.pipeline(transaction=True) + pipe.rpush(key, json.dumps(item, ensure_ascii=False)) + pipe.expire(key, self.session_ttl) + await pipe.execute() return True except Exception as e: logger.error("添加消息失败: {}", e) @@ -135,6 +190,7 @@ class RedisService: client = await self.get_client() key = self._conversation_key(conversation_id) history = await self.get_conversation_history(conversation_id) + target_index: int | None = None for i in range(len(history) - 1, -1, -1): if history[i].get("role") == "ai": existing = history[i].get("ttsAudioUrls") @@ -145,16 +201,20 @@ class RedisService: ) urls.append(url) history[i]["ttsAudioUrls"] = urls + target_index = i break - else: + if target_index is None: logger.warning( "append_tts_audio_url: no ai message in history conversation_id={}", conversation_id, ) return False - await client.setex( - key, self.session_ttl, json.dumps(history, ensure_ascii=False) + await client.lset( + key, + target_index, + json.dumps(history[target_index], ensure_ascii=False), ) + await client.expire(key, self.session_ttl) return True except Exception as e: logger.error("append_tts_audio_url 失败: {}", e) @@ -200,6 +260,9 @@ class RedisService: return False async def set_cache(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: + if ttl is None or ttl <= 0: + logger.error("设置缓存失败: TTL 必须为正整数 key={}", key) + return False try: client = await self.get_client() data = ( @@ -207,10 +270,7 @@ class RedisService: if not isinstance(value, str) else value ) - if ttl: - await client.setex(key, ttl, data) - else: - await client.set(key, data) + await client.setex(key, ttl, data) return True except Exception as e: logger.error("设置缓存失败: {}", e) diff --git a/api/app/core/redis_lock.py b/api/app/core/redis_lock.py index 56d4037..42d8091 100644 --- a/api/app/core/redis_lock.py +++ b/api/app/core/redis_lock.py @@ -1,27 +1,9 @@ """Small Redis lock helpers for background tasks.""" -import threading import uuid from dataclasses import dataclass -import redis - -from app.core.config import settings - -_redis_lock_client: redis.Redis | None = None -_redis_lock_init_lock = threading.Lock() - - -def _get_redis_lock_client() -> redis.Redis: - """进程内复用单个 Redis 客户端(decode_responses=False,与锁 token 字节一致)。""" - global _redis_lock_client - if _redis_lock_client is None: - with _redis_lock_init_lock: - if _redis_lock_client is None: - _redis_lock_client = redis.from_url( - settings.redis_url, decode_responses=False - ) - return _redis_lock_client +from app.core.redis_sync import get_sync_redis @dataclass(frozen=True) @@ -32,7 +14,7 @@ class RedisLockHandle: def acquire_redis_lock(key: str, *, ttl_seconds: int) -> RedisLockHandle | None: """Acquire a single-owner Redis lock or return None when unavailable.""" - client = _get_redis_lock_client() + client = get_sync_redis(decode_responses=False) token = uuid.uuid4().hex.encode("utf-8") if not client.set(key, token, nx=True, ex=ttl_seconds): return None @@ -43,7 +25,7 @@ def release_redis_lock(handle: RedisLockHandle | None) -> None: """Release the lock only if we still own it.""" if handle is None: return - _get_redis_lock_client().eval( + get_sync_redis(decode_responses=False).eval( """ if redis.call("GET", KEYS[1]) == ARGV[1] then return redis.call("DEL", KEYS[1]) diff --git a/api/app/core/redis_sync.py b/api/app/core/redis_sync.py new file mode 100644 index 0000000..3900a69 --- /dev/null +++ b/api/app/core/redis_sync.py @@ -0,0 +1,44 @@ +"""Process-wide synchronous Redis client factory with connection pooling.""" + +from __future__ import annotations + +import threading + +import redis + +from app.core.config import settings +from app.core.runtime_constants import redis_defaults + +_clients: dict[bool, redis.Redis] = {} +_init_lock = threading.Lock() + + +def get_sync_redis(*, decode_responses: bool = True) -> redis.Redis: + """Return a shared sync Redis client (one pool per decode_responses mode).""" + client = _clients.get(decode_responses) + if client is not None: + return client + with _init_lock: + client = _clients.get(decode_responses) + if client is None: + client = redis.from_url( + settings.redis_url_resolved, + decode_responses=decode_responses, + socket_timeout=redis_defaults.socket_timeout_seconds, + socket_connect_timeout=redis_defaults.socket_connect_timeout_seconds, + health_check_interval=redis_defaults.health_check_interval_seconds, + retry_on_timeout=True, + ) + _clients[decode_responses] = client + return client + + +def reset_sync_redis_clients_for_tests() -> None: + """Close and clear cached clients (tests only).""" + with _init_lock: + for client in _clients.values(): + try: + client.close() + except Exception: + pass + _clients.clear() diff --git a/api/app/core/redis_urls.py b/api/app/core/redis_urls.py new file mode 100644 index 0000000..bf94076 --- /dev/null +++ b/api/app/core/redis_urls.py @@ -0,0 +1,96 @@ +"""Redis URL resolution: password injection and Celery DB separation.""" + +from __future__ import annotations + +from urllib.parse import quote, urlparse, urlunparse + + +def _url_has_password(parsed) -> bool: + return bool(parsed.password or parsed.username) + + +def inject_redis_password(redis_url: str, password: str | None) -> str: + """Inject REDIS_PASSWORD into URL when the URL has no credentials.""" + if not password: + return redis_url + parsed = urlparse(redis_url) + if _url_has_password(parsed): + return redis_url + host = parsed.hostname or "localhost" + port = parsed.port + netloc = f":{quote(password, safe='')}@{host}" + if port is not None: + netloc = f":{quote(password, safe='')}@{host}:{port}" + return urlunparse( + ( + parsed.scheme or "redis", + netloc, + parsed.path or "", + parsed.params, + parsed.query, + parsed.fragment, + ) + ) + + +def _parse_db_index(path: str) -> int: + segment = (path or "").strip("/") + if not segment: + return 0 + try: + return int(segment) + except ValueError: + return 0 + + +def derive_celery_redis_url( + redis_url: str, + *, + celery_redis_url_override: str | None = None, +) -> str: + """Resolve Celery broker/backend URL (override or same host with DB+1). + + Business keys use ``REDIS_URL`` (typically DB/0); Celery broker/backend use + the next logical DB. After upgrading to DB separation, unconsumed Celery + keys on the business DB are abandoned (one-time cutover). + + When ``REDIS_URL`` uses DB/15, set ``CELERY_REDIS_URL`` explicitly — Redis + only supports logical DBs 0–15. + """ + if celery_redis_url_override: + return celery_redis_url_override + parsed = urlparse(redis_url) + db = _parse_db_index(parsed.path) + if db >= 15: + raise ValueError( + "REDIS_URL uses DB/15; Celery cannot auto-derive DB+1. " + "Set CELERY_REDIS_URL explicitly." + ) + new_path = f"/{db + 1}" + return urlunparse( + ( + parsed.scheme or "redis", + parsed.netloc, + new_path, + parsed.params, + parsed.query, + parsed.fragment, + ) + ) + + +def resolve_redis_urls( + redis_url: str, + *, + redis_password: str | None = None, + celery_redis_url_override: str | None = None, +) -> tuple[str, str]: + """Return (business_redis_url, celery_redis_url).""" + business = inject_redis_password(redis_url, redis_password) + celery = derive_celery_redis_url( + business, + celery_redis_url_override=celery_redis_url_override, + ) + if celery_redis_url_override is None: + celery = inject_redis_password(celery, redis_password) + return business, celery diff --git a/api/app/core/runtime_constants.py b/api/app/core/runtime_constants.py new file mode 100644 index 0000000..439db97 --- /dev/null +++ b/api/app/core/runtime_constants.py @@ -0,0 +1,13 @@ +"""运行时默认值 — 值来自 config/*.toml(SSOT)。""" + +from app.core.app_config import app_config + +llm_defaults = app_config.llm +asr_defaults = app_config.asr +tts_defaults = app_config.tts +celery_defaults = app_config.celery +redis_defaults = app_config.redis +alembic_defaults = app_config.alembic +agent_log_defaults = app_config.agent_log +otel_defaults = app_config.otel +misc_defaults = app_config.misc diff --git a/api/app/core/security.py b/api/app/core/security.py index 1a8e37e..84d3b00 100644 --- a/api/app/core/security.py +++ b/api/app/core/security.py @@ -11,8 +11,9 @@ import bcrypt import jwt from app.core.config import settings +from app.core.runtime_constants import misc_defaults -ALGORITHM = settings.algorithm +ALGORITHM = misc_defaults.algorithm SECRET_KEY = settings.secret_key ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes REFRESH_TOKEN_EXPIRE_DAYS = settings.refresh_token_expire_days diff --git a/api/app/core/task_tracker.py b/api/app/core/task_tracker.py index 710bf8e..ebf5119 100644 --- a/api/app/core/task_tracker.py +++ b/api/app/core/task_tracker.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List from app.core.logging import get_logger from app.core.redis import redis_service +from app.core.runtime_constants import redis_defaults logger = get_logger(__name__) @@ -16,7 +17,14 @@ class TaskTracker: """任务追踪器,使用 Redis 存储任务状态""" KEY_PREFIX = "task:user:" - TASK_TTL = 3600 + + @property + def task_ttl(self) -> int: + return int(redis_defaults.task_tracker_ttl_seconds) + + async def _refresh_key_ttl(self, key: str) -> None: + client = await redis_service.get_client() + await client.expire(key, self.task_ttl) async def add_task( self, user_id: str, task_id: str, task_type: str = "memoir" @@ -31,7 +39,7 @@ class TaskTracker: "created_at": datetime.now(timezone.utc).isoformat(), } await client.hset(key, task_id, json.dumps(task_info)) - await client.expire(key, self.TASK_TTL) + await client.expire(key, self.task_ttl) logger.debug("任务已记录: user_id={}, task_id={}", user_id, task_id) return True except Exception as e: @@ -54,6 +62,7 @@ class TaskTracker: if result is not None: task_info["result"] = result await client.hset(key, task_id, json.dumps(task_info)) + await self._refresh_key_ttl(key) return True except Exception as e: logger.error("更新任务状态失败: {}", e) @@ -108,6 +117,8 @@ class TaskTracker: client = await redis_service.get_client() key = f"{self.KEY_PREFIX}{user_id}:tasks" await client.hdel(key, task_id) + if await client.exists(key): + await self._refresh_key_ttl(key) return True except Exception as e: logger.error("移除任务失败: {}", e) diff --git a/api/app/core/telemetry.py b/api/app/core/telemetry.py index 21b45a4..053bfeb 100644 --- a/api/app/core/telemetry.py +++ b/api/app/core/telemetry.py @@ -31,12 +31,16 @@ from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio from app.core.config import settings +from app.core.runtime_constants import otel_defaults if TYPE_CHECKING: from fastapi import FastAPI _initialized = False _otel_logging_handler: LoggingHandler | None = None +_tracer_provider: TracerProvider | None = None +_meter_provider: MeterProvider | None = None +_log_provider: LoggerProvider | None = None def _build_resource(service_name: str) -> Resource: @@ -56,8 +60,8 @@ def _build_sampler(): TraceIdRatioBased, ) - name = (settings.otel_traces_sampler or "always_on").strip().lower() - arg = settings.otel_traces_sampler_arg + name = (otel_defaults.traces_sampler(settings.app_environment) or "always_on").strip().lower() + arg = otel_defaults.traces_sampler_arg(settings.app_environment) if name in ("always_on", "alwayson"): return ALWAYS_ON if name in ("always_off", "alwaysoff"): @@ -68,39 +72,58 @@ def _build_sampler(): return ParentBasedTraceIdRatio(ratio) +def _otlp_timeout_seconds() -> int | None: + env = (settings.app_environment or "").strip().lower() + if env == "development": + return 3 + return 10 + + def setup_telemetry(*, service_name: str) -> None: """配置 OTLP exporter 与自动 instrumentation(幂等)。""" global _initialized, _otel_logging_handler + global _tracer_provider, _meter_provider, _log_provider if _initialized or not settings.otel_enabled: return endpoint = settings.otel_exporter_otlp_endpoint.rstrip("/") - insecure = settings.otel_exporter_otlp_insecure + insecure = otel_defaults.exporter_insecure + timeout = _otlp_timeout_seconds() resource = _build_resource(service_name) - span_exporter = OTLPSpanExporter(endpoint=endpoint, insecure=insecure) - tracer_provider = TracerProvider(resource=resource, sampler=_build_sampler()) - tracer_provider.add_span_processor(BatchSpanProcessor(span_exporter)) - trace.set_tracer_provider(tracer_provider) + span_exporter = OTLPSpanExporter( + endpoint=endpoint, insecure=insecure, timeout=timeout + ) + _tracer_provider = TracerProvider(resource=resource, sampler=_build_sampler()) + _tracer_provider.add_span_processor( + BatchSpanProcessor(span_exporter, export_timeout_millis=(timeout or 10) * 1000) + ) + trace.set_tracer_provider(_tracer_provider) - metric_exporter = OTLPMetricExporter(endpoint=endpoint, insecure=insecure) + metric_exporter = OTLPMetricExporter( + endpoint=endpoint, insecure=insecure, timeout=timeout + ) metric_reader = PeriodicExportingMetricReader( metric_exporter, - export_interval_millis=settings.otel_metric_export_interval_ms, + export_interval_millis=otel_defaults.metric_export_interval_ms, ) - meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) - metrics.set_meter_provider(meter_provider) + _meter_provider = MeterProvider(resource=resource, metric_readers=[metric_reader]) + metrics.set_meter_provider(_meter_provider) - log_exporter = OTLPLogExporter(endpoint=endpoint, insecure=insecure) - log_provider = LoggerProvider(resource=resource) - log_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter)) - set_logger_provider(log_provider) + log_exporter = OTLPLogExporter(endpoint=endpoint, insecure=insecure, timeout=timeout) + _log_provider = LoggerProvider(resource=resource) + _log_provider.add_log_record_processor( + BatchLogRecordProcessor( + log_exporter, export_timeout_millis=(timeout or 10) * 1000 + ) + ) + set_logger_provider(_log_provider) LoggingInstrumentor().instrument(set_logging_format=True) _otel_logging_handler = LoggingHandler( level=logging.NOTSET, - logger_provider=log_provider, + logger_provider=_log_provider, ) logging.getLogger().addHandler(_otel_logging_handler) @@ -111,6 +134,56 @@ def setup_telemetry(*, service_name: str) -> None: _initialized = True +def shutdown_telemetry() -> None: + """停止 OTLP 导出线程并卸载 instrumentation(测试进程退出 / 热重载 / Ctrl+C 前调用)。""" + global _initialized, _otel_logging_handler + global _tracer_provider, _meter_provider, _log_provider + if not _initialized: + return + + for name in ( + "opentelemetry", + "opentelemetry.sdk", + "opentelemetry.exporter", + "opentelemetry.exporter.otlp", + ): + logging.getLogger(name).setLevel(logging.CRITICAL) + + if _otel_logging_handler is not None: + logging.getLogger().removeHandler(_otel_logging_handler) + _otel_logging_handler = None + + try: + FastAPIInstrumentor().uninstrument() + except Exception: + pass + + for instrumentor in ( + LoggingInstrumentor(), + HTTPXClientInstrumentor(), + RedisInstrumentor(), + SQLAlchemyInstrumentor(), + CeleryInstrumentor(), + ): + try: + instrumentor.uninstrument() + except Exception: + pass + + for provider in (_log_provider, _meter_provider, _tracer_provider): + if provider is None: + continue + try: + provider.shutdown() + except Exception: + pass + + _tracer_provider = None + _meter_provider = None + _log_provider = None + _initialized = False + + def instrument_fastapi_app(app: FastAPI) -> None: if not settings.otel_enabled: return diff --git a/api/app/features/auth/deps.py b/api/app/features/auth/deps.py index e68f8f9..3827350 100644 --- a/api/app/features/auth/deps.py +++ b/api/app/features/auth/deps.py @@ -1,16 +1,17 @@ """Auth feature dependencies: get_auth_service.""" from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import get_async_db -from app.core.dependencies import get_sms_sender +from app.core.dependencies import get_object_storage, get_sms_sender +from app.core.deps_types import DbDep from app.features.auth.service import AuthService from app.ports.sms import SmsSender +from app.ports.storage import ObjectStorage def get_auth_service( - db: AsyncSession = Depends(get_async_db), + db: DbDep, sms: SmsSender = Depends(get_sms_sender), + object_storage: ObjectStorage = Depends(get_object_storage), ) -> AuthService: - return AuthService(db=db, sms=sms) + return AuthService(db=db, sms=sms, object_storage=object_storage) diff --git a/api/app/features/auth/integrity.py b/api/app/features/auth/integrity.py new file mode 100644 index 0000000..ff5c686 --- /dev/null +++ b/api/app/features/auth/integrity.py @@ -0,0 +1,42 @@ +"""Map users-table unique constraint violations to auth error codes.""" + +from __future__ import annotations + +from sqlalchemy.exc import IntegrityError + +_PHONE_CONSTRAINT_MARKERS = frozenset({"phone", "ix_users_phone", "users_phone_key"}) +_EMAIL_CONSTRAINT_MARKERS = frozenset({"email", "users_email_key", "ix_users_email"}) + + +def _constraint_text(exc: IntegrityError) -> str: + orig = exc.orig + if orig is None: + return str(exc).lower() + diag = getattr(orig, "diag", None) + if diag is not None: + name = getattr(diag, "constraint_name", None) + if name: + return str(name).lower() + return str(orig).lower() + + +def _matches(markers: frozenset[str], text: str) -> bool: + return any(marker in text for marker in markers) + + +def user_integrity_auth_code( + exc: IntegrityError, + *, + phone_conflict: str, +) -> str | None: + """Return auth internal code for a users-table unique violation, or None.""" + text = _constraint_text(exc) + if _matches(_PHONE_CONSTRAINT_MARKERS, text): + return phone_conflict + if _matches(_EMAIL_CONSTRAINT_MARKERS, text): + return "EMAIL_EXISTS" + return None + + +def is_user_phone_unique_violation(exc: IntegrityError) -> bool: + return _matches(_PHONE_CONSTRAINT_MARKERS, _constraint_text(exc)) diff --git a/api/app/features/auth/models.py b/api/app/features/auth/models.py index 4e965ec..f7322db 100644 --- a/api/app/features/auth/models.py +++ b/api/app/features/auth/models.py @@ -14,8 +14,17 @@ class RefreshToken(Base): created_at = Column(DateTime(timezone=True), default=utc_now) is_revoked = Column(Boolean, default=False) device_info = Column(String, nullable=True) + replaced_by_token_id = Column( + String, ForeignKey("refresh_tokens.id"), nullable=True, index=True + ) + rotated_at = Column(DateTime(timezone=True), nullable=True) user = relationship("User", back_populates="refresh_tokens") + replaced_by = relationship( + "RefreshToken", + remote_side="RefreshToken.id", + foreign_keys=[replaced_by_token_id], + ) class SmsVerificationCode(Base): diff --git a/api/app/features/auth/repo.py b/api/app/features/auth/repo.py index 75a3193..216c459 100644 --- a/api/app/features/auth/repo.py +++ b/api/app/features/auth/repo.py @@ -1,6 +1,9 @@ -from sqlalchemy import select +from datetime import datetime + +from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession +from app.core.db import utc_now from app.features.auth.models import RefreshToken, SmsVerificationCode from app.features.user.models import User @@ -29,6 +32,27 @@ async def get_refresh_token_by_token( return result.scalar_one_or_none() +async def get_refresh_token_by_id( + token_id: str, db: AsyncSession +) -> RefreshToken | None: + return await db.get(RefreshToken, token_id) + + +async def link_refresh_rotation( + old_token_id: str, + new_token_id: str, + rotated_at: datetime, + db: AsyncSession, +) -> None: + """Record lineage on the consumed refresh token row.""" + stmt = ( + update(RefreshToken) + .where(RefreshToken.id == old_token_id) + .values(replaced_by_token_id=new_token_id, rotated_at=rotated_at) + ) + await db.execute(stmt) + + async def get_active_tokens_for_user( user_id: str, db: AsyncSession ) -> list[RefreshToken]: @@ -48,6 +72,25 @@ async def create_refresh_token(token: RefreshToken, db: AsyncSession) -> None: db.add(token) +async def try_consume_refresh_token( + token_str: str, db: AsyncSession +) -> RefreshToken | None: + """Atomically revoke a valid refresh token; returns row or None.""" + now = utc_now() + stmt = ( + update(RefreshToken) + .where( + RefreshToken.token == token_str, + RefreshToken.is_revoked.is_(False), + RefreshToken.expires_at > now, + ) + .values(is_revoked=True) + .returning(RefreshToken) + ) + result = await db.execute(stmt) + return result.scalar_one_or_none() + + # ── SMS verification code ───────────────────────────────────── @@ -63,7 +106,10 @@ async def get_recent_code_for_rate_limit( """Latest verification code record for the phone (for rate limit check).""" stmt = ( select(SmsVerificationCode) - .where(SmsVerificationCode.phone == phone) + .where( + SmsVerificationCode.phone == phone, + SmsVerificationCode.is_expired.is_(False), + ) .order_by(SmsVerificationCode.created_at.desc()) .limit(1) ) @@ -88,3 +134,38 @@ async def get_latest_unused_code( ) result = await db.execute(stmt) return result.scalar_one_or_none() + + +async def mark_verification_code_expired(code_id: str, db: AsyncSession) -> None: + """Mark a verification code expired (e.g. after SMS provider failure).""" + stmt = ( + update(SmsVerificationCode) + .where(SmsVerificationCode.id == code_id) + .values(is_expired=True) + ) + await db.execute(stmt) + + +async def try_consume_verification_code( + phone: str, + code: str, + purpose: str, + db: AsyncSession, +) -> SmsVerificationCode | None: + """Atomically mark matching unused code as used; returns row or None.""" + now = utc_now() + stmt = ( + update(SmsVerificationCode) + .where( + SmsVerificationCode.phone == phone, + SmsVerificationCode.purpose == purpose, + SmsVerificationCode.code == code, + SmsVerificationCode.is_used.is_(False), + SmsVerificationCode.is_expired.is_(False), + SmsVerificationCode.expires_at > now, + ) + .values(is_used=True, verified_at=now) + .returning(SmsVerificationCode) + ) + result = await db.execute(stmt) + return result.scalar_one_or_none() diff --git a/api/app/features/auth/router.py b/api/app/features/auth/router.py index d3108ae..5e7889c 100644 --- a/api/app/features/auth/router.py +++ b/api/app/features/auth/router.py @@ -1,25 +1,24 @@ -import io import time from pathlib import Path -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, UploadFile, status from fastapi.responses import FileResponse -from PIL import Image from app.core.config import settings from app.core.cos_url_keys import ( avatar_url_for_api_response, best_effort_delete_cos_object_for_url, - extract_cos_object_key_if_owned, ) -from app.core.dependencies import get_current_user +from app.core.deps_types import CurrentUserDep +from app.core.errors import BadRequestError, NotFoundError from app.core.logging import get_logger +from app.core.openapi import error_responses from app.features.auth.deps import get_auth_service from app.features.auth.preset_avatars import ( avatar_url_for_preset_filename, list_preset_items, - preset_filename_for_id, preset_file_path, + preset_filename_for_id, safe_avatar_upload_path, ) from app.features.auth.schemas import ( @@ -39,7 +38,7 @@ from app.features.auth.schemas import ( UpdateNicknameRequest, UserResponse, ) -from app.features.auth.service import AuthError, AuthService +from app.features.auth.service import AuthService from app.features.user.models import User logger = get_logger(__name__) @@ -47,28 +46,13 @@ logger = get_logger(__name__) router = APIRouter( prefix="/api/auth", tags=["auth"], - responses={401: {"description": "认证失败"}}, + responses=error_responses(401), ) AVATAR_DIR = Path("uploads/avatars") # ── helpers ────────────────────────────────────────────────── -_ERROR_STATUS: dict[str, int] = { - "INVALID_CREDENTIALS": status.HTTP_401_UNAUTHORIZED, - "INVALID_TOKEN": status.HTTP_401_UNAUTHORIZED, - "TOKEN_REVOKED": status.HTTP_401_UNAUTHORIZED, - "TOKEN_EXPIRED": status.HTTP_401_UNAUTHORIZED, - "USER_NOT_FOUND": status.HTTP_404_NOT_FOUND, - "PHONE_EXISTS": status.HTTP_400_BAD_REQUEST, -} - - -def _map_auth_error(e: AuthError) -> HTTPException: - code = _ERROR_STATUS.get(e.code, status.HTTP_400_BAD_REQUEST) - return HTTPException(status_code=code, detail=e.message) - - def _user_response(user: User) -> UserResponse: raw_lang = getattr(user, "language_preference", "zh") lang = str(raw_lang).strip().lower() if isinstance(raw_lang, str) else "zh" @@ -88,10 +72,7 @@ def _user_response(user: User) -> UserResponse: def _check_terms(agreed: bool) -> None: if not agreed: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="请先阅读并同意用户协议和隐私政策", - ) + raise BadRequestError("请先阅读并同意用户协议和隐私政策") def _mock_sms_login_route_enabled() -> bool: @@ -109,23 +90,20 @@ def _mock_sms_login_route_enabled() -> bool: response_model=TokenResponse, status_code=status.HTTP_201_CREATED, summary="手机号密码注册", - responses={400: {"description": "手机号/邮箱已注册或参数错误"}}, + responses=error_responses(400, descriptions={400: "手机号/邮箱已注册或参数错误"}), ) async def register( request: RegisterRequest, service: AuthService = Depends(get_auth_service), ): _check_terms(request.agreed_to_terms) - try: - result = await service.register( - phone=request.phone, - password=request.password, - nickname=request.nickname, - email=request.email, - language=request.language, - ) - except AuthError as e: - raise _map_auth_error(e) + result = await service.register( + phone=request.phone, + password=request.password, + nickname=request.nickname, + email=request.email, + language=request.language, + ) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], @@ -136,20 +114,17 @@ async def register( "/login", response_model=TokenResponse, summary="手机号密码登录", - responses={401: {"description": "手机号或密码错误"}}, + responses=error_responses(401, descriptions={401: "手机号或密码错误"}), ) async def login( request: LoginRequest, service: AuthService = Depends(get_auth_service), ): _check_terms(request.agreed_to_terms) - try: - result = await service.login( - phone=request.phone, - password=request.password, - ) - except AuthError as e: - raise _map_auth_error(e) + result = await service.login( + phone=request.phone, + password=request.password, + ) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], @@ -160,18 +135,20 @@ async def login( "/refresh", response_model=TokenResponse, summary="刷新访问令牌", - responses={401: {"description": "刷新令牌无效/已撤销/已过期"}}, + responses=error_responses( + 401, + descriptions={ + 401: "刷新令牌无效/已过期;已轮换 token 被重复使用时会吊销全部会话(REFRESH_TOKEN_REUSE)" + }, + ), ) async def refresh_token( request: RefreshTokenRequest, service: AuthService = Depends(get_auth_service), ): - try: - result = await service.refresh_tokens( - refresh_token=request.refresh_token, - ) - except AuthError as e: - raise _map_auth_error(e) + result = await service.refresh_tokens( + refresh_token=request.refresh_token, + ) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], @@ -188,7 +165,7 @@ async def refresh_token( ) async def logout( request: RefreshTokenRequest, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: AuthService = Depends(get_auth_service), ): await service.logout(request.refresh_token, current_user.id) @@ -200,7 +177,7 @@ async def logout( summary="登出所有设备", ) async def logout_all_devices( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: AuthService = Depends(get_auth_service), ): count = await service.logout_all(current_user.id) @@ -216,7 +193,7 @@ async def logout_all_devices( summary="获取当前用户信息", ) async def get_me( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, ): return _user_response(current_user) @@ -228,13 +205,10 @@ async def get_me( ) async def update_nickname( request: UpdateNicknameRequest, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: AuthService = Depends(get_auth_service), ): - try: - user = await service.update_nickname(current_user.id, request.nickname) - except AuthError as e: - raise _map_auth_error(e) + user = await service.update_nickname(current_user.id, request.nickname) return _user_response(user) @@ -245,35 +219,14 @@ async def update_nickname( "/me/avatar", response_model=UserResponse, summary="上传头像", - responses={400: {"description": "文件类型或大小不符合要求"}}, + responses=error_responses(400, descriptions={400: "文件类型或大小不符合要求"}), ) async def upload_avatar( - file: UploadFile = File(...), - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: AuthService = Depends(get_auth_service), + file: UploadFile = File(...), ): - allowed_types = ["image/jpeg", "image/png", "image/webp"] - - if file.content_type not in allowed_types: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"不支持的文件类型。仅支持: {', '.join(allowed_types)}", - ) - file_content = await file.read() - - if not file_content or len(file_content) == 0: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="文件内容为空", - ) - - if len(file_content) > 5 * 1024 * 1024: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="文件大小超过5MB限制", - ) - logger.debug( "上传头像: user_id={} filename={} content_type={} size={}", current_user.id, @@ -281,95 +234,13 @@ async def upload_avatar( file.content_type, len(file_content), ) - - if not ( - (settings.tencent_cos_secret_id or "").strip() - and (settings.tencent_cos_secret_key or "").strip() - and (settings.tencent_cos_bucket or "").strip() - ): - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="头像存储服务未配置,请稍后再试", - ) - - try: - image_bytes = io.BytesIO(file_content) - image_bytes.seek(0) - - header = image_bytes.read(16) - image_bytes.seek(0) - - is_valid_image = False - if header.startswith(b"\xff\xd8\xff"): - is_valid_image = True - elif header.startswith(b"\x89PNG\r\n\x1a\n"): - is_valid_image = True - elif header.startswith(b"RIFF") and b"WEBP" in header[:12]: - is_valid_image = True - else: - logger.warning("无法识别的图片文件头") - logger.debug("无法识别的文件头 hex={}", header[:12].hex()) - - if not is_valid_image: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"无效的图片文件格式。文件头: {header[:12].hex()}", - ) - - image = Image.open(image_bytes) - logger.debug( - "头像解码: format={} mode={} size={}", - image.format, - image.mode, - image.size, - ) - - if image.mode != "RGB": - image = image.convert("RGB") - - width, height = image.size - size = min(width, height) - left = (width - size) // 2 - top = (height - size) // 2 - right = left + size - bottom = top + size - image = image.crop((left, top, right, bottom)) - - if size > 512: - image = image.resize((512, 512), Image.Resampling.LANCZOS) - - jpeg_buffer = io.BytesIO() - image.save(jpeg_buffer, format="JPEG", quality=85, optimize=True) - jpeg_bytes = jpeg_buffer.getvalue() - - cos_key = f"avatars/{current_user.id}.jpg" - old_url = current_user.avatar_url - old_key = extract_cos_object_key_if_owned(old_url) if old_url else None - if old_key and old_key != cos_key: - best_effort_delete_cos_object_for_url(old_url) - - from app.core.dependencies import get_object_storage - - storage = get_object_storage() - try: - avatar_url = storage.upload(cos_key, jpeg_bytes, "image/jpeg") - except Exception as exc: - logger.exception("COS 头像上传失败: {}", exc) - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="头像存储暂时不可用,请稍后再试", - ) from exc - - user = await service.update_avatar_url(current_user.id, avatar_url) - return _user_response(user) - except HTTPException: - raise - except Exception as e: - logger.exception("头像上传失败: {}", e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="处理图片失败,请重试", - ) from e + user = await service.upload_avatar( + current_user.id, + file_content, + file.content_type or "", + old_avatar_url=current_user.avatar_url, + ) + return _user_response(user) @router.get( @@ -388,63 +259,48 @@ async def list_avatar_presets(): "/me/avatar/preset", response_model=UserResponse, summary="使用预设头像", - responses={400: {"description": "无效的预设编号"}}, + responses=error_responses(400, descriptions={400: "无效的预设编号"}), ) async def set_avatar_preset( request: SetAvatarPresetRequest, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: AuthService = Depends(get_auth_service), ): filename = preset_filename_for_id(request.preset_id) if filename is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="无效的预设头像编号", - ) + raise BadRequestError("无效的预设头像编号") path = preset_file_path(filename) if path is None or not path.exists(): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="预设头像不可用", - ) + raise BadRequestError("预设头像不可用") best_effort_delete_cos_object_for_url(current_user.avatar_url) avatar_url = f"{avatar_url_for_preset_filename(filename)}?v={time.time_ns()}" - try: - user = await service.update_avatar_url(current_user.id, avatar_url) - except AuthError as e: - raise _map_auth_error(e) + user = await service.update_avatar_url(current_user.id, avatar_url) return _user_response(user) @router.get( "/avatar-presets/{filename}", summary="获取预设头像图片", - responses={404: {"description": "预设不存在"}}, + responses=error_responses(404, descriptions={404: "预设不存在"}), ) async def get_avatar_preset(filename: str): path = preset_file_path(filename) if path is None or not path.exists(): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="预设头像不存在", - ) + raise NotFoundError("预设头像不存在") return FileResponse(path, media_type="image/png") @router.get( "/avatars/{filename}", summary="获取头像图片", - responses={404: {"description": "头像不存在"}}, + responses=error_responses(404, descriptions={404: "头像不存在"}), ) async def get_avatar(filename: str): AVATAR_DIR.mkdir(parents=True, exist_ok=True) file_path = safe_avatar_upload_path(filename, AVATAR_DIR) if file_path is None or not file_path.exists(): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="头像不存在", - ) + raise NotFoundError("头像不存在") return FileResponse(file_path, media_type="image/jpeg") @@ -454,46 +310,35 @@ async def get_avatar(filename: str): @router.post( "/sms/send", summary="发送短信验证码", - responses={ - 400: {"description": "手机号格式或用途不合法"}, - 429: {"description": "发送过于频繁"}, - 503: {"description": "短信服务不可用"}, - }, + responses=error_responses( + 400, + 429, + 502, + 503, + descriptions={ + 400: "手机号格式或用途不合法", + 429: "发送过于频繁(RATE_LIMITED)", + 502: "短信服务商调用失败(PROVIDER_ERROR,可重试)", + 503: "短信服务未配置或不可用(SERVICE_UNAVAILABLE)", + }, + ), ) async def send_sms_code( request: SendSmsRequest, service: AuthService = Depends(get_auth_service), ): if not request.phone.isdigit(): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="手机号格式不正确", - ) + raise BadRequestError("手机号格式不正确") valid_purposes = ["register", "login", "reset_password", "change_phone"] if request.purpose not in valid_purposes: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"无效的用途,必须是: {', '.join(valid_purposes)}", - ) + raise BadRequestError(f"无效的用途,必须是: {', '.join(valid_purposes)}") - try: - success, message, expires_in = await service.send_sms_code( - phone=request.phone, - purpose=request.purpose, - ip_address=None, - ) - except AuthError as e: - raise _map_auth_error(e) - - if not success: - if "频繁" in message: - status_code = status.HTTP_429_TOO_MANY_REQUESTS - elif "配置" in message or "配置错误" in message or "授权失败" in message: - status_code = status.HTTP_503_SERVICE_UNAVAILABLE - else: - status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - raise HTTPException(status_code=status_code, detail=message) + _success, message, expires_in = await service.send_sms_code( + phone=request.phone, + purpose=request.purpose, + ip_address=None, + ) return {"message": message, "expires_in": expires_in} @@ -502,22 +347,19 @@ async def send_sms_code( "/login/sms", response_model=TokenResponse, summary="短信验证码登录(新用户自动注册)", - responses={400: {"description": "验证码错误"}}, + responses=error_responses(400, descriptions={400: "验证码错误"}), ) async def login_with_sms( request: SmsLoginRequest, service: AuthService = Depends(get_auth_service), ): _check_terms(request.agreed_to_terms) - try: - result = await service.login_with_sms( - phone=request.phone, - code=request.code, - nickname=request.nickname, - language=request.language, - ) - except AuthError as e: - raise _map_auth_error(e) + result = await service.login_with_sms( + phone=request.phone, + code=request.code, + nickname=request.nickname, + language=request.language, + ) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], @@ -529,26 +371,23 @@ async def login_with_sms( response_model=TokenResponse, summary="[评测] Mock 短信登录(跳过验证码)", description=( - "需 MOCK_SMS_LOGIN_ENABLED=1 且 APP_ENV 非 production。" + "需 config deploy.mock_sms_login_enabled=true 且 APP_ENV 非 production。" "供 Eval Web 等内网工具联调,勿在生产环境开启。" ), - responses={404: {"description": "未启用或生产环境已禁用"}}, + responses=error_responses(404, descriptions={404: "未启用或生产环境已禁用"}), ) async def mock_sms_login_route( request: MockSmsLoginRequest, service: AuthService = Depends(get_auth_service), ): if not _mock_sms_login_route_enabled(): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found") + raise NotFoundError("Not Found") _check_terms(request.agreed_to_terms) - try: - result = await service.mock_sms_login( - phone=request.phone, - nickname=request.nickname, - language=request.language, - ) - except AuthError as e: - raise _map_auth_error(e) + result = await service.mock_sms_login( + phone=request.phone, + nickname=request.nickname, + language=request.language, + ) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], @@ -560,24 +399,21 @@ async def mock_sms_login_route( response_model=TokenResponse, status_code=status.HTTP_201_CREATED, summary="短信验证码注册", - responses={400: {"description": "验证码错误或手机号/邮箱已注册"}}, + responses=error_responses(400, descriptions={400: "验证码错误或手机号/邮箱已注册"}), ) async def register_with_sms( request: SmsRegisterRequest, service: AuthService = Depends(get_auth_service), ): _check_terms(request.agreed_to_terms) - try: - result = await service.register_with_sms( - phone=request.phone, - code=request.code, - password=request.password, - nickname=request.nickname, - email=request.email, - language=request.language, - ) - except AuthError as e: - raise _map_auth_error(e) + result = await service.register_with_sms( + phone=request.phone, + code=request.code, + password=request.password, + nickname=request.nickname, + email=request.email, + language=request.language, + ) return TokenResponse( access_token=result["access_token"], refresh_token=result["refresh_token"], @@ -590,44 +426,42 @@ async def register_with_sms( @router.post( "/password/reset", summary="通过短信验证码重置密码", - responses={ - 400: {"description": "验证码错误"}, - 404: {"description": "用户不存在"}, - }, + responses=error_responses( + 400, + 404, + descriptions={ + 400: "验证码错误", + 404: "用户不存在", + }, + ), ) async def reset_password( request: ResetPasswordRequest, service: AuthService = Depends(get_auth_service), ): - try: - await service.reset_password( - phone=request.phone, - code=request.code, - new_password=request.new_password, - ) - except AuthError as e: - raise _map_auth_error(e) + await service.reset_password( + phone=request.phone, + code=request.code, + new_password=request.new_password, + ) return {"message": "密码重置成功"} @router.post( "/password/change", summary="修改密码(需旧密码)", - responses={400: {"description": "旧密码错误"}}, + responses=error_responses(400, descriptions={400: "旧密码错误"}), ) async def change_password( request: ChangePasswordRequest, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: AuthService = Depends(get_auth_service), ): - try: - await service.change_password( - user_id=current_user.id, - old_password=request.old_password, - new_password=request.new_password, - ) - except AuthError as e: - raise _map_auth_error(e) + await service.change_password( + user_id=current_user.id, + old_password=request.old_password, + new_password=request.new_password, + ) return {"message": "密码修改成功"} @@ -635,19 +469,16 @@ async def change_password( "/phone/change", response_model=UserResponse, summary="更换手机号", - responses={400: {"description": "验证码错误或手机号已被占用"}}, + responses=error_responses(400, descriptions={400: "验证码错误或手机号已被占用"}), ) async def change_phone( request: ChangePhoneRequest, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: AuthService = Depends(get_auth_service), ): - try: - user = await service.change_phone( - user_id=current_user.id, - new_phone=request.new_phone, - code=request.code, - ) - except AuthError as e: - raise _map_auth_error(e) + user = await service.change_phone( + user_id=current_user.id, + new_phone=request.new_phone, + code=request.code, + ) return _user_response(user) diff --git a/api/app/features/auth/schemas.py b/api/app/features/auth/schemas.py index 4e5716b..3457c45 100644 --- a/api/app/features/auth/schemas.py +++ b/api/app/features/auth/schemas.py @@ -6,11 +6,11 @@ LanguagePreference = Literal["zh", "en"] class RegisterRequest(BaseModel): - phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") - password: str = Field(..., min_length=6, description="密码(至少6位)") - nickname: str = Field(..., min_length=1, max_length=50, description="昵称") + phone: str = Field(min_length=11, max_length=11, description="手机号(11位)") + password: str = Field(min_length=6, description="密码(至少6位)") + nickname: str = Field(min_length=1, max_length=50, description="昵称") email: Optional[str] = Field(None, description="邮箱(可选)") - agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") + agreed_to_terms: bool = Field(description="是否同意用户协议和隐私政策") language: Optional[LanguagePreference] = Field( None, description="device language at signup; only used when creating a new user", @@ -18,9 +18,9 @@ class RegisterRequest(BaseModel): class LoginRequest(BaseModel): - phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") - password: str = Field(..., min_length=1, description="密码") - agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") + phone: str = Field(min_length=11, max_length=11, description="手机号(11位)") + password: str = Field(min_length=1, description="密码") + agreed_to_terms: bool = Field(description="是否同意用户协议和隐私政策") class TokenResponse(BaseModel): @@ -30,7 +30,7 @@ class TokenResponse(BaseModel): class RefreshTokenRequest(BaseModel): - refresh_token: str = Field(..., description="刷新令牌") + refresh_token: str = Field(description="刷新令牌") class UserResponse(BaseModel): @@ -45,7 +45,7 @@ class UserResponse(BaseModel): class SendSmsRequest(BaseModel): - phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") + phone: str = Field(min_length=11, max_length=11, description="手机号(11位)") purpose: str = Field( ..., description="用途:register/login/reset_password/change_phone" ) @@ -58,9 +58,9 @@ class SendSmsResponse(BaseModel): class SmsLoginRequest(BaseModel): - phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") - code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") - agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") + phone: str = Field(min_length=11, max_length=11, description="手机号(11位)") + code: str = Field(min_length=6, max_length=6, description="验证码(6位)") + agreed_to_terms: bool = Field(description="是否同意用户协议和隐私政策") nickname: Optional[str] = Field( None, max_length=50, description="昵称(注册时必填,登录时可选)" ) @@ -73,8 +73,8 @@ class SmsLoginRequest(BaseModel): class MockSmsLoginRequest(BaseModel): """开发/评测专用:与 MOCK_SMS_LOGIN_ENABLED 联用,跳过短信校验。""" - phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") - agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") + phone: str = Field(min_length=11, max_length=11, description="手机号(11位)") + agreed_to_terms: bool = Field(description="是否同意用户协议和隐私政策") nickname: Optional[str] = Field( None, max_length=50, description="新用户昵称(可选)" ) @@ -85,12 +85,12 @@ class MockSmsLoginRequest(BaseModel): class SmsRegisterRequest(BaseModel): - phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") - code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") - password: str = Field(..., min_length=6, description="密码(至少6位)") - nickname: str = Field(..., min_length=1, max_length=50, description="昵称") + phone: str = Field(min_length=11, max_length=11, description="手机号(11位)") + code: str = Field(min_length=6, max_length=6, description="验证码(6位)") + password: str = Field(min_length=6, description="密码(至少6位)") + nickname: str = Field(min_length=1, max_length=50, description="昵称") email: Optional[str] = Field(None, description="邮箱(可选)") - agreed_to_terms: bool = Field(..., description="是否同意用户协议和隐私政策") + agreed_to_terms: bool = Field(description="是否同意用户协议和隐私政策") language: Optional[LanguagePreference] = Field( None, description="device language at signup; only used when creating a new user", @@ -98,21 +98,21 @@ class SmsRegisterRequest(BaseModel): class ResetPasswordRequest(BaseModel): - phone: str = Field(..., min_length=11, max_length=11, description="手机号(11位)") - code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") - new_password: str = Field(..., min_length=6, description="新密码(至少6位)") + phone: str = Field(min_length=11, max_length=11, description="手机号(11位)") + code: str = Field(min_length=6, max_length=6, description="验证码(6位)") + new_password: str = Field(min_length=6, description="新密码(至少6位)") class ChangePasswordRequest(BaseModel): - old_password: str = Field(..., min_length=1, description="旧密码") - new_password: str = Field(..., min_length=6, description="新密码(至少6位)") + old_password: str = Field(min_length=1, description="旧密码") + new_password: str = Field(min_length=6, description="新密码(至少6位)") class ChangePhoneRequest(BaseModel): new_phone: str = Field( ..., min_length=11, max_length=11, description="新手机号(11位)" ) - code: str = Field(..., min_length=6, max_length=6, description="验证码(6位)") + code: str = Field(min_length=6, max_length=6, description="验证码(6位)") class UpdateNicknameRequest(BaseModel): diff --git a/api/app/features/auth/service.py b/api/app/features/auth/service.py index 83eb9b7..bc4390e 100644 --- a/api/app/features/auth/service.py +++ b/api/app/features/auth/service.py @@ -1,11 +1,28 @@ +import asyncio +import io import random import secrets import uuid from datetime import datetime, timedelta, timezone +from PIL import Image +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import utc_now +from app.core.config import settings +from app.core.cos_url_keys import ( + best_effort_delete_cos_object_for_url, + extract_cos_object_key_if_owned, +) +from app.core.db import transactional, transactional_nested, utc_now +from app.core.errors import ( + AppError, + BadRequestError, + ProviderError, + RateLimitedError, + ServiceUnavailableError, +) +from app.core.logging import get_logger from app.core.security import ( create_access_token, get_token_expires_at, @@ -16,13 +33,29 @@ from app.core.security import ( create_refresh_token as generate_refresh_token_str, ) from app.features.auth import repo +from app.features.auth.integrity import is_user_phone_unique_violation, user_integrity_auth_code from app.features.auth.models import RefreshToken, SmsVerificationCode from app.features.user.models import User from app.ports.sms import SmsSender +from app.ports.storage import ObjectStorage + +logger = get_logger(__name__) CODE_LENGTH = 6 CODE_EXPIRE_MINUTES = 5 RATE_LIMIT_SECONDS = 60 +_SMS_CONSUME_FALLBACK = "验证码不存在或已使用" + + +def _sms_is_configured() -> bool: + return bool( + (settings.tencent_secret_id or "").strip() + and (settings.tencent_secret_key or "").strip() + and (settings.tencent_sms_sdk_app_id or "").strip() + and (settings.tencent_sms_sign_name or "").strip() + and (settings.tencent_sms_template_id or "").strip() + ) + _VALID_LANGUAGES = {"zh", "en"} @@ -35,41 +68,165 @@ def _normalize_language(lang: str | None) -> str: return s if s in _VALID_LANGUAGES else "zh" -class AuthError(Exception): +def _as_utc(dt: datetime) -> datetime: + """Normalize DB datetimes for safe comparison (sqlite may return naive).""" + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt.astimezone(timezone.utc) + + +_AUTH_CODE_MAP: dict[str, tuple[int, str]] = { + "INVALID_CREDENTIALS": (401, "AUTHENTICATION_FAILED"), + "INVALID_TOKEN": (401, "AUTHENTICATION_FAILED"), + "TOKEN_REVOKED": (401, "AUTHENTICATION_FAILED"), + "TOKEN_EXPIRED": (401, "AUTHENTICATION_FAILED"), + "REFRESH_TOKEN_REUSE": (401, "REFRESH_TOKEN_REUSE"), + "USER_NOT_FOUND": (404, "NOT_FOUND"), + "PHONE_EXISTS": (400, "PHONE_EXISTS"), + "EMAIL_EXISTS": (400, "EMAIL_EXISTS"), + "PHONE_TAKEN": (409, "PHONE_TAKEN"), + "INVALID_SMS_CODE": (400, "INVALID_SMS_CODE"), + "WRONG_PASSWORD": (400, "WRONG_PASSWORD"), + "AUTH_ERROR": (400, "BAD_REQUEST"), +} + + +class AuthError(AppError): def __init__(self, message: str, code: str = "AUTH_ERROR"): - self.message = message + status_code, error_code = _AUTH_CODE_MAP.get(code, (400, code)) + super().__init__(message, status_code=status_code, error_code=error_code) self.code = code - super().__init__(message) + + +def _raise_auth_error_from_user_integrity( + exc: IntegrityError, + *, + phone_conflict: str, +) -> None: + code = user_integrity_auth_code(exc, phone_conflict=phone_conflict) + if code == "PHONE_EXISTS": + raise AuthError("该手机号已被注册", "PHONE_EXISTS") from exc + if code == "EMAIL_EXISTS": + raise AuthError("该邮箱已被注册", "EMAIL_EXISTS") from exc + if code == "PHONE_TAKEN": + raise AuthError("该手机号已被其他用户使用", "PHONE_TAKEN") from exc + raise exc + + +async def _create_user_with_integrity_check( + db: AsyncSession, + user: User, + *, + phone_conflict: str, +) -> None: + await repo.create_user(user, db) + try: + await db.flush() + except IntegrityError as exc: + _raise_auth_error_from_user_integrity(exc, phone_conflict=phone_conflict) + + +def _public_tokens(issued: dict) -> dict: + """Strip internal fields before returning tokens to callers.""" + return { + "access_token": issued["access_token"], + "refresh_token": issued["refresh_token"], + } class AuthService: - def __init__(self, db: AsyncSession, sms: SmsSender): + def __init__( + self, + db: AsyncSession, + sms: SmsSender, + *, + object_storage: ObjectStorage | None = None, + ): self._db = db self._sms = sms + self._object_storage = object_storage # ── private helpers ────────────────────────────────────── def _generate_code(self) -> str: return "".join(str(random.randint(0, 9)) for _ in range(CODE_LENGTH)) - async def _verify_sms_code( + async def _check_sms_code( self, phone: str, code: str, purpose: str - ) -> tuple[bool, str]: - """Verify SMS code (DB check + mark used). Returns (success, message).""" + ) -> tuple[SmsVerificationCode | None, str]: + """Validate SMS code without consuming it. Returns (record, message). + + UX pre-check only; authoritative validation is ``try_consume_verification_code`` + inside a transaction. + """ record = await repo.get_latest_unused_code(phone, purpose, self._db) if not record: - return False, "验证码不存在或已使用" + return None, _SMS_CONSUME_FALLBACK now = utc_now() if now > record.expires_at: - record.is_expired = True - await self._db.commit() - return False, "验证码已过期" + async with transactional(self._db): + record.is_expired = True + return None, "验证码已过期" if record.code != code: - return False, "验证码错误" - record.is_used = True - record.verified_at = now - await self._db.commit() - return True, "验证成功" + return None, "验证码错误" + return record, "验证成功" + + async def _precheck_sms_code( + self, phone: str, code: str, purpose: str + ) -> str | None: + """UX pre-check: fast-fail without consuming. None means likely valid.""" + record, message = await self._check_sms_code(phone, code, purpose) + if record is None: + return message + return None + + async def _precheck_sms_code_for_purposes( + self, phone: str, code: str, purposes: tuple[str, ...] + ) -> str | None: + """Login flow: try each purpose until one pre-check passes.""" + last_message = _SMS_CONSUME_FALLBACK + for purpose in purposes: + record, message = await self._check_sms_code(phone, code, purpose) + if record is not None: + return None + last_message = message + return last_message + + async def _sms_invalid_message_after_consume_failure( + self, + phone: str, + code: str, + purposes: tuple[str, ...], + *, + fallback: str = _SMS_CONSUME_FALLBACK, + ) -> str: + """Re-read code state after atomic consume failed (race/expiry/concurrency).""" + for purpose in purposes: + record, message = await self._check_sms_code(phone, code, purpose) + if record is None: + return message + return fallback + + async def _consume_sms_code_or_raise( + self, + phone: str, + code: str, + purpose: str, + *, + purposes: tuple[str, ...] | None = None, + ) -> SmsVerificationCode: + """Atomically consume SMS code inside ``transactional()``; raise on failure.""" + purposes_to_try = purposes or (purpose,) + for p in purposes_to_try: + consumed = await repo.try_consume_verification_code( + phone, code, p, self._db + ) + if consumed is not None: + return consumed + message = await self._sms_invalid_message_after_consume_failure( + phone, code, purposes_to_try + ) + raise AuthError(message, "INVALID_SMS_CODE") async def send_sms_code( self, @@ -77,7 +234,9 @@ class AuthService: purpose: str, ip_address: str | None = None, ) -> tuple[bool, str, int]: - """Send SMS verification code. Returns (success, message, expires_in_seconds).""" + """Send SMS verification code. Returns (True, message, expires_in_seconds) on success.""" + if not _sms_is_configured(): + raise ServiceUnavailableError("短信服务未配置,请稍后再试") if purpose == "register": if await repo.get_user_by_phone(phone, self._db): raise AuthError("该手机号已被注册", "PHONE_EXISTS") @@ -90,10 +249,8 @@ class AuthService: elapsed = (now - recent.created_at).total_seconds() if elapsed < RATE_LIMIT_SECONDS: remaining = int(RATE_LIMIT_SECONDS - elapsed) - return False, f"发送过于频繁,请{remaining}秒后再试", 0 + raise RateLimitedError(f"发送过于频繁,请{remaining}秒后再试") code = self._generate_code() - if not self._sms.send_verification_code(phone, code): - return False, "短信发送失败,请稍后重试", 0 expires_at = utc_now() + timedelta(minutes=CODE_EXPIRE_MINUTES) record = SmsVerificationCode( id=str(uuid.uuid4()), @@ -103,8 +260,13 @@ class AuthService: expires_at=expires_at, ip_address=ip_address, ) - await repo.create_verification_code(record, self._db) - await self._db.commit() + async with transactional(self._db): + await repo.create_verification_code(record, self._db) + + if not self._sms.send_verification_code(phone, code): + async with transactional(self._db): + await repo.mark_verification_code_expired(record.id, self._db) + raise ProviderError("短信发送失败,请稍后重试") return True, "验证码已发送", CODE_EXPIRE_MINUTES * 60 async def _issue_tokens(self, user_id: str, device_info: str = "") -> dict: @@ -120,7 +282,41 @@ class AuthService: ) await repo.create_refresh_token(token, self._db) access = create_access_token(data={"sub": user_id}) - return {"access_token": access, "refresh_token": refresh_str} + return { + "access_token": access, + "refresh_token": refresh_str, + "refresh_token_id": token.id, + } + + async def _try_idempotent_refresh_within_grace( + self, + token_record: RefreshToken, + ) -> dict | None: + """Grace-window retry: return new access + existing replacement refresh.""" + grace = settings.refresh_token_reuse_grace_seconds + if grace <= 0: + return None + if not token_record.replaced_by_token_id or token_record.rotated_at is None: + return None + now = utc_now() + rotated_at = _as_utc(token_record.rotated_at) + if rotated_at + timedelta(seconds=grace) <= now: + return None + replacement = await repo.get_refresh_token_by_id( + token_record.replaced_by_token_id, self._db + ) + if replacement is None or replacement.is_revoked: + return None + if _as_utc(replacement.expires_at) < now: + return None + user = await repo.get_user_by_id(replacement.user_id, self._db) + if user is None: + return None + access = create_access_token(data={"sub": user.id}) + return { + "access_token": access, + "refresh_token": replacement.token, + } # ── public API ─────────────────────────────────────────── @@ -150,11 +346,13 @@ class AuthService: created_at=datetime.now(timezone.utc), language_preference=_normalize_language(language), ) - await repo.create_user(user, self._db) - tokens = await self._issue_tokens(user_id) - await self._db.commit() + async with transactional(self._db): + await _create_user_with_integrity_check( + self._db, user, phone_conflict="PHONE_EXISTS" + ) + tokens = await self._issue_tokens(user_id) await self._db.refresh(user) - return {"user": user, **tokens} + return {"user": user, **_public_tokens(tokens)} async def login( self, @@ -167,50 +365,104 @@ class AuthService: if not user or not verify_password(password, user.password_hash): raise AuthError("手机号或密码错误", "INVALID_CREDENTIALS") - tokens = await self._issue_tokens(user.id, device_info) - await self._db.commit() - return {"user": user, **tokens} + async with transactional(self._db): + tokens = await self._issue_tokens(user.id, device_info) + return {"user": user, **_public_tokens(tokens)} + + async def _revoke_all_active_tokens_in_session(self, user_id: str) -> int: + """Revoke all active refresh tokens on the current session (no commit).""" + tokens = await repo.get_active_tokens_for_user(user_id, self._db) + for token in tokens: + token.is_revoked = True + return len(tokens) async def refresh_tokens( self, refresh_token: str, device_info: str = "", ) -> dict: - """Refresh access token. Returns {access_token, refresh_token}.""" - token_record = await repo.get_refresh_token_by_token(refresh_token, self._db) - if not token_record: - raise AuthError("无效的刷新令牌", "INVALID_TOKEN") + """Rotate refresh token and issue a new access token pair.""" + reuse_detected = False + async with transactional(self._db): + consumed = await repo.try_consume_refresh_token(refresh_token, self._db) + if consumed is not None: + user = await repo.get_user_by_id(consumed.user_id, self._db) + if not user: + raise AuthError("用户不存在", "USER_NOT_FOUND") + issued = await self._issue_tokens( + consumed.user_id, + device_info or (consumed.device_info or ""), + ) + await repo.link_refresh_rotation( + consumed.id, + issued["refresh_token_id"], + utc_now(), + self._db, + ) + return _public_tokens(issued) - if token_record.is_revoked: - raise AuthError("刷新令牌已撤销", "TOKEN_REVOKED") + token_record = await repo.get_refresh_token_by_token( + refresh_token, self._db + ) + if not token_record: + raise AuthError("无效的刷新令牌", "INVALID_TOKEN") + if token_record.is_revoked: + idempotent = await self._try_idempotent_refresh_within_grace( + token_record + ) + if idempotent is None and not token_record.replaced_by_token_id: + # Concurrent refresh may observe revoke before lineage commits. + token_record = ( + await repo.get_refresh_token_by_token( + refresh_token, self._db + ) + or token_record + ) + idempotent = await self._try_idempotent_refresh_within_grace( + token_record + ) + if idempotent is not None: + return idempotent + grace = settings.refresh_token_reuse_grace_seconds + rotated_at = token_record.rotated_at + still_in_grace = ( + grace > 0 + and rotated_at is not None + and _as_utc(rotated_at) + timedelta(seconds=grace) > utc_now() + ) + if still_in_grace: + raise AuthError("无效的刷新令牌", "INVALID_TOKEN") + if ( + grace > 0 + and rotated_at is None + and not token_record.replaced_by_token_id + ): + # Revoke visible but lineage not committed yet (concurrent rotation). + raise AuthError("无效的刷新令牌", "INVALID_TOKEN") + logger.bind(user_id=token_record.user_id).warning( + "Refresh token reuse detected (grace expired or no lineage)" + ) + await self._revoke_all_active_tokens_in_session(token_record.user_id) + reuse_detected = True + elif _as_utc(token_record.expires_at) < utc_now(): + raise AuthError("刷新令牌已过期", "TOKEN_EXPIRED") + else: + raise AuthError("无效的刷新令牌", "INVALID_TOKEN") - if token_record.expires_at < datetime.now(timezone.utc): - raise AuthError("刷新令牌已过期", "TOKEN_EXPIRED") - - user = await repo.get_user_by_id(token_record.user_id, self._db) - if not user: - raise AuthError("用户不存在", "USER_NOT_FOUND") - - access_token = create_access_token(data={"sub": user.id}) - return { - "access_token": access_token, - "refresh_token": refresh_token, - } + if reuse_detected: + raise AuthError("刷新令牌已失效,请重新登录", "REFRESH_TOKEN_REUSE") async def logout(self, refresh_token: str, user_id: str) -> None: """Revoke a refresh token owned by the given user.""" token_record = await repo.get_refresh_token_by_token(refresh_token, self._db) if token_record and token_record.user_id == user_id: - token_record.is_revoked = True - await self._db.commit() + async with transactional(self._db): + token_record.is_revoked = True async def logout_all(self, user_id: str) -> int: """Revoke all refresh tokens for user. Returns count revoked.""" - tokens = await repo.get_active_tokens_for_user(user_id, self._db) - for token in tokens: - token.is_revoked = True - await self._db.commit() - return len(tokens) + async with transactional(self._db): + return await self._revoke_all_active_tokens_in_session(user_id) async def login_with_sms( self, @@ -221,18 +473,15 @@ class AuthService: language: str | None = None, ) -> dict: """SMS login (auto-register if new). Returns {user, access_token, refresh_token, is_new_user}.""" - success = False - message = "" - for purpose in ("login", "register"): - success, message = await self._verify_sms_code(phone, code, purpose) - if success: - break - - if not success: - raise AuthError(message, "INVALID_SMS_CODE") + precheck_error = await self._precheck_sms_code_for_purposes( + phone, code, ("login", "register") + ) + if precheck_error is not None: + raise AuthError(precheck_error, "INVALID_SMS_CODE") return await self._sms_login_after_code_verified( phone, + code=code, device_info=device_info, nickname=nickname, language=language, @@ -242,13 +491,15 @@ class AuthService: self, phone: str, *, + code: str | None = None, device_info: str = "", nickname: str | None = None, language: str | None = None, ) -> dict: - """SMS 已校验通过后:查找或创建用户并签发令牌。 + """SMS 校验通过后:在同一事务内原子消耗验证码、创建用户(如需)并签发令牌。 ``language`` 仅在「新用户」分支下写入;命中已有用户时不覆盖偏好。 + mock 路由传入 ``code=None`` 跳过验证码消耗。 """ user = await repo.get_user_by_phone(phone, self._db) is_new_user = user is None @@ -264,14 +515,33 @@ class AuthService: created_at=datetime.now(timezone.utc), language_preference=_normalize_language(language), ) - await repo.create_user(user, self._db) - tokens = await self._issue_tokens(user.id, device_info) - await self._db.commit() + async with transactional(self._db): + if code is not None: + await self._consume_sms_code_or_raise( + phone, code, "login", purposes=("login", "register") + ) + if is_new_user: + try: + async with transactional_nested(self._db): + await repo.create_user(user, self._db) + await self._db.flush() + except IntegrityError as exc: + if is_user_phone_unique_violation(exc): + existing = await repo.get_user_by_phone(phone, self._db) + if existing is None: + raise + user = existing + is_new_user = False + else: + _raise_auth_error_from_user_integrity( + exc, phone_conflict="PHONE_EXISTS" + ) + tokens = await self._issue_tokens(user.id, device_info) if is_new_user: await self._db.refresh(user) - return {"user": user, "is_new_user": is_new_user, **tokens} + return {"user": user, "is_new_user": is_new_user, **_public_tokens(tokens)} async def mock_sms_login( self, @@ -299,9 +569,9 @@ class AuthService: language: str | None = None, ) -> dict: """SMS register. Returns {user, access_token, refresh_token}.""" - success, message = await self._verify_sms_code(phone, code, "register") - if not success: - raise AuthError(message, "INVALID_SMS_CODE") + precheck_error = await self._precheck_sms_code(phone, code, "register") + if precheck_error is not None: + raise AuthError(precheck_error, "INVALID_SMS_CODE") if await repo.get_user_by_phone(phone, self._db): raise AuthError("该手机号已被注册", "PHONE_EXISTS") @@ -309,6 +579,27 @@ class AuthService: if email and await repo.get_user_by_email(email, self._db): raise AuthError("该邮箱已被注册", "EMAIL_EXISTS") + return await self._register_after_sms_verified( + phone=phone, + code=code, + password=password, + nickname=nickname, + email=email, + device_info=device_info, + language=language, + ) + + async def _register_after_sms_verified( + self, + *, + phone: str, + code: str, + password: str, + nickname: str, + email: str | None, + device_info: str, + language: str | None, + ) -> dict: user_id = str(uuid.uuid4()) user = User( id=user_id, @@ -320,11 +611,14 @@ class AuthService: created_at=datetime.now(timezone.utc), language_preference=_normalize_language(language), ) - await repo.create_user(user, self._db) - tokens = await self._issue_tokens(user_id, device_info) - await self._db.commit() + async with transactional(self._db): + await self._consume_sms_code_or_raise(phone, code, "register") + await _create_user_with_integrity_check( + self._db, user, phone_conflict="PHONE_EXISTS" + ) + tokens = await self._issue_tokens(user_id, device_info) await self._db.refresh(user) - return {"user": user, **tokens} + return {"user": user, **_public_tokens(tokens)} async def reset_password( self, @@ -333,16 +627,17 @@ class AuthService: new_password: str, ) -> None: """Reset password via SMS code.""" - success, message = await self._verify_sms_code(phone, code, "reset_password") - if not success: - raise AuthError(message, "INVALID_SMS_CODE") + precheck_error = await self._precheck_sms_code(phone, code, "reset_password") + if precheck_error is not None: + raise AuthError(precheck_error, "INVALID_SMS_CODE") user = await repo.get_user_by_phone(phone, self._db) if not user: raise AuthError("用户不存在", "USER_NOT_FOUND") - user.password_hash = hash_password(new_password) - await self._db.commit() + async with transactional(self._db): + await self._consume_sms_code_or_raise(phone, code, "reset_password") + user.password_hash = hash_password(new_password) async def change_password( self, @@ -358,8 +653,8 @@ class AuthService: if not verify_password(old_password, user.password_hash): raise AuthError("旧密码错误", "WRONG_PASSWORD") - user.password_hash = hash_password(new_password) - await self._db.commit() + async with transactional(self._db): + user.password_hash = hash_password(new_password) async def change_phone( self, @@ -368,9 +663,9 @@ class AuthService: code: str, ) -> User: """Change phone number via SMS code. Returns updated user.""" - success, message = await self._verify_sms_code(new_phone, code, "change_phone") - if not success: - raise AuthError(message, "INVALID_SMS_CODE") + precheck_error = await self._precheck_sms_code(new_phone, code, "change_phone") + if precheck_error is not None: + raise AuthError(precheck_error, "INVALID_SMS_CODE") existing = await repo.get_user_by_phone(new_phone, self._db) if existing and existing.id != user_id: @@ -380,8 +675,13 @@ class AuthService: if not user: raise AuthError("用户不存在", "USER_NOT_FOUND") - user.phone = new_phone - await self._db.commit() + async with transactional(self._db): + await self._consume_sms_code_or_raise(new_phone, code, "change_phone") + user.phone = new_phone + try: + await self._db.flush() + except IntegrityError as exc: + _raise_auth_error_from_user_integrity(exc, phone_conflict="PHONE_TAKEN") await self._db.refresh(user) return user @@ -391,8 +691,8 @@ class AuthService: if not user: raise AuthError("用户不存在", "USER_NOT_FOUND") - user.nickname = nickname.strip() - await self._db.commit() + async with transactional(self._db): + user.nickname = nickname.strip() await self._db.refresh(user) return user @@ -401,7 +701,111 @@ class AuthService: user = await repo.get_user_by_id(user_id, self._db) if not user: raise AuthError("用户不存在", "USER_NOT_FOUND") - user.avatar_url = avatar_url - await self._db.commit() + async with transactional(self._db): + user.avatar_url = avatar_url await self._db.refresh(user) return user + + async def upload_avatar( + self, + user_id: str, + file_content: bytes, + content_type: str, + *, + old_avatar_url: str | None, + ) -> User: + """Validate, process, upload avatar to COS, and persist URL.""" + allowed_types = ["image/jpeg", "image/png", "image/webp"] + if content_type not in allowed_types: + raise BadRequestError( + f"不支持的文件类型。仅支持: {', '.join(allowed_types)}" + ) + if not file_content: + raise BadRequestError("文件内容为空") + if len(file_content) > 5 * 1024 * 1024: + raise BadRequestError("文件大小超过5MB限制") + if not ( + (settings.tencent_secret_id or "").strip() + and (settings.tencent_secret_key or "").strip() + and (settings.tencent_cos_bucket or "").strip() + ): + raise ServiceUnavailableError("头像存储服务未配置,请稍后再试") + + jpeg_bytes = await _process_avatar_jpeg_async(file_content) + cos_key = f"avatars/{user_id}.jpg" + old_key = extract_cos_object_key_if_owned(old_avatar_url) if old_avatar_url else None + + if not self._object_storage: + raise ServiceUnavailableError("头像存储服务未配置,请稍后再试") + + try: + avatar_url = await asyncio.to_thread( + self._object_storage.upload, cos_key, jpeg_bytes, "image/jpeg" + ) + except Exception as exc: + from app.core.logging import get_logger + + get_logger(__name__).exception("COS 头像上传失败: {}", exc) + raise ServiceUnavailableError("头像存储暂时不可用,请稍后再试") from exc + + try: + user = await self.update_avatar_url(user_id, avatar_url) + except Exception: + try: + await asyncio.to_thread(self._object_storage.delete, cos_key) + except Exception as cleanup_exc: + from app.core.logging import get_logger + + get_logger(__name__).warning( + "头像 DB 写入失败后清理 COS 对象失败: key={} err={}", + cos_key, + cleanup_exc, + ) + raise + + if old_key and old_key != cos_key: + best_effort_delete_cos_object_for_url(old_avatar_url) + return user + + +def _is_valid_image_header(header: bytes) -> bool: + if header.startswith(b"\xff\xd8\xff"): + return True + if header.startswith(b"\x89PNG\r\n\x1a\n"): + return True + if header.startswith(b"RIFF") and b"WEBP" in header[:12]: + return True + return False + + +async def _process_avatar_jpeg_async(file_content: bytes) -> bytes: + try: + return await asyncio.to_thread(_process_avatar_jpeg, file_content) + except BadRequestError: + raise + except Exception as exc: + raise BadRequestError("无效的图片文件") from exc + + +def _process_avatar_jpeg(file_content: bytes) -> bytes: + image_bytes = io.BytesIO(file_content) + header = image_bytes.read(16) + image_bytes.seek(0) + if not _is_valid_image_header(header): + raise BadRequestError(f"无效的图片文件格式。文件头: {header[:12].hex()}") + + image = Image.open(image_bytes) + if image.mode != "RGB": + image = image.convert("RGB") + + width, height = image.size + size = min(width, height) + left = (width - size) // 2 + top = (height - size) // 2 + image = image.crop((left, top, left + size, top + size)) + if size > 512: + image = image.resize((512, 512), Image.Resampling.LANCZOS) + + jpeg_buffer = io.BytesIO() + image.save(jpeg_buffer, format="JPEG", quality=85, optimize=True) + return jpeg_buffer.getvalue() diff --git a/api/app/features/content/router.py b/api/app/features/content/router.py index c4d1f11..242e3df 100644 --- a/api/app/features/content/router.py +++ b/api/app/features/content/router.py @@ -6,11 +6,12 @@ from pathlib import Path from typing import List from fastapi import APIRouter -from fastapi.responses import HTMLResponse +from fastapi.responses import FileResponse +from app.core.openapi import error_responses from app.features.content.schemas import FAQResponse -router = APIRouter(tags=["content"], responses={404: {"description": "资源不存在"}}) +router = APIRouter(tags=["content"], responses=error_responses(404)) _STATIC_DIR = Path(__file__).resolve().parent.parent.parent.parent / "static" @@ -135,302 +136,19 @@ async def get_faqs(): return FAQS -@router.get("/api/legal/terms", response_class=HTMLResponse) +@router.get("/api/legal/terms") async def get_terms(): """用户协议页面""" - html_content = """ - - - - - - 用户协议 - 岁月时书 - - - -
-

用户协议

-
岁月时书用户服务协议
- -
- 服务提供方:上海华嘎科技有限公司
- 产品名称:岁月时书
- 生效日期:2026年1月27日 -
- -

一、协议的接受

-

欢迎使用岁月时书(以下简称"本服务")。本协议是您与上海华嘎科技有限公司(以下简称"我们"或"公司")之间关于使用本服务的法律协议。

-

请您仔细阅读本协议的全部内容,特别是涉及免除或限制责任的条款、法律适用和争议解决条款。当您点击"同意"按钮或实际使用本服务时,即表示您已充分阅读、理解并同意接受本协议的全部内容。

- -

二、服务说明

-

1. 岁月时书是一款帮助用户记录和整理人生回忆的智能应用服务。

-

2. 我们有权根据业务发展需要调整、变更或终止部分或全部服务内容。

-

3. 我们保留随时修改或中断服务而不需通知用户的权利。

- -

三、用户账户

-

1. 您需要注册账户才能使用本服务的部分功能。注册时,您应当提供真实、准确、完整的个人信息。

-

2. 您有责任维护账户信息的安全性和准确性,并对账户下的所有活动负责。

-

3. 如发现账户被盗用或存在安全漏洞,请立即通知我们。

- -

四、用户行为规范

-

1. 您在使用本服务时,应当遵守国家法律法规,不得利用本服务从事违法违规活动。

-

2. 您不得上传、发布、传播含有以下内容的信息:

-

- (1)违反国家法律法规、危害国家安全、破坏社会稳定的内容;
- (2)侵犯他人知识产权、隐私权、名誉权等合法权益的内容;
- (3)色情、暴力、赌博、诈骗等不良信息;
- (4)其他违反公序良俗的内容。 -

-

3. 您应当尊重他人的合法权益,不得恶意干扰、破坏本服务的正常运行。

- -

五、知识产权

-

1. 本服务的所有知识产权,包括但不限于商标、专利、著作权等,均归我们所有。

-

2. 您在使用本服务过程中产生的内容(包括但不限于文字、图片、音频等),其知识产权归您所有,但您授予我们在提供服务所必需的范围内使用这些内容的权利。

-

3. 未经我们书面许可,您不得以任何形式复制、传播、展示、镜像、上传、下载本服务的任何内容。

- -

六、隐私保护

-

我们非常重视您的隐私保护。关于我们如何收集、使用、存储和保护您的个人信息,请详细阅读我们的《隐私政策》。

- -

七、免责声明

-

1. 本服务基于现有技术和条件提供,我们不对服务的及时性、准确性、完整性、可靠性作任何明示或暗示的保证。

-

2. 因不可抗力、计算机病毒、黑客攻击、系统不稳定、用户设备故障等原因导致的服务中断或数据丢失,我们不承担责任。

-

3. 您因使用本服务而产生的任何直接或间接损失,我们均不承担责任。

- -

八、服务变更与终止

-

1. 我们有权根据业务发展需要,随时变更、中断或终止部分或全部服务。

-

2. 如您违反本协议,我们有权立即终止向您提供服务,并保留追究法律责任的权利。

-

3. 服务终止后,您账户内的数据可能被删除,请您提前备份重要数据。

- -

九、协议修改

-

我们有权随时修改本协议。协议修改后,我们会在相关页面公布修改后的协议内容。如您不同意修改后的协议,请停止使用本服务;如您继续使用,则视为接受修改后的协议。

- -

十、法律适用与争议解决

-

1. 本协议的订立、生效、解释、履行和争议解决均适用中华人民共和国大陆地区法律法规。

-

2. 如因本协议产生任何争议,双方应友好协商解决;协商不成的,任何一方均可向我们住所地有管辖权的人民法院提起诉讼。

- -

十一、其他

-

1. 如本协议的任何条款被认定为无效或不可执行,不影响其他条款的效力。

-

2. 本协议的标题仅为方便阅读而设,不影响本协议任何条款的含义或解释。

-

3. 如您对本协议有任何疑问,可通过我们提供的联系方式与我们联系。

- -
- 最后更新时间:2026年1月27日 -
-
- - - """ - return HTMLResponse(content=html_content) + return FileResponse(_STATIC_DIR / "legal" / "terms.html", media_type="text/html") -@router.get("/api/legal/privacy", response_class=HTMLResponse) +@router.get("/api/legal/privacy") async def get_privacy(): """隐私政策页面""" - html_content = """ - - - - - - 隐私政策 - 岁月时书 - - - -
-

隐私政策

-
岁月时书隐私保护政策
- -
- 服务提供方:上海华嘎科技有限公司
- 产品名称:岁月时书
- 生效日期:2026年1月27日 -
- -

上海华嘎科技有限公司(以下简称"我们")非常重视用户的隐私保护。本隐私政策说明了我们如何收集、使用、存储和保护您的个人信息。请您仔细阅读本隐私政策,以了解我们对您个人信息的处理方式。

- -

一、信息收集

-

为了向您提供更好的服务,我们可能会收集以下信息:

-

1. 账户信息:当您注册账户时,我们会收集您的手机号码、密码、昵称、邮箱(可选)等信息。

-

2. 设备信息:我们可能会收集您的设备型号、操作系统版本、设备标识符、IP地址等信息,用于提供更好的服务体验和保障账户安全。

-

3. 使用信息:我们会收集您使用本服务时产生的信息,包括但不限于对话记录、语音内容、文字内容、操作日志等。

-

4. 位置信息:在您授权的情况下,我们可能会收集您的位置信息,用于提供基于位置的服务。

- -

二、信息使用

-

我们收集您的个人信息主要用于以下目的:

-

1. 提供服务:使用您的信息来提供、维护、改进我们的服务,包括处理您的对话请求、生成回忆录内容等。

-

2. 账户管理:用于账户注册、登录验证、密码重置、账户安全保护等。

-

3. 客户服务:用于响应您的咨询、处理您的反馈、解决技术问题等。

-

4. 安全保护:用于检测、预防、处理欺诈、滥用、安全风险和技术问题。

-

5. 法律合规:遵守适用的法律法规、法律程序或政府要求。

-

6. 服务改进:分析用户使用情况,改进我们的产品和服务质量。

- -

三、信息存储

-

1. 存储地点:您的个人信息将存储在中华人民共和国境内。如需跨境传输,我们将严格按照相关法律法规执行。

-

2. 存储期限:我们仅在为实现本政策所述目的所必需的期间内保留您的个人信息。在您注销账户后,我们将删除或匿名化处理您的个人信息,法律法规另有规定的除外。

-

3. 安全措施:我们采用行业标准的安全技术和措施来保护您的个人信息,包括但不限于数据加密、访问控制、安全审计等。

- -

四、信息共享与披露

-

我们承诺不会向第三方出售、出租或以其他方式披露您的个人信息,但以下情况除外:

-

1. 获得您的同意:在获得您明确同意的情况下,我们可能会与第三方共享您的信息。

-

2. 服务提供商:我们可能会与为我们提供服务的第三方(如云服务提供商、数据分析服务商等)共享必要的信息,但这些第三方必须遵守严格的保密义务。

-

3. 法律要求:根据法律法规、法律程序、诉讼或政府主管部门的要求,我们可能需要披露您的个人信息。

-

4. 紧急情况:为保护我们、用户或公众的权利、财产或安全,我们可能会在必要时披露相关信息。

- -

五、您的权利

-

根据相关法律法规,您对自己的个人信息享有以下权利:

-

1. 访问权:您有权访问我们持有的您的个人信息。

-

2. 更正权:您有权要求更正不准确或不完整的个人信息。

-

3. 删除权:在特定情况下,您有权要求删除您的个人信息。

-

4. 撤回同意:您有权撤回之前给予我们的同意,但这可能影响您使用部分服务功能。

-

5. 注销账户:您有权注销您的账户。账户注销后,我们将删除或匿名化处理您的个人信息。

-

如您需要行使上述权利,请通过我们提供的联系方式与我们联系。

- -

六、未成年人保护

-

我们非常重视对未成年人个人信息的保护。如果您是18周岁以下的未成年人,请在您的监护人同意和指导下使用本服务。如果我们发现自己在未事先获得可证实的监护人同意的情况下收集了未成年人的个人信息,我们会设法尽快删除相关数据。

- -

七、Cookie和类似技术

-

我们可能会使用Cookie和类似技术来收集信息、改善用户体验、分析服务使用情况等。您可以通过浏览器设置管理Cookie,但请注意,禁用Cookie可能会影响部分服务功能的使用。

- -

八、第三方服务

-

我们的服务可能包含指向第三方网站、产品和服务的链接。我们不对这些第三方的隐私做法负责,建议您仔细阅读这些第三方的隐私政策。

- -

九、隐私政策的更新

-

我们可能会不时更新本隐私政策。更新后,我们会在相关页面公布最新版本的隐私政策,并通过适当方式通知您。如您不同意更新后的隐私政策,请停止使用本服务;如您继续使用,则视为接受更新后的隐私政策。

- -

十、联系我们

-

如您对本隐私政策有任何疑问、意见或建议,或需要行使您的相关权利,请通过以下方式与我们联系:

-

公司名称:上海华嘎科技有限公司
- 产品名称:岁月时书

-

我们将在收到您的请求后,尽快予以回复。

- -
- 最后更新时间:2026年1月27日 -
-
- - - """ - return HTMLResponse(content=html_content) + return FileResponse(_STATIC_DIR / "legal" / "privacy.html", media_type="text/html") -@router.get("/", response_class=HTMLResponse) +@router.get("/") async def get_home(): """应用官网主页""" - html_path = _STATIC_DIR / "home.html" - return HTMLResponse(content=html_path.read_text(encoding="utf-8")) + return FileResponse(_STATIC_DIR / "home" / "index.html", media_type="text/html") diff --git a/api/app/features/conversation/constants.py b/api/app/features/conversation/constants.py new file mode 100644 index 0000000..35bf167 --- /dev/null +++ b/api/app/features/conversation/constants.py @@ -0,0 +1,5 @@ +"""Chat / 访谈产品常量 — 值来自 config/*.toml(SSOT)。""" + +from app.core.app_config import app_config + +chat = app_config.chat diff --git a/api/app/features/conversation/deps.py b/api/app/features/conversation/deps.py index f0ec578..7879859 100644 --- a/api/app/features/conversation/deps.py +++ b/api/app/features/conversation/deps.py @@ -1,18 +1,17 @@ """Conversation feature dependencies: get_conversation_service.""" from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import get_async_db from app.core.dependencies import get_object_storage from app.features.conversation.service import ConversationService from app.features.quota.deps import get_quota_service from app.features.quota.service import QuotaService from app.ports.storage import ObjectStorage +from app.core.deps_types import DbDep def get_conversation_service( - db: AsyncSession = Depends(get_async_db), + db: DbDep, quota_service: QuotaService = Depends(get_quota_service), object_storage: ObjectStorage = Depends(get_object_storage), ) -> ConversationService: diff --git a/api/app/features/conversation/history_store.py b/api/app/features/conversation/history_store.py index 1762971..3de9fc0 100644 --- a/api/app/features/conversation/history_store.py +++ b/api/app/features/conversation/history_store.py @@ -1,4 +1,12 @@ -"""Durable conversation turn persistence + Redis cache sync (feature layer).""" +"""Durable conversation turn persistence + Redis cache sync (feature layer). + +PostgreSQL is the source of truth for conversation history. Each write path +commits via ``transactional()`` first; ``_sync_redis_best_effort`` runs only +after a successful DB commit. Redis sync failures are logged as warnings and +do not roll back durable state. A brief "DB has data, cache missing" window is +expected under Redis outages; WS reconnect and ``load_canonical_history`` read +from DB and self-heal the cache on the next successful sync. +""" from __future__ import annotations @@ -10,9 +18,11 @@ from typing import Any from sqlalchemy.ext.asyncio import AsyncSession from app.core import redis as redis_core +from app.core.db import transactional from app.core.logging import get_logger from app.features.conversation import repo -from app.features.conversation.models import ConversationMessage +from app.features.conversation.lineage_schemas import DialogueLineage +from app.features.conversation.models import ConversationMessage, Segment from app.features.conversation.session_history import ( conversation_messages_to_redis_history, ) @@ -80,9 +90,9 @@ class ConversationHistoryStore: message_type="text", created_at=created_at, ) - repo.add_conversation_message(msg, self._db) - await self._touch_conversation(conversation_id, occurred_at=created_at) - await self._db.commit() + async with transactional(self._db): + repo.add_conversation_message(msg, self._db) + await self._touch_conversation(conversation_id, occurred_at=created_at) await self._sync_redis_best_effort(conversation_id) return msg.id @@ -132,16 +142,107 @@ class ConversationHistoryStore: created_at=ai_ts, memory_retrieval_trace_json=memory_retrieval_trace, ) - repo.add_conversation_message(human, self._db) - repo.add_conversation_message(ai, self._db) - await self._touch_conversation(conversation_id, occurred_at=ai_ts) - await self._db.commit() + async with transactional(self._db): + repo.add_conversation_message(human, self._db) + repo.add_conversation_message(ai, self._db) + await self._touch_conversation(conversation_id, occurred_at=ai_ts) await self._sync_redis_best_effort(conversation_id) return HumanAiTurnIds( human_message_id=str(human.id), assistant_message_id=str(ai.id), ) + async def record_human_ai_turn_with_segment( + self, + conversation_id: str, + user_message: str, + responses: list[str], + segment: Segment, + *, + user_message_timestamp: datetime | None, + is_from_voice: bool, + voice_session_id: str | None, + audio_duration_seconds: int | None, + agent_response: str, + memory_retrieval_trace: dict | None = None, + ) -> HumanAiTurnIds | None: + """Persist human/ai messages and segment metadata in one transaction.""" + if not responses: + return None + human_ts = user_message_timestamp or _utc_now() + if human_ts.tzinfo is None: + human_ts = human_ts.replace(tzinfo=timezone.utc) + ai_ts = human_ts + timedelta(microseconds=1) + human_type = "audio" if is_from_voice else "text" + segment_id = str(segment.id) + human = ConversationMessage( + id=str(uuid.uuid4()), + conversation_id=conversation_id, + role="human", + content=user_message, + message_type=human_type, + voice_session_id=voice_session_id, + duration_seconds=audio_duration_seconds + if audio_duration_seconds is not None and audio_duration_seconds > 0 + else None, + segment_id=segment_id, + created_at=human_ts, + ) + combined = AI_RESPONSE_SEGMENT_JOIN.join(responses) + ai = ConversationMessage( + id=str(uuid.uuid4()), + conversation_id=conversation_id, + role="ai", + content=combined, + message_type="text", + segment_id=segment_id, + created_at=ai_ts, + memory_retrieval_trace_json=memory_retrieval_trace, + ) + async with transactional(self._db): + repo.add_conversation_message(human, self._db) + repo.add_conversation_message(ai, self._db) + # Postgres: segments.user_message_id FK must exist before segment UPDATE; + # SQLAlchemy may otherwise flush the dirty segment row before message INSERTs. + await self._db.flush() + await self._touch_conversation(conversation_id, occurred_at=ai_ts) + segment.agent_response = agent_response + segment.user_message_id = str(human.id) + segment.lineage_json = DialogueLineage.for_single_turn( + conversation_id=conversation_id, + user_message_id=str(human.id), + assistant_message_id=str(ai.id), + segment_ids=[segment_id], + ).model_dump(mode="json") + await self._sync_redis_best_effort(conversation_id) + return HumanAiTurnIds( + human_message_id=str(human.id), + assistant_message_id=str(ai.id), + ) + + async def attach_ai_tts_for_turn( + self, + conversation_id: str, + *, + tts_audio_urls: list[str], + segment: Segment, + ) -> None: + """Update latest AI message and segment TTS URLs in one transaction.""" + if not tts_audio_urls: + return + segment_id = str(segment.id) + async with transactional(self._db): + row = await repo.set_latest_ai_message_tts_audio_urls( + conversation_id, + self._db, + tts_audio_urls=tts_audio_urls, + segment_id=segment_id, + ) + if row is None: + return + segment.tts_audio_urls = list(tts_audio_urls) + await self._sync_redis_best_effort(conversation_id) + async def attach_ai_tts_audio_urls( self, conversation_id: str, @@ -151,13 +252,13 @@ class ConversationHistoryStore: ) -> None: if not tts_audio_urls: return - row = await repo.set_latest_ai_message_tts_audio_urls( - conversation_id, - self._db, - tts_audio_urls=tts_audio_urls, - segment_id=segment_id, - ) + async with transactional(self._db): + row = await repo.set_latest_ai_message_tts_audio_urls( + conversation_id, + self._db, + tts_audio_urls=tts_audio_urls, + segment_id=segment_id, + ) if row is None: return - await self._db.commit() await self._sync_redis_best_effort(conversation_id) diff --git a/api/app/features/conversation/input_normalize.py b/api/app/features/conversation/input_normalize.py index edb6bb3..71952f0 100644 --- a/api/app/features/conversation/input_normalize.py +++ b/api/app/features/conversation/input_normalize.py @@ -9,9 +9,9 @@ from __future__ import annotations from typing import Any -from app.core.config import settings from app.core.logging import get_logger from app.core.text_normalize import apply_oral_rules, llm_normalize_text +from app.features.conversation.constants import chat logger = get_logger(__name__) @@ -23,8 +23,8 @@ def _llm_normalize_chat_input(text: str, llm: Any) -> str | None: return llm_normalize_text( text, llm, - max_input_chars=int(settings.chat_input_normalize_llm_max_input_chars), - max_tokens=int(settings.chat_input_normalize_llm_max_tokens), + max_input_chars=int(chat.input_normalize_llm_max_input_chars), + max_tokens=int(chat.input_normalize_llm_max_tokens), agent_name="chat_input_normalize.llm", ) @@ -44,9 +44,9 @@ def normalize_chat_input_for_agent( - llm:先规则,再(可选)LLM;无 llm 或失败则保留规则结果 - chat_input_normalize_llm_voice_only:mode=llm 时仅 is_from_voice 为真才调用 LLM """ - if not settings.chat_input_normalize_enabled: + if not chat.input_normalize_enabled: return text or "" - mode = (settings.chat_input_normalize_mode or "rules").strip().lower() + mode = (chat.input_normalize_mode or "rules").strip().lower() if mode == "off": return text or "" @@ -55,7 +55,7 @@ def normalize_chat_input_for_agent( return base effective_llm = llm - if settings.chat_input_normalize_llm_voice_only and not is_from_voice: + if chat.input_normalize_llm_voice_only and not is_from_voice: effective_llm = None refined = _llm_normalize_chat_input(base, effective_llm) diff --git a/api/app/features/conversation/router.py b/api/app/features/conversation/router.py index 7064c07..8fdccd4 100644 --- a/api/app/features/conversation/router.py +++ b/api/app/features/conversation/router.py @@ -2,69 +2,73 @@ 对话 feature — conversations 路由 """ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends -from app.core.dependencies import get_current_user +from app.core.deps_types import CurrentUserDep from app.core.logging import get_logger +from app.core.openapi import error_responses from app.features.conversation.deps import get_conversation_service +from app.features.conversation.schemas import ( + ConversationDetailResponse, + ConversationListItemResponse, + CreateConversationResponse, + DeleteConversationResponse, + EndConversationResponse, + MessageResponse, + OrganizeResponse, +) from app.features.conversation.service import ConversationService -from app.features.user.models import User router = APIRouter( prefix="/api/conversations", tags=["conversations"], - responses={ - 401: {"description": "认证失败"}, - 403: {"description": "权限不足"}, - 404: {"description": "资源不存在"}, - 429: {"description": "配额已用尽"}, - }, + responses=error_responses(401, 403, 404, 429), ) logger = get_logger(__name__) -@router.get("") +@router.get("", response_model=list[ConversationListItemResponse]) async def get_conversations( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: ConversationService = Depends(get_conversation_service), ): """获取当前用户的所有对话列表(需要认证)""" return await service.list_for_user(current_user.id) -@router.post("") +@router.post("", response_model=CreateConversationResponse) async def create_conversation( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: ConversationService = Depends(get_conversation_service), ): """创建新对话(需要认证)。对话轮数在每次发送消息时校验。""" return await service.create(current_user.id) -@router.get("/{conversation_id}") +@router.get("/{conversation_id}", response_model=ConversationDetailResponse) async def get_conversation( conversation_id: str, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: ConversationService = Depends(get_conversation_service), ): """获取对话详情(需要认证,只能访问自己的对话)""" return await service.get_one(conversation_id, current_user.id) -@router.post("/{conversation_id}/end") +@router.post("/{conversation_id}/end", response_model=EndConversationResponse) async def end_conversation( conversation_id: str, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: ConversationService = Depends(get_conversation_service), ): """结束对话(需要认证,只能结束自己的对话)""" return await service.end(conversation_id, current_user.id) -@router.delete("/{conversation_id}") +@router.delete("/{conversation_id}", response_model=DeleteConversationResponse) async def delete_conversation( conversation_id: str, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: ConversationService = Depends(get_conversation_service), ): """删除对话(需要认证,只能删除自己的对话)""" @@ -72,36 +76,28 @@ async def delete_conversation( return {"message": "对话已删除"} -@router.get("/{conversation_id}/messages") +@router.get("/{conversation_id}/messages", response_model=list[MessageResponse]) async def get_messages( conversation_id: str, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: ConversationService = Depends(get_conversation_service), ): """获取对话的消息列表(需要认证,只能访问自己的对话)""" return await service.get_messages(conversation_id, current_user.id) -@router.post("/{conversation_id}/organize") +@router.post("/{conversation_id}/organize", response_model=OrganizeResponse) async def organize_conversation( conversation_id: str, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: ConversationService = Depends(get_conversation_service), ): """ 整理对话内容成章节(需要认证,只能整理自己的对话) 手动触发对话整理,将对话中的内容整理成回忆录章节 """ - try: - return await service.organize( - conversation_id, - current_user.id, - current_user.subscription_type, - ) - except HTTPException: - raise - except Exception as e: - logger.exception("提交整理任务失败: {}", e) - raise HTTPException( - status_code=500, detail="提交整理任务失败,请稍后重试" - ) from e + return await service.organize( + conversation_id, + current_user.id, + current_user.subscription_type, + ) diff --git a/api/app/features/conversation/schemas.py b/api/app/features/conversation/schemas.py index 24e3749..c5e733c 100644 --- a/api/app/features/conversation/schemas.py +++ b/api/app/features/conversation/schemas.py @@ -1,7 +1,55 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field class CreateConversationResponse(BaseModel): id: str + user_id: str started_at: str status: str + + +class ConversationListItemResponse(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + id: str + title: str + avatar_url: str | None = Field(None, alias="avatarUrl") + latest_message_preview: str | None = Field(None, alias="latestMessagePreview") + latest_message_time: int = Field(alias="latestMessageTime") + started_at: int = Field(alias="startedAt") + unread_count: int = Field(0, alias="unreadCount") + is_default_assistant: bool = Field(alias="isDefaultAssistant") + has_user_message: bool = Field(alias="hasUserMessage") + + +class ConversationDetailResponse(BaseModel): + id: str + user_id: str + started_at: str + ended_at: str | None = None + duration_seconds: int | None = None + summary: str | None = None + status: str + current_topic: str | None = None + conversation_stage: str | None = None + + +class EndConversationResponse(BaseModel): + id: str + status: str + ended_at: str + duration_seconds: int | None = None + + +class DeleteConversationResponse(BaseModel): + message: str + + +class MessageResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + +class OrganizeResponse(BaseModel): + message: str + conversation_id: str + segments_count: int diff --git a/api/app/features/conversation/service.py b/api/app/features/conversation/service.py index aeb21a5..d8c568b 100644 --- a/api/app/features/conversation/service.py +++ b/api/app/features/conversation/service.py @@ -4,7 +4,6 @@ import asyncio import uuid from datetime import datetime, timezone -from fastapi import HTTPException from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat.personas import agent_name @@ -13,10 +12,18 @@ from app.core.cos_url_keys import ( collect_cos_keys_from_tts_url_list, extract_cos_object_key_if_owned, ) +from app.core.db import transactional +from app.core.errors import ( + AuthorizationError, + BadRequestError, + NotFoundError, + QuotaExceededError, +) from app.core.logging import get_logger from app.core.redis import redis_service from app.core.storage_purge import delete_object_storage_keys_best_effort from app.features.conversation import repo +from app.features.conversation.history_store import ConversationHistoryStore from app.features.conversation.models import Conversation, Segment from app.features.conversation.session_history import ( conversation_messages_to_redis_history, @@ -132,8 +139,8 @@ class ConversationService: started_at=datetime.now(timezone.utc), status="active", ) - self._db.add(conv) - await self._db.commit() + async with transactional(self._db): + self._db.add(conv) await self._db.refresh(conv) return conv, "" if conv.user_id != user_id: @@ -152,7 +159,7 @@ class ConversationService: audio_duration_seconds: int | None = None, ) -> Segment: if conversation.user_id != user_id: - raise ValueError("conversation ownership mismatch") + raise AuthorizationError("无权访问此对话") segment = Segment( id=str(uuid.uuid4()), conversation_id=conversation.id, @@ -161,9 +168,9 @@ class ConversationService: audio_duration_seconds=audio_duration_seconds, processed=False, ) - self._db.add(segment) - conversation.last_message_at = datetime.now(timezone.utc) - await self._db.commit() + async with transactional(self._db): + self._db.add(segment) + conversation.last_message_at = datetime.now(timezone.utc) await self._db.refresh(segment) return segment @@ -183,6 +190,10 @@ class ConversationService: logger.warning("conversation history cache read skipped: {}", exc) history = [] if history: + try: + await redis_service.extend_session_ttl(conversation_id) + except Exception as exc: + logger.debug("conversation history ttl extend skipped: {}", exc) return history rows = await repo.get_conversation_messages(conversation_id, self._db) @@ -196,6 +207,13 @@ class ConversationService: return [] + async def record_ai_only_turn( + self, conversation_id: str, texts: list[str] + ) -> str | None: + return await ConversationHistoryStore(self._db).record_ai_only_turn( + conversation_id, texts + ) + async def list_for_user(self, user_id: str) -> list[dict]: conversations = await repo.get_user_conversations(user_id, self._db) # Fetch language once for fallback title localization (no per-row N+1). @@ -235,8 +253,8 @@ class ConversationService: started_at=datetime.now(timezone.utc), status="active", ) - repo.add_conversation(conv, self._db) - await self._db.commit() + async with transactional(self._db): + repo.add_conversation(conv, self._db) await self._db.refresh(conv) return { "id": conv.id, @@ -248,7 +266,7 @@ class ConversationService: async def get_or_404(self, conversation_id: str, user_id: str) -> Conversation: conv = await repo.get_conversation(conversation_id, self._db) if not conv or conv.user_id != user_id or conv.deleted_at is not None: - raise HTTPException(status_code=404, detail="Conversation not found") + raise NotFoundError("Conversation not found") return conv async def get_one(self, conversation_id: str, user_id: str) -> dict: @@ -267,13 +285,13 @@ class ConversationService: async def end(self, conversation_id: str, user_id: str) -> dict: conv = await self.get_or_404(conversation_id, user_id) - conv.status = "ended" - conv.ended_at = datetime.now(timezone.utc) - if conv.started_at: - conv.duration_seconds = int( - (conv.ended_at - conv.started_at).total_seconds() - ) - await self._db.commit() + async with transactional(self._db): + conv.status = "ended" + conv.ended_at = datetime.now(timezone.utc) + if conv.started_at: + conv.duration_seconds = int( + (conv.ended_at - conv.started_at).total_seconds() + ) return { "id": conv.id, "status": conv.status, @@ -305,8 +323,8 @@ class ConversationService: ) await self._clear_history(conversation_id) - conv.deleted_at = datetime.now(timezone.utc) - await self._db.commit() + async with transactional(self._db): + conv.deleted_at = datetime.now(timezone.utc) delete_object_storage_keys_best_effort( self._object_storage, @@ -328,6 +346,22 @@ class ConversationService: except Exception: return [] + async def align_conversation_stage_from_memoir( + self, conversation: Conversation, memoir_stage: str + ) -> None: + """Align conversation_stage with memoir state without regressing stage order.""" + from app.agents.stage_constants import STAGE_TO_ORDER + + ms = (memoir_stage or "").strip() + if not ms: + return + cs = (conversation.conversation_stage or "").strip() + async with transactional(self._db): + if not cs: + conversation.conversation_stage = ms + elif STAGE_TO_ORDER.get(ms, -1) >= STAGE_TO_ORDER.get(cs, -1): + conversation.conversation_stage = ms + async def organize( self, conversation_id: str, user_id: str, subscription_type: str ) -> dict: @@ -335,12 +369,12 @@ class ConversationService: pending_p1 = await repo.get_segments_pending_phase1(conversation_id, self._db) has_p2 = await repo.conversation_has_pending_phase2(conversation_id, self._db) if not pending_p1 and not has_p2: - raise HTTPException(status_code=400, detail="该对话没有可整理的内容") + raise BadRequestError("该对话没有可整理的内容") can_submit, quota_message = await self._quota.check_can_submit_organize( user_id, subscription_type ) if not can_submit: - raise HTTPException(status_code=403, detail=quota_message) + raise QuotaExceededError(quota_message) if pending_p1: segment_ids = [s.id for s in pending_p1] process_memoir_phase1.delay(conv.user_id, segment_ids) diff --git a/api/app/features/conversation/ws/connection_manager.py b/api/app/features/conversation/ws/connection_manager.py index 43ddd05..2ac0cb9 100644 --- a/api/app/features/conversation/ws/connection_manager.py +++ b/api/app/features/conversation/ws/connection_manager.py @@ -2,8 +2,9 @@ from typing import Dict -from fastapi import HTTPException, WebSocket +from fastapi import WebSocket +from app.core.errors import NotFoundError from app.core.logging import get_logger logger = get_logger(__name__) @@ -43,7 +44,7 @@ class ConnectionManager: if conversation_id in self.active_connections: websocket = self.active_connections[conversation_id] return await websocket.receive_json() - raise HTTPException(status_code=404, detail="Connection not found") + raise NotFoundError("Connection not found") manager = ConnectionManager() diff --git a/api/app/features/conversation/ws/persist.py b/api/app/features/conversation/ws/persist.py new file mode 100644 index 0000000..c0e20a1 --- /dev/null +++ b/api/app/features/conversation/ws/persist.py @@ -0,0 +1,55 @@ +"""WS pipeline 分段持久化(语音分段、按需 TTS 等独立 commit 场景)。 + +Each helper here opens its own ``transactional()`` on the long-lived WS +``AsyncSession`` (see ``app.core.db``): multiple commits per connection are +intentional for incremental durability. + +``ConversationHistoryStore`` handles whole human/AI turn writes and refreshes +Redis after those commits. Pipeline code may interleave ``persist.py`` segment +commits with turn-level ``history_store`` writes; readers should treat DB as +authoritative when cache and DB diverge briefly. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Optional + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.db import transactional +from app.features.conversation.models import Conversation, ConversationMessage, Segment + + +def mark_conversation_active( + conversation: Conversation, at: Optional[datetime] = None +) -> datetime: + activity_time = at or datetime.now(timezone.utc) + conversation.last_message_at = activity_time + return activity_time + + +async def persist_message_tts_url_segment( + db: AsyncSession, + msg: ConversationMessage, + segment_index: int, + url_stored: str, +) -> None: + """按需 TTS:写入 message.tts_audio_urls[segment_index] 并提交。""" + urls = list(msg.tts_audio_urls or []) + while len(urls) <= segment_index: + urls.append("") + urls[segment_index] = url_stored + async with transactional(db): + msg.tts_audio_urls = urls + + +async def persist_voice_segment_row( + db: AsyncSession, + segment: Segment, + conversation: Conversation, +) -> None: + """语音分段入库并刷新会话活跃时间。""" + async with transactional(db): + db.add(segment) + mark_conversation_active(conversation) diff --git a/api/app/features/conversation/ws/pipeline.py b/api/app/features/conversation/ws/pipeline.py index 0745491..71c2776 100644 --- a/api/app/features/conversation/ws/pipeline.py +++ b/api/app/features/conversation/ws/pipeline.py @@ -14,7 +14,7 @@ from app.core.logging import get_logger if TYPE_CHECKING: from app.features.quota.service import QuotaService -from sqlalchemy import select, update +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat import ChatOrchestrator @@ -27,6 +27,10 @@ from app.core.cos_url_keys import ( extract_cos_object_key_if_owned, ) from app.core.db import AsyncSessionLocal +from app.features.conversation.ws.persist import ( + persist_message_tts_url_segment, + persist_voice_segment_row, +) from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider from app.features.conversation.chat_turn import ( ChatTurnContext, @@ -37,7 +41,6 @@ from app.features.conversation.history_store import ( AI_RESPONSE_SEGMENT_JOIN, ConversationHistoryStore, ) -from app.features.conversation.lineage_schemas import DialogueLineage from app.features.conversation.models import Conversation, ConversationMessage, Segment from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType @@ -47,11 +50,12 @@ from app.features.conversation.ws.profile_collector import ( get_missing_profile_fields, ) from app.features.conversation.ws.topic_chips_push import maybe_send_topic_chips_ws -from app.features.memoir.state_service import get_or_create_state from app.features.memoir.background_runner import BackgroundTaskRunner -from app.features.memoir.ingest_scheduler import MemoirIngestScheduler +from app.features.memoir.ingest_scheduler import MemoirIngestScheduler, MemoirTrigger +from app.features.memoir.state_service import get_or_create_state from app.features.user.models import User from app.ports.asr import ASRTranscriptionError +from app.core.runtime_constants import tts_defaults logger = get_logger(__name__) @@ -118,20 +122,20 @@ async def _send_tts_audio( chunk_index, language, (text or "")[:30], - settings.tts_provider, + tts_defaults.provider, ) return None if _tts_epoch_value(conversation_id) != tts_epoch_start: return None - ext = _tts_object_ext(settings.tts_codec) - content_type = _tts_codec_to_content_type(settings.tts_codec) + ext = _tts_object_ext(tts_defaults.codec) + content_type = _tts_codec_to_content_type(tts_defaults.codec) storage = get_object_storage() key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}" public_url = storage.upload(key, audio_bytes, content_type) # 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC) payload_data: Dict[str, Any] = { - "format": settings.tts_codec, + "format": tts_defaults.codec, "audio_base64": base64.b64encode(audio_bytes).decode("utf-8"), "audio_url": playback_url, "index": chunk_index, @@ -182,7 +186,7 @@ async def handle_tts_request_on_demand( segment_index, len(segment_text or ""), settings.enable_tts, - settings.tts_provider, + tts_defaults.provider, ) conv = await db.get(Conversation, conversation_id) @@ -260,7 +264,7 @@ async def handle_tts_request_on_demand( "conversation_id": conversation_id, "data": { "audio_url": playback_url, - "format": settings.tts_codec, + "format": tts_defaults.codec, "index": segment_index, "total": chunk_total, "assistant_message_id": assistant_message_id, @@ -319,11 +323,7 @@ async def handle_tts_request_on_demand( ) return False, "语音合成失败" - while len(urls) <= segment_index: - urls.append("") - urls[segment_index] = url_stored - msg.tts_audio_urls = urls - await db.commit() + await persist_message_tts_url_segment(db, msg, segment_index, url_stored) store = ConversationHistoryStore(db) await store._sync_redis_best_effort(conversation_id) @@ -344,6 +344,24 @@ _background_runner = BackgroundTaskRunner() memoir_ingest_scheduler = MemoirIngestScheduler(_background_runner) +async def _schedule_memoir_ingest_for_segment( + user_id: str, + segment: Segment, + *, + trigger: MemoirTrigger = "turn", +) -> None: + """Queue memoir phase1 after segment text (and ideally lineage) is durable.""" + text = (segment.user_input_text or "").strip() + if not text: + return + await memoir_ingest_scheduler.queue_segment( + user_id, + str(segment.id), + text_char_count=len(text), + trigger=trigger, + ) + + # ── 分段流状态 ────────────────────────────────────────────────── @@ -457,14 +475,6 @@ def _utc_now() -> datetime: return datetime.now(timezone.utc) -def _mark_conversation_active( - conversation: Conversation, at: Optional[datetime] = None -) -> datetime: - activity_time = at or _utc_now() - conversation.last_message_at = activity_time - return activity_time - - def _voice_session_id_from_client_segment_id( client_segment_id: Optional[str], ) -> Optional[str]: @@ -825,15 +835,9 @@ async def process_audio_segment( else None, processed=False, ) - db.add(segment) - user_message_timestamp = _mark_conversation_active(conversation) - await db.commit() + await persist_voice_segment_row(db, segment, conversation) + user_message_timestamp = conversation.last_message_at await db.refresh(segment) - await memoir_ingest_scheduler.queue_segment( - conversation.user_id, - segment.id, - text_char_count=len((transcript_text or "").strip()), - ) ready_segments: List[Tuple[int, str, Segment]] = [] tts_flag_this_voice_session = False @@ -943,6 +947,8 @@ async def process_user_message( *, force_skip_tts: bool = False, tts_this_turn: Optional[bool] = None, + memoir_trigger: MemoirTrigger = "turn", + schedule_memoir: bool = True, ) -> None: """处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。""" with business_span("conversation.ws.process_turn"): @@ -956,6 +962,8 @@ async def process_user_message( user_message_timestamp, force_skip_tts=force_skip_tts, tts_this_turn=tts_this_turn, + memoir_trigger=memoir_trigger, + schedule_memoir=schedule_memoir, ) @@ -970,6 +978,8 @@ async def _process_user_message_inner( *, force_skip_tts: bool = False, tts_this_turn: Optional[bool] = None, + memoir_trigger: MemoirTrigger = "turn", + schedule_memoir: bool = True, ) -> None: store = ConversationHistoryStore(db) tts_urls: list[str] = [] @@ -1022,18 +1032,17 @@ async def _process_user_message_inner( want_tts, ) - segment.agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses) - _mark_conversation_active(conversation) - turn_ids = await store.record_human_ai_turn( + agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses) + turn_ids = await store.record_human_ai_turn_with_segment( conversation_id=conversation_id, user_message=user_message, responses=responses, + segment=segment, user_message_timestamp=user_message_timestamp, is_from_voice=is_from_voice, voice_session_id=voice_session_id, audio_duration_seconds=audio_dur, - tts_audio_urls=None, - segment_id=segment.id, + agent_response=agent_response, memory_retrieval_trace=turn.memory_retrieval_trace, ) if not turn_ids: @@ -1053,23 +1062,22 @@ async def _process_user_message_inner( "timestamp": datetime.now(timezone.utc).isoformat(), }, ) + owner_id = (user.id if user is not None else None) or conversation.user_id + if schedule_memoir: + await _schedule_memoir_ingest_for_segment( + owner_id, + segment, + trigger=memoir_trigger, + ) return - lineage = DialogueLineage.for_single_turn( - conversation_id=conversation_id, - user_message_id=turn_ids.human_message_id, - assistant_message_id=turn_ids.assistant_message_id, - segment_ids=[str(segment.id)], - ) - await db.execute( - update(Segment) - .where(Segment.id == segment.id) - .values( - user_message_id=turn_ids.human_message_id, - lineage_json=lineage.model_dump(mode="json"), + owner_id = (user.id if user is not None else None) or conversation.user_id + if schedule_memoir: + await _schedule_memoir_ingest_for_segment( + owner_id, + segment, + trigger=memoir_trigger, ) - ) - await db.commit() ai_msg_id = turn_ids.assistant_message_id tts_epoch_start = _tts_epoch_value(conversation_id) @@ -1156,32 +1164,20 @@ async def _process_user_message_inner( logger.warning("after-turn topic chips skipped: {}", chip_err) if tts_urls: - await store.attach_ai_tts_audio_urls( + await store.attach_ai_tts_for_turn( conversation_id, tts_audio_urls=tts_urls, - segment_id=segment.id, + segment=segment, ) - await db.execute( - update(Segment) - .where(Segment.id == segment.id) - .values(tts_audio_urls=tts_urls) - ) - await db.commit() except Exception as e: if tts_urls: try: - await store.attach_ai_tts_audio_urls( + await store.attach_ai_tts_for_turn( conversation_id, tts_audio_urls=tts_urls, - segment_id=segment.id, + segment=segment, ) - await db.execute( - update(Segment) - .where(Segment.id == segment.id) - .values(tts_audio_urls=tts_urls) - ) - await db.commit() except Exception as persist_error: logger.warning("补写 TTS 元数据失败: {}", persist_error) logger.exception("处理用户消息失败: {}", e) diff --git a/api/app/features/conversation/ws/profile_collector.py b/api/app/features/conversation/ws/profile_collector.py index 2ad2376..40e75cb 100644 --- a/api/app/features/conversation/ws/profile_collector.py +++ b/api/app/features/conversation/ws/profile_collector.py @@ -2,6 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession +from app.core.db import transactional from app.features.user.models import User @@ -37,17 +38,21 @@ async def apply_extracted_profile(user: User, extracted: dict, db: AsyncSession) """将提取到的资料信息保存到用户模型""" changed = False if "birth_year" in extracted and not user.birth_year: - user.birth_year = extracted["birth_year"] changed = True if "birth_place" in extracted and not user.birth_place: - user.birth_place = extracted["birth_place"] changed = True if "grew_up_place" in extracted and not user.grew_up_place: - user.grew_up_place = extracted["grew_up_place"] changed = True if "occupation" in extracted and not user.occupation: - user.occupation = extracted["occupation"] changed = True if changed: - await db.commit() + async with transactional(db): + if "birth_year" in extracted and not user.birth_year: + user.birth_year = extracted["birth_year"] + if "birth_place" in extracted and not user.birth_place: + user.birth_place = extracted["birth_place"] + if "grew_up_place" in extracted and not user.grew_up_place: + user.grew_up_place = extracted["grew_up_place"] + if "occupation" in extracted and not user.occupation: + user.occupation = extracted["occupation"] await db.refresh(user) diff --git a/api/app/features/conversation/ws/router.py b/api/app/features/conversation/ws/router.py index 16f3fda..44f78fe 100644 --- a/api/app/features/conversation/ws/router.py +++ b/api/app/features/conversation/ws/router.py @@ -12,13 +12,11 @@ from starlette.websockets import WebSocketState from app.agents.chat.background_voice import infer_background_voice from app.agents.chat.prompts_profile import format_user_profile_context -from app.agents.stage_constants import STAGE_TO_ORDER from app.core.config import settings from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider from app.core.logging import get_logger from app.core.security import verify_token -from app.features.conversation.history_store import ConversationHistoryStore from app.features.conversation.service import ConversationService from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType @@ -30,7 +28,6 @@ from app.features.conversation.ws.pipeline import ( cleanup_segment_states, get_or_create_segment_state, handle_tts_request_on_demand, - memoir_ingest_scheduler, process_audio_segment, process_conversation_segments, process_persisted_user_segment_response, @@ -40,9 +37,10 @@ from app.features.conversation.ws.pipeline import ( from app.features.conversation.ws.profile_collector import get_missing_profile_fields from app.features.conversation.ws.quota_guard import check_ws_quota from app.features.conversation.ws.topic_chips_push import maybe_send_topic_chips_ws -from app.features.memoir.state_service import get_or_create_state +from app.features.memoir.service import MemoirService from app.features.quota.service import QuotaService -from app.features.user.models import User +from app.features.user.service import UserService +from app.features.conversation.constants import chat logger = get_logger(__name__) @@ -92,7 +90,8 @@ async def websocket_endpoint( return async with AsyncSessionLocal() as db: - user = await db.get(User, user_id) + user_service = UserService(db) + user = await user_service.get_by_id(user_id) if not user: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在" @@ -108,6 +107,7 @@ async def websocket_endpoint( quota_service = QuotaService(db=db) conversation_service = ConversationService(db=db, quota_service=quota_service) + memoir_service = MemoirService(db=db) try: await manager.send_message( @@ -158,15 +158,10 @@ async def websocket_endpoint( # 冷启动对齐 conversation_stage 与 MemoirState.current_stage; # 若对话行已有更靠前的人生阶段(STAGE_TO_ORDER 更大),不覆盖以免回退。 - memoir_state = await get_or_create_state(user_id, db) - ms = (memoir_state.current_stage or "").strip() - cs = (conversation.conversation_stage or "").strip() - if ms: - if not cs: - conversation.conversation_stage = ms - elif STAGE_TO_ORDER.get(ms, -1) >= STAGE_TO_ORDER.get(cs, -1): - conversation.conversation_stage = ms - await db.commit() + memoir_state = await memoir_service.get_or_create_memoir_state(user_id) + await conversation_service.align_conversation_stage_from_memoir( + conversation, memoir_state.current_stage or "" + ) await db.refresh(conversation) history = await conversation_service.ensure_redis_history_from_db( @@ -184,7 +179,7 @@ async def websocket_endpoint( """统一:把一组 AI 消息落库并按 [SPLIT] 分段下发。""" if not texts: return - ai_msg_id = await ConversationHistoryStore(db).record_ai_only_turn( + ai_msg_id = await conversation_service.record_ai_only_turn( conversation_id, texts ) if not ai_msg_id: @@ -271,9 +266,9 @@ async def websocket_endpoint( else: # 历史非空:判断是否需要回访问候(距上次消息超过阈值) idle_hours = _idle_hours_since(conversation.last_message_at) - threshold = float(settings.chat_re_greeting_idle_hours) + threshold = float(chat.re_greeting_idle_hours) if ( - settings.chat_re_greeting_enabled + chat.re_greeting_enabled and not get_missing_profile_fields(user) and idle_hours is not None and idle_hours >= threshold @@ -383,11 +378,6 @@ async def websocket_endpoint( user_id, text_message, ) - await memoir_ingest_scheduler.queue_segment( - conversation.user_id, - segment.id, - text_char_count=len(text_message.strip()), - ) task = asyncio.create_task( process_persisted_user_segment_response( @@ -645,11 +635,6 @@ async def websocket_endpoint( audio_duration_seconds=ads if ads > 0 else None, ) ) - await memoir_ingest_scheduler.queue_segment( - conversation.user_id, - segment.id, - text_char_count=len((asr_text or "").strip()), - ) if asr_text and not asr_text.startswith("转写失败"): task = asyncio.create_task( diff --git a/api/app/features/conversation/ws/topic_chips_push.py b/api/app/features/conversation/ws/topic_chips_push.py index e48ffb8..100b3f6 100644 --- a/api/app/features/conversation/ws/topic_chips_push.py +++ b/api/app/features/conversation/ws/topic_chips_push.py @@ -14,6 +14,7 @@ from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType from app.features.conversation.ws.profile_collector import get_missing_profile_fields from app.features.user.models import User +from app.features.conversation.constants import chat log = get_logger(__name__) @@ -27,7 +28,7 @@ async def maybe_send_topic_chips_ws( language: str, ) -> None: """资料齐备且开关开启时,按当前回忆录阶段下发 topic_suggestions。失败静默。""" - if not settings.chat_topic_chips_enabled: + if not chat.topic_chips_enabled: return if get_missing_profile_fields(user): return @@ -41,7 +42,7 @@ async def maybe_send_topic_chips_ws( chips = build_topic_chips( stage, empty_slots, - max_chips=settings.chat_topic_chips_max, + max_chips=chat.topic_chips_max, language=language, ) if not chips: diff --git a/api/app/features/evaluation/constants.py b/api/app/features/evaluation/constants.py new file mode 100644 index 0000000..3170c60 --- /dev/null +++ b/api/app/features/evaluation/constants.py @@ -0,0 +1,5 @@ +"""Internal evaluation / judge 产品常量 — 值来自 config/*.toml(SSOT)。""" + +from app.core.app_config import app_config + +eval_cfg = app_config.eval diff --git a/api/app/features/evaluation/deps.py b/api/app/features/evaluation/deps.py index fa0860b..daa1628 100644 --- a/api/app/features/evaluation/deps.py +++ b/api/app/features/evaluation/deps.py @@ -3,9 +3,8 @@ from __future__ import annotations from typing import Annotated from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import get_async_db +from app.core.deps_types import DbDep from app.features.evaluation.admin_service import EvaluationAdminService from app.features.evaluation.judge_manual_service import EvalJudgeManualService from app.features.evaluation.memoir_readiness_service import MemoirReadinessService @@ -14,26 +13,26 @@ from app.features.quota.deps import get_quota_service from app.features.quota.service import QuotaService -def get_evaluation_admin_service( - db: Annotated[AsyncSession, Depends(get_async_db)], -) -> EvaluationAdminService: +def get_session_catalog_service(db: DbDep) -> "SessionCatalogService": + from app.features.evaluation.session_catalog_service import SessionCatalogService + + return SessionCatalogService(db) + + +def get_evaluation_admin_service(db: DbDep) -> EvaluationAdminService: return EvaluationAdminService(db) def get_replay_conversation_service( - db: Annotated[AsyncSession, Depends(get_async_db)], + db: DbDep, quota: Annotated[QuotaService, Depends(get_quota_service)], ) -> ReplayConversationService: return ReplayConversationService(db, quota) -def get_eval_judge_manual_service( - db: Annotated[AsyncSession, Depends(get_async_db)], -) -> EvalJudgeManualService: +def get_eval_judge_manual_service(db: DbDep) -> EvalJudgeManualService: return EvalJudgeManualService(db) -def get_memoir_readiness_service( - db: Annotated[AsyncSession, Depends(get_async_db)], -) -> MemoirReadinessService: +def get_memoir_readiness_service(db: DbDep) -> MemoirReadinessService: return MemoirReadinessService(db) diff --git a/api/app/features/evaluation/errors.py b/api/app/features/evaluation/errors.py index 67e9e16..00c4d2f 100644 --- a/api/app/features/evaluation/errors.py +++ b/api/app/features/evaluation/errors.py @@ -1,13 +1,21 @@ -"""评测 API 领域异常(由 router 映射为 HTTP 状态码)。""" +"""评测 API 领域异常(继承 AppError,由全局 handler 统一映射)。""" + +from app.core.errors import BadRequestError, NotFoundError -class EvaluationNotFoundError(Exception): +class EvaluationNotFoundError(NotFoundError): def __init__(self, detail: str = "not found") -> None: - self.detail = detail super().__init__(detail) + @property + def detail(self) -> str: + return self.message -class EvaluationBadRequestError(Exception): + +class EvaluationBadRequestError(BadRequestError): def __init__(self, detail: str) -> None: - self.detail = detail super().__init__(detail) + + @property + def detail(self) -> str: + return self.message diff --git a/api/app/features/evaluation/eval_trace_format.py b/api/app/features/evaluation/eval_trace_format.py index 97ee0f6..48e8387 100644 --- a/api/app/features/evaluation/eval_trace_format.py +++ b/api/app/features/evaluation/eval_trace_format.py @@ -2,14 +2,15 @@ from __future__ import annotations -from app.core.config import settings from app.features.conversation.models import Segment +from app.features.evaluation.constants import eval_cfg from app.features.evaluation.eval_trace_schemas import ( ChapterEvidenceBundle, EvidenceFormatMeta, FormattedMemoirEvidence, StoryEvidenceBundle, ) +from app.features.memoir.constants import memoir from app.features.memory.models import ( MemoryChunk, MemoryFact, @@ -20,7 +21,7 @@ from app.features.memory.models import ( def _memoir_evidence_char_cap() -> int: """与 ``Settings.eval_judge_memoir_evidence_max_chars`` 对齐。""" - return max(1000, int(settings.eval_judge_memoir_evidence_max_chars)) + return max(1000, int(eval_cfg.judge_memoir_evidence_max_chars)) def _approx_tokens(chars: int) -> int: diff --git a/api/app/features/evaluation/gating_service.py b/api/app/features/evaluation/gating_service.py index cd51867..3705fb0 100644 --- a/api/app/features/evaluation/gating_service.py +++ b/api/app/features/evaluation/gating_service.py @@ -7,6 +7,7 @@ from typing import Any from app.core.config import settings from app.features.evaluation.models import EvalCase, EvalRun +from app.features.evaluation.constants import eval_cfg @dataclass @@ -28,7 +29,7 @@ def compute_gate( thr = ( regression_threshold if regression_threshold is not None - else settings.eval_gate_protected_regression_threshold + else eval_cfg.gate_protected_regression_threshold ) by_case: dict[str, dict[str, EvalRun]] = {} for r in runs: diff --git a/api/app/features/evaluation/internal_auth.py b/api/app/features/evaluation/internal_auth.py index 96c64b9..b97a6f3 100644 --- a/api/app/features/evaluation/internal_auth.py +++ b/api/app/features/evaluation/internal_auth.py @@ -2,9 +2,10 @@ from typing import Annotated -from fastapi import Depends, Header, HTTPException, status +from fastapi import Depends, Header from app.core.config import settings +from app.core.errors import AuthenticationError, ServiceUnavailableError from app.core.logging import get_logger logger = get_logger(__name__) @@ -22,9 +23,8 @@ class InternalEvalPrincipal: def require_internal_eval_enabled() -> None: if not (settings.internal_eval_api_key or "").strip(): logger.warning("internal_eval_api_key 未配置,内部评测 API 拒绝访问") - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="内部评测服务未启用(缺少 INTERNAL_EVAL_API_KEY)", + raise ServiceUnavailableError( + "内部评测服务未启用(缺少 INTERNAL_EVAL_API_KEY)" ) @@ -37,16 +37,10 @@ def verify_internal_eval_key( require_internal_eval_enabled() expected = (settings.internal_eval_api_key or "").strip() if not expected: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="内部评测服务未启用", - ) + raise ServiceUnavailableError("内部评测服务未启用") provided = (header_value or query_value or "").strip() if not provided or provided != expected: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="无效的内部评测密钥", - ) + raise AuthenticationError("无效的内部评测密钥") return InternalEvalPrincipal() diff --git a/api/app/features/evaluation/judge_manual_service.py b/api/app/features/evaluation/judge_manual_service.py index dc4d75c..d6682e9 100644 --- a/api/app/features/evaluation/judge_manual_service.py +++ b/api/app/features/evaluation/judge_manual_service.py @@ -12,6 +12,7 @@ from typing import Any from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings +from app.core.db import transactional from app.core.dependencies import ( EvalJudgeProvider, build_eval_judge_llm_spec, @@ -47,6 +48,8 @@ from app.features.evaluation.transcript_for_judge import ( from app.features.evaluation.user_export_fixtures import read_user_export_fixture from app.features.memoir.repo import get_chapters_for_memoir_list from app.features.story.repo import get_stories_for_user +from app.features.evaluation.constants import eval_cfg +from app.features.memoir.constants import memoir logger = get_logger(__name__) @@ -55,8 +58,7 @@ _MAX_EVAL_STORIES = 40 # memoir_snapshot 等仍限幅 _PRIOR_TRANSCRIPT_MAX_CHARS = 8000 _JUDGE_CONFIG_HINT = ( - "评审未配置:智谱需 eval_judge_api_key 或 zhipu_api_key;" - "DeepSeek 需 deepseek_api_key(或 llm_api_key)" + "评审未配置:智谱需 ZHIPU_API_KEY;DeepSeek 需 DEEPSEEK_API_KEY" ) @@ -126,7 +128,7 @@ def _clip_md_for_judge(text: str, max_chars: int | None = None) -> str: cap = ( max_chars if max_chars is not None - else max(1000, int(settings.eval_judge_memoir_body_max_chars)) + else max(1000, int(eval_cfg.judge_memoir_body_max_chars)) ) s = (text or "").strip() if len(s) <= cap: @@ -170,11 +172,12 @@ class EvalJudgeManualService: self, conversation_id: str, bundle: dict[str, Any] ) -> None: try: - row = await conversation_repo.set_playground_conversation_judge_json( - conversation_id, self._db, bundle - ) - if row is not None: - await self._db.commit() + async with transactional(self._db): + row = await conversation_repo.set_playground_conversation_judge_json( + conversation_id, self._db, bundle + ) + if row is None: + return except Exception: logger.exception( "persist playground_conversation_judge_json failed conversation_id={}", @@ -717,7 +720,7 @@ class EvalJudgeManualService: if (getattr(x, "canonical_markdown", None) or "").strip() ) - conc = max(1, min(32, int(settings.eval_judge_memoir_chapter_concurrency))) + conc = max(1, min(32, int(eval_cfg.judge_memoir_chapter_concurrency))) logger.info( "event=eval_memoir_judge_start user_id={} judge_provider={} judge_model={} " "chapters_total={} chapters_nonempty={} chapter_concurrency={}", @@ -741,7 +744,7 @@ class EvalJudgeManualService: baseline_excerpt = _clip_md_for_judge( bl.body, max_chars=max( - 1000, int(settings.eval_judge_memoir_evidence_max_chars) + 1000, int(eval_cfg.judge_memoir_evidence_max_chars) ), ) md = f"# 章节:{ch.title}\n\n{_clip_md_for_judge(body)}" @@ -959,7 +962,7 @@ class EvalJudgeManualService: baseline_excerpt = _clip_md_for_judge( bl.body, max_chars=max( - 1000, int(settings.eval_judge_memoir_evidence_max_chars) + 1000, int(eval_cfg.judge_memoir_evidence_max_chars) ), ) md = f"# 章节:{ch.title}\n\n{_clip_md_for_judge(body)}" @@ -999,7 +1002,7 @@ class EvalJudgeManualService: yield {"event": "chapters_prepared", "count": len(prepared)} - conc = max(1, min(32, int(settings.eval_judge_memoir_chapter_concurrency))) + conc = max(1, min(32, int(eval_cfg.judge_memoir_chapter_concurrency))) sem = asyncio.Semaphore(conc) result_queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue() diff --git a/api/app/features/evaluation/judge_service.py b/api/app/features/evaluation/judge_service.py index d9f9929..22d1169 100644 --- a/api/app/features/evaluation/judge_service.py +++ b/api/app/features/evaluation/judge_service.py @@ -21,6 +21,8 @@ from app.features.evaluation.rubrics.conversation_v1 import ( TURN_JUDGE_INSTRUCTIONS, ) from app.features.evaluation.rubrics.memoir_v1 import MEMOIR_JUDGE_INSTRUCTIONS +from app.features.evaluation.constants import eval_cfg +from app.features.memoir.constants import memoir logger = get_logger(__name__) @@ -38,23 +40,23 @@ def _eval_judge_prompt_char_pool_for_context(context_window_tokens: int) -> int: """整段请求的字符预算(由评审模型 context window 推导,保守)。""" toks = ( int(context_window_tokens) - - settings.eval_judge_completion_reserve_tokens - - settings.eval_judge_prompt_budget_safety_tokens + - eval_cfg.judge_completion_reserve_tokens + - eval_cfg.judge_prompt_budget_safety_tokens ) toks = max(1, toks) - return max(1, int(toks / settings.eval_judge_approx_tokens_per_char)) + return max(1, int(toks / eval_cfg.judge_approx_tokens_per_char)) def _eval_judge_prompt_char_pool() -> int: return _eval_judge_prompt_char_pool_for_context( - settings.eval_judge_context_window_tokens + eval_cfg.judge_context_window_tokens ) def eval_judge_conversation_transcript_max_chars() -> int: """整段对话评审:【完整对话】transcript 最大字符数(默认 GLM 上下文)。""" - if settings.eval_judge_max_transcript_chars > 0: - return settings.eval_judge_max_transcript_chars + if eval_cfg.judge_max_transcript_chars > 0: + return eval_cfg.judge_max_transcript_chars overhead = len(CONV_JUDGE_INSTRUCTIONS) + len(_CONV_HEADER) + 32 return max(1, _eval_judge_prompt_char_pool() - overhead) @@ -62,8 +64,8 @@ def eval_judge_conversation_transcript_max_chars() -> int: def eval_judge_conversation_transcript_max_chars_for_context( context_window_tokens: int, ) -> int: - if settings.eval_judge_max_transcript_chars > 0: - return settings.eval_judge_max_transcript_chars + if eval_cfg.judge_max_transcript_chars > 0: + return eval_cfg.judge_max_transcript_chars overhead = len(CONV_JUDGE_INSTRUCTIONS) + len(_CONV_HEADER) + 32 pool = _eval_judge_prompt_char_pool_for_context(context_window_tokens) return max(1, pool - overhead) @@ -71,8 +73,8 @@ def eval_judge_conversation_transcript_max_chars_for_context( def eval_judge_turn_prior_transcript_max_chars() -> int: """逐轮评审:截至上一轮的 transcript 节选上限(默认 GLM 上下文)。""" - if settings.eval_judge_max_transcript_chars > 0: - return settings.eval_judge_max_transcript_chars + if eval_cfg.judge_max_transcript_chars > 0: + return eval_cfg.judge_max_transcript_chars static = len(TURN_JUDGE_INSTRUCTIONS) + 8800 return max(1, _eval_judge_prompt_char_pool() - static) @@ -80,17 +82,17 @@ def eval_judge_turn_prior_transcript_max_chars() -> int: def eval_judge_turn_prior_transcript_max_chars_for_context( context_window_tokens: int, ) -> int: - if settings.eval_judge_max_transcript_chars > 0: - return settings.eval_judge_max_transcript_chars + if eval_cfg.judge_max_transcript_chars > 0: + return eval_cfg.judge_max_transcript_chars static = len(TURN_JUDGE_INSTRUCTIONS) + 8800 pool = _eval_judge_prompt_char_pool_for_context(context_window_tokens) return max(1, pool - static) def eval_judge_compare_transcript_each_max_chars() -> int: - """单侧对称参考上限(默认与 settings.eval_judge_context_window_tokens 一致)。""" + """单侧对称参考上限(默认与 eval_cfg.judge_context_window_tokens 一致)。""" return eval_judge_compare_transcript_each_max_chars_for_context( - settings.eval_judge_context_window_tokens + eval_cfg.judge_context_window_tokens ) @@ -98,18 +100,18 @@ def eval_judge_compare_transcript_pair_total_budget_for_context( context_window_tokens: int, ) -> int: """A/B 同 prompt 时,两份 transcript 合计最大字符数(已扣对比模板与双份 JSON 等开销)。""" - if settings.eval_judge_max_compare_transcript_chars_each > 0: - return max(1, 2 * int(settings.eval_judge_max_compare_transcript_chars_each)) + if eval_cfg.judge_max_compare_transcript_chars_each > 0: + return max(1, 2 * int(eval_cfg.judge_max_compare_transcript_chars_each)) pool = _eval_judge_prompt_char_pool_for_context(context_window_tokens) - return max(1, pool - int(settings.eval_judge_compare_prompt_overhead_chars)) + return max(1, pool - int(eval_cfg.judge_compare_prompt_overhead_chars)) def eval_judge_compare_transcript_each_max_chars_for_context( context_window_tokens: int, ) -> int: """单侧对称上限的参考值(auto 模式下约为合计预算的一半;供兼容与展示)。""" - if settings.eval_judge_max_compare_transcript_chars_each > 0: - return int(settings.eval_judge_max_compare_transcript_chars_each) + if eval_cfg.judge_max_compare_transcript_chars_each > 0: + return int(eval_cfg.judge_max_compare_transcript_chars_each) total = eval_judge_compare_transcript_pair_total_budget_for_context( context_window_tokens ) @@ -120,7 +122,7 @@ def eval_judge_compare_bundle_caps( context_window_tokens: int, ) -> tuple[int, int | None]: """返回 (compare_cap_total, per_side_cap|None),供 Playground 摘要与流式对比共用。""" - per = int(settings.eval_judge_max_compare_transcript_chars_each or 0) + per = int(eval_cfg.judge_max_compare_transcript_chars_each or 0) if per > 0: return max(1, 2 * per), per return eval_judge_compare_transcript_pair_total_budget_for_context( @@ -249,8 +251,8 @@ def _build_memoir_judge_prompt( "若证据不足,须保守打分并写 `insufficient_evidence`。", "", ] - ev_cap = max(1, int(settings.eval_judge_memoir_evidence_max_chars)) - body_cap = max(1, int(settings.eval_judge_memoir_body_max_chars)) + ev_cap = max(1, int(eval_cfg.judge_memoir_evidence_max_chars)) + body_cap = max(1, int(eval_cfg.judge_memoir_body_max_chars)) if notes: sections.extend(["【评审说明】", notes[:1200], ""]) if source: @@ -290,7 +292,7 @@ class EvalJudgeService: self._llm = judge_llm self._http_error_vendor: EvalJudgeProvider = http_error_vendor self._ctx_tokens = int( - context_window_tokens or settings.eval_judge_context_window_tokens + context_window_tokens or eval_cfg.judge_context_window_tokens ) def _conv_transcript_cap(self) -> int: @@ -382,7 +384,7 @@ class EvalJudgeService: ) -> AsyncIterator[str]: """流式输出中文对比与建议(非 JSON)。""" if not self._llm: - yield "[错误] 未配置评审模型 API Key(智谱:eval_judge_api_key / zhipu_api_key;DeepSeek:deepseek_api_key)" + yield "[错误] 未配置评审模型 API Key(智谱:ZHIPU_API_KEY;DeepSeek:DEEPSEEK_API_KEY)" return cap_total, per_side = eval_judge_compare_bundle_caps(self._ctx_tokens) cap_single = self._conv_transcript_cap() @@ -507,7 +509,7 @@ class EvalJudgeService: prompt, MemoirJudgeOutput, max_tokens=max( - 512, int(settings.eval_judge_memoir_completion_max_tokens) + 512, int(eval_cfg.judge_memoir_completion_max_tokens) ), agent="EvalJudgeService.judge_memoir", http_error_vendor=self._http_error_vendor, diff --git a/api/app/features/evaluation/replay_service.py b/api/app/features/evaluation/replay_service.py index 5dc3d57..43bc543 100644 --- a/api/app/features/evaluation/replay_service.py +++ b/api/app/features/evaluation/replay_service.py @@ -10,7 +10,7 @@ from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import utc_now +from app.core.db import transactional, utc_now from app.core.logging import get_logger from app.core.security import hash_password from app.features.auth import repo as auth_repo @@ -64,8 +64,8 @@ class ReplayConversationService: subscription_type="free", created_at=utc_now(), ) - await auth_repo.create_user(user, self._db) - await self._db.commit() + async with transactional(self._db): + await auth_repo.create_user(user, self._db) await self._db.refresh(user) conversation_id = str(uuid.uuid4()) @@ -159,13 +159,6 @@ class ReplayConversationService: segment = await conv_service.create_user_segment(conv, conv.user_id, text) segment_ids.append(segment.id) ts = segment.created_at or conv.last_message_at - if not skip_memoir: - await memoir_ingest_scheduler.queue_segment( - conv.user_id, - segment.id, - text_char_count=len(text), - trigger="evaluation_replay", - ) await process_user_message( conversation_id=cid, user_message=text, @@ -175,6 +168,8 @@ class ReplayConversationService: user=user, user_message_timestamp=ts, force_skip_tts=skip_tts, + memoir_trigger="evaluation_replay", + schedule_memoir=not skip_memoir, ) count += 1 diff --git a/api/app/features/evaluation/router.py b/api/app/features/evaluation/router.py index c8eb89e..1cfcf15 100644 --- a/api/app/features/evaluation/router.py +++ b/api/app/features/evaluation/router.py @@ -5,11 +5,11 @@ from __future__ import annotations import json from typing import Annotated -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, Query from fastapi.responses import StreamingResponse -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import get_async_db +from app.core.deps_types import DbDep +from app.core.errors import BadRequestError, NotFoundError from app.core.memoir_pipeline_progress import get_pipeline_run_for_eval from app.features.evaluation.admin_service import EvaluationAdminService from app.features.evaluation.deps import ( @@ -17,11 +17,9 @@ from app.features.evaluation.deps import ( get_evaluation_admin_service, get_memoir_readiness_service, get_replay_conversation_service, + get_session_catalog_service, ) -from app.features.evaluation.errors import ( - EvaluationBadRequestError, - EvaluationNotFoundError, -) +from app.features.evaluation.errors import EvaluationBadRequestError from app.features.evaluation.importers.user_export_markdown import ( extract_memoir_chapter_sections_from_export_md, extract_source_user_id_from_export_md, @@ -60,6 +58,8 @@ from app.features.evaluation.schemas import ( from app.features.evaluation.session_catalog_service import SessionCatalogService from app.features.evaluation.user_export_fixtures import read_user_export_fixture +SessionCatalogDep = Annotated[SessionCatalogService, Depends(get_session_catalog_service)] + router = APIRouter(tags=["internal-evaluation"]) @@ -69,18 +69,12 @@ async def eval_api_ping() -> dict[str, str | bool]: return {"ok": True, "service": "life-echo-internal-eval"} -def _eval_http_exc( - e: EvaluationNotFoundError | EvaluationBadRequestError, -) -> HTTPException: - if isinstance(e, EvaluationNotFoundError): - return HTTPException(status_code=404, detail=e.detail) - return HTTPException(status_code=400, detail=e.detail) @router.get("/sessions", response_model=SessionListResponse) async def list_sessions( _auth: InternalEvalAuth, - db: Annotated[AsyncSession, Depends(get_async_db)], + catalog: SessionCatalogDep, offset: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=200), user_id: str | None = Query(None), @@ -90,7 +84,6 @@ async def list_sessions( description="按会话 status 过滤,如 active", ), ): - catalog = SessionCatalogService(db) rows, total = await catalog.list_sessions( offset=offset, limit=limit, user_id=user_id, q=q, status=status ) @@ -119,12 +112,11 @@ async def list_sessions( async def get_session_dialogue( conversation_id: str, _auth: InternalEvalAuth, - db: Annotated[AsyncSession, Depends(get_async_db)], + catalog: SessionCatalogDep, ): - catalog = SessionCatalogService(db) out = await catalog.get_session_dialogue(conversation_id) if not out: - raise HTTPException(status_code=404, detail="conversation not found") + raise NotFoundError("conversation not found") return out @@ -134,12 +126,11 @@ async def get_session_dialogue( async def get_session_transcript( conversation_id: str, _auth: InternalEvalAuth, - db: Annotated[AsyncSession, Depends(get_async_db)], + catalog: SessionCatalogDep, ): - catalog = SessionCatalogService(db) tr = await catalog.get_transcript(conversation_id) if not tr: - raise HTTPException(status_code=404, detail="conversation not found") + raise NotFoundError("conversation not found") return SessionTranscriptOut( conversation_id=tr.conversation_id, user_id=tr.user_id, @@ -155,12 +146,11 @@ async def get_session_transcript( async def get_playground_conversation_judge( conversation_id: str, _auth: InternalEvalAuth, - db: Annotated[AsyncSession, Depends(get_async_db)], + catalog: SessionCatalogDep, ): - catalog = SessionCatalogService(db) tr = await catalog.get_transcript(conversation_id) if not tr: - raise HTTPException(status_code=404, detail="conversation not found") + raise NotFoundError("conversation not found") judge = await catalog.get_playground_conversation_judge_json(conversation_id) return PlaygroundConversationJudgeOut( conversation_id=conversation_id, @@ -185,22 +175,16 @@ async def get_memoir_pipeline_run( ] = None, ): if not phase1_task_id and not memoir_correlation_id: - raise HTTPException( - status_code=400, - detail="provide phase1_task_id or memoir_correlation_id", - ) + raise BadRequestError("provide phase1_task_id or memoir_correlation_id") if phase1_task_id and memoir_correlation_id: - raise HTTPException( - status_code=400, - detail="provide only one of phase1_task_id or memoir_correlation_id", - ) + raise BadRequestError("provide only one of phase1_task_id or memoir_correlation_id") snap = get_pipeline_run_for_eval( user_id.strip(), memoir_correlation_id=memoir_correlation_id, phase1_task_id=phase1_task_id, ) if not snap: - raise HTTPException(status_code=404, detail="pipeline snapshot not found") + raise NotFoundError("pipeline snapshot not found") return MemoirPipelineRunOut.model_validate(snap) @@ -220,15 +204,10 @@ async def memoir_phase1_ready( ), ], ): - try: - return await svc.memoir_phase1_ready_for_segments( - conversation_id=conversation_id, - segment_ids=segment_ids, - ) - except EvaluationNotFoundError as e: - raise _eval_http_exc(e) from e - except EvaluationBadRequestError as e: - raise _eval_http_exc(e) from e + return await svc.memoir_phase1_ready_for_segments( + conversation_id=conversation_id, + segment_ids=segment_ids, + ) @router.post( @@ -240,14 +219,9 @@ async def memoir_submit_phase1( _auth: InternalEvalAuth, svc: Annotated[MemoirReadinessService, Depends(get_memoir_readiness_service)], ): - try: - return await svc.submit_memoir_phase1_for_conversation( - conversation_id=conversation_id, - ) - except EvaluationNotFoundError as e: - raise _eval_http_exc(e) from e - except EvaluationBadRequestError as e: - raise _eval_http_exc(e) from e + return await svc.submit_memoir_phase1_for_conversation( + conversation_id=conversation_id, + ) @router.post("/sessions/replay-bootstrap", response_model=ReplayBootstrapOut) @@ -258,10 +232,7 @@ async def replay_bootstrap( ReplayConversationService, Depends(get_replay_conversation_service) ], ): - try: - cid = await replay.bootstrap_conversation(body.user_id) - except EvaluationBadRequestError as e: - raise _eval_http_exc(e) from e + cid = await replay.bootstrap_conversation(body.user_id) return ReplayBootstrapOut(conversation_id=cid) @@ -272,10 +243,7 @@ async def create_eval_sandbox( ReplayConversationService, Depends(get_replay_conversation_service) ], ): - try: - uid, cid, phone, nick = await replay.create_eval_sandbox() - except EvaluationBadRequestError as e: - raise _eval_http_exc(e) from e + uid, cid, phone, nick = await replay.create_eval_sandbox() return EvalSandboxOut( user_id=uid, conversation_id=cid, @@ -293,42 +261,34 @@ async def replay_conversation( ], ): if body.fixture_filename and body.user_utterances: - raise HTTPException( - status_code=400, - detail="provide only one of fixture_filename or user_utterances", + raise BadRequestError("provide only one of fixture_filename or user_utterances") + segment_ids: list[str] = [] + timing = None + if body.fixture_filename: + fn = body.fixture_filename.strip() + n, echo, segment_ids, timing = await replay.replay_fixture( + conversation_id=body.conversation_id, + fixture_filename=fn, + flush_memoir_after=body.flush_memoir_after, + skip_memoir=body.skip_memoir, + skip_tts=body.skip_tts, + ) + elif body.user_utterances is not None: + utt = [str(u) for u in body.user_utterances if str(u).strip()] + if not utt: + raise EvaluationBadRequestError("user_utterances is empty") + n, segment_ids, timing = await replay.replay_utterances( + conversation_id=body.conversation_id, + utterances=utt, + flush_memoir_after=body.flush_memoir_after, + skip_memoir=body.skip_memoir, + skip_tts=body.skip_tts, + ) + echo = utt + else: + raise EvaluationBadRequestError( + "fixture_filename or user_utterances required" ) - try: - segment_ids: list[str] = [] - timing = None - if body.fixture_filename: - fn = body.fixture_filename.strip() - n, echo, segment_ids, timing = await replay.replay_fixture( - conversation_id=body.conversation_id, - fixture_filename=fn, - flush_memoir_after=body.flush_memoir_after, - skip_memoir=body.skip_memoir, - skip_tts=body.skip_tts, - ) - elif body.user_utterances is not None: - utt = [str(u) for u in body.user_utterances if str(u).strip()] - if not utt: - raise EvaluationBadRequestError("user_utterances is empty") - n, segment_ids, timing = await replay.replay_utterances( - conversation_id=body.conversation_id, - utterances=utt, - flush_memoir_after=body.flush_memoir_after, - skip_memoir=body.skip_memoir, - skip_tts=body.skip_tts, - ) - echo = utt - else: - raise EvaluationBadRequestError( - "fixture_filename or user_utterances required" - ) - except EvaluationNotFoundError as e: - raise _eval_http_exc(e) from e - except EvaluationBadRequestError as e: - raise _eval_http_exc(e) from e return ReplayConversationOut( conversation_id=body.conversation_id, turns_replayed=n, @@ -348,17 +308,12 @@ async def judge_conversation_manual( EvalJudgeManualService, Depends(get_eval_judge_manual_service) ], ): - try: - payload = await judge_svc.judge_conversation( - body.conversation_id, - body.fixture_filename, - judge_provider=body.judge_provider, - judge_model=body.judge_model, - ) - except EvaluationNotFoundError as e: - raise _eval_http_exc(e) from e - except EvaluationBadRequestError as e: - raise _eval_http_exc(e) from e + payload = await judge_svc.judge_conversation( + body.conversation_id, + body.fixture_filename, + judge_provider=body.judge_provider, + judge_model=body.judge_model, + ) return ManualJudgeConversationOut.model_validate(payload) @@ -411,18 +366,13 @@ async def retry_baseline_conversation_judge( EvalJudgeManualService, Depends(get_eval_judge_manual_service) ], ): - try: - payload = await judge_svc.retry_baseline_conversation_judge( - body.conversation_id, - body.fixture_filename, - include_baseline_turn_judges=body.include_baseline_turn_judges, - judge_provider=body.judge_provider, - judge_model=body.judge_model, - ) - except EvaluationNotFoundError as e: - raise _eval_http_exc(e) from e - except EvaluationBadRequestError as e: - raise _eval_http_exc(e) from e + payload = await judge_svc.retry_baseline_conversation_judge( + body.conversation_id, + body.fixture_filename, + include_baseline_turn_judges=body.include_baseline_turn_judges, + judge_provider=body.judge_provider, + judge_model=body.judge_model, + ) return RetryBaselineJudgeOut.model_validate(payload) @@ -434,15 +384,12 @@ async def judge_memoir_chapters_manual( EvalJudgeManualService, Depends(get_eval_judge_manual_service) ], ): - try: - payload = await judge_svc.judge_memoir_for_user( - body.user_id, - body.baseline_sections, - judge_provider=body.judge_provider, - judge_model=body.judge_model, - ) - except EvaluationBadRequestError as e: - raise _eval_http_exc(e) from e + payload = await judge_svc.judge_memoir_for_user( + body.user_id, + body.baseline_sections, + judge_provider=body.judge_provider, + judge_model=body.judge_model, + ) return ManualJudgeMemoirOut.model_validate(payload) @@ -490,10 +437,7 @@ async def get_user_memoir_snapshot( EvalJudgeManualService, Depends(get_eval_judge_manual_service) ], ): - try: - payload = await judge_svc.memoir_snapshot(user_id) - except EvaluationBadRequestError as e: - raise _eval_http_exc(e) from e + payload = await judge_svc.memoir_snapshot(user_id) return UserMemoirSnapshotOut.model_validate(payload) @@ -519,11 +463,9 @@ async def get_user_export_fixture( try: turns, raw_md = read_user_export_fixture(filename) except ValueError: - raise HTTPException( - status_code=400, detail="invalid fixture filename" - ) from None + raise BadRequestError("invalid fixture filename") from None except FileNotFoundError: - raise HTTPException(status_code=404, detail="fixture not found") from None + raise NotFoundError("fixture not found") memoir_tuples = extract_memoir_chapter_sections_from_export_md(raw_md) return UserExportFixtureDetailOut( filename=filename, diff --git a/api/app/features/memoir/background_runner.py b/api/app/features/memoir/background_runner.py index 74afd61..d698fc5 100644 --- a/api/app/features/memoir/background_runner.py +++ b/api/app/features/memoir/background_runner.py @@ -7,9 +7,9 @@ import time from dataclasses import dataclass, field from typing import Dict, List, Sequence -from app.core.config import settings from app.core.logging import get_logger from app.core.task_tracker import task_tracker +from app.features.memoir.constants import memoir logger = get_logger(__name__) @@ -113,8 +113,8 @@ class BackgroundTaskRunner: if not batch or not batch.segment_ids: return - min_c = int(settings.memoir_segment_batch_min_chars) - max_w = float(settings.memoir_segment_batch_max_wait_seconds) + min_c = int(memoir.segment_batch_min_chars) + max_w = float(memoir.segment_batch_max_wait_seconds) if min_c <= 0: segment_ids = self._pop_batch(user_id) diff --git a/api/app/features/memoir/constants.py b/api/app/features/memoir/constants.py new file mode 100644 index 0000000..b4bab4f --- /dev/null +++ b/api/app/features/memoir/constants.py @@ -0,0 +1,5 @@ +"""Memoir 流水线产品常量 — 值来自 config/*.toml(SSOT)。""" + +from app.core.app_config import app_config + +memoir = app_config.memoir diff --git a/api/app/features/memoir/cover_eligibility.py b/api/app/features/memoir/cover_eligibility.py index 16a09db..60e0587 100644 --- a/api/app/features/memoir/cover_eligibility.py +++ b/api/app/features/memoir/cover_eligibility.py @@ -9,6 +9,7 @@ from app.features.memoir.asset_resolver import ( parse_asset_refs, strip_image_placeholders, ) +from app.features.memoir.constants import memoir from app.features.memoir.memoir_images.schema import ( IMAGE_STATUS_FAILED, IMAGE_STATUS_PENDING, @@ -61,7 +62,7 @@ def chapter_eligible_for_cover_by_inline_body_image_count( chapter: Any, *, markdown: str | None = None ) -> bool: """正文内 asset:// 数量 ≥ 配置阈值时允许封面;markdown 非 None 时仅用该串计数。""" - min_required = int(settings.memoir_min_inline_images_for_chapter_cover) + min_required = int(memoir.min_inline_images_for_chapter_cover) return count_chapter_inline_body_images(chapter, markdown=markdown) >= min_required diff --git a/api/app/features/memoir/deps.py b/api/app/features/memoir/deps.py index 9866c5e..6f41a82 100644 --- a/api/app/features/memoir/deps.py +++ b/api/app/features/memoir/deps.py @@ -1,18 +1,17 @@ """Memoir feature dependencies: get_memoir_service(注入 MemoryService 供章节生成使用 evidence)。""" from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import get_async_db from app.core.dependencies import get_object_storage from app.features.memoir.service import MemoirService from app.features.memory.deps import get_memory_service from app.features.memory.service import MemoryService from app.ports.storage import ObjectStorage +from app.core.deps_types import DbDep def get_memoir_service( - db: AsyncSession = Depends(get_async_db), + db: DbDep, memory_service: MemoryService = Depends(get_memory_service), object_storage: ObjectStorage = Depends(get_object_storage), ) -> MemoirService: diff --git a/api/app/features/memoir/helpers.py b/api/app/features/memoir/helpers.py index 1dbb089..490cdc8 100644 --- a/api/app/features/memoir/helpers.py +++ b/api/app/features/memoir/helpers.py @@ -27,6 +27,7 @@ from app.features.memoir.memoir_images.storage import ( resolve_image_storage_key, ) from app.features.memoir.models import Chapter +from app.core.runtime_constants import misc_defaults from app.features.memoir.reading_segment_materialize import ( resolve_reading_segments_for_chapter_detail, ) @@ -44,10 +45,10 @@ def first_normalized_image_for_api(img: dict | None) -> dict | None: def normalize_image_assets_for_api(images: list[dict] | None) -> list[dict]: bucket = settings.tencent_cos_bucket or "" - region = settings.tencent_cos_region or "" + region = misc_defaults.tencent_cos_region or "" base_url = settings.tencent_cos_base_url or "" storage = TencentCosStorageService.from_settings(settings) - img_settings = MemoirImageSettings.from_settings(settings) + img_settings = MemoirImageSettings.from_env() source_assets = normalize_image_assets(images) if not img_settings.enabled: source_assets = completed_image_assets(source_assets) diff --git a/api/app/features/memoir/memoir_images/prompting.py b/api/app/features/memoir/memoir_images/prompting.py index d907640..a5620f5 100644 --- a/api/app/features/memoir/memoir_images/prompting.py +++ b/api/app/features/memoir/memoir_images/prompting.py @@ -6,6 +6,7 @@ from app.core.config import settings from app.core.json_utils import extract_json_payload from app.core.langchain_llm import invoke_json_object from app.core.logging import get_logger +from app.features.memoir.constants import memoir from .settings import MemoirImageSettings @@ -84,7 +85,7 @@ class MemoirImagePromptService: "prompt_context": prompt_context, } except Exception as exc: - if settings.image_prompt_fallback_disabled: + if memoir.image_prompt_fallback_disabled: raise logger.warning( "图片 prompt 生成回退到默认模板: chapter_category={}, title={}, error={}", @@ -92,7 +93,7 @@ class MemoirImagePromptService: chapter_title, exc, ) - elif settings.image_prompt_fallback_disabled: + elif memoir.image_prompt_fallback_disabled: raise RuntimeError( "MemoirImagePromptService.build_prompt requires LLM when " "image_prompt_fallback_disabled is True" @@ -121,7 +122,7 @@ class MemoirImagePromptService: ) -> dict[str, str]: """生成章节封面图的 image-generation prompt。""" excerpt = (context_excerpt or "").strip() - if settings.image_prompt_fallback_disabled and not excerpt: + if memoir.image_prompt_fallback_disabled and not excerpt: raise RuntimeError( "Chapter cover prompt requires non-empty context_excerpt when " "image_prompt_fallback_disabled is True" @@ -165,7 +166,7 @@ class MemoirImagePromptService: "prompt_context": prompt_context, } except Exception as exc: - if settings.image_prompt_fallback_disabled: + if memoir.image_prompt_fallback_disabled: raise logger.warning( "封面 prompt 生成回退到默认模板: chapter_category={}, title={}, error={}", @@ -173,7 +174,7 @@ class MemoirImagePromptService: chapter_title, exc, ) - elif settings.image_prompt_fallback_disabled: + elif memoir.image_prompt_fallback_disabled: raise RuntimeError( "MemoirImagePromptService.build_cover_prompt requires LLM when " "image_prompt_fallback_disabled is True" @@ -208,7 +209,7 @@ class MemoirImagePromptService: from app.agents.stage_constants import STAGE_TO_DEFAULT_CATEGORY brief = (prompt_brief or "").strip() - if settings.image_prompt_fallback_disabled and not brief: + if memoir.image_prompt_fallback_disabled and not brief: raise RuntimeError( "Story image prompt requires non-empty prompt_brief when " "image_prompt_fallback_disabled is True" @@ -258,7 +259,7 @@ class MemoirImagePromptService: "prompt_context": prompt_context, } except Exception as exc: - if settings.image_prompt_fallback_disabled: + if memoir.image_prompt_fallback_disabled: raise logger.warning( "story 主图 prompt 生成回退到默认模板: stage={}, title={}, error={}", @@ -266,7 +267,7 @@ class MemoirImagePromptService: story_title, exc, ) - elif settings.image_prompt_fallback_disabled: + elif memoir.image_prompt_fallback_disabled: raise RuntimeError( "MemoirImagePromptService.build_story_primary_prompt requires LLM when " "image_prompt_fallback_disabled is True" diff --git a/api/app/features/memoir/memoir_images/settings.py b/api/app/features/memoir/memoir_images/settings.py index 29094f7..d9dc9cc 100644 --- a/api/app/features/memoir/memoir_images/settings.py +++ b/api/app/features/memoir/memoir_images/settings.py @@ -1,8 +1,7 @@ -from dataclasses import dataclass -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from app.core.config import Settings +from app.core.config import settings +from app.core.runtime_constants import misc_defaults +from app.features.memoir.constants import memoir +from app.features.story.constants import story DEFAULT_LIBLIB_TEMPLATE_UUID = "5d7e67009b344550bc1aa6ccbfa1d7f4" DEFAULT_IMAGE_PROVIDER = "liblib" @@ -12,7 +11,6 @@ DEFAULT_POLL_INTERVAL_SECONDS = 5 DEFAULT_MAX_ATTEMPTS = 60 -@dataclass(frozen=True) class MemoirImageSettings: enabled: bool = False provider: str = DEFAULT_IMAGE_PROVIDER @@ -23,24 +21,37 @@ class MemoirImageSettings: liblib_template_uuid: str = DEFAULT_LIBLIB_TEMPLATE_UUID story_image_min_body_chars: int = 400 - @classmethod - def from_settings(cls, settings: "Settings") -> "MemoirImageSettings": - s = settings - return cls( - enabled=bool(s.memoir_image_enabled), - provider=s.memoir_image_provider or DEFAULT_IMAGE_PROVIDER, - default_style=s.memoir_image_style_default or DEFAULT_IMAGE_STYLE, - default_size=s.memoir_image_size_default or DEFAULT_IMAGE_SIZE, - poll_interval_seconds=s.memoir_image_poll_interval, - max_attempts=s.memoir_image_max_attempts, - liblib_template_uuid=s.liblib_template_uuid or DEFAULT_LIBLIB_TEMPLATE_UUID, - story_image_min_body_chars=int( - getattr(s, "story_image_min_body_chars", 800) or 0 - ), - ) + def __init__( + self, + *, + enabled: bool = False, + provider: str = DEFAULT_IMAGE_PROVIDER, + default_style: str = DEFAULT_IMAGE_STYLE, + default_size: str = DEFAULT_IMAGE_SIZE, + poll_interval_seconds: int = DEFAULT_POLL_INTERVAL_SECONDS, + max_attempts: int = DEFAULT_MAX_ATTEMPTS, + liblib_template_uuid: str = DEFAULT_LIBLIB_TEMPLATE_UUID, + story_image_min_body_chars: int = 400, + ) -> None: + self.enabled = enabled + self.provider = provider + self.default_style = default_style + self.default_size = default_size + self.poll_interval_seconds = poll_interval_seconds + self.max_attempts = max_attempts + self.liblib_template_uuid = liblib_template_uuid + self.story_image_min_body_chars = story_image_min_body_chars @classmethod def from_env(cls) -> "MemoirImageSettings": - from app.core.config import settings as _s - - return cls.from_settings(_s) + return cls( + enabled=bool(settings.memoir_image_enabled), + provider=memoir.image_provider or DEFAULT_IMAGE_PROVIDER, + default_style=memoir.image_style_default or DEFAULT_IMAGE_STYLE, + default_size=memoir.image_size_default or DEFAULT_IMAGE_SIZE, + poll_interval_seconds=memoir.image_poll_interval, + max_attempts=memoir.image_max_attempts, + liblib_template_uuid=settings.liblib_template_uuid + or DEFAULT_LIBLIB_TEMPLATE_UUID, + story_image_min_body_chars=int(story.image_min_body_chars or 0), + ) diff --git a/api/app/features/memoir/memoir_images/storage.py b/api/app/features/memoir/memoir_images/storage.py index 61a542e..d70144b 100644 --- a/api/app/features/memoir/memoir_images/storage.py +++ b/api/app/features/memoir/memoir_images/storage.py @@ -190,13 +190,15 @@ class TencentCosStorageService: @classmethod def from_settings(cls, settings) -> "TencentCosStorageService": + from app.core.runtime_constants import misc_defaults + config = ( - getattr(settings, "tencent_cos_secret_id", "") or "", - getattr(settings, "tencent_cos_secret_key", "") or "", - getattr(settings, "tencent_cos_region", "") or "", - getattr(settings, "tencent_cos_bucket", "") or "", - getattr(settings, "tencent_cos_base_url", "") or "", - getattr(settings, "tencent_cos_token", "") or "", + (getattr(settings, "tencent_secret_id", "") or "").strip(), + (getattr(settings, "tencent_secret_key", "") or "").strip(), + misc_defaults.tencent_cos_region, + (getattr(settings, "tencent_cos_bucket", "") or "").strip(), + (getattr(settings, "tencent_cos_base_url", "") or "").strip(), + "", ) if cls._instance is None or cls._instance_config != config: cls._instance = cls( diff --git a/api/app/features/memoir/oral_normalize.py b/api/app/features/memoir/oral_normalize.py index 5af3676..4b53a3a 100644 --- a/api/app/features/memoir/oral_normalize.py +++ b/api/app/features/memoir/oral_normalize.py @@ -12,6 +12,7 @@ from typing import Any from app.core.config import settings from app.core.text_normalize import apply_oral_rules, llm_normalize_text +from app.features.memoir.constants import memoir def _llm_normalize_oral(text: str, llm: Any) -> str | None: @@ -19,8 +20,8 @@ def _llm_normalize_oral(text: str, llm: Any) -> str | None: return llm_normalize_text( text, llm, - max_input_chars=int(settings.memoir_oral_normalize_llm_max_input_chars), - max_tokens=int(settings.memoir_oral_normalize_llm_max_tokens), + max_input_chars=int(memoir.oral_normalize_llm_max_input_chars), + max_tokens=int(memoir.oral_normalize_llm_max_tokens), agent_name="oral_normalize.llm", ) @@ -33,9 +34,9 @@ def normalize_oral_for_memoir(text: str, *, llm: Any | None = None) -> str: - rules:仅规则 - rules + LLM 分支:先规则,再(可选)LLM;LLM 失败则保留规则结果 """ - if not settings.memoir_oral_normalize_enabled: + if not memoir.oral_normalize_enabled: return text or "" - mode = (settings.memoir_oral_normalize_mode or "rules").strip().lower() + mode = (memoir.oral_normalize_mode or "rules").strip().lower() if mode == "off": return text or "" diff --git a/api/app/features/memoir/router.py b/api/app/features/memoir/router.py index 3d92c85..d345e86 100644 --- a/api/app/features/memoir/router.py +++ b/api/app/features/memoir/router.py @@ -2,29 +2,34 @@ 回忆录 feature — books / chapters / memoir-state 合并路由 """ -from typing import List, Optional +from typing import Optional from fastapi import APIRouter, Body, Depends, Query -from app.core.dependencies import get_current_user +from app.core.deps_types import CurrentUserDep from app.core.logging import get_logger +from app.core.openapi import error_responses from app.features.memoir.deps import get_memoir_service from app.features.memoir.schemas import ( + BookResponse, + ChapterDetailResponse, + ChapterListItemResponse, + CoverGenerationResponse, ExportPdfRequest, + ExportPdfResponse, + MemoirStateResponse, + NextQuestionContextResponse, SetChapterStoryOrderRequest, + StatusMessageResponse, + StoryOrderResponse, UpdateBookRequest, ) from app.features.memoir.service import MemoirService -from app.features.user.models import User router = APIRouter( prefix="/api", tags=["memoir"], - responses={ - 401: {"description": "认证失败"}, - 403: {"description": "权限不足"}, - 404: {"description": "资源不存在"}, - }, + responses=error_responses(401, 403, 404), ) logger = get_logger(__name__) @@ -34,39 +39,39 @@ logger = get_logger(__name__) # =========================================================================== -@router.get("/books/current") +@router.get("/books/current", response_model=BookResponse) async def get_current_book( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: MemoirService = Depends(get_memoir_service), ): """获取当前回忆录(需要认证)""" return await service.get_current_book(current_user.id) -@router.post("/books/clear-update") +@router.post("/books/clear-update", response_model=StatusMessageResponse) async def clear_book_update( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: MemoirService = Depends(get_memoir_service), ): """清除回忆录更新标记""" return await service.clear_book_update(current_user.id) -@router.put("/books/{book_id}") +@router.put("/books/{book_id}", response_model=BookResponse) async def update_book( book_id: str, + current_user: CurrentUserDep, request: UpdateBookRequest = Body(...), - current_user: User = Depends(get_current_user), service: MemoirService = Depends(get_memoir_service), ): """更新书籍标题(需要认证,只能更新自己的回忆录)""" return await service.update_book(book_id, current_user.id, request.title) -@router.post("/books/export-pdf") +@router.post("/books/export-pdf", response_model=ExportPdfResponse) async def export_pdf( + current_user: CurrentUserDep, request: ExportPdfRequest = Body(...), - current_user: User = Depends(get_current_user), service: MemoirService = Depends(get_memoir_service), ): """导出 PDF(需要认证,只能导出自己的回忆录)""" @@ -78,9 +83,9 @@ async def export_pdf( # =========================================================================== -@router.get("/chapters", response_model=List[dict]) +@router.get("/chapters", response_model=list[ChapterListItemResponse]) async def get_chapters( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, is_new: Optional[bool] = Query(None, description="仅返回未读章节"), service: MemoirService = Depends(get_memoir_service), ): @@ -90,19 +95,19 @@ async def get_chapters( return await service.get_chapters(current_user.id, is_new=is_new) -@router.get("/chapters/{chapter_id}", response_model=dict) +@router.get("/chapters/{chapter_id}", response_model=ChapterDetailResponse) async def get_chapter( chapter_id: str, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: MemoirService = Depends(get_memoir_service), ): """获取章节详情(需要认证,只能访问自己的章节)""" return await service.get_chapter(chapter_id, current_user.id) -@router.post("/chapters/check-cover-generation") +@router.post("/chapters/check-cover-generation", response_model=CoverGenerationResponse) async def check_cover_generation( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: MemoirService = Depends(get_memoir_service), ): """ @@ -112,21 +117,21 @@ async def check_cover_generation( return await service.check_and_trigger_cover_generation(current_user.id) -@router.delete("/chapters/{chapter_id}") +@router.delete("/chapters/{chapter_id}", response_model=StatusMessageResponse) async def disable_chapter( chapter_id: str, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: MemoirService = Depends(get_memoir_service), ): """清除章节(将章节标记为 disabled,需要认证,只能操作自己的章节)""" return await service.disable_chapter(chapter_id, current_user.id) -@router.put("/chapters/{chapter_id}/story-order") +@router.put("/chapters/{chapter_id}/story-order", response_model=StoryOrderResponse) async def set_chapter_story_order( chapter_id: str, + current_user: CurrentUserDep, request: SetChapterStoryOrderRequest = Body(...), - current_user: User = Depends(get_current_user), service: MemoirService = Depends(get_memoir_service), ): """ @@ -142,27 +147,27 @@ async def set_chapter_story_order( # =========================================================================== -@router.get("/memoir-state") +@router.get("/memoir-state", response_model=MemoirStateResponse) async def get_memoir_state( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: MemoirService = Depends(get_memoir_service), ): """获取当前用户回忆录状态""" return await service.get_memoir_state(current_user.id) -@router.get("/memoir-state/next-question") +@router.get("/memoir-state/next-question", response_model=NextQuestionContextResponse) async def get_next_question_context( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: MemoirService = Depends(get_memoir_service), ): """获取下一步问题的上下文(当前阶段与空 slot)""" return await service.get_next_question_context(current_user.id) -@router.post("/memoir-state/mark-read") +@router.post("/memoir-state/mark-read", response_model=StatusMessageResponse) async def mark_memoir_read( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: MemoirService = Depends(get_memoir_service), ): """标记回忆录更新已读""" diff --git a/api/app/features/memoir/schemas.py b/api/app/features/memoir/schemas.py index 4dbcf09..747beca 100644 --- a/api/app/features/memoir/schemas.py +++ b/api/app/features/memoir/schemas.py @@ -1,4 +1,45 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field + +from app.agents.state_schema import MemoirStateSchema as MemoirStateResponse + +__all__ = [ + "BookResponse", + "ChapterDetailResponse", + "ChapterListItemResponse", + "CoverGenerationResponse", + "ExportPdfRequest", + "ExportPdfResponse", + "ImageAssetResponse", + "MemoirStateResponse", + "NextQuestionContextResponse", + "ReadingSegmentResponse", + "SetChapterStoryOrderRequest", + "StatusMessageResponse", + "StoryOrderResponse", + "UpdateBookRequest", +] + + +class ImageAssetResponse(BaseModel): + placeholder: str = "" + description: str = "" + index: int = 0 + status: str = "pending" + prompt: str | None = None + url: str | None = None + provider: str | None = None + style: str | None = None + size: str | None = None + error: str | None = None + retryable: bool | None = None + created_at: str | None = None + updated_at: str | None = None + + +class ReadingSegmentResponse(BaseModel): + story_id: str + body_markdown: str = "" + cover_asset: ImageAssetResponse | None = None class UpdateBookRequest(BaseModel): @@ -14,3 +55,75 @@ class SetChapterStoryOrderRequest(BaseModel): """按顺序绑定本章节要收录的 stories(覆盖原有 chapter_story_links)。""" story_ids: list[str] + + +class BookResponse(BaseModel): + id: str | None = None + title: str | None = None + total_pages: int | None = None + total_words: int | None = None + cover_image_url: str | None = None + has_update: bool | None = None + last_update_chapter_id: str | None = None + message: str | None = None + + +class StatusMessageResponse(BaseModel): + status: str + message: str | None = None + + +class ExportPdfResponse(BaseModel): + pdf_base64: str + filename: str + + +class ChapterListItemResponse(BaseModel): + model_config = ConfigDict(extra="ignore") + + id: str + title: str | None = None + category: str | None = None + order_index: int | None = None + status: str = "draft" + summary: str = "" + canonical_markdown: str = "" + cover_asset: ImageAssetResponse | None = None + images: list[ImageAssetResponse] = Field(default_factory=list) + word_count: int = 0 + updated_at: str | None = None + is_new: bool = False + source_segments: list[str] = Field(default_factory=list) + + +class ChapterDetailResponse(BaseModel): + model_config = ConfigDict(extra="ignore") + + id: str + title: str | None = None + canonical_markdown: str = "" + order_index: int | None = None + status: str | None = None + category: str | None = None + images: list[ImageAssetResponse] = Field(default_factory=list) + cover_asset: ImageAssetResponse | None = None + reading_segments: list[ReadingSegmentResponse] = Field(default_factory=list) + updated_at: str | None = None + is_new: bool = False + source_segments: list[str] = Field(default_factory=list) + + +class CoverGenerationResponse(BaseModel): + triggered: list[str] = Field(default_factory=list) + + +class StoryOrderResponse(BaseModel): + status: str + chapter_id: str + story_count: int + + +class NextQuestionContextResponse(BaseModel): + current_stage: str | None = None + empty_slots: list[str] = Field(default_factory=list) + covered_stages: list[str] = Field(default_factory=list) diff --git a/api/app/features/memoir/service.py b/api/app/features/memoir/service.py index c40700d..beffd15 100644 --- a/api/app/features/memoir/service.py +++ b/api/app/features/memoir/service.py @@ -3,12 +3,13 @@ import asyncio from typing import List, Optional -from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload -from app.agents.state_schema import narrative_coverage_state +from app.agents.state_schema import MemoirStateSchema, narrative_coverage_state +from app.core.db import transactional +from app.core.errors import AuthorizationError, BadRequestError, NotFoundError from app.core.logging import get_logger from app.core.storage_purge import delete_object_storage_keys_best_effort from app.features.memoir import repo @@ -104,18 +105,18 @@ class MemoirService: book = await repo.get_current_book(user_id, self._db) if not book: return {"status": "ok", "message": "No book found"} - book.has_update = False - await self._db.commit() + async with transactional(self._db): + book.has_update = False return {"status": "ok"} async def update_book(self, book_id: str, user_id: str, title: str) -> dict: book = await self._db.get(Book, book_id) if not book: - raise HTTPException(status_code=404, detail="Book not found") + raise NotFoundError("Book not found") if book.user_id != user_id: - raise HTTPException(status_code=403, detail="无权更新此回忆录") - book.title = title - await self._db.commit() + raise AuthorizationError("无权更新此回忆录") + async with transactional(self._db): + book.title = title await self._db.refresh(book) return { "id": book.id, @@ -132,9 +133,9 @@ class MemoirService: book = await self._db.get(Book, book_id) if not book: - raise HTTPException(status_code=404, detail="Book not found") + raise NotFoundError("Book not found") if book.user_id != user_id: - raise HTTPException(status_code=403, detail="无权导出此回忆录") + raise AuthorizationError("无权导出此回忆录") stmt = ( select(Chapter) .where(Chapter.user_id == user_id, Chapter.is_active == True) @@ -195,16 +196,16 @@ class MemoirService: async def get_chapter(self, chapter_id: str, user_id: str) -> dict: chapter = await repo.get_chapter_by_id(chapter_id, self._db) if not chapter: - raise HTTPException(status_code=404, detail="Chapter not found") + raise NotFoundError("Chapter not found") if chapter.user_id != user_id: - raise HTTPException(status_code=403, detail="无权访问此章节") + raise AuthorizationError("无权访问此章节") if not chapter.is_active: - raise HTTPException(status_code=404, detail="Chapter not found") + raise NotFoundError("Chapter not found") chapter, md_override = prepare_chapter_read_view(chapter) if not chapter_meets_minimum_display( chapter, canonical_markdown_override=md_override ): - raise HTTPException(status_code=404, detail="Chapter not found") + raise NotFoundError("Chapter not found") asset_map = await signed_urls_for_asset_ids( self._db, collect_asset_ids_for_chapter(chapter) ) @@ -217,12 +218,12 @@ class MemoirService: async def disable_chapter(self, chapter_id: str, user_id: str) -> dict: chapter = await repo.get_chapter_by_id(chapter_id, self._db) if not chapter: - raise HTTPException(status_code=404, detail="Chapter not found") + raise NotFoundError("Chapter not found") if chapter.user_id != user_id: - raise HTTPException(status_code=403, detail="无权操作此章节") + raise AuthorizationError("无权操作此章节") cos_keys = await repo.collect_cos_storage_keys_for_chapter(self._db, chapter) - chapter.is_active = False - await self._db.commit() + async with transactional(self._db): + chapter.is_active = False delete_object_storage_keys_best_effort( self._object_storage, cos_keys, @@ -235,32 +236,41 @@ class MemoirService: ) -> dict: chapter = await self._db.get(Chapter, chapter_id) if not chapter: - raise HTTPException(status_code=404, detail="Chapter not found") + raise NotFoundError("Chapter not found") if chapter.user_id != user_id: - raise HTTPException(status_code=403, detail="无权操作此章节") + raise AuthorizationError("无权操作此章节") if not chapter.is_active: - raise HTTPException(status_code=404, detail="Chapter not found") + raise NotFoundError("Chapter not found") try: - await repo.replace_chapter_story_links_async( - self._db, - chapter_id=chapter_id, - user_id=user_id, - story_ids=story_ids, - ) + async with transactional(self._db): + await repo.replace_chapter_story_links_async( + self._db, + chapter_id=chapter_id, + user_id=user_id, + story_ids=story_ids, + ) + ch = await repo.get_chapter_with_story_links_for_compose( + chapter_id, self._db + ) + if not ch: + raise NotFoundError("Chapter not found") + if not ch.story_links: + md = "" + else: + md = materialize_chapter_markdown_from_loaded_chapter(ch) + await repo.append_chapter_compose_version_async(self._db, ch, md) except ValueError as exc: - raise HTTPException(status_code=400, detail=str(exc)) from exc - - ch = await repo.get_chapter_with_story_links_for_compose(chapter_id, self._db) - if not ch: - raise HTTPException(status_code=404, detail="Chapter not found") - if not ch.story_links: - md = "" - else: - md = materialize_chapter_markdown_from_loaded_chapter(ch) - await repo.append_chapter_compose_version_async(self._db, ch, md) - await self._db.commit() + msg = str(exc) + if "not found" in msg.lower() or "access denied" in msg.lower(): + raise NotFoundError(msg) from exc + raise BadRequestError(msg) from exc return {"status": "ok", "chapter_id": chapter_id, "story_count": len(story_ids)} + async def get_or_create_memoir_state(self, user_id: str) -> MemoirStateSchema: + from app.features.memoir.state_service import get_or_create_state + + return await get_or_create_state(user_id, self._db) + async def get_memoir_state(self, user_id: str) -> dict: from app.features.memoir.state_service import get_or_create_state @@ -316,14 +326,14 @@ class MemoirService: async def mark_memoir_read(self, user_id: str) -> dict: stmt = select(Chapter).where(Chapter.user_id == user_id, Chapter.is_new == True) result = await self._db.execute(stmt) - for chapter in result.scalars().all(): - chapter.is_new = False stmt_book = ( select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()) ) result_book = await self._db.execute(stmt_book) book = result_book.scalar_one_or_none() - if book: - book.has_update = False - await self._db.commit() + async with transactional(self._db): + for chapter in result.scalars().all(): + chapter.is_new = False + if book: + book.has_update = False return {"status": "ok"} diff --git a/api/app/features/memoir/state_service.py b/api/app/features/memoir/state_service.py index 8246d9e..5eeaec8 100644 --- a/api/app/features/memoir/state_service.py +++ b/api/app/features/memoir/state_service.py @@ -24,7 +24,9 @@ from app.agents.state_schema import ( narrative_coverage_state, ) from app.core.config import settings +from app.core.db import transactional, transactional_sync from app.features.memoir.models import MemoirState as MemoirStateModel +from app.features.memoir.constants import memoir def _slots_snapshot_for_merge(raw: Dict[str, Dict] | None) -> Dict[str, Dict]: @@ -83,8 +85,8 @@ async def get_or_create_state(user_id: str, db: AsyncSession) -> MemoirStateSche for k, v in default.slots.items() }, ) - db.add(state) - await db.commit() + async with transactional(db): + db.add(state) await db.refresh(state) return coerce_memoir_state(state) @@ -101,7 +103,7 @@ def _apply_current_stage_policy( state.current_stage = stage_norm return - if not settings.memoir_extraction_updates_current_stage: + if not memoir.extraction_updates_current_stage: return cur_b = chat_bucket(state.current_stage or current_from_db) new_b = chat_bucket(stage_norm) @@ -140,20 +142,20 @@ async def update_slot( if slot_name not in allowed_slot_names_for_stage(stage_norm, current_from_db): return coerce_memoir_state(state) - slots = _slots_snapshot_for_merge( - state.slots if isinstance(state.slots, dict) else None - ) - stage_slots = dict(slots.get(stage_norm, {}) or {}) - existing = stage_slots.get(slot_name, {}) + async with transactional(db): + slots = _slots_snapshot_for_merge( + state.slots if isinstance(state.slots, dict) else None + ) + stage_slots = dict(slots.get(stage_norm, {}) or {}) + existing = stage_slots.get(slot_name, {}) - merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) - stage_slots[slot_name] = SlotData( - snippet=snippet, segment_ids=merged_segment_ids - ).model_dump() - slots[stage_norm] = stage_slots - state.slots = slots - _apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch) - await db.commit() + merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) + stage_slots[slot_name] = SlotData( + snippet=snippet, segment_ids=merged_segment_ids + ).model_dump() + slots[stage_norm] = stage_slots + state.slots = slots + _apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch) await db.refresh(state) return coerce_memoir_state(state) @@ -168,19 +170,19 @@ async def mark_stage_complete( if not state: return await get_or_create_state(user_id, db) - covered = state.covered_stages or [] - if stage not in covered: - covered.append(stage) - state.covered_stages = covered + async with transactional(db): + covered = state.covered_stages or [] + if stage not in covered: + covered.append(stage) + state.covered_stages = covered - stage_order = state.stage_order or default_state().stage_order - if state.current_stage == stage: - try: - idx = stage_order.index(stage) - state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)] - except ValueError: - state.current_stage = default_state().current_stage - await db.commit() + stage_order = state.stage_order or default_state().stage_order + if state.current_stage == stage: + try: + idx = stage_order.index(stage) + state.current_stage = stage_order[min(idx + 1, len(stage_order) - 1)] + except ValueError: + state.current_stage = default_state().current_stage await db.refresh(state) return coerce_memoir_state(state) @@ -205,11 +207,11 @@ async def switch_stage( result = await db.execute(stmt) state = result.scalar_one() - fb = state.current_stage or "childhood" - state.current_stage = normalize_chat_stage( - new_stage, fallback=fb, log_context={"user_id": user_id} - ) - await db.commit() + async with transactional(db): + fb = state.current_stage or "childhood" + state.current_stage = normalize_chat_stage( + new_stage, fallback=fb, log_context={"user_id": user_id} + ) await db.refresh(state) return coerce_memoir_state(state) @@ -234,10 +236,10 @@ async def save_interview_state_meta( result = await db.execute(stmt) state = result.scalar_one() - state.known_facts_json = [x.model_dump() for x in known_facts] - state.persona_threads_json = [x.model_dump() for x in persona_threads] - state.recent_questions_json = list(recent_questions) - await db.commit() + async with transactional(db): + state.known_facts_json = [x.model_dump() for x in known_facts] + state.persona_threads_json = [x.model_dump() for x in persona_threads] + state.recent_questions_json = list(recent_questions) await db.refresh(state) return coerce_memoir_state(state) @@ -261,8 +263,8 @@ def get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema: for k, v in default.slots.items() }, ) - db.add(state) - db.commit() + with transactional_sync(db): + db.add(state) db.refresh(state) return coerce_memoir_state(state) @@ -298,19 +300,19 @@ def update_slot_sync( if slot_name not in allowed_slot_names_for_stage(stage_norm, current_from_db): return coerce_memoir_state(state) - slots = _slots_snapshot_for_merge( - state.slots if isinstance(state.slots, dict) else None - ) - stage_slots = dict(slots.get(stage_norm, {}) or {}) - existing = stage_slots.get(slot_name, {}) + with transactional_sync(db): + slots = _slots_snapshot_for_merge( + state.slots if isinstance(state.slots, dict) else None + ) + stage_slots = dict(slots.get(stage_norm, {}) or {}) + existing = stage_slots.get(slot_name, {}) - merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) - stage_slots[slot_name] = SlotData( - snippet=snippet, segment_ids=merged_segment_ids - ).model_dump() - slots[stage_norm] = stage_slots - state.slots = slots - _apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch) - db.commit() + merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids}) + stage_slots[slot_name] = SlotData( + snippet=snippet, segment_ids=merged_segment_ids + ).model_dump() + slots[stage_norm] = stage_slots + state.slots = slots + _apply_current_stage_policy(state, stage_norm, memoir_batch=memoir_batch) db.refresh(state) return coerce_memoir_state(state) diff --git a/api/app/features/memoir/story_pipeline_sync.py b/api/app/features/memoir/story_pipeline_sync.py index dc127d6..eecf269 100644 --- a/api/app/features/memoir/story_pipeline_sync.py +++ b/api/app/features/memoir/story_pipeline_sync.py @@ -27,7 +27,6 @@ from app.agents.memoir.story_route_agent import ( StoryRouteAgent, default_append_target_story_id, ) -from app.core.business_telemetry import business_span from app.agents.stage_constants import ( CATEGORY_TO_CHAT_STAGE, CHAPTER_CATEGORIES, @@ -35,6 +34,7 @@ from app.agents.stage_constants import ( STAGE_TO_ORDER, ) from app.agents.state_schema import MemoirStateSchema +from app.core.business_telemetry import business_span from app.core.config import settings from app.core.logging import get_logger from app.features.conversation.lineage_schemas import aggregate_lineage_from_segments @@ -61,6 +61,8 @@ from app.features.memoir.repo import ( ) from app.features.memory.evidence_format import format_evidence_chunks_for_prompt from app.features.story.models import Story, StoryVersion +from app.features.memoir.constants import memoir +from app.features.story.constants import story from app.features.story.sync_write import ( append_story_version_sync, count_story_versions_sync, @@ -267,7 +269,7 @@ def _title_slots_filtered_for_generation( slot_snippets: dict[str, str], *, md: str, oral_scope: str ) -> dict[str, str]: """仅保留与正文或本批口述有文本重叠的 slot,降低档案/历史 slot 串台到标题。""" - if not settings.memoir_title_slots_require_body_or_oral_match: + if not memoir.title_slots_require_body_or_oral_match: return dict(slot_snippets) hay = f"{(md or '').strip()}\n{(oral_scope or '').strip()}" if not hay.strip(): @@ -311,7 +313,7 @@ def _strip_ungrounded_title_segments( """ 按 · / • 分节丢弃含未落地履历短语的小节;全部丢弃则占位。 """ - if not settings.memoir_title_hay_grounding_strict_phrases_enabled: + if not memoir.title_hay_grounding_strict_phrases_enabled: return (title or "").strip() or _placeholder_title( chapter_category, language=language ) @@ -358,7 +360,7 @@ def _maybe_generate_title( ) -> str: """Generate a title only when body is long enough; otherwise return placeholder.""" body_len = len((md or "").strip()) - if body_len < settings.story_title_min_body_chars: + if body_len < story.title_min_body_chars: return _placeholder_title(chapter_category, language=language) content_excerpt = (md or "").strip()[:300] merged_slots = _title_slots_filtered_for_generation( @@ -392,8 +394,8 @@ def _route_segment_texts(category_segments: list) -> list[tuple[str, str]]: for seg in category_segments: raw = seg.user_input_text or "" if ( - settings.memoir_oral_normalize_enabled - and (settings.memoir_oral_normalize_mode or "rules").strip().lower() + memoir.oral_normalize_enabled + and (memoir.oral_normalize_mode or "rules").strip().lower() != "off" ): t = apply_oral_rules(raw) @@ -435,7 +437,7 @@ def _gate_narrative_fidelity( from app.agents.memoir.fidelity_check_agent import FidelityCheckAgent check_llm = fidelity_llm if fidelity_llm is not None else llm - if not settings.memoir_fidelity_check_enabled or not check_llm: + if not memoir.fidelity_check_enabled or not check_llm: return narrative_raw, "none" agent = FidelityCheckAgent() ex = (existing_canonical or "").strip() or None @@ -471,7 +473,7 @@ def _apply_narrative_body_safety( m = (md or "").strip() ex = (existing_for_narrative or "").strip() o = (oral or "").strip() - min_len = int(settings.memoir_narrative_evidence_overlap_min_chars) + min_len = int(memoir.narrative_evidence_overlap_min_chars) ev_plain = strip_evidence_for_overlap_check(evidence_text) if m and body_contains_prompt_artifact(m): logger.warning( @@ -494,7 +496,7 @@ def _apply_narrative_body_safety( "evidence_leak_heuristic" ) if ( - settings.memoir_evidence_scene_anchor_check_enabled + memoir.evidence_scene_anchor_check_enabled and m and evidence_text.strip() and evidence_scene_anchor_leak(m, ev_plain, o, ex) @@ -667,8 +669,8 @@ def _resolve_append_target( memoir_correlation_id: str | None, ) -> tuple[str | None, str, str]: """Resolve append target and return (target_story_id, existing_for_narrative, decision_source).""" - max_chars = int(settings.story_append_max_canonical_chars) - max_ver = int(settings.story_append_max_versions) + max_chars = int(story.append_max_canonical_chars) + max_ver = int(story.append_max_versions) target_story_id: str | None = None existing_for_narrative = "" @@ -696,7 +698,7 @@ def _resolve_append_target( and candidate_stories and decision_source not in FALLBACK_NEW_STORY_REASONS and len(oral_norm) - <= int(settings.memoir_story_route_append_guardrail_oral_chars) + <= int(memoir.story_route_append_guardrail_oral_chars) ): tid_g = default_append_target_story_id(candidate_stories, story_meta, settings) if tid_g: @@ -817,7 +819,13 @@ def _execute_narrative_unit( if target_story_id: sid_s = str(target_story_id) - ver = append_story_version_sync(session, sid_s, md) + try: + ver = append_story_version_sync(session, sid_s, md) + except ValueError as exc: + logger.warning( + "append_story_version_sync failed story_id={}: {}", sid_s, exc + ) + return None _persist_story_lineage_sync( session, story_id=sid_s, @@ -1049,9 +1057,9 @@ def _run_story_pipeline_batch_inner( source_ids = [seg.id for seg in category_segments] n_units = len(category_segments) - top_k = int(settings.evidence_top_k_default) - if n_units > int(settings.evidence_large_batch_threshold): - top_k = int(settings.evidence_top_k_large_batch) + top_k = int(story.evidence_top_k_default) + if n_units > int(story.evidence_large_batch_threshold): + top_k = int(story.evidence_top_k_large_batch) def _oral_job() -> tuple[str, float]: with business_span("memoir.story_pipeline.oral_normalize"): @@ -1178,7 +1186,7 @@ def _run_story_pipeline_batch_inner( plan is None and single_route is not None and single_route.reason in FALLBACK_NEW_STORY_REASONS - and bool(settings.memoir_route_defer_enabled) + and bool(memoir.route_defer_enabled) ): defer_ids = [str(s.id) for s in category_segments] logger.info( diff --git a/api/app/features/memory/chat_memory_injection.py b/api/app/features/memory/chat_memory_injection.py index 1af20e0..9d769fe 100644 --- a/api/app/features/memory/chat_memory_injection.py +++ b/api/app/features/memory/chat_memory_injection.py @@ -10,7 +10,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import Any -from app.core.config import settings +from app.features.conversation.constants import chat +from app.features.memory.constants import memory from app.features.memory.evidence_format import ( dedupe_evidence_chunk_rows, format_evidence_chunks_for_chat_prompt, @@ -153,7 +154,7 @@ def build_planner_preview( t = (text or "").strip() if not t: return "" - max_c = min(int(settings.chat_memory_evidence_max_chars), 2000) + max_c = min(int(chat.memory_evidence_max_chars), 2000) if len(t) > max_c: return t[: max_c - 3] + "..." return t @@ -174,7 +175,7 @@ def slice_interview_memory( had_retrieval=False, ) - use_safe = settings.chat_memory_safe_evidence_format_enabled + use_safe = chat.memory_safe_evidence_format_enabled planner_preview = build_planner_preview(evidence, use_safe_chat_format=use_safe) had = bool(planner_preview.strip()) diff --git a/api/app/features/memory/compaction_service.py b/api/app/features/memory/compaction_service.py index bcc9ad8..e01604a 100644 --- a/api/app/features/memory/compaction_service.py +++ b/api/app/features/memory/compaction_service.py @@ -16,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.logging import get_logger from app.features.memory.models import MemoryChunk, MemorySource +from app.features.memory.constants import memory from app.features.memory.repo import ( create_curation_action, get_first_chunk_after_cursor, @@ -227,7 +228,7 @@ async def run_memory_compaction( cand_sources = [str(x) for x in cand_sources] has_candidate_filter = "candidate_chunk_ids" in ctx or "candidate_source_ids" in ctx - max_chunks = settings.memory_compaction_max_chunks_per_run + max_chunks = memory.compaction_max_chunks_per_run incremental = await list_incremental_chunks_for_compaction( db, user_id=user_id, @@ -248,7 +249,7 @@ async def run_memory_compaction( cursor_ts, cursor_id, max_steps=max( - settings.memory_compaction_max_chunks_per_run * 2, + memory.compaction_max_chunks_per_run * 2, 500, ), ) @@ -288,12 +289,12 @@ async def run_memory_compaction( "skipped_reason": "empty_incremental", } - sim_th = settings.memory_compaction_chunk_similarity_threshold - min_layers = settings.memory_compaction_min_layers_for_exclude - jaccard_min = settings.memory_compaction_text_jaccard_min - year_w = settings.memory_compaction_metadata_event_year_window - max_neighbors = settings.memory_compaction_max_neighbors_per_chunk - max_excludes = settings.memory_compaction_max_excludes_per_run + sim_th = memory.compaction_chunk_similarity_threshold + min_layers = memory.compaction_min_layers_for_exclude + jaccard_min = memory.compaction_text_jaccard_min + year_w = memory.compaction_metadata_event_year_window + max_neighbors = memory.compaction_max_neighbors_per_chunk + max_excludes = memory.compaction_max_excludes_per_run local_excluded: set[str] = set() excludes_done = 0 diff --git a/api/app/features/memory/constants.py b/api/app/features/memory/constants.py new file mode 100644 index 0000000..0c57d54 --- /dev/null +++ b/api/app/features/memory/constants.py @@ -0,0 +1,5 @@ +"""Memory 富化 / compaction 产品常量 — 值来自 config/*.toml(SSOT)。""" + +from app.core.app_config import app_config + +memory = app_config.memory diff --git a/api/app/features/memory/deps.py b/api/app/features/memory/deps.py index b69acb2..1596445 100644 --- a/api/app/features/memory/deps.py +++ b/api/app/features/memory/deps.py @@ -1,9 +1,8 @@ from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import get_async_db from app.core.dependencies import get_embedding_provider from app.features.memory.service import MemoryService +from app.core.deps_types import DbDep def _get_embedding_provider(): @@ -11,7 +10,7 @@ def _get_embedding_provider(): async def get_memory_service( - db: AsyncSession = Depends(get_async_db), + db: DbDep, embedding_provider=Depends(_get_embedding_provider), ) -> MemoryService: return MemoryService(db=db, embedding_provider=embedding_provider) diff --git a/api/app/features/memory/enrichment.py b/api/app/features/memory/enrichment.py index 135d355..71c0311 100644 --- a/api/app/features/memory/enrichment.py +++ b/api/app/features/memory/enrichment.py @@ -30,6 +30,7 @@ from app.features.memory.repo import ( set_source_enrichment_status, ) from app.features.user.models import User +from app.features.memory.constants import memory logger = get_logger(__name__) @@ -65,7 +66,7 @@ def _resolve_gateway_llm() -> Any | None: def _max_enrichment_chars() -> int: from app.core.config import settings - return settings.memory_enrichment_max_chars + return memory.enrichment_max_chars def _enrichment_prompt(numbered_blocks: str, narrator_label: str) -> str: @@ -118,7 +119,7 @@ async def enrich_memory_after_ingest_async( ) -> dict: from app.core.config import settings - if not settings.memory_enrichment_enabled: + if not memory.enrichment_enabled: await set_source_enrichment_status( db, source_id=source_id, diff --git a/api/app/features/memory/ingest_service.py b/api/app/features/memory/ingest_service.py index c791e1a..a6b4cde 100644 --- a/api/app/features/memory/ingest_service.py +++ b/api/app/features/memory/ingest_service.py @@ -5,6 +5,8 @@ from __future__ import annotations from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings +from app.core.db import transactional +from app.core.errors import BadRequestError from app.core.logging import get_logger from app.features.conversation.lineage_schemas import ( primary_user_message_id_from_lineage, @@ -22,8 +24,10 @@ from app.features.memory.enrichment_scheduler import ( from app.features.memory.repo import ( create_chunk, create_source, + get_transcript_source_by_segment_id, ) from app.ports.embedding import EmbeddingProvider +from app.features.memory.constants import memory logger = get_logger(__name__) @@ -53,34 +57,32 @@ class MemoryIngestService: lineage_json: dict | None = None, ) -> str: if not transcript or not transcript.strip(): - raise ValueError("transcript cannot be empty") + raise BadRequestError("transcript cannot be empty") primary_mid = ( primary_user_message_id_from_lineage(lineage_json) if lineage_json else None ) - source = await create_source( - self._db, - user_id=user_id, - source_type="transcript", - raw_text=transcript.strip(), - conversation_id=conversation_id, - lineage_json=lineage_json, - primary_user_message_id=primary_mid, - ) - - chunk_records: list[tuple[str, str]] = [] - for i, content in enumerate(chunk_transcript(transcript.strip())): - chunk = await create_chunk( + async with transactional(self._db): + source = await create_source( self._db, - source_id=source.id, user_id=user_id, - content=content, - chunk_index=i, + source_type="transcript", + raw_text=transcript.strip(), + conversation_id=conversation_id, + lineage_json=lineage_json, + primary_user_message_id=primary_mid, ) - chunk_records.append((chunk.id, content)) - await self._db.flush() - await self._db.commit() + chunk_records: list[tuple[str, str]] = [] + for i, content in enumerate(chunk_transcript(transcript.strip())): + chunk = await create_chunk( + self._db, + source_id=source.id, + user_id=user_id, + content=content, + chunk_index=i, + ) + chunk_records.append((chunk.id, content)) embedding_result = await MemoryEmbeddingService( self._db, @@ -108,7 +110,7 @@ class MemoryIngestService: embedding_result.get("status"), emb_ok, embedding_task_id, - settings.memory_enrichment_enabled, + memory.enrichment_enabled, enrichment_task_id, ) return source.id @@ -116,50 +118,63 @@ class MemoryIngestService: async def ingest_transcripts_batch( self, user_id: str, - items: list[tuple[str, str, dict | None]], + items: list[tuple[str, str, dict | None, str | None]], *, memoir_correlation_id: str | None = None, ) -> list[str]: """ Batch ingest transcript items through the async memory path. - items: (conversation_id, transcript, lineage_json). Empty transcripts are skipped. + items: (conversation_id, transcript, lineage_json, segment_id). + Empty transcripts are skipped. When segment_id is set and a transcript + source already exists for the user, returns the existing source id. """ source_ids: list[str] = [] chunk_records: list[tuple[str, str]] = [] + new_source_ids: list[str] = [] - for conversation_id, transcript, lineage_json in items: - text = (transcript or "").strip() - if not text: - continue - primary_mid = ( - primary_user_message_id_from_lineage(lineage_json) - if lineage_json - else None - ) - source = await create_source( - self._db, - user_id=user_id, - source_type="transcript", - raw_text=text, - conversation_id=conversation_id or None, - lineage_json=lineage_json, - primary_user_message_id=primary_mid, - ) - source_ids.append(source.id) - - for i, content in enumerate(chunk_transcript(text)): - chunk = await create_chunk( - self._db, - source_id=source.id, - user_id=user_id, - content=content, - chunk_index=i, + async with transactional(self._db): + for conversation_id, transcript, lineage_json, segment_id in items: + text = (transcript or "").strip() + if not text: + continue + sid = (segment_id or "").strip() or None + if sid: + existing = await get_transcript_source_by_segment_id( + self._db, + user_id=user_id, + segment_id=sid, + ) + if existing is not None: + source_ids.append(existing.id) + continue + primary_mid = ( + primary_user_message_id_from_lineage(lineage_json) + if lineage_json + else None ) - chunk_records.append((chunk.id, content)) + source = await create_source( + self._db, + user_id=user_id, + source_type="transcript", + raw_text=text, + conversation_id=conversation_id or None, + segment_id=sid, + lineage_json=lineage_json, + primary_user_message_id=primary_mid, + ) + source_ids.append(source.id) + new_source_ids.append(source.id) - await self._db.flush() - await self._db.commit() + for i, content in enumerate(chunk_transcript(text)): + chunk = await create_chunk( + self._db, + source_id=source.id, + user_id=user_id, + content=content, + chunk_index=i, + ) + chunk_records.append((chunk.id, content)) vectors_written = 0 embedding_retry_task_ids: list[str] = [] @@ -168,7 +183,7 @@ class MemoryIngestService: self._db, embedding_provider=self._embedding, ) - for source_id in source_ids: + for source_id in new_source_ids: result = await embedding_service.embed_source(user_id, source_id) vectors_written += int(result.get("vectors_written") or 0) status = str(result.get("status") or "unknown") @@ -185,7 +200,7 @@ class MemoryIngestService: emb_ok = self._embedding.is_available() if self._embedding else False task_ids = self._enrichment_scheduler.schedule_many( user_id, - source_ids, + new_source_ids, memoir_correlation_id=memoir_correlation_id, ) @@ -200,7 +215,7 @@ class MemoryIngestService: emb_ok, embedding_statuses, len(embedding_retry_task_ids), - settings.memory_enrichment_enabled, + memory.enrichment_enabled, len(task_ids), ) return source_ids diff --git a/api/app/features/memory/models.py b/api/app/features/memory/models.py index 586c318..1fb3f97 100644 --- a/api/app/features/memory/models.py +++ b/api/app/features/memory/models.py @@ -33,6 +33,7 @@ class MemorySource(Base): enrichment_status = Column(String, default="pending") enrichment_error = Column(Text, nullable=True) conversation_id = Column(String, ForeignKey("conversations.id"), nullable=True) + segment_id = Column(String, ForeignKey("segments.id", ondelete="SET NULL"), nullable=True) lineage_json = Column(JSON, nullable=True) primary_user_message_id = Column(String, nullable=True) created_at = Column(DateTime(timezone=True), default=utc_now) diff --git a/api/app/features/memory/repo.py b/api/app/features/memory/repo.py index 45c433f..5ffdda4 100644 --- a/api/app/features/memory/repo.py +++ b/api/app/features/memory/repo.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta, timezone from sqlalchemy import cast, literal, or_, select, text, tuple_, update from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from sqlalchemy.types import String as SqlString from app.features.memory.models import ( @@ -27,6 +28,7 @@ async def create_source( source_type: str, raw_text: str | None = None, conversation_id: str | None = None, + segment_id: str | None = None, captured_at: datetime | None = None, lineage_json: dict | None = None, primary_user_message_id: str | None = None, @@ -40,6 +42,7 @@ async def create_source( embedding_status="pending", enrichment_status="pending", conversation_id=conversation_id, + segment_id=segment_id, lineage_json=lineage_json, primary_user_message_id=primary_user_message_id, captured_at=captured_at or datetime.now(timezone.utc), @@ -48,6 +51,34 @@ async def create_source( return source +async def get_transcript_source_by_segment_id( + db: AsyncSession, + *, + user_id: str, + segment_id: str, +) -> MemorySource | None: + stmt = select(MemorySource).where( + MemorySource.user_id == user_id, + MemorySource.segment_id == segment_id, + MemorySource.source_type == "transcript", + ) + return (await db.execute(stmt)).scalar_one_or_none() + + +def get_transcript_source_by_segment_id_sync( + db: Session, + *, + user_id: str, + segment_id: str, +) -> MemorySource | None: + stmt = select(MemorySource).where( + MemorySource.user_id == user_id, + MemorySource.segment_id == segment_id, + MemorySource.source_type == "transcript", + ) + return db.execute(stmt).scalar_one_or_none() + + async def create_chunk( db: AsyncSession, *, diff --git a/api/app/features/memory/router.py b/api/app/features/memory/router.py index 1152bef..f2ea4d3 100644 --- a/api/app/features/memory/router.py +++ b/api/app/features/memory/router.py @@ -1,9 +1,10 @@ """Memory 策展与内部扩展 API。""" -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, status from pydantic import BaseModel, Field -from app.core.dependencies import get_current_user +from app.core.deps_types import CurrentUserDep +from app.core.errors import NotFoundError from app.features.memory.deps import get_memory_service from app.features.memory.service import MemoryService from app.features.user.models import User @@ -22,50 +23,46 @@ class RejectFactBody(BaseModel): @router.post("/chunks/{chunk_id}/exclude", status_code=status.HTTP_204_NO_CONTENT) async def exclude_chunk( chunk_id: str, - body: ExcludeBody | None = None, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, memory: MemoryService = Depends(get_memory_service), + body: ExcludeBody | None = None, ): reason = (body.reason if body else "") or "" ok = await memory.exclude_chunk(current_user.id, chunk_id, reason=reason) if not ok: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="chunk 不存在" - ) + raise NotFoundError("chunk 不存在") @router.post("/chunks/{chunk_id}/restore", status_code=status.HTTP_204_NO_CONTENT) async def restore_chunk( chunk_id: str, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, memory: MemoryService = Depends(get_memory_service), ): ok = await memory.restore_chunk(current_user.id, chunk_id) if not ok: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="chunk 不存在" - ) + raise NotFoundError("chunk 不存在") @router.post("/facts/{fact_id}/confirm", status_code=status.HTTP_204_NO_CONTENT) async def confirm_fact( fact_id: str, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, memory: MemoryService = Depends(get_memory_service), ): ok = await memory.confirm_fact(current_user.id, fact_id) if not ok: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="fact 不存在") + raise NotFoundError("fact 不存在") @router.post("/facts/{fact_id}/reject", status_code=status.HTTP_204_NO_CONTENT) async def reject_fact( fact_id: str, - body: RejectFactBody | None = None, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, memory: MemoryService = Depends(get_memory_service), + body: RejectFactBody | None = None, ): reason = (body.reason if body else "") or "" ok = await memory.reject_fact(current_user.id, fact_id, reason=reason) if not ok: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="fact 不存在") + raise NotFoundError("fact 不存在") diff --git a/api/app/features/memory/service.py b/api/app/features/memory/service.py index b1e140c..f379f95 100644 --- a/api/app/features/memory/service.py +++ b/api/app/features/memory/service.py @@ -7,6 +7,7 @@ Celery task 只能作为同步入口包装 async service,不再维护 sync mem from sqlalchemy.ext.asyncio import AsyncSession +from app.core.db import transactional from app.core.logging import get_logger from app.features.memory.embedding_service import MemoryEmbeddingService from app.features.memory.ingest_service import MemoryIngestService @@ -57,7 +58,7 @@ class MemoryService: async def ingest_transcripts_batch( self, user_id: str, - items: list[tuple[str, str, dict | None]], + items: list[tuple[str, str, dict | None, str | None]], *, memoir_correlation_id: str | None = None, ) -> list[str]: @@ -112,71 +113,71 @@ class MemoryService: async def exclude_chunk( self, user_id: str, chunk_id: str, *, reason: str = "" ) -> bool: - ok = await set_chunk_excluded(self._db, chunk_id, user_id, True) - if not ok: - return False - stale_count = await mark_facts_stale_for_excluded_chunk( - self._db, - user_id=user_id, - chunk_id=chunk_id, - ) - await create_curation_action( - self._db, - user_id=user_id, - action_type="exclude", - target_type="chunk", - target_id=chunk_id, - details={ - **({"reason": reason} if reason else {}), - "staled_fact_count": stale_count, - }, - ) - await self._db.commit() + async with transactional(self._db): + ok = await set_chunk_excluded(self._db, chunk_id, user_id, True) + if not ok: + return False + stale_count = await mark_facts_stale_for_excluded_chunk( + self._db, + user_id=user_id, + chunk_id=chunk_id, + ) + await create_curation_action( + self._db, + user_id=user_id, + action_type="exclude", + target_type="chunk", + target_id=chunk_id, + details={ + **({"reason": reason} if reason else {}), + "staled_fact_count": stale_count, + }, + ) return True async def restore_chunk(self, user_id: str, chunk_id: str) -> bool: - ok = await set_chunk_excluded(self._db, chunk_id, user_id, False) - if not ok: - return False - await create_curation_action( - self._db, - user_id=user_id, - action_type="restore", - target_type="chunk", - target_id=chunk_id, - details={"fact_restore_policy": "requires_reenrichment"}, - ) - await self._db.commit() + async with transactional(self._db): + ok = await set_chunk_excluded(self._db, chunk_id, user_id, False) + if not ok: + return False + await create_curation_action( + self._db, + user_id=user_id, + action_type="restore", + target_type="chunk", + target_id=chunk_id, + details={"fact_restore_policy": "requires_reenrichment"}, + ) return True async def confirm_fact(self, user_id: str, fact_id: str) -> bool: - ok = await set_memory_fact_status(self._db, fact_id, user_id, "confirmed") - if not ok: - return False - await create_curation_action( - self._db, - user_id=user_id, - action_type="confirm", - target_type="fact", - target_id=fact_id, - details=None, - ) - await self._db.commit() + async with transactional(self._db): + ok = await set_memory_fact_status(self._db, fact_id, user_id, "confirmed") + if not ok: + return False + await create_curation_action( + self._db, + user_id=user_id, + action_type="confirm", + target_type="fact", + target_id=fact_id, + details=None, + ) return True async def reject_fact( self, user_id: str, fact_id: str, *, reason: str = "" ) -> bool: - ok = await set_memory_fact_status(self._db, fact_id, user_id, "rejected") - if not ok: - return False - await create_curation_action( - self._db, - user_id=user_id, - action_type="reject", - target_type="fact", - target_id=fact_id, - details={"reason": reason} if reason else None, - ) - await self._db.commit() - return ok + async with transactional(self._db): + ok = await set_memory_fact_status(self._db, fact_id, user_id, "rejected") + if not ok: + return False + await create_curation_action( + self._db, + user_id=user_id, + action_type="reject", + target_type="fact", + target_id=fact_id, + details={"reason": reason} if reason else None, + ) + return True diff --git a/api/app/features/payment/deps.py b/api/app/features/payment/deps.py index 60f30a6..66b15a0 100644 --- a/api/app/features/payment/deps.py +++ b/api/app/features/payment/deps.py @@ -1,12 +1,11 @@ from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import get_async_db from app.features.payment.order_service import PaymentOrderService from app.features.payment.payment_config import PaymentConfig from app.features.payment.payment_facade import PaymentService from app.features.plan.deps import get_plan_service from app.features.plan.service import PlanService +from app.core.deps_types import DbDep _payment_service = None @@ -21,7 +20,7 @@ def get_payment_service() -> PaymentService: def get_payment_order_service( - db: AsyncSession = Depends(get_async_db), + db: DbDep, plan_service: PlanService = Depends(get_plan_service), ) -> PaymentOrderService: """Payment order facade: create_order, callbacks, list/status.""" diff --git a/api/app/features/payment/order_service.py b/api/app/features/payment/order_service.py index b77f044..ea1c910 100644 --- a/api/app/features/payment/order_service.py +++ b/api/app/features/payment/order_service.py @@ -7,11 +7,17 @@ import time import uuid from datetime import timedelta -from fastapi import HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import utc_now +from app.core.db import transactional, utc_now +from app.core.errors import ( + AppError, + BadRequestError, + GatewayTimeoutError, + NotFoundError, + ServiceUnavailableError, +) from app.core.logging import get_logger from app.features.payment.models import Order from app.features.payment.schemas import ( @@ -47,6 +53,11 @@ class PaymentOrderService: self._db = db self._plan_service = plan_service + async def _persist_failed_order(self, order: Order) -> None: + async with transactional(self._db): + self._db.add(order) + order.status = "failed" + async def create_order( self, user_id: str, @@ -62,23 +73,18 @@ class PaymentOrderService: plans = self._plan_service.get_plans_for_api() plan = next((p for p in plans if p.id == plan_id), None) if plan is None: - raise HTTPException(status_code=400, detail="无效的套餐 ID") + raise BadRequestError("无效的套餐 ID") if plan.price <= 0: - raise HTTPException(status_code=400, detail="免费套餐无需支付") + raise BadRequestError("免费套餐无需支付") if payment_method not in ("wechat", "alipay"): - raise HTTPException( - status_code=400, detail="不支持的支付方式,仅支持 wechat / alipay" - ) + raise BadRequestError("不支持的支付方式,仅支持 wechat / alipay") client = _get_payment_service_client() if not client.is_method_available(payment_method): if payment_method == "alipay": - raise HTTPException( - status_code=503, detail="支付宝支付接口正在开发中,暂时不可用" - ) - raise HTTPException( - status_code=503, - detail=f"{payment_method} 支付暂不可用,请选择其他支付方式", + raise ServiceUnavailableError("支付宝支付接口正在开发中,暂时不可用") + raise ServiceUnavailableError( + f"{payment_method} 支付暂不可用,请选择其他支付方式" ) amount_fen = int(plan_price * 100) @@ -96,8 +102,6 @@ class PaymentOrderService: created_at=now, expired_at=now + timedelta(minutes=ORDER_EXPIRE_MINUTES), ) - self._db.add(order) - await self._db.flush() if payment_method == "wechat": try: @@ -106,16 +110,12 @@ class PaymentOrderService: timeout=WECHAT_INIT_TIMEOUT_SEC, ) except asyncio.TimeoutError: - order.status = "failed" - await self._db.flush() - raise HTTPException( - status_code=504, detail="微信支付初始化超时,请稍后重试。" - ) + await self._persist_failed_order(order) + raise GatewayTimeoutError("微信支付初始化超时,请稍后重试。") from None except Exception as e: - order.status = "failed" - await self._db.flush() + await self._persist_failed_order(order) logger.exception("微信支付客户端初始化失败: {}", e) - raise HTTPException(status_code=503, detail=f"微信支付暂不可用: {e!s}") + raise ServiceUnavailableError(f"微信支付暂不可用: {e!s}") from e try: payment_result = await asyncio.wait_for( @@ -124,32 +124,33 @@ class PaymentOrderService: payment_method, order_no, amount_fen, - f"岁月时书 - {plan_display_name}", + f"岁月留书 - {plan_display_name}", ), timeout=PREPAY_TIMEOUT_SEC, ) except asyncio.TimeoutError: - order.status = "failed" - await self._db.flush() - raise HTTPException( - status_code=504, - detail="创建预支付超时,请检查网络或稍后重试。若为微信支付,请确认商户配置与网络可达微信服务器。", - ) + await self._persist_failed_order(order) + raise GatewayTimeoutError( + "创建预支付超时,请检查网络或稍后重试。若为微信支付,请确认商户配置与网络可达微信服务器。" + ) from None except PaymentError as e: - order.status = "failed" - await self._db.flush() - raise HTTPException( - status_code=500, detail=f"创建支付订单失败: {e.message}" - ) + await self._persist_failed_order(order) + raise AppError( + f"创建支付订单失败: {e.message}", + status_code=500, + error_code="PAYMENT_FAILED", + ) from e except Exception as e: - order.status = "failed" - await self._db.flush() + await self._persist_failed_order(order) logger.exception("创建支付订单异常: {}", e) - raise HTTPException( - status_code=500, detail=f"创建支付订单异常: {type(e).__name__}: {e!s}" - ) + raise AppError( + f"创建支付订单异常: {type(e).__name__}: {e!s}", + status_code=500, + error_code="INTERNAL_ERROR", + ) from e - await self._db.commit() + async with transactional(self._db): + self._db.add(order) logger.info( "订单创建成功: order_no={}, payment_method={}, amount_fen={}", order_no, @@ -173,29 +174,29 @@ class PaymentOrderService: logger.info("支付回调: 订单已处理过 {}", out_trade_no) return now = utc_now() - order.status = "paid" - order.trade_no = trade_no - order.paid_at = now - user_result = await self._db.execute( - select(User).where(User.id == order.user_id) - ) - user = user_result.scalar_one_or_none() - if user: - duration_days = SUBSCRIPTION_DURATION_DAYS.get(order.plan_id, 365) - if user.subscription_expires_at and user.subscription_expires_at > now: - user.subscription_expires_at = user.subscription_expires_at + timedelta( - days=duration_days - ) - else: - user.subscription_expires_at = now + timedelta(days=duration_days) - user.subscription_type = order.plan_id - logger.info( - "用户 {} 订阅已升级为 {},到期: {}", - user.id, - order.plan_id, - user.subscription_expires_at, + async with transactional(self._db): + order.status = "paid" + order.trade_no = trade_no + order.paid_at = now + user_result = await self._db.execute( + select(User).where(User.id == order.user_id) ) - await self._db.commit() + user = user_result.scalar_one_or_none() + if user: + duration_days = SUBSCRIPTION_DURATION_DAYS.get(order.plan_id, 365) + if user.subscription_expires_at and user.subscription_expires_at > now: + user.subscription_expires_at = user.subscription_expires_at + timedelta( + days=duration_days + ) + else: + user.subscription_expires_at = now + timedelta(days=duration_days) + user.subscription_type = order.plan_id + logger.info( + "用户 {} 订阅已升级为 {},到期: {}", + user.id, + order.plan_id, + user.subscription_expires_at, + ) logger.info( "支付成功处理完成: 订单 {}, 第三方交易号 {}", out_trade_no, trade_no ) @@ -232,7 +233,7 @@ class PaymentOrderService: ) order = result.scalar_one_or_none() if order is None: - raise HTTPException(status_code=404, detail="订单不存在") + raise NotFoundError("订单不存在") return OrderStatusResponse( order_id=order.id, plan_id=order.plan_id, diff --git a/api/app/features/payment/payment_config.py b/api/app/features/payment/payment_config.py index a10db56..8c6a816 100644 --- a/api/app/features/payment/payment_config.py +++ b/api/app/features/payment/payment_config.py @@ -72,6 +72,8 @@ class PaymentConfig: @classmethod def from_settings(cls, settings) -> "PaymentConfig": + from app.core.runtime_constants import misc_defaults + wechat_private_key = ( getattr(settings, "wechat_pay_private_key", "") or "" ).strip() @@ -90,9 +92,7 @@ class PaymentConfig: wechat_private_key_path = ( getattr(settings, "wechat_pay_private_key_path", "") or "" ).strip() - alipay_under = ( - getattr(settings, "alipay_under_development", "true") or "true" - ).lower() + alipay_under = misc_defaults.alipay_under_development.lower() config = cls( wechat=WeChatPayConfig( app_id=getattr(settings, "wechat_pay_app_id", "") or "", @@ -115,7 +115,7 @@ class PaymentConfig: private_key=getattr(settings, "alipay_private_key", "") or "", alipay_public_key=getattr(settings, "alipay_public_key", "") or "", notify_url=getattr(settings, "alipay_notify_url", "") or "", - sign_type=getattr(settings, "alipay_sign_type", "RSA2") or "RSA2", + sign_type=misc_defaults.alipay_sign_type, ), alipay_under_development=alipay_under in ("true", "1", "yes"), ) diff --git a/api/app/features/payment/payment_exceptions.py b/api/app/features/payment/payment_exceptions.py index ed48f8d..cf06052 100644 --- a/api/app/features/payment/payment_exceptions.py +++ b/api/app/features/payment/payment_exceptions.py @@ -1,11 +1,21 @@ -"""支付模块异常定义(从 payment 迁入 app)""" +"""支付模块异常定义(继承 AppError,由全局 handler 统一映射)。""" + +from app.core.errors import AppError + +_PAYMENT_CODE_MAP: dict[str, tuple[int, str]] = { + "PAYMENT_CONFIG_ERROR": (502, "PROVIDER_ERROR"), + "PAYMENT_CREATE_ERROR": (502, "PROVIDER_ERROR"), + "PAYMENT_NOTIFY_ERROR": (400, "BAD_REQUEST"), + "PAYMENT_QUERY_ERROR": (502, "PROVIDER_ERROR"), + "PAYMENT_ERROR": (400, "BAD_REQUEST"), +} -class PaymentError(Exception): +class PaymentError(AppError): def __init__(self, message: str = "支付异常", code: str = "PAYMENT_ERROR"): - self.message = message + status_code, error_code = _PAYMENT_CODE_MAP.get(code, (400, code)) + super().__init__(message, status_code=status_code, error_code=error_code) self.code = code - super().__init__(self.message) class PaymentConfigError(PaymentError): diff --git a/api/app/features/payment/router.py b/api/app/features/payment/router.py index b598d1c..c010636 100644 --- a/api/app/features/payment/router.py +++ b/api/app/features/payment/router.py @@ -1,10 +1,12 @@ from typing import List -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, Request from fastapi.responses import PlainTextResponse -from app.core.dependencies import get_current_user +from app.core.deps_types import CurrentUserDep +from app.core.errors import BadRequestError from app.core.logging import get_logger +from app.core.openapi import error_responses from app.features.payment.deps import get_payment_order_service from app.features.payment.order_service import PaymentOrderService from app.features.payment.schemas import ( @@ -15,25 +17,20 @@ from app.features.payment.schemas import ( ) from app.features.plan.deps import get_plan_service from app.features.plan.service import PlanService -from app.features.user.models import User logger = get_logger(__name__) router = APIRouter( prefix="/api/payment", tags=["payment"], - responses={ - 401: {"description": "认证失败"}, - 404: {"description": "订单不存在"}, - 503: {"description": "支付服务暂不可用"}, - }, + responses=error_responses(401, 404, 500, 502, 503, 504), ) @router.post("/create-order", response_model=CreateOrderResponse) async def create_order( request: CreateOrderRequest, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: PaymentOrderService = Depends(get_payment_order_service), plan_service: PlanService = Depends(get_plan_service), ): @@ -41,7 +38,7 @@ async def create_order( (p for p in plan_service.get_plans_for_api() if p.id == request.plan_id), None ) if plan is None: - raise HTTPException(status_code=400, detail="无效的套餐 ID") + raise BadRequestError("无效的套餐 ID") return await service.create_order( user_id=current_user.id, user_subscription_type=current_user.subscription_type, @@ -87,7 +84,7 @@ async def alipay_notify( @router.get("/order/{order_id}/status", response_model=OrderStatusResponse) async def get_order_status( order_id: str, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: PaymentOrderService = Depends(get_payment_order_service), ): return await service.get_order_status(order_id, current_user.id) @@ -95,7 +92,7 @@ async def get_order_status( @router.get("/orders", response_model=List[OrderListResponse]) async def list_orders( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: PaymentOrderService = Depends(get_payment_order_service), ): return await service.list_orders(current_user.id) diff --git a/api/app/features/plan/catalog.py b/api/app/features/plan/catalog.py index 4cdf5ed..6f3a337 100644 --- a/api/app/features/plan/catalog.py +++ b/api/app/features/plan/catalog.py @@ -3,7 +3,7 @@ from app.core.config import settings from app.features.plan.schemas import PlanResponse -ENABLE_TEST_PLAN = (settings.enable_test_plan or "").lower() in ("1", "true", "yes") +ENABLE_TEST_PLAN = settings.enable_test_plan AVAILABLE_PLANS = [ PlanResponse( diff --git a/api/app/features/plan/deps.py b/api/app/features/plan/deps.py index 69f05ed..183177b 100644 --- a/api/app/features/plan/deps.py +++ b/api/app/features/plan/deps.py @@ -1,13 +1,22 @@ """Plan feature dependencies: get_plan_service.""" +from typing import Annotated + from fastapi import Depends +from app.core.dependencies import get_current_user from app.features.plan.service import PlanService from app.features.quota.deps import get_quota_service from app.features.quota.service import QuotaService +from app.features.user.models import User +from app.core.deps_types import DbDep def get_plan_service( - quota_service: QuotaService = Depends(get_quota_service), + quota_service: Annotated[QuotaService, Depends(get_quota_service)], ) -> PlanService: return PlanService(quota_service=quota_service) + + +PlanServiceDep = Annotated[PlanService, Depends(get_plan_service)] +CurrentUserDep = Annotated[User, Depends(get_current_user)] diff --git a/api/app/features/plan/router.py b/api/app/features/plan/router.py index 0c5d8f0..5e0e44c 100644 --- a/api/app/features/plan/router.py +++ b/api/app/features/plan/router.py @@ -2,36 +2,29 @@ 订阅计划路由。 """ -from typing import List +from fastapi import APIRouter -from fastapi import APIRouter, Depends - -from app.core.dependencies import get_current_user -from app.features.plan.deps import get_plan_service +from app.core.openapi import error_responses +from app.features.plan.deps import CurrentUserDep, PlanServiceDep from app.features.plan.schemas import CurrentPlanResponse, PlanResponse -from app.features.plan.service import PlanService, get_plans_for_api -from app.features.user.models import User router = APIRouter( prefix="/api/plans", tags=["plans"], - responses={ - 401: {"description": "认证失败"}, - 404: {"description": "资源不存在"}, - }, + responses=error_responses(401, 404), ) -@router.get("", response_model=List[PlanResponse]) -async def get_plans(): +@router.get("", response_model=list[PlanResponse]) +def get_plans(service: PlanServiceDep) -> list[PlanResponse]: """获取所有可用的订阅计划(开发环境 ENABLE_TEST_PLAN=1 时包含「一分钱测试版」)。""" - return get_plans_for_api() + return service.get_plans_for_api() @router.get("/current", response_model=CurrentPlanResponse) async def get_current_plan( - current_user: User = Depends(get_current_user), - service: PlanService = Depends(get_plan_service), -): + current_user: CurrentUserDep, + service: PlanServiceDep, +) -> CurrentPlanResponse: """获取当前用户的订阅计划信息""" return await service.get_current_plan_response(current_user) diff --git a/api/app/features/plan/schemas.py b/api/app/features/plan/schemas.py index 84ebec9..7f891b9 100644 --- a/api/app/features/plan/schemas.py +++ b/api/app/features/plan/schemas.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from pydantic import BaseModel @@ -9,17 +7,24 @@ class PlanResponse(BaseModel): display_name: str price: float currency: str - features: List[str] - max_conversations: Optional[int] = None - max_chapters: Optional[int] = None - max_words: Optional[int] = None + features: list[str] + max_conversations: int | None = None + max_chapters: int | None = None + max_words: int | None = None is_popular: bool = False +class PlanUsageResponse(BaseModel): + conversations: int + chapters: int + max_conversations: int | None = None + max_chapters: int | None = None + + class CurrentPlanResponse(BaseModel): plan_id: str plan_name: str subscription_type: str - expires_at: Optional[str] = None - features: List[str] - usage: dict + expires_at: str | None = None + features: list[str] + usage: PlanUsageResponse diff --git a/api/app/features/plan/service.py b/api/app/features/plan/service.py index b793ef1..20f6544 100644 --- a/api/app/features/plan/service.py +++ b/api/app/features/plan/service.py @@ -7,7 +7,11 @@ from app.features.plan.catalog import ( get_plan_by_type, get_plans_for_api, ) -from app.features.plan.schemas import CurrentPlanResponse, PlanResponse +from app.features.plan.schemas import ( + CurrentPlanResponse, + PlanResponse, + PlanUsageResponse, +) from app.features.quota.service import QuotaService from app.features.user.models import User @@ -32,12 +36,12 @@ class PlanService: async def get_current_plan_response(self, user: User) -> CurrentPlanResponse: plan = get_plan_by_type(user.subscription_type) segment_count, chapter_count = await self._quota.get_usage(user.id) - usage = { - "conversations": segment_count, - "chapters": chapter_count, - "max_conversations": plan.max_conversations, - "max_chapters": plan.max_chapters, - } + usage = PlanUsageResponse( + conversations=segment_count, + chapters=chapter_count, + max_conversations=plan.max_conversations, + max_chapters=plan.max_chapters, + ) expires_at = None if user.subscription_expires_at: expires_at = user.subscription_expires_at.isoformat() diff --git a/api/app/features/quota/deps.py b/api/app/features/quota/deps.py index 7000a79..5119f04 100644 --- a/api/app/features/quota/deps.py +++ b/api/app/features/quota/deps.py @@ -1,13 +1,12 @@ """Quota feature dependencies: get_quota_service.""" from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import get_async_db from app.features.quota.service import QuotaService +from app.core.deps_types import DbDep def get_quota_service( - db: AsyncSession = Depends(get_async_db), + db: DbDep, ) -> QuotaService: return QuotaService(db=db) diff --git a/api/app/features/quota/router.py b/api/app/features/quota/router.py index 5485c62..12e949f 100644 --- a/api/app/features/quota/router.py +++ b/api/app/features/quota/router.py @@ -5,6 +5,7 @@ from fastapi import APIRouter, Depends from app.core.dependencies import get_current_user +from app.core.openapi import error_responses from app.features.quota.deps import get_quota_service from app.features.quota.schemas import QuotaCheckResponse from app.features.quota.service import QuotaService @@ -13,12 +14,7 @@ from app.features.user.models import User router = APIRouter( prefix="/api/quota", tags=["quota"], - responses={ - 401: {"description": "认证失败"}, - 403: {"description": "权限不足"}, - 404: {"description": "资源不存在"}, - 429: {"description": "配额已用尽"}, - }, + responses=error_responses(401, 403, 404, 429), ) diff --git a/api/app/features/story/constants.py b/api/app/features/story/constants.py new file mode 100644 index 0000000..0555f92 --- /dev/null +++ b/api/app/features/story/constants.py @@ -0,0 +1,5 @@ +"""Story / chapter / evidence 流水线产品常量 — 值来自 config/*.toml(SSOT)。""" + +from app.core.app_config import app_config + +story = app_config.story diff --git a/api/app/features/story/deps.py b/api/app/features/story/deps.py index dd48524..bb8867e 100644 --- a/api/app/features/story/deps.py +++ b/api/app/features/story/deps.py @@ -1,9 +1,8 @@ from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import get_async_db from app.features.story.service import StoryService +from app.core.deps_types import DbDep -async def get_story_service(db: AsyncSession = Depends(get_async_db)) -> StoryService: +async def get_story_service(db: DbDep) -> StoryService: return StoryService(db=db) diff --git a/api/app/features/story/post_commit.py b/api/app/features/story/post_commit.py index d5367a2..c4d1003 100644 --- a/api/app/features/story/post_commit.py +++ b/api/app/features/story/post_commit.py @@ -6,37 +6,26 @@ enqueue 失败不回滚已提交数据,仅记录日志;依赖后续触发或 from __future__ import annotations -import threading from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, cast -import redis - -from app.core.config import settings from app.core.logging import get_logger from app.core.memoir_pipeline_progress import merge_pipeline_run from app.core.memory_compaction_schedule import schedule_memory_compaction_run +from app.core.redis_sync import get_sync_redis +from app.features.memoir.constants import memoir +from app.features.story.constants import story logger = get_logger(__name__) -_redis_client: redis.Redis | None = None -_redis_lock = threading.Lock() def _story_image_enqueue_key(story_id: str) -> str: return f"enqueue:story-image:{story_id}" -def _get_redis() -> redis.Redis: - """进程内复用单个 Redis 客户端,避免重复创建连接池。""" - global _redis_client - if _redis_client is None: - with _redis_lock: - if _redis_client is None: - _redis_client = redis.from_url( - settings.redis_url, decode_responses=True - ) - return _redis_client +def _get_redis(): + return get_sync_redis(decode_responses=True) @dataclass @@ -68,7 +57,7 @@ def enqueue_story_post_commit_effects( """ result = PostCommitResult() r = _get_redis() - ttl = int(settings.story_image_enqueue_dedup_ttl) + ttl = int(story.image_enqueue_dedup_ttl) if need_image and story_ids: from app.tasks.story_image_tasks import ( @@ -131,7 +120,7 @@ def enqueue_story_post_commit_effects( recompose_chapter as recompose_chapter_task, ) - cd = int(settings.recompose_chapter_delay_seconds) + cd = int(story.recompose_chapter_delay_seconds) for cid in sorted(chapter_ids): try: rkwargs: dict[str, Any] = {} @@ -201,13 +190,13 @@ def enqueue_story_post_commit_effects( ) result.errors.append(f"compaction:{exc}") - if need_quality_pass and settings.memoir_quality_pass_enabled and story_ids: + if need_quality_pass and memoir.quality_pass_enabled and story_ids: try: from app.tasks.memoir_quality_pass_tasks import ( memoir_quality_pass as quality_pass_task, ) - cd = int(settings.memoir_quality_pass_delay_seconds) + cd = int(memoir.quality_pass_delay_seconds) qp_ar = cast(Any, quality_pass_task).apply_async( args=[user_id, sorted(story_ids), sorted(chapter_ids)], kwargs={"memoir_correlation_id": memoir_correlation_id}, diff --git a/api/app/features/story/service.py b/api/app/features/story/service.py index 1f84b75..86e534c 100644 --- a/api/app/features/story/service.py +++ b/api/app/features/story/service.py @@ -10,6 +10,8 @@ from datetime import datetime, timezone from sqlalchemy.ext.asyncio import AsyncSession +from app.core.db import transactional +from app.core.errors import NotFoundError from app.core.logging import get_logger from app.features.memoir import repo as memoir_repo from app.features.memoir.asset_resolver import strip_asset_image_refs_from_markdown @@ -124,45 +126,46 @@ class StoryService: ) -> str: """Create story, commit, return story_id.""" md = strip_asset_image_refs_from_markdown(canonical_markdown or "") - story = await create_story( - self._db, - user_id=user_id, - title=title, - stage=stage, - story_type=story_type, - summary=summary, - canonical_markdown=md, - ) - await self._db.flush() - apply_infer_story_time_start_to_model(story) - if md.strip(): - version = await create_story_version( + async with transactional(self._db): + story = await create_story( self._db, - story_id=story.id, - version_no=1, - markdown_snapshot=md, - actor_type="ai", - source_type="generate", + user_id=user_id, + title=title, + stage=stage, + story_type=story_type, + summary=summary, + canonical_markdown=md, ) await self._db.flush() - story.current_version_id = version.id - await _extract_and_store_image_intent( - self._db, - story=story, - version=version, - markdown=md, - ) - if md.strip(): - await memoir_repo.mark_chapters_dirty_for_story(self._db, story.id) - await self._db.commit() + apply_infer_story_time_start_to_model(story) + if md.strip(): + version = await create_story_version( + self._db, + story_id=story.id, + version_no=1, + markdown_snapshot=md, + actor_type="ai", + source_type="generate", + ) + await self._db.flush() + story.current_version_id = version.id + await _extract_and_store_image_intent( + self._db, + story=story, + version=version, + markdown=md, + ) + if md.strip(): + await memoir_repo.mark_chapters_dirty_for_story(self._db, story.id) + story_id = story.id if md.strip(): from app.features.memoir.repo import get_chapter_ids_linked_to_story from app.features.story.post_commit import enqueue_story_post_commit_effects - chapter_ids = set(await get_chapter_ids_linked_to_story(self._db, story.id)) + chapter_ids = set(await get_chapter_ids_linked_to_story(self._db, story_id)) pc = enqueue_story_post_commit_effects( user_id=user_id, - story_ids={story.id}, + story_ids={story_id}, chapter_ids=chapter_ids, trigger_source="manual_api", need_compaction=False, @@ -176,7 +179,7 @@ class StoryService: pc.enqueued_chapter_recompose_count, pc.errors, ) - return story.id + return story_id async def append_version( self, @@ -191,32 +194,33 @@ class StoryService: """Append new version, update canonical_markdown, return version_id.""" story = await get_story_by_id(self._db, story_id) if not story: - raise ValueError(f"Story {story_id} not found") + raise NotFoundError(f"Story {story_id} not found") md = strip_asset_image_refs_from_markdown(markdown_snapshot or "") parent_id = story.current_version_id version_no = (await count_story_versions(self._db, story_id)) + 1 - version = await create_story_version( - self._db, - story_id=story_id, - version_no=version_no, - markdown_snapshot=md, - actor_type=actor_type, - source_type=source_type, - parent_version_id=parent_id, - prompt_meta=prompt_meta, - ) - version.change_summary = change_summary - story.current_version_id = version.id - story.canonical_markdown = md - apply_infer_story_time_start_to_model(story) - await _extract_and_store_image_intent( - self._db, - story=story, - version=version, - markdown=md, - ) - await memoir_repo.mark_chapters_dirty_for_story(self._db, story_id) - await self._db.commit() + async with transactional(self._db): + version = await create_story_version( + self._db, + story_id=story_id, + version_no=version_no, + markdown_snapshot=md, + actor_type=actor_type, + source_type=source_type, + parent_version_id=parent_id, + prompt_meta=prompt_meta, + ) + version.change_summary = change_summary + story.current_version_id = version.id + story.canonical_markdown = md + apply_infer_story_time_start_to_model(story) + await _extract_and_store_image_intent( + self._db, + story=story, + version=version, + markdown=md, + ) + await memoir_repo.mark_chapters_dirty_for_story(self._db, story_id) + version_id = version.id from app.features.memoir.repo import get_chapter_ids_linked_to_story from app.features.story.post_commit import enqueue_story_post_commit_effects @@ -236,7 +240,7 @@ class StoryService: pc.enqueued_chapter_recompose_count, pc.errors, ) - return version.id + return version_id async def link_evidence( self, @@ -248,15 +252,15 @@ class StoryService: weight: float | None = None, ) -> None: """Add evidence link. Caller must ensure story exists.""" - await create_story_evidence_link( - self._db, - story_id=story_id, - evidence_type=evidence_type, - evidence_id=evidence_id, - role=role, - weight=weight, - ) - await self._db.commit() + async with transactional(self._db): + await create_story_evidence_link( + self._db, + story_id=story_id, + evidence_type=evidence_type, + evidence_id=evidence_id, + role=role, + weight=weight, + ) async def get_stories( self, user_id: str, *, status: str | None = "active" diff --git a/api/app/features/tasks/deps.py b/api/app/features/tasks/deps.py index 1299347..f1abfd5 100644 --- a/api/app/features/tasks/deps.py +++ b/api/app/features/tasks/deps.py @@ -1,6 +1,7 @@ """Tasks feature 依赖:提供 get_tasks_service。""" from app.features.tasks.service import TasksService +from app.core.deps_types import DbDep def get_tasks_service() -> TasksService: diff --git a/api/app/features/tasks/router.py b/api/app/features/tasks/router.py index 99beb77..1fb9b41 100644 --- a/api/app/features/tasks/router.py +++ b/api/app/features/tasks/router.py @@ -8,6 +8,7 @@ from fastapi import APIRouter, Depends from pydantic import BaseModel from app.core.dependencies import get_current_user +from app.core.openapi import error_responses from app.features.tasks.deps import get_tasks_service from app.features.tasks.service import TasksService from app.features.user.models import User @@ -15,7 +16,7 @@ from app.features.user.models import User router = APIRouter( prefix="/api/tasks", tags=["tasks"], - responses={401: {"description": "认证失败"}}, + responses=error_responses(401), ) diff --git a/api/app/features/user/deps.py b/api/app/features/user/deps.py index 71b43b7..d480ac3 100644 --- a/api/app/features/user/deps.py +++ b/api/app/features/user/deps.py @@ -1,13 +1,12 @@ """User feature dependencies: get_user_service.""" from fastapi import Depends -from sqlalchemy.ext.asyncio import AsyncSession -from app.core.db import get_async_db from app.features.user.service import UserService +from app.core.deps_types import DbDep def get_user_service( - db: AsyncSession = Depends(get_async_db), + db: DbDep, ) -> UserService: return UserService(db=db) diff --git a/api/app/features/user/router.py b/api/app/features/user/router.py index b212d0e..adfdc61 100644 --- a/api/app/features/user/router.py +++ b/api/app/features/user/router.py @@ -1,13 +1,15 @@ import uuid -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, status from app.core.config import settings from app.core.cos_url_keys import avatar_url_for_api_response -from app.core.dependencies import get_current_user, get_object_storage +from app.core.dependencies import get_object_storage +from app.core.deps_types import CurrentUserDep +from app.core.errors import BadRequestError, NotFoundError from app.core.logging import get_logger +from app.core.openapi import error_responses from app.features.user.deps import get_user_service -from app.features.user.models import User from app.features.user.schemas import ( FeedbackResponse, PurgeUserDataRequest, @@ -18,33 +20,28 @@ from app.features.user.schemas import ( UpdateUserProfileRequest, UserProfileResponse, ) -from app.features.user.service import UserService, _coerce_language as _coerce_language_token +from app.features.user.service import UserService +from app.features.user.service import _coerce_language as _coerce_language_token from app.ports.storage import ObjectStorage logger = get_logger(__name__) -_SHARED_RESPONSES = { - 401: {"description": "认证失败"}, - 403: {"description": "权限不足"}, - 404: {"description": "资源不存在"}, -} - router = APIRouter( prefix="/api/user", tags=["user"], - responses=_SHARED_RESPONSES, + responses=error_responses(401, 403, 404), ) feedback_router = APIRouter( prefix="/api/feedback", tags=["feedback"], - responses=_SHARED_RESPONSES, + responses=error_responses(401, 403, 404), ) @router.get("/profile", response_model=UserProfileResponse) async def get_user_profile( - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, ): return UserProfileResponse( id=current_user.id, @@ -67,7 +64,7 @@ async def get_user_profile( @router.put("/profile", response_model=UserProfileResponse) async def update_user_profile( body: UpdateUserProfileRequest, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: UserService = Depends(get_user_service), ): logger.info( @@ -81,7 +78,7 @@ async def update_user_profile( @router.post("/data/purge", response_model=PurgeUserDataResponse) async def purge_user_data( body: PurgeUserDataRequest, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: UserService = Depends(get_user_service), object_storage: ObjectStorage = Depends(get_object_storage), ): @@ -92,26 +89,23 @@ async def purge_user_data( 保留 users 表中的账号与登录字段(手机号、密码等),并清空出生年/出生地/成长地/职业等档案字段。 口令见请求体 schema 说明。 """ - try: - return await service.purge_all_user_data( - current_user.id, - confirmation=body.confirmation, - object_storage=object_storage, - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e + return await service.purge_all_user_data( + current_user.id, + confirmation=body.confirmation, + object_storage=object_storage, + ) @router.post("/test-subscription", response_model=TestSubscriptionResponse) async def test_subscription( body: TestSubscriptionRequest, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, service: UserService = Depends(get_user_service), ): if not settings.enable_test_subscription: - raise HTTPException(status_code=404, detail="测试订阅功能未开放") + raise NotFoundError("测试订阅功能未开放") if body.action == "activate" and body.plan_id not in ("pro", "pro_plus"): - raise HTTPException(status_code=400, detail="plan_id 仅支持 pro 或 pro_plus") + raise BadRequestError("plan_id 仅支持 pro 或 pro_plus") return await service.toggle_test_subscription( current_user.id, body.action, body.plan_id ) @@ -122,7 +116,7 @@ async def test_subscription( ) async def submit_feedback( request: SubmitFeedbackRequest, - current_user: User = Depends(get_current_user), + current_user: CurrentUserDep, ): """提交用户反馈。用户可通过此接口提交反馈意见或联系客服。""" feedback_id = str(uuid.uuid4()) diff --git a/api/app/features/user/schemas.py b/api/app/features/user/schemas.py index 6a0f1a2..1e5cfc2 100644 --- a/api/app/features/user/schemas.py +++ b/api/app/features/user/schemas.py @@ -41,7 +41,7 @@ class TestSubscriptionResponse(BaseModel): class SubmitFeedbackRequest(BaseModel): """提交反馈请求""" - content: str = Field(..., min_length=1, max_length=2000, description="反馈内容") + content: str = Field(min_length=1, max_length=2000, description="反馈内容") contact: Optional[str] = Field(None, max_length=100, description="联系方式(可选)") diff --git a/api/app/features/user/service.py b/api/app/features/user/service.py index cee51e7..e63aa80 100644 --- a/api/app/features/user/service.py +++ b/api/app/features/user/service.py @@ -3,7 +3,8 @@ from datetime import timedelta from sqlalchemy.ext.asyncio import AsyncSession from app.core.cos_url_keys import avatar_url_for_api_response -from app.core.db import utc_now +from app.core.db import transactional, utc_now +from app.core.errors import BadRequestError, NotFoundError from app.core.logging import get_logger from app.core.redis import redis_service from app.core.task_tracker import task_tracker @@ -50,16 +51,19 @@ class UserService: def __init__(self, db: AsyncSession): self._db = db + async def get_by_id(self, user_id: str) -> User | None: + return await repo.get_user_by_id(user_id, self._db) + async def update_profile( self, user_id: str, body: UpdateUserProfileRequest ) -> UserProfileResponse: user = await repo.get_user_by_id(user_id, self._db) if not user: - raise ValueError("用户不存在") - for field in ("birth_year", "birth_place", "grew_up_place", "occupation"): - if field in body.model_fields_set: - setattr(user, field, getattr(body, field)) - await self._db.commit() + raise NotFoundError("用户不存在") + async with transactional(self._db): + for field in ("birth_year", "birth_place", "grew_up_place", "occupation"): + if field in body.model_fields_set: + setattr(user, field, getattr(body, field)) await self._db.refresh(user) return _user_to_profile(user) @@ -68,24 +72,23 @@ class UserService: ) -> TestSubscriptionResponse: user = await repo.get_user_by_id(user_id, self._db) if not user: - raise ValueError("用户不存在") + raise NotFoundError("用户不存在") now = utc_now() - if action == "activate": - user.subscription_type = plan_id - user.subscription_expires_at = now + timedelta(days=365) - await self._db.commit() - return TestSubscriptionResponse( - success=True, - message=f"已开启测试订阅:{plan_id}", - subscription_type=plan_id, - ) - user.subscription_type = "free" - user.subscription_expires_at = None - await self._db.commit() + async with transactional(self._db): + if action == "activate": + user.subscription_type = plan_id + user.subscription_expires_at = now + timedelta(days=365) + subscription_type = plan_id + message = f"已开启测试订阅:{plan_id}" + else: + user.subscription_type = "free" + user.subscription_expires_at = None + subscription_type = "free" + message = "已关闭测试订阅,恢复免费体验版" return TestSubscriptionResponse( success=True, - message="已关闭测试订阅,恢复免费体验版", - subscription_type="free", + message=message, + subscription_type=subscription_type, ) async def purge_all_user_data( @@ -97,11 +100,11 @@ class UserService: ) -> PurgeUserDataResponse: """物理删除该用户业务数据(保留 users 行与登录字段);并清空出生年/出生地等档案字段;提交后再清 Redis 等。""" if confirmation != PURGE_USER_DATA_CONFIRMATION: - raise ValueError("确认文案不正确,请按提示完整输入口令") + raise BadRequestError("确认文案不正确,请按提示完整输入口令") user = await repo.get_user_by_id(user_id, self._db) if not user: - raise ValueError("用户不存在") + raise NotFoundError("用户不存在") logger.info("用户数据清空开始 user_id={}", user_id) @@ -120,10 +123,10 @@ class UserService: len(story_ids), ) - await repo.purge_user_related_rows(self._db, user_id) - await repo.clear_user_demographics(self._db, user_id) - await repo.clear_user_avatar_url(self._db, user_id) - await self._db.commit() + async with transactional(self._db): + await repo.purge_user_related_rows(self._db, user_id) + await repo.clear_user_demographics(self._db, user_id) + await repo.clear_user_avatar_url(self._db, user_id) logger.info("用户数据 DB 行已删除、档案字段已清空并提交 user_id={}", user_id) if object_storage and storage_keys: diff --git a/api/app/internal_main.py b/api/app/internal_main.py index 55e8354..0ffd027 100644 --- a/api/app/internal_main.py +++ b/api/app/internal_main.py @@ -8,6 +8,7 @@ from __future__ import annotations +from contextlib import asynccontextmanager from pathlib import Path from app.core.logging import get_logger, setup_logging @@ -15,10 +16,11 @@ from app.core.logging import get_logger, setup_logging setup_logging() from app.core.config import settings +from app.core.runtime_constants import otel_defaults from app.core.telemetry import instrument_fastapi_app, setup_telemetry setup_telemetry( - service_name=settings.otel_service_name or "life-echo-internal-api", + service_name=otel_defaults.service_name or "life-echo-internal-api", ) from fastapi import FastAPI @@ -30,60 +32,13 @@ from app.core.errors import register_exception_handlers from app.core.middleware import RequestIdMiddleware from app.features.evaluation import models as _eval_models # noqa: F401 from app.features.evaluation.router import router as eval_router +from app.features.evaluation.constants import eval_cfg logger = get_logger(__name__) -internal_app = FastAPI( - title="Life Echo Internal Evaluation API", - version="0.1.0", - docs_url="/docs" if settings.internal_eval_enable_docs else None, - redoc_url="/redoc" if settings.internal_eval_enable_docs else None, - openapi_url="/openapi.json" if settings.internal_eval_enable_docs else None, -) -instrument_fastapi_app(internal_app) - -internal_app.add_middleware(RequestIdMiddleware) -_origins = [ - o.strip() - for o in (settings.internal_eval_cors_origins or "").split(",") - if o.strip() -] -# 浏览器不允许 Origin=* 与 credentials 同时出现;未配置显式白名单时关闭 credentials。 -_allow_creds = bool(_origins) -internal_app.add_middleware( - CORSMiddleware, - allow_origins=_origins if _origins else ["*"], - allow_credentials=_allow_creds, - allow_methods=["*"], - allow_headers=["*"], -) -register_exception_handlers(internal_app) - - -@internal_app.get("/", include_in_schema=False, response_class=HTMLResponse) -async def internal_eval_landing(): - """浏览器打开内部评测 API 根路径时提示:界面在 Vite(默认 5174),本进程仅为 API。""" - docs_hint = ( - '

OpenAPI 文档 /docs

' - if settings.internal_eval_enable_docs - else "

(未开启文档;设置 INTERNAL_EVAL_ENABLE_DOCS=1 后可访问 /docs)

" - ) - return f""" -内部评测 API - -

Life Echo · 内部回归评测 API

-

这里是 HTTP API(端口由启动命令决定),没有内置网页。 -浏览「回归评测台」请在仓库执行 ./development.shcd app-eval-web && npm run dev, -在终端里打开 Vite 给出的地址(一般为 http://127.0.0.1:5174/)。

-

健康检查:/health

-{docs_hint} -

会话与对比接口前缀:/internal/api/evaluation/

-""" - - -@internal_app.on_event("startup") -async def _startup(): +@asynccontextmanager +async def lifespan(app: FastAPI): import asyncio from app.core.alembic_startup import run_alembic_upgrade_at_startup @@ -100,10 +55,15 @@ async def _startup(): except Exception as e: logger.warning("Redis 连接失败: {}", e) + yield -@internal_app.on_event("shutdown") -async def _shutdown(): logger.info("内部评测 API 关闭中…") + try: + from app.core.telemetry import shutdown_telemetry + + shutdown_telemetry() + except Exception as e: + logger.warning("关闭 OpenTelemetry 失败: {}", e) try: from app.core.redis import redis_service @@ -112,6 +72,55 @@ async def _shutdown(): logger.warning("关闭 Redis 失败: {}", e) +internal_app = FastAPI( + title="Life Echo Internal Evaluation API", + version="0.1.0", + docs_url="/docs" if eval_cfg.internal_enable_docs else None, + redoc_url="/redoc" if eval_cfg.internal_enable_docs else None, + openapi_url="/openapi.json" if eval_cfg.internal_enable_docs else None, + lifespan=lifespan, +) + +instrument_fastapi_app(internal_app) + +internal_app.add_middleware(RequestIdMiddleware) +_origins = [ + o.strip() + for o in (eval_cfg.internal_cors_origins or "").split(",") + if o.strip() +] +_allow_creds = bool(_origins) +internal_app.add_middleware( + CORSMiddleware, + allow_origins=_origins if _origins else ["*"], + allow_credentials=_allow_creds, + allow_methods=["*"], + allow_headers=["*"], +) +register_exception_handlers(internal_app) + + +@internal_app.get("/", include_in_schema=False, response_class=HTMLResponse) +async def internal_eval_landing(): + """浏览器打开内部评测 API 根路径时提示:界面在 Vite(默认 5174),本进程仅为 API。""" + docs_hint = ( + '

OpenAPI 文档 /docs

' + if eval_cfg.internal_enable_docs + else "

(未开启文档;在 config/development.toml 的 [eval] internal_enable_docs = true)

" + ) + return f""" +内部评测 API + +

Life Echo · 内部回归评测 API

+

这里是 HTTP API(端口由启动命令决定),没有内置网页。 +浏览「回归评测台」请在仓库执行 ./development.shcd app-eval-web && npm run dev, +在终端里打开 Vite 给出的地址(一般为 http://127.0.0.1:5174/)。

+

健康检查:/health

+{docs_hint} +

会话与对比接口前缀:/internal/api/evaluation/

+""" + + internal_app.include_router(eval_router, prefix="/internal/api/evaluation") _static_dir = Path(__file__).resolve().parent.parent / "static" diff --git a/api/app/main.py b/api/app/main.py index 18a7b65..5cd2224 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -2,6 +2,7 @@ FastAPI 应用入口(app 内主入口,符合架构计划) """ +from contextlib import asynccontextmanager from pathlib import Path from app.core.logging import get_logger, setup_logging @@ -9,10 +10,11 @@ from app.core.logging import get_logger, setup_logging setup_logging() from app.core.config import settings +from app.core.runtime_constants import asr_defaults, otel_defaults from app.core.telemetry import instrument_fastapi_app, setup_telemetry setup_telemetry( - service_name=settings.otel_service_name or "life-echo-api", + service_name=otel_defaults.service_name or "life-echo-api", ) from fastapi import FastAPI @@ -44,44 +46,23 @@ from app.features.payment import models as _payment_models # noqa: F401 from app.features.story import models as _story_models # noqa: F401 from app.features.user import models as _user_models # noqa: F401 -app = FastAPI( - title="Life Echo API", - version="1.0.0", - docs_url="/docs" if settings.enable_docs else None, - redoc_url="/redoc" if settings.enable_docs else None, - openapi_url="/openapi.json" if settings.enable_docs else None, -) - -instrument_fastapi_app(app) - -# OpenAPI 全局增强 -app.openapi = lambda: custom_openapi(app) # type: ignore[assignment] - logger = get_logger(__name__) -# Middleware(注册顺序:LIFO,先注册的后执行) -app.add_middleware(RequestIdMiddleware) -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) -# 全局异常处理 -register_exception_handlers(app) - - -@app.on_event("startup") -async def startup_event(): - """应用启动事件:Alembic 迁移(可重试、可配置 fail-fast)、Redis、ASR、支付预初始化。""" +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期:启动迁移/Redis/ASR/支付预初始化,关闭时释放连接。""" import asyncio from app.core.alembic_startup import run_alembic_upgrade_at_startup logger.info("Life Echo API 正在启动...") - + if settings.app_environment == "production" and not ( + settings.api_cors_origins or "" + ).strip(): + logger.warning( + "生产环境未配置 API_CORS_ORIGINS;浏览器跨域请求将无法携带 credentials" + ) await asyncio.to_thread(run_alembic_upgrade_at_startup) try: @@ -104,11 +85,9 @@ async def startup_event(): else: asr_ready = True if asr_ready: - from app.core.config import settings - name = ( "腾讯云一句话识别" - if settings.asr_provider == "tencent" + if asr_defaults.provider == "tencent" else "本地 Whisper" ) logger.info("ASR 服务已就绪({})", name) @@ -131,12 +110,15 @@ async def startup_event(): except Exception as e: logger.warning("微信支付预初始化失败(首次下单时再初始化): {}", e) + yield -@app.on_event("shutdown") -async def shutdown_event(): - """应用关闭事件""" logger.info("Life Echo API 正在关闭...") + try: + from app.core.telemetry import shutdown_telemetry + shutdown_telemetry() + except Exception as e: + logger.warning("关闭 OpenTelemetry 失败: {}", e) try: from app.core.redis import redis_service @@ -146,6 +128,38 @@ async def shutdown_event(): logger.warning("关闭 Redis 连接失败: {}", e) +app = FastAPI( + title="Life Echo API", + version="1.0.0", + docs_url="/docs" if settings.enable_docs else None, + redoc_url="/redoc" if settings.enable_docs else None, + openapi_url="/openapi.json" if settings.enable_docs else None, + lifespan=lifespan, +) + +instrument_fastapi_app(app) + +# OpenAPI 全局增强 +app.openapi = lambda: custom_openapi(app) # type: ignore[assignment] + +# Middleware(注册顺序:LIFO,先注册的后执行) +app.add_middleware(RequestIdMiddleware) +_origins = [ + o.strip() for o in (settings.api_cors_origins or "").split(",") if o.strip() +] +_allow_creds = bool(_origins) +app.add_middleware( + CORSMiddleware, + allow_origins=_origins if _origins else ["*"], + allow_credentials=_allow_creds, + allow_methods=["*"], + allow_headers=["*"], +) + +# 全局异常处理 +register_exception_handlers(app) + + # ── Feature routers ────────────────────────────────────────── app.include_router(auth_router) app.websocket("/ws/conversation/{conversation_id}")(websocket_endpoint) diff --git a/api/app/ports/llm.py b/api/app/ports/llm.py index 8e436cc..83b2196 100644 --- a/api/app/ports/llm.py +++ b/api/app/ports/llm.py @@ -16,6 +16,10 @@ class LLMProvider(Protocol): ) -> str: """Single-turn completion, returns full response text. + Implementations MUST wrap underlying SDK calls with their own telemetry + (e.g. ``langchain_invoke_span``) so callers such as ``LlmGateway`` can + invoke ``complete()`` without adding a second observability layer. + ``max_tokens`` when set is passed to the underlying chat API (adapter-specific). """ diff --git a/api/app/tasks/celery_app.py b/api/app/tasks/celery_app.py index 62a439c..1a4068b 100644 --- a/api/app/tasks/celery_app.py +++ b/api/app/tasks/celery_app.py @@ -22,7 +22,13 @@ setup_telemetry(service_name="life-echo-celery-worker") instrument_celery() from celery import Celery -from celery.signals import task_failure, task_postrun, task_prerun, task_success +from celery.signals import ( + task_failure, + task_postrun, + task_prerun, + task_success, + worker_shutting_down, +) from app.core.celery_log_context import clear_celery_log_extras, set_celery_log_extras from app.core.log_events import celery_prerun_extras @@ -35,14 +41,24 @@ from app.features.memory import models as _memory_models # noqa: F401 from app.features.payment import models as _payment_models # noqa: F401 from app.features.story import models as _story_models # noqa: F401 from app.features.user import models as _user_models # noqa: F401 +from app.core.runtime_constants import celery_defaults +from app.features.memory.constants import memory -REDIS_URL = settings.redis_url +CELERY_REDIS_URL = settings.celery_redis_url_resolved + +_celery_lifecycle_log = get_logger(__name__) +_celery_lifecycle_log.info( + "event=celery_redis_urls business_redis_url={} celery_redis_url={} " + "msg=Celery broker/backend URL resolved", + settings.redis_url_resolved, + CELERY_REDIS_URL, +) # 创建 Celery 应用 celery_app = Celery( "life_echo", - broker=REDIS_URL, - backend=REDIS_URL, + broker=CELERY_REDIS_URL, + backend=CELERY_REDIS_URL, include=[ "app.tasks.memoir_tasks", "app.tasks.story_title_tasks", @@ -69,32 +85,63 @@ celery_app.conf.update( # 任务结果过期时间(1小时) result_expires=3600, # 任务执行设置 - task_soft_time_limit=300, # 5分钟软超时 + task_soft_time_limit=300, # 5分钟软超时(默认;重任务见 task_annotations) task_time_limit=600, # 10分钟硬超时 # 并发设置 worker_prefetch_multiplier=1, # 每次只预取一个任务 worker_concurrency=4, # 并发 worker 数量 + # Broker 连接 + broker_pool_limit=celery_defaults.broker_pool_limit, + broker_connection_retry_on_startup=celery_defaults.broker_connection_retry_on_startup, # 任务重试设置 task_acks_late=True, # 任务完成后再确认 task_reject_on_worker_lost=True, # worker 丢失时拒绝任务 task_routes={ "app.tasks.memory_enrichment_tasks.embed_memory_source": { - "queue": settings.celery_memory_enrichment_queue, + "queue": celery_defaults.memory_enrichment_queue, }, "app.tasks.memory_enrichment_tasks.enrich_memory_source": { - "queue": settings.celery_memory_enrichment_queue, + "queue": celery_defaults.memory_enrichment_queue, }, }, ) celery_app.conf.task_annotations = { "app.tasks.memory_enrichment_tasks.embed_memory_source": { - "soft_time_limit": 660, - "time_limit": 960, + "soft_time_limit": celery_defaults.enrichment_soft_time_limit, + "time_limit": celery_defaults.enrichment_hard_time_limit, }, "app.tasks.memory_enrichment_tasks.enrich_memory_source": { - "soft_time_limit": 660, - "time_limit": 960, + "soft_time_limit": celery_defaults.enrichment_soft_time_limit, + "time_limit": celery_defaults.enrichment_hard_time_limit, + }, + "app.tasks.memoir_tasks.process_memoir_phase1": { + "soft_time_limit": celery_defaults.memoir_soft_time_limit, + "time_limit": celery_defaults.memoir_hard_time_limit, + }, + "app.tasks.memoir_tasks.process_memoir_phase2": { + "soft_time_limit": celery_defaults.memoir_soft_time_limit, + "time_limit": celery_defaults.memoir_hard_time_limit, + }, + "app.tasks.memoir_tasks.generate_chapter_content": { + "soft_time_limit": celery_defaults.memoir_soft_time_limit, + "time_limit": celery_defaults.memoir_hard_time_limit, + }, + "app.tasks.story_image_tasks.generate_story_image": { + "soft_time_limit": celery_defaults.image_soft_time_limit, + "time_limit": celery_defaults.image_hard_time_limit, + }, + "app.tasks.chapter_cover_tasks.generate_chapter_cover": { + "soft_time_limit": celery_defaults.image_soft_time_limit, + "time_limit": celery_defaults.image_hard_time_limit, + }, + "app.tasks.chapter_compose_tasks.recompose_chapter": { + "soft_time_limit": celery_defaults.image_soft_time_limit, + "time_limit": celery_defaults.image_hard_time_limit, + }, + "app.tasks.memory_compaction_tasks.memory_compaction_sweep": { + "soft_time_limit": celery_defaults.compaction_sweep_soft_time_limit, + "time_limit": celery_defaults.compaction_sweep_hard_time_limit, }, } @@ -105,7 +152,11 @@ celery_app.conf.beat_schedule = { }, } -_celery_lifecycle_log = get_logger(__name__) +@worker_shutting_down.connect +def _shutdown_otel_on_worker_exit(**_: object) -> None: + from app.core.telemetry import shutdown_telemetry + + shutdown_telemetry() def _summarize_task_return(retval: object) -> str: diff --git a/api/app/tasks/chapter_compose_tasks.py b/api/app/tasks/chapter_compose_tasks.py index 98f71c8..7a9b591 100644 --- a/api/app/tasks/chapter_compose_tasks.py +++ b/api/app/tasks/chapter_compose_tasks.py @@ -10,18 +10,20 @@ from app.core.chapter_pipeline_lock import ( release_chapter_pipeline_lock, ) from app.core.config import settings -from app.core.db import get_sync_db +from app.core.db import get_sync_db, transactional_sync from app.core.logging import get_logger from app.core.memoir_pipeline_progress import merge_fanout_item from app.core.memoir_pipeline_trace import new_memoir_correlation_id from app.core.memory_compaction_schedule import schedule_memory_compaction_run from app.features.memoir import repo as memoir_repo from app.features.memoir.models import Chapter +from app.features.memoir.constants import memoir +from app.features.story.constants import story logger = get_logger(__name__) -@shared_task(bind=True, max_retries=8, default_retry_delay=30) +@shared_task(bind=True, max_retries=8, default_retry_delay=30, ignore_result=True) def recompose_chapter( self, chapter_id: str, memoir_correlation_id: str | None = None ) -> dict: @@ -29,7 +31,7 @@ def recompose_chapter( 按章节物化 canonical_markdown:仅当 markdown_compose_dirty 为 True 时执行; 与 pipeline 共用章节级 Redis 锁,拿不到锁则跳过(依赖后续触发重试)。 """ - lock_ttl = int(settings.chapter_pipeline_lock_ttl_seconds) + lock_ttl = int(story.chapter_pipeline_lock_ttl_seconds) tid = str(self.request.id) t0 = time.perf_counter() merge_fanout_item( @@ -90,10 +92,10 @@ def recompose_chapter( chapter_id, uid, stage, - settings.memoir_recompose_retry_on_lock_contention, + memoir.recompose_retry_on_lock_contention, ms, ) - if settings.memoir_recompose_retry_on_lock_contention: + if memoir.recompose_retry_on_lock_contention: countdown = max(15, min(120, lock_ttl // 4)) raise self.retry(countdown=countdown) merge_fanout_item( @@ -106,13 +108,12 @@ def recompose_chapter( ) return {"status": "skip_lock_contention"} try: - composed = memoir_repo.compose_chapter_from_story_links_sync( - session, chapter_id - ) - session.commit() + with transactional_sync(session): + composed = memoir_repo.compose_chapter_from_story_links_sync( + session, chapter_id + ) user_id = uid except Exception as exc: - session.rollback() logger.warning( "recompose_chapter failed chapter_id={} err={}", chapter_id, exc ) diff --git a/api/app/tasks/chapter_cover_enqueue.py b/api/app/tasks/chapter_cover_enqueue.py index 4fec856..afc25cc 100644 --- a/api/app/tasks/chapter_cover_enqueue.py +++ b/api/app/tasks/chapter_cover_enqueue.py @@ -6,13 +6,12 @@ from __future__ import annotations from typing import Literal -import redis from sqlalchemy import select from sqlalchemy.orm import joinedload -from app.core.config import settings from app.core.db import get_sync_db from app.core.logging import get_logger +from app.core.redis_sync import get_sync_redis from app.features.memoir.asset_resolver import strip_image_placeholders from app.features.memoir.cover_eligibility import ( chapter_eligible_for_cover_by_inline_body_image_count, @@ -94,8 +93,8 @@ def try_enqueue_generate_chapter_cover( return False key = _enqueue_dedup_key(chapter_id) + client = get_sync_redis(decode_responses=True) try: - client = redis.from_url(settings.redis_url, decode_responses=True) if not client.set( key, "1", nx=True, ex=CHAPTER_COVER_ENQUEUE_DEDUP_TTL_SECONDS ): @@ -123,7 +122,6 @@ def try_enqueue_generate_chapter_cover( exc, ) try: - client = redis.from_url(settings.redis_url, decode_responses=True) client.delete(key) except Exception: pass diff --git a/api/app/tasks/chapter_cover_tasks.py b/api/app/tasks/chapter_cover_tasks.py index 5dde8ce..ec91eee 100644 --- a/api/app/tasks/chapter_cover_tasks.py +++ b/api/app/tasks/chapter_cover_tasks.py @@ -16,7 +16,7 @@ from sqlalchemy import and_, func, or_, select, update from sqlalchemy.orm import joinedload from app.agents.image_prompt import get_image_prompt_orchestrator -from app.core.db import get_sync_db +from app.core.db import get_sync_db, transactional_sync from app.core.dependencies import get_image_generator from app.core.logging import get_logger from app.core.redis_lock import acquire_redis_lock, release_redis_lock @@ -38,6 +38,10 @@ CHAPTER_COVER_LOCK_TTL_SECONDS = 1800 CHAPTER_COVER_CLAIM_TTL_SECONDS = 1800 +class _ClaimSkipped(Exception): + """Concurrent worker won the intent claim; abort transactional block.""" + + def _build_cover_cos_key(user_id: str, chapter_id: str, prompt: str) -> str: short_hash = hashlib.sha1(prompt.encode("utf-8")).hexdigest()[:10] return f"chapters/{user_id}/{chapter_id}/cover-{short_hash}.png" @@ -100,24 +104,29 @@ def _claim_chapter_cover_intent_sync(db, chapter: Chapter, claim_token: str): .limit(1) ).scalar_one_or_none() if candidate_id: - claimed = db.execute( - update(ChapterCoverIntent) - .where(ChapterCoverIntent.id == candidate_id) - .where(_chapter_cover_claimable_clause(now)) - .values( - status="processing", - claim_token=claim_token, - claimed_at=now, - updated_at=now, - error=None, - attempt_count=func.coalesce(ChapterCoverIntent.attempt_count, 0) + 1, - ) - ) - if (claimed.rowcount or 0) != 1: - db.rollback() + try: + with transactional_sync(db): + claimed = db.execute( + update(ChapterCoverIntent) + .where(ChapterCoverIntent.id == candidate_id) + .where(_chapter_cover_claimable_clause(now)) + .values( + status="processing", + claim_token=claim_token, + claimed_at=now, + updated_at=now, + error=None, + attempt_count=func.coalesce( + ChapterCoverIntent.attempt_count, 0 + ) + + 1, + ) + ) + if (claimed.rowcount or 0) != 1: + raise _ClaimSkipped() + intent = db.get(ChapterCoverIntent, candidate_id) + except _ClaimSkipped: return None - intent = db.get(ChapterCoverIntent, candidate_id) - db.commit() return intent cutoff = now - timedelta(seconds=CHAPTER_COVER_CLAIM_TTL_SECONDS) @@ -141,13 +150,13 @@ def _claim_chapter_cover_intent_sync(db, chapter: Chapter, claim_token: str): claimed_at=now, attempt_count=1, ) - db.add(intent) - db.flush() - db.commit() + with transactional_sync(db): + db.add(intent) + db.flush() return intent -@shared_task(bind=True, max_retries=3, default_retry_delay=30) +@shared_task(bind=True, max_retries=3, default_retry_delay=30, ignore_result=True) def generate_chapter_cover(self, chapter_id: str): """ 为 chapter 生成封面。 @@ -263,30 +272,29 @@ def generate_chapter_cover(self, chapter_id: str): ) return {"status": "superseded_or_cancelled"} - asset = Asset( - id=asset_id, - asset_type="chapter_cover", - storage_key=cos_key, - url=url, - provider=settings.provider, - style_profile=style_for_image, - prompt_final=prompt_final, - status="completed", - ) - db.add(asset) - db.flush() + with transactional_sync(db): + asset = Asset( + id=asset_id, + asset_type="chapter_cover", + storage_key=cos_key, + url=url, + provider=settings.provider, + style_profile=style_for_image, + prompt_final=prompt_final, + status="completed", + ) + db.add(asset) + db.flush() - intent_db.asset_id = asset_id - intent_db.status = "completed" - intent_db.claim_token = None - intent_db.claimed_at = None - intent_db.error = None - intent_db.updated_at = datetime.now(timezone.utc) + intent_db.asset_id = asset_id + intent_db.status = "completed" + intent_db.claim_token = None + intent_db.claimed_at = None + intent_db.error = None + intent_db.updated_at = datetime.now(timezone.utc) - chapter_db = db.get(Chapter, chapter_id) - chapter_db.cover_asset_id = asset_id - - db.commit() + chapter_db = db.get(Chapter, chapter_id) + chapter_db.cover_asset_id = asset_id ms = (time.perf_counter() - t0) * 1000 logger.info( @@ -308,17 +316,17 @@ def generate_chapter_cover(self, chapter_id: str): except Exception as exc: if intent is not None: with get_sync_db() as db: - intent_db = db.get(ChapterCoverIntent, intent.id) - if ( - intent_db - and (intent_db.claim_token or "").strip() == claim_token - ): - intent_db.status = "failed" - intent_db.claim_token = None - intent_db.claimed_at = None - intent_db.error = str(exc) - intent_db.updated_at = datetime.now(timezone.utc) - db.commit() + with transactional_sync(db): + intent_db = db.get(ChapterCoverIntent, intent.id) + if ( + intent_db + and (intent_db.claim_token or "").strip() == claim_token + ): + intent_db.status = "failed" + intent_db.claim_token = None + intent_db.claimed_at = None + intent_db.error = str(exc) + intent_db.updated_at = datetime.now(timezone.utc) ms = (time.perf_counter() - t0) * 1000 logger.warning( "event=chapter_cover_task_failed chapter_id={} duration_ms={:.1f} error={} " diff --git a/api/app/tasks/memoir_quality_pass_tasks.py b/api/app/tasks/memoir_quality_pass_tasks.py index b6dfaa3..45e62f1 100644 --- a/api/app/tasks/memoir_quality_pass_tasks.py +++ b/api/app/tasks/memoir_quality_pass_tasks.py @@ -17,13 +17,15 @@ from sqlalchemy.orm import Session from app.agents.memoir.narrative_agent import NarrativeAgent from app.core.config import settings -from app.core.db import get_sync_db +from app.core.db import get_sync_db, transactional_sync from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger from app.core.memoir_pipeline_progress import merge_pipeline_run from app.features.memoir.models import Chapter from app.features.memoir.repo import mark_chapter_dirty_sync from app.features.story.models import Story +from app.features.memoir.constants import memoir +from app.features.story.constants import story logger = get_logger(__name__) @@ -54,7 +56,7 @@ def _polish_story_title( return False body = (story.canonical_markdown or "").strip() - if len(body) < settings.story_title_min_body_chars: + if len(body) < story.title_min_body_chars: return False narrative_agent = NarrativeAgent() @@ -76,7 +78,7 @@ def _polish_story_title( return True -@shared_task(bind=True, max_retries=2, default_retry_delay=30) +@shared_task(bind=True, max_retries=2, default_retry_delay=30, ignore_result=True) def memoir_quality_pass( self, user_id: str, @@ -89,7 +91,7 @@ def memoir_quality_pass( Runs asynchronously after the fast draft is committed and visible. """ qptid = str(self.request.id) - if not settings.memoir_quality_pass_enabled: + if not memoir.quality_pass_enabled: if memoir_correlation_id: merge_pipeline_run( memoir_correlation_id, @@ -152,34 +154,32 @@ def memoir_quality_pass( == "en" else "zh" ) - for sid in story_ids: - story = db.get(Story, sid) - if not story or story.user_id != user_id: - continue + with transactional_sync(db): + for sid in story_ids: + story = db.get(Story, sid) + if not story or story.user_id != user_id: + continue - chapter_category = story.stage or "summary" - if _polish_story_title( - db, - story, - llm, - chapter_category=chapter_category, - language=user_language, - ): - titles_polished += 1 - stmt = select(Chapter.id).where( - Chapter.user_id == user_id, - Chapter.category == chapter_category, - Chapter.is_active == True, # noqa: E712 - ) - ch_id = db.execute(stmt).scalar_one_or_none() - if ch_id: - chapters_dirtied.add(str(ch_id)) + chapter_category = story.stage or "summary" + if _polish_story_title( + db, + story, + llm, + chapter_category=chapter_category, + language=user_language, + ): + titles_polished += 1 + stmt = select(Chapter.id).where( + Chapter.user_id == user_id, + Chapter.category == chapter_category, + Chapter.is_active == True, # noqa: E712 + ) + ch_id = db.execute(stmt).scalar_one_or_none() + if ch_id: + chapters_dirtied.add(str(ch_id)) - for ch_id in chapters_dirtied: - mark_chapter_dirty_sync(db, ch_id) - - if titles_polished > 0: - db.commit() + for ch_id in chapters_dirtied: + mark_chapter_dirty_sync(db, ch_id) elapsed = time.perf_counter() - t0 duration_ms = elapsed * 1000 diff --git a/api/app/tasks/memoir_tasks.py b/api/app/tasks/memoir_tasks.py index 0d6b6da..6500a10 100644 --- a/api/app/tasks/memoir_tasks.py +++ b/api/app/tasks/memoir_tasks.py @@ -14,21 +14,22 @@ from celery import shared_task from celery.exceptions import Retry from celery.result import AsyncResult from sqlalchemy import func, select +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from app.agents.chat.background_voice import infer_background_voice from app.agents.chat.prompts_profile import format_user_profile_context from app.agents.memoir import MemoirOrchestrator from app.agents.stage_constants import normalize_chapter_category +from app.core.business_telemetry import business_span from app.core.chapter_pipeline_lock import ( acquire_chapter_pipeline_lock as _acquire_chapter_lock, ) from app.core.chapter_pipeline_lock import ( release_chapter_pipeline_lock as _release_chapter_lock, ) -from app.core.business_telemetry import business_span from app.core.config import settings -from app.core.db import AsyncSessionLocal, get_sync_db +from app.core.db import AsyncSessionLocal, get_sync_db, transactional_sync from app.core.dependencies import get_embedding_provider from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger @@ -69,11 +70,19 @@ from app.features.memoir.story_pipeline_sync import ( run_story_pipeline_for_category_batch, ) from app.features.memory.service import MemoryService +from app.features.memory.repo import get_transcript_source_by_segment_id_sync from app.features.user.models import User from app.tasks.celery_app import celery_app +from app.core.redis_sync import get_sync_redis +from app.core.runtime_constants import llm_defaults, redis_defaults +from app.features.memoir.constants import memoir +from app.features.story.constants import story logger = get_logger(__name__) -_REDIS_CLIENTS: dict[bool, redis.Redis] = {} + + +def _get_redis_client(*, decode_responses: bool = False) -> redis.Redis: + return get_sync_redis(decode_responses=decode_responses) def _run_post_pipeline_commit( @@ -146,7 +155,7 @@ def _get_llm_fast(): async def _memory_ingest_transcripts_batch( user_id: str, - items: list[tuple[str, str, dict | None]], + items: list[tuple[str, str, dict | None, str | None]], *, memoir_correlation_id: str, ) -> list[str]: @@ -159,6 +168,47 @@ async def _memory_ingest_transcripts_batch( ) +def _phase1_memory_ingest_batch_sync( + db: Session, + user_id: str, + ingest_items: list[tuple[str, str, dict | None, str | None]], + *, + memoir_correlation_id: str, +) -> list[str]: + """Run phase1 batch memory ingest; resolve segment unique races, else fail the task.""" + if not ingest_items: + return [] + try: + return asyncio.run( + _memory_ingest_transcripts_batch( + user_id, + ingest_items, + memoir_correlation_id=memoir_correlation_id, + ) + ) + except IntegrityError: + logger.warning( + "event=memoir_phase1_memory_ingest_race user_id={} item_count={} " + "msg=Concurrent segment ingest; resolving existing sources", + user_id, + len(ingest_items), + ) + resolved: list[str] = [] + for _conv_id, _text, _lineage, segment_id in ingest_items: + sid = (segment_id or "").strip() + if not sid: + continue + existing = get_transcript_source_by_segment_id_sync( + db, + user_id=user_id, + segment_id=sid, + ) + if existing is None: + raise + resolved.append(existing.id) + return resolved + + async def _memory_retrieve_evidence( user_id: str, query: str, @@ -171,21 +221,8 @@ async def _memory_retrieve_evidence( return bundle.model_dump() -def _get_redis_client(*, decode_responses: bool = False) -> redis.Redis: - from app.core.config import settings - - client = _REDIS_CLIENTS.get(decode_responses) - if client is None: - client = redis.from_url( - settings.redis_url, - decode_responses=decode_responses, - ) - _REDIS_CLIENTS[decode_responses] = client - return client - - def _chapter_lock_ttl() -> int: - return int(settings.chapter_pipeline_lock_ttl_seconds) + return int(story.chapter_pipeline_lock_ttl_seconds) def _update_task_status_sync( @@ -210,7 +247,7 @@ def _update_task_status_sync( task_info["result"] = result r.hset(key, task_id, json.dumps(task_info)) - r.expire(key, 3600) # 1小时过期 + r.expire(key, redis_defaults.task_tracker_ttl_seconds) logger.debug("任务状态已更新: task_id={} status={}", task_id, status) except Exception as e: @@ -309,7 +346,7 @@ def _should_trigger_phase2( chapter_category: str, current_segment_chars: int, ) -> bool: - if current_segment_chars >= int(settings.memoir_narrative_immediate_char_threshold): + if current_segment_chars >= int(memoir.narrative_immediate_char_threshold): return True user_convs = select(Conversation.id).where( Conversation.user_id == user_id, @@ -326,9 +363,9 @@ def _should_trigger_phase2( ) row = db.execute(stmt).one() count, total_chars = int(row[0] or 0), int(row[1] or 0) - if count >= int(settings.memoir_narrative_batch_min_segments): + if count >= int(memoir.narrative_batch_min_segments): return True - if total_chars >= int(settings.memoir_narrative_batch_min_chars): + if total_chars >= int(memoir.narrative_batch_min_chars): return True return False @@ -385,8 +422,8 @@ def _persist_phase2_route_defer( 返回 Celery 任务的 result dict(``status=deferred``)。 """ now_ts = datetime.now(timezone.utc) - max_attempts = int(settings.memoir_route_defer_max_attempts) - defer_seconds = float(settings.memoir_route_defer_seconds) + max_attempts = int(memoir.route_defer_max_attempts) + defer_seconds = float(memoir.route_defer_seconds) deferred_until_ts = now_ts + timedelta(seconds=max(defer_seconds, 1.0)) rows: list[Segment] = [] @@ -396,19 +433,18 @@ def _persist_phase2_route_defer( saturated_segments = 0 new_max_attempts_reached = False - for seg in rows: - prev_count = int(seg.narrative_defer_count or 0) - seg.narrative_defer_count = prev_count + 1 - seg.narrative_defer_reason = defer_reason - seg.narrative_last_attempt_at = now_ts - if seg.narrative_defer_count >= max_attempts: - seg.narrative_deferred_until = None - saturated_segments += 1 - new_max_attempts_reached = True - else: - seg.narrative_deferred_until = deferred_until_ts - - db.commit() + with transactional_sync(db): + for seg in rows: + prev_count = int(seg.narrative_defer_count or 0) + seg.narrative_defer_count = prev_count + 1 + seg.narrative_defer_reason = defer_reason + seg.narrative_last_attempt_at = now_ts + if seg.narrative_defer_count >= max_attempts: + seg.narrative_deferred_until = None + saturated_segments += 1 + new_max_attempts_reached = True + else: + seg.narrative_deferred_until = deferred_until_ts next_task_id: str | None = None if rows and not new_max_attempts_reached: @@ -469,7 +505,7 @@ def _schedule_phase2_timeout( ) -> str | None: """Reset countdown for Phase 2 narrative for one category。返回 Celery task_id。""" _revoke_phase2_timeout(user_id, chapter_category) - countdown = float(max(1.0, settings.memoir_narrative_batch_max_wait_seconds)) + countdown = float(max(1.0, memoir.narrative_batch_max_wait_seconds)) p2_kwargs: dict = {} if memoir_correlation_id: p2_kwargs["memoir_correlation_id"] = memoir_correlation_id @@ -504,7 +540,7 @@ def _dispatch_phase2_immediate( "kwargs": p2_kwargs, } fixed_tid: str | None = None - if settings.memoir_phase2_singleflight_immediate: + if memoir.phase2_singleflight_immediate: fixed_tid = _phase2_immediate_task_id(user_id, chapter_category) send_kw["task_id"] = fixed_tid ar = celery_app.send_task("app.tasks.memoir_tasks.process_memoir_phase2", **send_kw) @@ -515,7 +551,7 @@ def _dispatch_phase2_immediate( user_id, chapter_category, memoir_correlation_id or "", - "singleflight" if settings.memoir_phase2_singleflight_immediate else "unique", + "singleflight" if memoir.phase2_singleflight_immediate else "unique", out_tid or "", ) return out_tid @@ -580,7 +616,7 @@ def dispatch_pending_memoir_phase2_for_user(user_id: str) -> None: ) -@shared_task(bind=True, max_retries=3, default_retry_delay=30) +@shared_task(bind=True, max_retries=3, default_retry_delay=30, ignore_result=True) def process_memoir_phase2( self, user_id: str, @@ -734,9 +770,9 @@ def process_memoir_phase2( segment_texts = [seg.user_input_text or "" for seg in category_segments] combined_text = "\n\n".join(segment_texts) n_units = len(category_segments) - evidence_top_k = int(settings.evidence_top_k_default) - if n_units > int(settings.evidence_large_batch_threshold): - evidence_top_k = int(settings.evidence_top_k_large_batch) + evidence_top_k = int(story.evidence_top_k_default) + if n_units > int(story.evidence_large_batch_threshold): + evidence_top_k = int(story.evidence_top_k_large_batch) try: memory_evidence = asyncio.run( _memory_retrieve_evidence( @@ -813,34 +849,33 @@ def process_memoir_phase2( image_settings.enabled and chapter_needs_cover_enqueue(chapter) ) - stmt_book = ( - select(Book) - .where(Book.user_id == user_id) - .order_by(Book.updated_at.desc()) - ) - result_book = db.execute(stmt_book) - book = result_book.scalar_one_or_none() - if not book: - book = Book( - id=str(uuid.uuid4()), - user_id=user_id, - title="我的回忆录", - total_pages=0, - total_words=0, - cover_image_url=None, + with transactional_sync(db): + stmt_book = ( + select(Book) + .where(Book.user_id == user_id) + .order_by(Book.updated_at.desc()) ) - db.add(book) - book.has_update = True - book.last_update_chapter_id = chapter.id + result_book = db.execute(stmt_book) + book = result_book.scalar_one_or_none() + if not book: + book = Book( + id=str(uuid.uuid4()), + user_id=user_id, + title="我的回忆录", + total_pages=0, + total_words=0, + cover_image_url=None, + ) + db.add(book) + book.has_update = True + book.last_update_chapter_id = chapter.id - if needs_cover_enqueue: - chapters_to_enqueue.add(chapter.id) + if needs_cover_enqueue: + chapters_to_enqueue.add(chapter.id) - for seg in category_segments: - seg.narrated = True - seg.processed = True - - db.commit() + for seg in category_segments: + seg.narrated = True + seg.processed = True _run_post_pipeline_commit( user_id=user_id, @@ -925,7 +960,7 @@ def process_memoir_phase2( raise self.retry(exc=e) from e -@shared_task(bind=True, max_retries=3, default_retry_delay=60) +@shared_task(bind=True, max_retries=3, default_retry_delay=60, ignore_result=True) def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): """ Phase 1:记忆 ingest + 抽取/分类;持久化 topic_category / skip_narrative; @@ -991,6 +1026,30 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): ) return {"status": "no_segments"} + for seg in segments: + db.refresh(seg) + + lineage_missing = [ + str(seg.id) + for seg in segments + if (seg.agent_response or "").strip() + and not isinstance(getattr(seg, "lineage_json", None), dict) + ] + if lineage_missing: + logger.warning( + "event=memoir_phase1_lineage_pending user_id={} task_id={} " + "segment_ids={} msg=Agent response persisted without lineage; retrying", + user_id, + task_id, + lineage_missing, + ) + raise self.retry( + countdown=15, + exc=RuntimeError( + f"memoir_phase1_lineage_pending: {len(lineage_missing)} segments" + ), + ) + merge_pipeline_run( memoir_correlation_id, { @@ -1002,8 +1061,9 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): ) ingest_t0 = time.perf_counter() with business_span("memoir.phase1.ingest"): - ingest_items: list[tuple[str, str, dict | None]] = [] + ingest_items: list[tuple[str, str, dict | None, str | None]] = [] non_empty_segments: list = [] + ingested_source_ids: list[str] = [] for seg in segments: text = (seg.user_input_text or "").strip() if not text: @@ -1011,37 +1071,46 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): conv_id = getattr(seg, "conversation_id", None) or "" ln = getattr(seg, "lineage_json", None) lineage_payload = ln if isinstance(ln, dict) else None - ingest_items.append((conv_id, text, lineage_payload)) + if lineage_payload is None and not (seg.agent_response or "").strip(): + logger.debug( + "event=memoir_phase1_skip_memory_ingest segment_id={} " + "msg=No lineage and no agent response yet", + seg.id, + ) + continue + existing = get_transcript_source_by_segment_id_sync( + db, + user_id=user_id, + segment_id=str(seg.id), + ) + if existing is not None: + ingested_source_ids.append(existing.id) + continue + ingest_items.append( + (conv_id, text, lineage_payload, str(seg.id)), + ) non_empty_segments.append(seg) - ingested_source_ids: list[str] = [] if ingest_items: - try: - ingested_source_ids = asyncio.run( - _memory_ingest_transcripts_batch( - user_id, - ingest_items, - memoir_correlation_id=memoir_correlation_id, - ) - ) - for seg, sid in zip( - non_empty_segments, ingested_source_ids, strict=True - ): - logger.info( - "event=memory_transcript_ingested user_id={} task_id={} " - "source_id={} conversation_id={} segment_id={} transcript_chars={}", - user_id, - task_id, - sid, - getattr(seg, "conversation_id", None) or "", - seg.id, - len((seg.user_input_text or "").strip()), - ) - except Exception as e: - logger.warning( - "Memory batch ingest 失败: {} exc_type={}", - e, - type(e).__name__, + new_source_ids = _phase1_memory_ingest_batch_sync( + db, + user_id, + ingest_items, + memoir_correlation_id=memoir_correlation_id, + ) + ingested_source_ids.extend(new_source_ids) + for seg, sid in zip( + non_empty_segments, new_source_ids, strict=True + ): + logger.info( + "event=memory_transcript_ingested user_id={} task_id={} " + "source_id={} conversation_id={} segment_id={} transcript_chars={}", + user_id, + task_id, + sid, + getattr(seg, "conversation_id", None) or "", + seg.id, + len((seg.user_input_text or "").strip()), ) ingest_elapsed = time.perf_counter() - ingest_t0 merge_pipeline_run( @@ -1059,10 +1128,10 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): llm = _get_llm() llm_fast = _get_llm_fast() or llm - if (settings.llm_fast_model or "").strip(): + if (llm_defaults.fast_model or "").strip(): logger.info( "event=llm_fast_tier_used pipeline=memoir_prepare_batches model={}", - settings.llm_fast_model, + llm_defaults.fast_model, ) prep_t0 = time.perf_counter() @@ -1118,42 +1187,43 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): f"memoir_phase1_missing_category: {len(missing_cat)} segments" ) - for seg in segments: - cat = prepared.segment_chapter_category[str(seg.id)] - seg.topic_category = cat - is_skip = str(seg.id) in skip_ids - seg.skip_narrative = is_skip - seg.narrated = False - if is_skip: - seg.processed = True - - db.flush() - categories_for_phase2: Set[str] = set() phase2_immediate: list[str] = [] phase2_timeout: list[str] = [] woke_up_by_category: dict[str, int] = {} - for chapter_category, cat_segments in prepared.category_to_segments.items(): - batch_non_skip = [ - s - for s in cat_segments - if str(s.id) not in prepared.segment_skip_story_ids - ] - if not batch_non_skip: - continue - woke = _wake_deferred_segments_for_category( - db, user_id, chapter_category - ) - if woke: - woke_up_by_category[chapter_category] = woke - max_chars = max( - len((s.user_input_text or "").strip()) for s in batch_non_skip - ) - categories_for_phase2.add(chapter_category) - if _should_trigger_phase2(db, user_id, chapter_category, max_chars): - phase2_immediate.append(chapter_category) - else: - phase2_timeout.append(chapter_category) + with transactional_sync(db): + for seg in segments: + cat = prepared.segment_chapter_category[str(seg.id)] + seg.topic_category = cat + is_skip = str(seg.id) in skip_ids + seg.skip_narrative = is_skip + seg.narrated = False + if is_skip: + seg.processed = True + + db.flush() + + for chapter_category, cat_segments in prepared.category_to_segments.items(): + batch_non_skip = [ + s + for s in cat_segments + if str(s.id) not in prepared.segment_skip_story_ids + ] + if not batch_non_skip: + continue + woke = _wake_deferred_segments_for_category( + db, user_id, chapter_category + ) + if woke: + woke_up_by_category[chapter_category] = woke + max_chars = max( + len((s.user_input_text or "").strip()) for s in batch_non_skip + ) + categories_for_phase2.add(chapter_category) + if _should_trigger_phase2(db, user_id, chapter_category, max_chars): + phase2_immediate.append(chapter_category) + else: + phase2_timeout.append(chapter_category) if woke_up_by_category: logger.info( @@ -1163,8 +1233,6 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): woke_up_by_category, ) - db.commit() - merge_pipeline_run( memoir_correlation_id, { @@ -1278,7 +1346,7 @@ def process_memoir_phase1(self, user_id: str, segment_ids: List[str]): _update_task_status_sync(user_id, task_id, "failure", {"error": str(e)}) raise self.retry(exc=e) from e -@shared_task(bind=True, max_retries=3, default_retry_delay=30) +@shared_task(bind=True, max_retries=3, default_retry_delay=30, ignore_result=True) def generate_chapter_content(self, user_id: str, stage: str, new_content: str): """ 单独生成章节内容的任务(用于实时更新) @@ -1360,7 +1428,8 @@ def generate_chapter_content(self, user_id: str, stage: str, new_content: str): exc=RuntimeError("story_pipeline returned no chapter"), countdown=30, ) - db.commit() + with transactional_sync(db): + pass # commit pending pipeline writes db.refresh(chapter) ch_ids: set[str] = {str(chapter.id)} diff --git a/api/app/tasks/memory_compaction_tasks.py b/api/app/tasks/memory_compaction_tasks.py index 4002cf7..d69e326 100644 --- a/api/app/tasks/memory_compaction_tasks.py +++ b/api/app/tasks/memory_compaction_tasks.py @@ -10,8 +10,7 @@ from typing import Any from celery import shared_task from app.core.business_telemetry import business_span -from app.core.config import settings -from app.core.db import AsyncSessionLocal +from app.core.db import AsyncSessionLocal, transactional from app.core.logging import get_logger from app.core.memory_compaction_schedule import ( finalize_memory_compaction_run, @@ -23,6 +22,7 @@ from app.core.memory_compaction_schedule import ( from app.core.redis_lock import acquire_redis_lock, release_redis_lock from app.features.memory.repo import list_users_with_recent_chunks from app.features.memory.service import MemoryService +from app.features.memory.constants import memory logger = get_logger(__name__) @@ -37,41 +37,53 @@ async def _run_memory_compaction_async( context: dict[str, Any] | None, ) -> dict[str, Any]: async with AsyncSessionLocal() as db: - service = MemoryService(db) - out = await service.compact_user(user_id, context) - await db.commit() - return out + async with transactional(db): + service = MemoryService(db) + return await service.compact_user(user_id, context) -@shared_task -def memory_compaction_sweep() -> dict[str, Any]: +@shared_task(bind=True, ignore_result=True) +def memory_compaction_sweep(self) -> dict[str, Any]: """Beat:为近期有记忆写入的用户调度 compaction(debounce 仍由 schedule 合并)。""" t0 = time.perf_counter() - if not settings.memory_compaction_enabled: + if not memory.compaction_enabled: return {"skipped": True, "reason": "disabled"} - hours = int(settings.memory_compaction_sweep_recent_hours) + hours = int(memory.compaction_sweep_recent_hours) with business_span("memory.compaction.sweep", hours=hours): user_ids = asyncio.run(_list_users_with_recent_chunks_async(hours)) ctx_base: dict[str, Any] = {"trigger_source": "beat", "sweep_hours": hours} + scheduled = 0 + failed = 0 for uid in user_ids: - schedule_memory_compaction_run(uid, dict(ctx_base)) + try: + schedule_memory_compaction_run(uid, dict(ctx_base)) + scheduled += 1 + except Exception as exc: + failed += 1 + logger.warning( + "event=memory_compaction_sweep_schedule_failed user_id={} exc={} " + "msg=单用户 compaction 调度失败,继续扫描", + uid, + exc, + ) ms = (time.perf_counter() - t0) * 1000 logger.info( - "event=memory_compaction_sweep_done hours={} scheduled_users={} duration_ms={:.1f} " - "msg=记忆压缩定时扫描已调度", + "event=memory_compaction_sweep_done hours={} scheduled_users={} failed_users={} " + "duration_ms={:.1f} msg=记忆压缩定时扫描已调度", hours, - len(user_ids), + scheduled, + failed, ms, ) - return {"scheduled": len(user_ids), "user_ids": user_ids} + return {"scheduled": scheduled, "failed": failed, "hours": hours} -@shared_task(bind=True, max_retries=12, default_retry_delay=20) +@shared_task(bind=True, max_retries=12, default_retry_delay=20, ignore_result=True) def memory_compaction_run( self, user_id: str, context: dict[str, Any] | None = None ) -> dict[str, Any]: run_t0 = time.perf_counter() - if not settings.memory_compaction_enabled: + if not memory.compaction_enabled: return {"skipped": True, "reason": "disabled"} ctx = dict(context or {}) @@ -83,7 +95,7 @@ def memory_compaction_run( lock = acquire_redis_lock( f"lock:memory_compaction:{user_id}", - ttl_seconds=settings.memory_compaction_lock_ttl_seconds, + ttl_seconds=memory.compaction_lock_ttl_seconds, ) if lock is None: ms = (time.perf_counter() - run_t0) * 1000 diff --git a/api/app/tasks/memory_enrichment_tasks.py b/api/app/tasks/memory_enrichment_tasks.py index 2a70136..4a1c296 100644 --- a/api/app/tasks/memory_enrichment_tasks.py +++ b/api/app/tasks/memory_enrichment_tasks.py @@ -1,7 +1,7 @@ """ Memory pipeline Celery tasks — retry embedding and enrichment after durable ingest. -Tasks are routed to ``settings.celery_memory_enrichment_queue`` (default ``memory_idle``); +Tasks are routed to ``celery_defaults.memory_enrichment_queue`` (default ``memory_idle``); run workers with ``-Q celery,memory_idle`` or a dedicated low-priority worker for that queue. """ @@ -18,6 +18,8 @@ from app.core.dependencies import get_embedding_provider from app.core.logging import get_logger from app.core.memoir_pipeline_progress import merge_fanout_item from app.features.memory.service import MemoryService +from app.core.runtime_constants import celery_defaults +from app.features.memory.constants import memory logger = get_logger(__name__) @@ -29,7 +31,6 @@ async def _enrich_memory_source_async( async with AsyncSessionLocal() as db: service = MemoryService(db) await service.enrich_source(user_id, source_id, llm=None) - await db.commit() async def _embed_memory_source_async( @@ -43,7 +44,6 @@ async def _embed_memory_source_async( source_id, raise_on_failure=True, ) - await db.commit() return result @@ -58,7 +58,7 @@ def schedule_memory_embedding( sid = (source_id or "").strip() if not uid or not sid: return None - q = (settings.celery_memory_enrichment_queue or "").strip() or "memory_idle" + q = (celery_defaults.memory_enrichment_queue or "").strip() or "memory_idle" try: task = cast(Any, embed_memory_source) ar = task.apply_async( @@ -103,13 +103,13 @@ def schedule_memory_enrichment( When ``memoir_correlation_id`` is set, records ``fanout.memory_enrichment`` as enqueued for eval / pipeline progress (same as the former Phase1 loop). """ - if not settings.memory_enrichment_enabled: + if not memory.enrichment_enabled: return None uid = (user_id or "").strip() sid = (source_id or "").strip() if not uid or not sid: return None - q = (settings.celery_memory_enrichment_queue or "").strip() or "memory_idle" + q = (celery_defaults.memory_enrichment_queue or "").strip() or "memory_idle" try: task = cast(Any, enrich_memory_source) ar = task.apply_async( @@ -222,7 +222,7 @@ def enrich_memory_source( Post-ingest enrichment: one LLM call → session summary + structured facts. Runs outside the memoir Phase1 hot path so narrative generation isn't blocked. """ - if not settings.memory_enrichment_enabled: + if not memory.enrichment_enabled: return {"status": "disabled"} tid = str(self.request.id) diff --git a/api/app/tasks/story_image_tasks.py b/api/app/tasks/story_image_tasks.py index 50ad6f8..5acd473 100644 --- a/api/app/tasks/story_image_tasks.py +++ b/api/app/tasks/story_image_tasks.py @@ -15,7 +15,7 @@ from PIL import Image from sqlalchemy import and_, func, or_, select, update from app.agents.image_prompt import get_image_prompt_orchestrator -from app.core.db import get_sync_db +from app.core.db import get_sync_db, transactional_sync from app.core.dependencies import get_image_generator from app.core.logging import get_logger from app.core.memoir_pipeline_progress import merge_fanout_item @@ -34,21 +34,25 @@ STORY_IMAGE_LOCK_TTL_SECONDS = 1800 STORY_IMAGE_CLAIM_TTL_SECONDS = 1800 +class _ClaimSkipped(Exception): + """Concurrent worker won the intent claim; abort transactional block.""" + + def _enqueue_chapter_effects_after_image_backfill(story_id: str) -> None: """主图回填后标记关联章节 dirty,并经统一 post-commit 入口派发章节物化与 compaction。""" try: with get_sync_db() as session: from app.features.memoir import repo as memoir_repo - story = session.get(Story, story_id) - if not story: - return - uid = str(story.user_id) - memoir_repo.mark_chapters_dirty_for_story_sync(session, story_id) - chapter_ids = memoir_repo.get_chapter_ids_linked_to_story_sync( - session, story_id - ) - session.commit() + with transactional_sync(session): + story = session.get(Story, story_id) + if not story: + return + uid = str(story.user_id) + memoir_repo.mark_chapters_dirty_for_story_sync(session, story_id) + chapter_ids = memoir_repo.get_chapter_ids_linked_to_story_sync( + session, story_id + ) user_id = uid except Exception as exc: logger.warning( @@ -120,37 +124,40 @@ def _claim_story_image_intent_sync(db, story_id: str, claim_token: str): if not candidate_id: return None - claimed = db.execute( - update(StoryImageIntent) - .where(StoryImageIntent.id == candidate_id) - .where(_story_image_claimable_clause(now)) - .values( - status="processing", - claim_token=claim_token, - claimed_at=now, - updated_at=now, - error=None, - attempt_count=func.coalesce(StoryImageIntent.attempt_count, 0) + 1, - ) - ) - if (claimed.rowcount or 0) != 1: - db.rollback() - return None + try: + with transactional_sync(db): + claimed = db.execute( + update(StoryImageIntent) + .where(StoryImageIntent.id == candidate_id) + .where(_story_image_claimable_clause(now)) + .values( + status="processing", + claim_token=claim_token, + claimed_at=now, + updated_at=now, + error=None, + attempt_count=func.coalesce(StoryImageIntent.attempt_count, 0) + + 1, + ) + ) + if (claimed.rowcount or 0) != 1: + raise _ClaimSkipped() - row = ( - db.execute( - select(StoryImageIntent, Story) - .join(Story, StoryImageIntent.story_id == Story.id) - .where(StoryImageIntent.id == candidate_id) - ) - .unique() - .first() - ) - db.commit() + row = ( + db.execute( + select(StoryImageIntent, Story) + .join(Story, StoryImageIntent.story_id == Story.id) + .where(StoryImageIntent.id == candidate_id) + ) + .unique() + .first() + ) + except _ClaimSkipped: + return None return row -@shared_task(bind=True, max_retries=3, default_retry_delay=30) +@shared_task(bind=True, max_retries=3, default_retry_delay=30, ignore_result=True) def generate_story_image(self, story_id: str, memoir_correlation_id: str | None = None): """ 为 story 生成主插图。 @@ -208,14 +215,14 @@ def generate_story_image(self, story_id: str, memoir_correlation_id: str | None ).strip() if len(plain) < min_body: with get_sync_db() as db: - intent_db = db.get(StoryImageIntent, intent.id) - if intent_db and (intent_db.status or "").strip() == "processing": - intent_db.status = "skipped" - intent_db.error = f"body_below_min_chars:{len(plain)}" - intent_db.claim_token = None - intent_db.claimed_at = None - intent_db.updated_at = datetime.now(timezone.utc) - db.commit() + with transactional_sync(db): + intent_db = db.get(StoryImageIntent, intent.id) + if intent_db and (intent_db.status or "").strip() == "processing": + intent_db.status = "skipped" + intent_db.error = f"body_below_min_chars:{len(plain)}" + intent_db.claim_token = None + intent_db.claimed_at = None + intent_db.updated_at = datetime.now(timezone.utc) logger.info( "generate_story_image: skipped body too short story={} len={} min={}", story_id, @@ -299,96 +306,93 @@ def generate_story_image(self, story_id: str, memoir_correlation_id: str | None ) return {"status": "superseded_or_cancelled"} - asset = Asset( - id=asset_id, - asset_type="story_image", - storage_key=cos_key, - url=url, - provider=settings.provider, - style_profile=style_for_image, - prompt_final=prompt_final, - status="completed", - ) - db.add(asset) - db.flush() - - story_db = db.get(Story, story_id) - target_vid = intent_db.story_version_id or story_db.current_version_id - current_vid = story_db.current_version_id - - intent_db.asset_id = asset_id - intent_db.status = "completed" - intent_db.claim_token = None - intent_db.claimed_at = None - intent_db.error = None - intent_db.updated_at = datetime.now(timezone.utc) - db.flush() - - # 仅当 intent 仍指向当前版本时回填正文,避免慢任务/重试把图插到新版本上 - if not target_vid or target_vid != current_vid: - db.commit() - logger.debug( - "generate_story_image: stale intent skip backfill story={} " - "intent_ver={} current={} url={} asset={}", - story_id, - target_vid, - current_vid, - url, - asset_id, + with transactional_sync(db): + asset = Asset( + id=asset_id, + asset_type="story_image", + storage_key=cos_key, + url=url, + provider=settings.provider, + style_profile=style_for_image, + prompt_final=prompt_final, + status="completed", ) - merge_fanout_item( - memoir_correlation_id, - list_name="story_images", - id_field="story_id", - item_id=story_id, - task_id=celery_tid, - status="success_stale", + db.add(asset) + db.flush() + + story_db = db.get(Story, story_id) + target_vid = intent_db.story_version_id or story_db.current_version_id + current_vid = story_db.current_version_id + + intent_db.asset_id = asset_id + intent_db.status = "completed" + intent_db.claim_token = None + intent_db.claimed_at = None + intent_db.error = None + intent_db.updated_at = datetime.now(timezone.utc) + db.flush() + + # 仅当 intent 仍指向当前版本时回填正文,避免慢任务/重试把图插到新版本上 + if not target_vid or target_vid != current_vid: + logger.debug( + "generate_story_image: stale intent skip backfill story={} " + "intent_ver={} current={} url={} asset={}", + story_id, + target_vid, + current_vid, + url, + asset_id, + ) + merge_fanout_item( + memoir_correlation_id, + list_name="story_images", + id_field="story_id", + item_id=story_id, + task_id=celery_tid, + status="success_stale", + ) + return {"status": "success_stale", "asset_id": asset_id} + + ver = db.get(StoryVersion, target_vid) + if not ver: + merge_fanout_item( + memoir_correlation_id, + list_name="story_images", + id_field="story_id", + item_id=story_id, + task_id=celery_tid, + status="success_no_snapshot", + ) + return {"status": "success_no_snapshot", "asset_id": asset_id} + + base_md = strip_asset_image_refs_from_markdown(ver.markdown_snapshot or "") + alt_text = (getattr(intent_db, "prompt_brief", None) or "").strip() + if not alt_text: + alt_text = (getattr(intent_db, "caption", None) or "").strip() + backfilled_md = backfill_image_into_markdown( + base_md, + asset_id=asset_id, + image_alt=alt_text or "主插图", ) - return {"status": "success_stale", "asset_id": asset_id} - - ver = db.get(StoryVersion, target_vid) - if not ver: - db.commit() - merge_fanout_item( - memoir_correlation_id, - list_name="story_images", - id_field="story_id", - item_id=story_id, - task_id=celery_tid, - status="success_no_snapshot", + max_stmt = select(func.max(StoryVersion.version_no)).where( + StoryVersion.story_id == story_id ) - return {"status": "success_no_snapshot", "asset_id": asset_id} - - base_md = strip_asset_image_refs_from_markdown(ver.markdown_snapshot or "") - alt_text = (getattr(intent_db, "prompt_brief", None) or "").strip() - if not alt_text: - alt_text = (getattr(intent_db, "caption", None) or "").strip() - backfilled_md = backfill_image_into_markdown( - base_md, - asset_id=asset_id, - image_alt=alt_text or "主插图", - ) - max_stmt = select(func.max(StoryVersion.version_no)).where( - StoryVersion.story_id == story_id - ) - max_no = db.execute(max_stmt).scalar() - version_no = (max_no or 0) + 1 - new_ver = StoryVersion( - id=str(uuid.uuid4()), - story_id=story_id, - version_no=version_no, - markdown_snapshot=backfilled_md, - change_summary="主插图回填", - actor_type="system", - source_type="image_backfill", - parent_version_id=story_db.current_version_id, - ) - db.add(new_ver) - db.flush() - story_db.current_version_id = new_ver.id - story_db.canonical_markdown = backfilled_md - - db.commit() + max_no = db.execute(max_stmt).scalar() + version_no = (max_no or 0) + 1 + new_ver = StoryVersion( + id=str(uuid.uuid4()), + story_id=story_id, + version_no=version_no, + markdown_snapshot=backfilled_md, + change_summary="主插图回填", + actor_type="system", + source_type="image_backfill", + parent_version_id=story_db.current_version_id, + ) + db.add(new_ver) + db.flush() + story_db.current_version_id = new_ver.id + story_db.canonical_markdown = backfilled_md _enqueue_chapter_effects_after_image_backfill(story_id) @@ -420,18 +424,18 @@ def generate_story_image(self, story_id: str, memoir_correlation_id: str | None except Exception as exc: if intent is not None: with get_sync_db() as db: - intent_db = db.get(StoryImageIntent, intent.id) - if ( - intent_db - and (intent_db.status or "").strip() != "completed" - and (intent_db.claim_token or "").strip() == claim_token - ): - intent_db.status = "failed" - intent_db.claim_token = None - intent_db.claimed_at = None - intent_db.error = str(exc) - intent_db.updated_at = datetime.now(timezone.utc) - db.commit() + with transactional_sync(db): + intent_db = db.get(StoryImageIntent, intent.id) + if ( + intent_db + and (intent_db.status or "").strip() != "completed" + and (intent_db.claim_token or "").strip() == claim_token + ): + intent_db.status = "failed" + intent_db.claim_token = None + intent_db.claimed_at = None + intent_db.error = str(exc) + intent_db.updated_at = datetime.now(timezone.utc) merge_fanout_item( memoir_correlation_id, list_name="story_images", diff --git a/api/app/tasks/story_title_tasks.py b/api/app/tasks/story_title_tasks.py index 7452c91..bcca9f7 100644 --- a/api/app/tasks/story_title_tasks.py +++ b/api/app/tasks/story_title_tasks.py @@ -4,14 +4,14 @@ import time from celery import shared_task -from app.core.db import get_sync_db +from app.core.db import get_sync_db, transactional_sync from app.core.llm_gateway import LlmGateway, LlmUseCase from app.core.logging import get_logger logger = get_logger(__name__) -@shared_task(bind=True, max_retries=2, default_retry_delay=15) +@shared_task(bind=True, max_retries=2, default_retry_delay=15, ignore_result=True) def generate_story_title_after_create( self, story_id: str, @@ -128,8 +128,8 @@ def generate_story_title_after_create( ms, ) return {"status": "skip_placeholder"} - st.title = new_title - db.commit() + with transactional_sync(db): + st.title = new_title ms = (time.perf_counter() - t0) * 1000 logger.info( "event=story_title_task_done story_id={} user_id={} duration_ms={:.1f} " diff --git a/api/config/default.toml b/api/config/default.toml new file mode 100644 index 0000000..7788fc6 --- /dev/null +++ b/api/config/default.toml @@ -0,0 +1,239 @@ +# Life Echo API — default configuration (SSOT base layer) +# Environment overlays: config/development.toml | staging.toml | production.toml + +[deploy] +alembic_startup_fail_fast = false +access_token_expire_minutes = 120 +refresh_token_expire_days = 30 +refresh_token_reuse_grace_seconds = 30 +mock_sms_login_enabled = false +tencent_sms_sdk_app_id = "" +tencent_sms_sign_name = "" +tencent_sms_template_id = "" +tencent_cos_bucket = "" +tencent_cos_base_url = "" +enable_tts = true +memoir_image_enabled = false +enable_docs = true +wechat_pay_app_id = "" +wechat_pay_mch_id = "" +wechat_pay_private_key_path = "certs/apiclient_key.pem" +wechat_pay_cert_serial_no = "" +wechat_pay_notify_url = "" +wechat_pay_platform_public_key_path = "" +wechat_pay_platform_public_key_id = "" +alipay_app_id = "" +alipay_notify_url = "" +liblib_template_uuid = "" +log_level = "INFO" +otel_enabled = false +otel_exporter_otlp_endpoint = "http://localhost:48317" +api_cors_origins = "" + +[chat] +interview_max_tokens = 512 +interview_max_segments = 2 +interview_max_chars_per_segment = 380 +opening_max_tokens = 380 +profile_followup_max_tokens = 280 +history_max_pairs = 15 +history_max_chars = 6000 +era_context_enabled = true +stage_detection_enabled = true +stage_detection_max_tokens = 128 +interview_persona = "default" +interview_temperature = 0.93 +memory_retrieval_enabled = true +memory_top_k = 8 +memory_evidence_max_chars = 4096 +memory_safe_evidence_format_enabled = true +reply_planner_llm_enabled = false +reply_planner_max_tokens = 256 +reply_planner_temperature = 0.2 +re_greeting_enabled = true +re_greeting_idle_hours = 6.0 +topic_chips_enabled = true +topic_chips_max = 4 +input_normalize_enabled = true +input_normalize_mode = "rules" +input_normalize_llm_max_tokens = 512 +input_normalize_llm_max_input_chars = 8000 +input_normalize_llm_voice_only = true +profile_max_turns = 8 +profile_extract_max_tokens = 512 + +[memoir] +fidelity_check_enabled = true +fidelity_check_max_tokens = 512 +oral_normalize_enabled = true +oral_normalize_mode = "rules" +oral_normalize_llm_max_tokens = 512 +oral_normalize_llm_max_input_chars = 8000 +phase1_batch_llm_enabled = true +phase1_batch_llm_max_tokens = 4096 +phase1_batch_llm_chunk_size = 24 +pipeline_run_ttl_seconds = 172800 +extraction_max_tokens = 1024 +classification_max_tokens = 256 +narrative_max_tokens = 4096 +narrative_merge_max_tokens = 8192 +title_max_tokens = 256 +story_route_max_tokens = 1024 +story_batch_plan_max_tokens = 4096 +segment_batch_min_chars = 50 +segment_batch_max_wait_seconds = 60.0 +narrative_immediate_char_threshold = 50 +narrative_batch_min_segments = 3 +narrative_batch_min_chars = 80 +narrative_batch_max_wait_seconds = 120.0 +extraction_updates_current_stage = false +fidelity_fail_open_on_parse_error = false +narrative_evidence_overlap_min_chars = 14 +evidence_scene_anchor_check_enabled = true +title_slots_require_body_or_oral_match = true +title_hay_grounding_strict_phrases_enabled = true +recompose_retry_on_lock_contention = true +phase2_singleflight_immediate = true +route_defer_enabled = true +route_defer_seconds = 120.0 +route_defer_max_attempts = 3 +quality_pass_enabled = true +quality_pass_delay_seconds = 5 +story_route_append_guardrail_oral_chars = 1800 +min_inline_images_for_chapter_cover = 1 +image_poll_interval = 3 +image_max_attempts = 20 +image_provider = "liblib" +image_style_default = "watercolor" +image_size_default = "1280x720" +image_download_hosts = "" +image_prompt_fallback_disabled = false + +[memory] +enrichment_enabled = true +enrichment_max_chars = 12000 +compaction_enabled = true +compaction_debounce_seconds = 105 +compaction_lock_ttl_seconds = 600 +compaction_chunk_similarity_threshold = 0.92 +compaction_min_layers_for_exclude = 2 +compaction_max_chunks_per_run = 200 +compaction_max_excludes_per_run = 50 +compaction_max_neighbors_per_chunk = 25 +compaction_text_jaccard_min = 0.55 +compaction_metadata_event_year_window = 1 +compaction_sweep_recent_hours = 24 + +[story] +image_min_body_chars = 400 +image_enqueue_dedup_ttl = 300 +recompose_chapter_delay_seconds = 8 +chapter_pipeline_lock_ttl_seconds = 360 +append_max_canonical_chars = 12000 +append_max_versions = 20 +route_candidate_body_max_chars = 2200 +route_candidate_total_max_chars = 20000 +route_long_body_head_chars = 700 +route_long_body_tail_chars = 700 +route_summary_min_chars = 30 +route_index_preview_chars = 140 +title_min_body_chars = 60 +evidence_top_k_default = 10 +evidence_top_k_large_batch = 5 +evidence_large_batch_threshold = 3 + +[eval] +judge_base_url = "https://open.bigmodel.cn/api/paas/v4" +judge_model = "glm-5" +judge_temperature = 0.3 +judge_deepseek_model = "deepseek-v4-flash" +judge_deepseek_thinking_enabled = false +judge_deepseek_context_window_tokens = 64000 +judge_context_window_tokens = 200000 +judge_completion_reserve_tokens = 4096 +judge_prompt_budget_safety_tokens = 2048 +judge_approx_tokens_per_char = 1.0 +judge_max_transcript_chars = 0 +judge_max_compare_transcript_chars_each = 0 +judge_compare_prompt_overhead_chars = 10000 +judge_memoir_chapter_concurrency = 4 +judge_memoir_body_max_chars = 36000 +judge_memoir_evidence_max_chars = 32000 +judge_memoir_completion_max_tokens = 3072 +candidate_temperature = 0.7 +gate_protected_regression_threshold = 2.0 +execution_enabled = true +internal_enable_docs = false +internal_cors_origins = "" + +[llm] +deepseek_base_url = "https://api.deepseek.com" +deepseek_model = "deepseek-v4-flash" +deepseek_thinking_enabled = false +temperature = 0.7 +fast_model = "" +embedding_base_url = "https://open.bigmodel.cn/api/paas/v4" +embedding_model = "embedding-3" + +[asr] +provider = "whisper" +model_size = "small" +device = "auto" +compute_type = "auto" +model_cache_dir = "" + +[tts] +provider = "tencent" +voice_type = 501004 +voice_type_en = 501004 +codec = "mp3" + +[redis] +socket_timeout_seconds = 5.0 +socket_connect_timeout_seconds = 2.0 +health_check_interval_seconds = 30 +task_tracker_ttl_seconds = 86400 + +[celery] +memory_enrichment_queue = "memory_idle" +broker_pool_limit = 10 +broker_connection_retry_on_startup = true +memoir_soft_time_limit = 1800 +memoir_hard_time_limit = 2400 +image_soft_time_limit = 600 +image_hard_time_limit = 900 +compaction_sweep_soft_time_limit = 300 +compaction_sweep_hard_time_limit = 600 +enrichment_soft_time_limit = 660 +enrichment_hard_time_limit = 960 + +[alembic] +run_on_startup = true +max_retries = 3 +retry_base_seconds = 1.0 + +[agent_log] +agent_verbose = false +max_chars = 4096 +omit_system_message_body = true +json_prompt_prefix_chars = 0 +json_prompt_prefix_only_if_len_gt = 4000 +prompt_mode = "preview" +prompt_dedup = false +celery_log_level = "" +httpx_log_level = "" +log_json_file = "" + +[otel] +exporter_insecure = true +service_name = "life-echo-api" +metric_export_interval_ms = 10000 + +[misc] +algorithm = "HS256" +redis_session_ttl = 86400 +tencent_sms_template_param_count = 2 +tencent_cos_region = "ap-shanghai" +liblib_base_url = "https://openapi.liblibai.cloud" +alipay_sign_type = "RSA2" +alipay_under_development = "true" diff --git a/api/config/development.toml b/api/config/development.toml new file mode 100644 index 0000000..bcf3492 --- /dev/null +++ b/api/config/development.toml @@ -0,0 +1,21 @@ +# development overlay — merged onto config/default.toml + +[deploy] +mock_sms_login_enabled = true +memoir_image_enabled = true +otel_enabled = true +otel_exporter_otlp_endpoint = "http://localhost:48317" +tencent_sms_sdk_app_id = "1401010099" +tencent_sms_sign_name = "上海华嘎科技有限公司" +tencent_sms_template_id = "2592163" +tencent_cos_bucket = "life-echo-dev-1319381411" +tencent_cos_base_url = "https://life-echo-dev-1319381411.cos.ap-shanghai.myqcloud.com" +wechat_pay_app_id = "wx1df508452e06cfb8" +wechat_pay_mch_id = "1662979099" +wechat_pay_cert_serial_no = "1AA82328AC1456C6F115B014606F22CD621D2032" +wechat_pay_notify_url = "https://lifecho.worldsplats.com/api/payment/notify/wechat" +alipay_notify_url = "https://lifecho.worldsplats.com/api/payment/notify/alipay" +liblib_template_uuid = "5d7e67009b344550bc1aa6ccbfa1d7f4" + +[eval] +internal_enable_docs = true diff --git a/api/config/production.toml b/api/config/production.toml new file mode 100644 index 0000000..49178f2 --- /dev/null +++ b/api/config/production.toml @@ -0,0 +1,18 @@ +# production overlay — merged onto config/default.toml + +[deploy] +alembic_startup_fail_fast = true +memoir_image_enabled = true +tencent_sms_sdk_app_id = "1401010099" +tencent_sms_sign_name = "上海华嘎科技有限公司" +tencent_sms_template_id = "2592163" +tencent_cos_bucket = "life-echo-prod-1319381411" +tencent_cos_base_url = "https://life-echo-prod-1319381411.cos.ap-shanghai.myqcloud.com" +wechat_pay_app_id = "wx1df508452e06cfb8" +wechat_pay_mch_id = "1662979099" +wechat_pay_cert_serial_no = "1AA82328AC1456C6F115B014606F22CD621D2032" +wechat_pay_notify_url = "https://lifecho.worldsplats.com/api/payment/notify/wechat" +alipay_notify_url = "https://lifecho.worldsplats.com/api/payment/notify/alipay" +liblib_template_uuid = "5d7e67009b344550bc1aa6ccbfa1d7f4" +otel_exporter_otlp_endpoint = "http://otel-collector:4317" +api_cors_origins = "https://lifecho.worldsplats.com" diff --git a/api/config/staging.toml b/api/config/staging.toml new file mode 100644 index 0000000..a38bee8 --- /dev/null +++ b/api/config/staging.toml @@ -0,0 +1,19 @@ +# staging overlay — merged onto config/default.toml + +[deploy] +alembic_startup_fail_fast = true +mock_sms_login_enabled = true +memoir_image_enabled = true +tencent_sms_sdk_app_id = "1401010099" +tencent_sms_sign_name = "上海华嘎科技有限公司" +tencent_sms_template_id = "2592163" +tencent_cos_bucket = "life-echo-dev-1319381411" +tencent_cos_base_url = "https://life-echo-dev-1319381411.cos.ap-shanghai.myqcloud.com" +wechat_pay_app_id = "wx1df508452e06cfb8" +wechat_pay_mch_id = "1662979099" +wechat_pay_cert_serial_no = "1AA82328AC1456C6F115B014606F22CD621D2032" +wechat_pay_notify_url = "https://lifecho.worldsplats.com/api/payment/notify/wechat" +alipay_notify_url = "https://lifecho.worldsplats.com/api/payment/notify/alipay" +liblib_template_uuid = "5d7e67009b344550bc1aa6ccbfa1d7f4" +otel_exporter_otlp_endpoint = "http://otel-collector:4317" +api_cors_origins = "https://lifecho.worldsplats.com" diff --git a/api/deploy.sh b/api/deploy.sh index 8ccbaa9..14c77ff 100755 --- a/api/deploy.sh +++ b/api/deploy.sh @@ -79,16 +79,13 @@ main() { exit 1 fi - # 检查必要的环境变量 + # 检查必要的环境变量(密钥与连接串;SMS 模板等见 config/production.toml) REQUIRED_VARS=( "DATABASE_URL" "REDIS_URL" - "JWT_SECRET_KEY" - "TENCENT_SMS_SECRET_ID" - "TENCENT_SMS_SECRET_KEY" - "TENCENT_SMS_SDK_APP_ID" - "TENCENT_SMS_SIGN_NAME" - "TENCENT_SMS_TEMPLATE_ID" + "SECRET_KEY" + "TENCENT_SECRET_ID" + "TENCENT_SECRET_KEY" ) ALL_VARS_SET=true diff --git a/api/development.sh b/api/development.sh index ceb3576..8f45e4a 100755 --- a/api/development.sh +++ b/api/development.sh @@ -25,9 +25,10 @@ API_PORT="${API_PORT:-8000}" CELERY_POOL="${CELERY_POOL:-solo}" SKIP_INSTALL="${SKIP_INSTALL:-0}" SKIP_INFRA="${SKIP_INFRA:-0}" -# 可观测性:空=若 .env 中 OTEL_ENABLED=true 则启动 compose;0=不启;1=强制启动 +# 可观测性:空=读 config/*.toml deploy.otel_enabled(或 .env 中 OTEL_ENABLED 覆盖);0=不启;1=强制启动 START_OBSERVABILITY="${START_OBSERVABILITY:-}" SHUTDOWN_TIMEOUT="${SHUTDOWN_TIMEOUT:-12}" +CELERY_SHUTDOWN_TIMEOUT="${CELERY_SHUTDOWN_TIMEOUT:-25}" # 与 docker-compose.observability.yml / .env.example 默认宿主机端口一致 OTEL_GRPC_HOST_PORT="${OTEL_GRPC_HOST_PORT:-48317}" @@ -171,6 +172,7 @@ stop_process_gracefully() { local name="$1" local pid="$2" local timeout="${3:-10}" + local signal="${4:-TERM}" if ! is_pid_alive "${pid}"; then print_ok "${name} 已退出" @@ -179,13 +181,21 @@ stop_process_gracefully() { print_warn "正在停止 ${name}(PID: ${pid})..." kill_children_term "${pid}" - kill -TERM "${pid}" 2>/dev/null || true + kill "-${signal}" "${pid}" 2>/dev/null || true if wait_pid_exit "${pid}" "${timeout}"; then print_ok "${name} 已停止" return 0 fi + if [[ "${signal}" != "TERM" ]]; then + kill -TERM "${pid}" 2>/dev/null || true + if wait_pid_exit "${pid}" 5; then + print_ok "${name} 已停止" + return 0 + fi + fi + print_warn "${name} 在 ${timeout}s 内未退出,准备强制结束" kill -KILL "${pid}" 2>/dev/null || true wait_pid_exit "${pid}" 3 || true @@ -201,19 +211,19 @@ cleanup() { print_header "正在关闭开发环境" if is_pid_alive "${EVAL_WEB_PID}"; then - stop_process_gracefully "eval-web (Vite)" "${EVAL_WEB_PID}" "${SHUTDOWN_TIMEOUT}" - fi - - if is_pid_alive "${INTERNAL_EVAL_PID}"; then - stop_process_gracefully "Internal Eval API (:${INTERNAL_EVAL_PORT})" "${INTERNAL_EVAL_PID}" "${SHUTDOWN_TIMEOUT}" + stop_process_gracefully "eval-web (Vite)" "${EVAL_WEB_PID}" "${SHUTDOWN_TIMEOUT}" INT fi if is_pid_alive "${API_PID}"; then - stop_process_gracefully "FastAPI" "${API_PID}" "${SHUTDOWN_TIMEOUT}" + stop_process_gracefully "FastAPI" "${API_PID}" "${SHUTDOWN_TIMEOUT}" INT + fi + + if is_pid_alive "${INTERNAL_EVAL_PID}"; then + stop_process_gracefully "Internal Eval API (:${INTERNAL_EVAL_PORT})" "${INTERNAL_EVAL_PID}" "${SHUTDOWN_TIMEOUT}" INT fi if is_pid_alive "${CELERY_PID}"; then - stop_process_gracefully "Celery" "${CELERY_PID}" "${SHUTDOWN_TIMEOUT}" + stop_process_gracefully "Celery" "${CELERY_PID}" "${CELERY_SHUTDOWN_TIMEOUT}" INT fi if [[ "${INFRA_STARTED}" == "1" ]]; then @@ -256,12 +266,73 @@ read_env_bool() { esac } +read_app_env_from_dotenv() { + local app_env="${APP_ENV:-development}" + if [[ -f "${ROOT_DIR}/.env" ]]; then + local env_line + env_line="$(grep -E '^APP_ENV=' "${ROOT_DIR}/.env" | tail -1 | cut -d= -f2- | tr -d '\r' | sed 's/^"//;s/"$//')" + if [[ -n "${env_line}" ]]; then + app_env="${env_line}" + fi + fi + printf '%s\n' "${app_env}" +} + +read_toml_bool_field() { + local section="$1" + local field="$2" + local default="${3:-0}" + local app_env + app_env="$(read_app_env_from_dotenv)" + local enabled + enabled="$( + cd "${ROOT_DIR}" && uv run python -c " +from app.core.app_config_loader import load_app_config +cfg = load_app_config('${app_env}') +print('1' if getattr(getattr(cfg, '${section}'), '${field}') else '0') +" 2>/dev/null | tail -1 + )" + case "${enabled}" in + 1) return 0 ;; + *) [[ "${default}" == "1" ]] ;; + esac +} + +read_deploy_otel_enabled() { + local default="${1:-0}" + + if [[ -n "${OTEL_ENABLED:-}" ]]; then + read_env_bool "OTEL_ENABLED" "${default}" + return + fi + if [[ -f "${ROOT_DIR}/.env" ]] && grep -qE '^OTEL_ENABLED=' "${ROOT_DIR}/.env" 2>/dev/null; then + read_env_bool "OTEL_ENABLED" "${default}" + return + fi + + read_toml_bool_field "deploy" "otel_enabled" "${default}" +} + +read_eval_internal_enable_docs() { + if [[ -n "${INTERNAL_EVAL_ENABLE_DOCS:-}" ]]; then + case "${INTERNAL_EVAL_ENABLE_DOCS}" in + 1 | true | TRUE | yes | YES | on | ON) return 0 ;; + *) return 1 ;; + esac + fi + if [[ -f "${ROOT_DIR}/.env" ]] && grep -qE '^INTERNAL_EVAL_ENABLE_DOCS=' "${ROOT_DIR}/.env" 2>/dev/null; then + read_env_bool "INTERNAL_EVAL_ENABLE_DOCS" "0" + return + fi + read_toml_bool_field "eval" "internal_enable_docs" "0" +} + should_start_observability() { case "${START_OBSERVABILITY}" in 0 | false | FALSE | no | NO | off | OFF) return 1 ;; 1 | true | TRUE | yes | YES | on | ON) return 0 ;; esac - read_env_bool "OTEL_ENABLED" "0" + read_deploy_otel_enabled "0" } docker_compose_cmd() { @@ -294,7 +365,7 @@ wait_otel_collector_ready() { } check_otel_collector_ready() { - if ! read_env_bool "OTEL_ENABLED" "0"; then + if ! read_deploy_otel_enabled "0"; then return 0 fi if is_port_listening "${OTEL_GRPC_HOST_PORT}"; then @@ -308,10 +379,10 @@ check_otel_collector_ready() { return 0 fi fi - print_warn "OTEL_ENABLED=true 但 :${OTEL_GRPC_HOST_PORT} 未监听" + print_warn "deploy.otel_enabled=true 但 :${OTEL_GRPC_HOST_PORT} 未监听" print_warn "请确认本次启动日志中有「启动可观测性栈」;或手动执行:" print_warn " docker compose -f docker-compose.dev.yml -f docker-compose.observability.yml up -d" - print_warn "不需要可观测性时在 .env.development 设 OTEL_ENABLED=false" + print_warn "不需要可观测性时在 config/development.toml 设 otel_enabled=false" return 1 } @@ -328,7 +399,7 @@ start_infra() { if [[ "${OBSERVABILITY_STARTED}" == "1" ]]; then print_ok "Grafana http://127.0.0.1:${GRAFANA_HOST_PORT} (admin/admin)" print_ok "Prometheus http://127.0.0.1:${PROMETHEUS_HOST_PORT}" - print_ok "OTLP gRPC 127.0.0.1:${OTEL_GRPC_HOST_PORT}(应用读 .env 中 OTEL_*,无需 export)" + print_ok "OTLP gRPC 127.0.0.1:${OTEL_GRPC_HOST_PORT}(应用读 config/*.toml deploy.otel_*)" print_ok "详见 docs/observability.md" schedule_observability_browser fi @@ -690,7 +761,7 @@ start_internal_eval_http() { echo "评测 Web UI: http://127.0.0.1:${EVAL_WEB_PORT}/ (Vite /internal → :${INTERNAL_EVAL_PORT})" echo "内部评测 API: http://127.0.0.1:${INTERNAL_EVAL_PORT}/health" echo "评测 REST: http://127.0.0.1:${INTERNAL_EVAL_PORT}/internal/api/evaluation" - if [[ "${INTERNAL_EVAL_ENABLE_DOCS:-}" == "1" ]] || grep -qE '^INTERNAL_EVAL_ENABLE_DOCS=true' "${ROOT_DIR}/.env" 2>/dev/null; then + if read_eval_internal_enable_docs; then echo "内部评测文档: http://127.0.0.1:${INTERNAL_EVAL_PORT}/docs" fi echo "说明: api/docs/internal-eval.md" @@ -765,7 +836,7 @@ start_services() { echo "评测 Web UI: http://127.0.0.1:${EVAL_WEB_PORT}/" echo "内部评测 API: http://127.0.0.1:${INTERNAL_EVAL_PORT}/health" fi - if read_env_bool "OTEL_ENABLED" "0"; then + if read_deploy_otel_enabled "0"; then echo "可观测性: Grafana http://127.0.0.1:${GRAFANA_HOST_PORT} | Prometheus http://127.0.0.1:${PROMETHEUS_HOST_PORT}" if is_port_listening "${GRAFANA_HOST_PORT}"; then schedule_observability_browser @@ -792,7 +863,7 @@ main() { trap cleanup EXIT INT TERM ensure_venv - # 必须在 start_infra 之前同步,否则 should_start_observability 读不到 .env.development 里的 OTEL_ENABLED + # 必须在 start_infra 之前同步 .env,以便 read_deploy_otel_enabled 读到 APP_ENV ensure_dotenv_from_development if [[ "${SKIP_INFRA}" != "1" ]]; then diff --git a/api/docker-compose.dev.yml b/api/docker-compose.dev.yml index c474fcc..0611162 100644 --- a/api/docker-compose.dev.yml +++ b/api/docker-compose.dev.yml @@ -31,12 +31,20 @@ services: container_name: life-echo-redis-dev ports: - "48307:6379" + environment: + REDIS_PASSWORD: ${REDIS_PASSWORD:-} volumes: - redis_data_dev:/data - command: redis-server --appendonly yes + command: > + sh -c 'exec redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru + $${REDIS_PASSWORD:+--requirepass "$$REDIS_PASSWORD"}' restart: unless-stopped healthcheck: - test: ["CMD", "redis-cli", "ping"] + test: + [ + "CMD-SHELL", + 'if [ -n "$$REDIS_PASSWORD" ]; then redis-cli -a "$$REDIS_PASSWORD" ping | grep -q PONG; else redis-cli ping | grep -q PONG; fi', + ] interval: 10s timeout: 5s retries: 5 diff --git a/api/docker-compose.yml b/api/docker-compose.yml index a678f32..f69b16c 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -25,18 +25,26 @@ services: max-size: "10m" max-file: "3" - # Redis 服务(用于会话存储和 Celery 消息队列) + # Redis 服务(业务 key DB/0;Celery broker/backend 由应用自动使用 DB/1) redis: image: m.daocloud.io/docker.io/library/redis:7-alpine container_name: life-echo-redis # ports: # - "6379:6379" # 不暴露到宿主机,仅在 Docker 网络内部访问 + environment: + REDIS_PASSWORD: ${REDIS_PASSWORD:-} volumes: - redis_data:/data - command: redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru + command: > + sh -c 'exec redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru + $${REDIS_PASSWORD:+--requirepass "$$REDIS_PASSWORD"}' restart: always healthcheck: - test: ["CMD", "redis-cli", "ping"] + test: + [ + "CMD-SHELL", + 'if [ -n "$$REDIS_PASSWORD" ]; then redis-cli -a "$$REDIS_PASSWORD" ping | grep -q PONG; else redis-cli ping | grep -q PONG; fi', + ] interval: 10s timeout: 5s retries: 5 @@ -64,8 +72,10 @@ services: - .env environment: - ASR_MODEL_CACHE_DIR=/app/models/whisper - - ALEMBIC_STARTUP_FAIL_FAST=true - APP_ENV=${APP_ENV:-production} + - REDIS_URL=redis://redis:6379/0 + - CELERY_REDIS_URL=redis://redis:6379/1 + - REDIS_PASSWORD=${REDIS_PASSWORD:-} volumes: - /root/apiclient_key.pem:/app/certs/apiclient_key.pem:ro restart: always @@ -100,6 +110,9 @@ services: - .env environment: - APP_ENV=${APP_ENV:-production} + - REDIS_URL=redis://redis:6379/0 + - CELERY_REDIS_URL=redis://redis:6379/1 + - REDIS_PASSWORD=${REDIS_PASSWORD:-} restart: always depends_on: postgres: @@ -133,6 +146,9 @@ services: - .env environment: - APP_ENV=${APP_ENV:-production} + - REDIS_URL=redis://redis:6379/0 + - CELERY_REDIS_URL=redis://redis:6379/1 + - REDIS_PASSWORD=${REDIS_PASSWORD:-} restart: always depends_on: postgres: @@ -149,26 +165,39 @@ services: max-size: "10m" max-file: "3" - # Flower(Celery 监控面板,可选) - # flower: - # build: - # context: . - # dockerfile: Dockerfile - # image: life-echo-api:latest - # container_name: life-echo-flower - # command: celery -A app.tasks.celery_app flower --port=5555 - # ports: - # - "5555:5555" - # env_file: - # - .env - # environment: - # - REDIS_URL=redis://redis:6379/0 - # restart: always - # depends_on: - # redis: - # condition: service_healthy - # networks: - # - life-echo-network + flower: + build: + context: . + dockerfile: Dockerfile + image: life-echo-api:latest + container_name: life-echo-flower + command: > + sh -c 'uv run celery -A app.tasks.celery_app flower --port=5555 + --basic_auth=$${FLOWER_USER:-admin}:$${FLOWER_PASSWORD:-changeme}' + ports: + - "127.0.0.1:${FLOWER_HOST_PORT:-5555}:5555" + env_file: + - .env + environment: + - APP_ENV=${APP_ENV:-production} + - REDIS_URL=redis://redis:6379/0 + - CELERY_REDIS_URL=redis://redis:6379/1 + - REDIS_PASSWORD=${REDIS_PASSWORD:-} + - FLOWER_USER=${FLOWER_USER:-admin} + - FLOWER_PASSWORD=${FLOWER_PASSWORD:-changeme} + restart: always + depends_on: + redis: + condition: service_healthy + celery-worker: + condition: service_started + networks: + - life-echo-network + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" networks: life-echo-network: diff --git a/api/docs/configuration.md b/api/docs/configuration.md new file mode 100644 index 0000000..cc2a146 --- /dev/null +++ b/api/docs/configuration.md @@ -0,0 +1,111 @@ +# 配置 SSOT(TOML + .env) + +Life Echo API 配置分为两层,避免 env 膨胀,同时保留密钥安全边界。 + +## 两层 SSOT + +| 层 | 来源 | 改什么 | 如何生效 | +|----|------|--------|----------| +| **Secrets / bootstrap** | [`.env.example`](../.env.example) → `.env.development` / staging / production | `DATABASE_URL`、`SECRET_KEY`、API 密钥、支付/Liblib 私钥 | 改 env,重启进程 | +| **非密钥配置** | [`config/default.toml`](../config/default.toml) + `config/{APP_ENV}.toml` | 功能开关、token 过期、SMS 模板 ID、Chat/Memoir 调参、OTel 等 | 改 TOML,**发版**(随镜像/代码部署) | + +`APP_ENV` 必须在 `.env` 中设置,用于选择 overlay:`development` / `staging` / `production`。 + +加载逻辑:[`app/core/app_config_loader.py`](../app/core/app_config_loader.py) 深合并 `default.toml` 与 `{APP_ENV}.toml`,经 Pydantic 校验后暴露为 [`app_config`](../app/core/app_config.py)。 + +## 文件说明 + +``` +api/config/ + default.toml # 全量默认值(所有 section) + development.toml # 本地差异(OTel、mock 登录、dev COS 等) + staging.toml # 预发差异(CORS、SMS/COS 业务 ID 等) + production.toml # 生产差异 +``` + +TOML 顶层 section 与代码模块对应: + +| Section | 用途 | +|---------|------| +| `[deploy]` | 原 Settings 非密钥项(开关、SMS 模板、CORS、OTel endpoint 等) | +| `[chat]` | 访谈 / 对话 | +| `[memoir]` | 回忆录流水线 | +| `[memory]` | Memory 富化 / compaction | +| `[story]` | Story / 章节 | +| `[eval]` | 内网评测 / judge | +| `[llm]` `[asr]` `[tts]` `[celery]` `[alembic]` `[agent_log]` `[otel]` `[misc]` | 运行时默认 | + +业务代码仍可通过原有 import 读取(薄 re-export): + +- `from app.features.conversation.constants import chat` +- `from app.core.runtime_constants import llm_defaults` +- `settings.enable_tts` 等 deploy 字段通过 [`SettingsFacade`](../app/core/config.py) 代理到 TOML + +## 常见操作 + +**改访谈温度** + +```toml +# config/default.toml 或 config/staging.toml +[chat] +interview_temperature = 0.7 +``` + +**改生产 CORS** + +```toml +# config/production.toml +[deploy] +api_cors_origins = "https://your-domain.com" +``` + +**改 DeepSeek API Key** — 仍在 `.env`: + +```env +DEEPSEEK_API_KEY=sk-... +``` + +## 旧 env 键对照(节选) + +| 旧 env 键 | 新位置 | +|-----------|--------| +| `CHAT_INTERVIEW_PERSONA` | `[chat] interview_persona` | +| `CHAT_INTERVIEW_TEMPERATURE` | `[chat] interview_temperature` | +| `MEMOIR_ORAL_NORMALIZE_MODE` | `[memoir] oral_normalize_mode` | +| `MEMORY_COMPACTION_ENABLED` | `[memory] compaction_enabled` | +| `EVAL_JUDGE_MODEL` | `[eval] judge_model` | +| `DEEPSEEK_BASE_URL` | `[llm] deepseek_base_url` | +| `ENABLE_TTS` | `[deploy] enable_tts` | +| `TENCENT_SMS_SDK_APP_ID` | `[deploy] tencent_sms_sdk_app_id` | +| `OTEL_ENABLED` | `[deploy] otel_enabled` | +| `SECRET_KEY` | 仍在 `.env` | + +## 测试 + +- 设置 `CONFIG_DIR` 指向 fixture 目录可隔离加载:见 [`tests/test_app_config_loader.py`](../tests/test_app_config_loader.py) +- `tests/test_settings_allowlist.py` 防止 Settings 字段反弹 + +## Docker + +`config/` 随 `COPY . .` 打入镜像;compose 仍 `env_file: .env` 注入密钥。确保容器内 `APP_ENV` 与目标 overlay 文件名一致。 + +### Redis / Celery 分离 + +| 用途 | 默认 DB | 环境变量 | +|------|---------|----------| +| 会话、对话历史、task tracker 等业务 key | DB/0 | `REDIS_URL` | +| Celery broker + result backend | DB/1 | `CELERY_REDIS_URL`(compose 显式注入;未设时由 `REDIS_URL` 自动 +1) | + +**升级注意:** 启用 DB 分离后,旧版在 DB/0 上未消费的 Celery 消息会被丢弃(一次性 cutover,无需迁移脚本)。 + +若 `REDIS_URL` 使用 DB/15,必须显式设置 `CELERY_REDIS_URL`(Redis 仅支持 logical DB 0–15,无法 auto +1)。 + +## Shell 脚本与 TOML + +| 脚本 | 仍读 `.env` | 已改读 TOML | +|------|-------------|-------------| +| `development.sh` | `APP_ENV`、密钥、`OTEL_ENABLED`(legacy 覆盖) | `deploy.otel_enabled` → 是否起 Grafana 栈;`eval.internal_enable_docs` → 是否打印 `/docs` 链接 | +| `deploy.sh` | 密钥、`DATABASE_URL` 等 | SMS 模板等不再检查 env(在 `config/production.toml`) | +| `verify_observability_metrics.sh` | — | 提示文案指向 `deploy.otel_enabled` | + +**勿再依赖** `.env` 中的 `OTEL_ENABLED`、`MOCK_SMS_LOGIN_ENABLED`、`INTERNAL_EVAL_ENABLE_DOCS` 等(应用运行时已从 TOML 读取);测试里 `os.environ["OTEL_ENABLED"]` 亦已无效,见 `tests/conftest.py`。 diff --git a/api/docs/internal-eval.md b/api/docs/internal-eval.md index d6824e1..81a576c 100644 --- a/api/docs/internal-eval.md +++ b/api/docs/internal-eval.md @@ -4,7 +4,7 @@ ## 启动 -**推荐一条命令**:`./development.sh` 默认启动主站(**8000**)、Celery、内部评测 API(默认 **7999**)、评测 Web(**5174**);`.env` 中 `OTEL_ENABLED=true` 时并起 Grafana 且自动打开浏览器。`./internal-eval.sh` 仅为兼容转发。 +**推荐一条命令**:`./development.sh` 默认启动主站(**8000**)、Celery、内部评测 API(默认 **7999**)、评测 Web(**5174**);`config/development.toml` 中 `deploy.otel_enabled=true` 时并起 Grafana 且自动打开浏览器。`./internal-eval.sh` 仅为兼容转发。 | | `./development.sh`(默认) | |---|-------------------------------| @@ -34,7 +34,7 @@ SKIP_INFRA=1 SKIP_INSTALL=1 EVAL_ATTACH_ONLY=1 ./development.sh ```bash cd api export INTERNAL_EVAL_API_KEY='your-long-random-secret' -export INTERNAL_EVAL_ENABLE_DOCS=1 # 可选,开 /docs +export INTERNAL_EVAL_ENABLE_DOCS=1 # 可选 legacy;推荐 config/development.toml → [eval] internal_enable_docs = true # 评测评审(Playground / Memoir 手动的对话与成稿打分) # 智谱:默认 EVAL_JUDGE_API_KEY,否则回退 ZHIPU_API_KEY export EVAL_JUDGE_API_KEY='...' # 可选 @@ -83,7 +83,7 @@ VITE_EVAL_API_BASE=http://127.0.0.1:8001 VITE_EVAL_API_KEY=与上同 npm run dev ### Mock 登录(仅非 production) -在主站 `.env` / `.env.development` 中设置 **`MOCK_SMS_LOGIN_ENABLED=1`**(或 `true`)。`APP_ENV=production` 时 **`POST /api/auth/mock/sms-login` 始终返回 404**。请求体:`phone`(11 位)、`agreed_to_terms: true`,可选 `nickname`(新用户);响应与正式短信登录相同(`access_token` + `refresh_token`)。**切勿在生产环境开启。** +在 **`config/development.toml`** 的 `[deploy]` 中设 **`mock_sms_login_enabled = true`**(legacy 仍可在 `.env` 设 `MOCK_SMS_LOGIN_ENABLED=1`)。`APP_ENV=production` 时 **`POST /api/auth/mock/sms-login` 始终返回 404**。请求体:`phone`(11 位)、`agreed_to_terms: true`,可选 `nickname`(新用户);响应与正式短信登录相同(`access_token` + `refresh_token`)。**切勿在生产环境开启。** ## 真实链路透传回放(与 App 一致) diff --git a/api/docs/observability.md b/api/docs/observability.md index 4c67c29..d62345a 100644 --- a/api/docs/observability.md +++ b/api/docs/observability.md @@ -2,7 +2,7 @@ 本地开发使用 **OpenTelemetry** 采集 traces / metrics / logs,经 **OTel Collector** 写入 **Tempo / Prometheus / Loki**,在 **Grafana** 统一查看。 -配置写在 **`.env`**(由 `.env.development` 经 `development.sh` 同步,或从 [`.env.example`](../.env.example) 复制),`app.core.config.settings` 启动时自动读取,**无需**在 shell 里 `export OTEL_*`。 +配置 SSOT:**[`config/default.toml`](../config/default.toml)** + **`config/{APP_ENV}.toml`** 的 `[deploy]` section(`otel_enabled`、`otel_exporter_otlp_endpoint`)。密钥与 `APP_ENV` 仍在 `.env`。 ## 启动栈 @@ -37,103 +37,45 @@ docker compose -f docker-compose.dev.yml -f docker-compose.observability.yml up ## 启用应用导出 -在 [`.env.example`](../.env.example) 已给出本地默认值,同步到 `.env` 即可,例如: +在 **`config/development.toml`**(或对应环境 overlay)中: -```env -OTEL_ENABLED=true -OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:48317 -OTEL_TRACES_SAMPLER=always_on -OTEL_SERVICE_NAME=life-echo-api +```toml +[deploy] +otel_enabled = true +otel_exporter_otlp_endpoint = "http://localhost:48317" ``` -推荐与全栈一并启动(`./development.sh` 在 `.env` 里 `OTEL_ENABLED=true` 时会起 observability compose,并默认打开 Grafana 浏览器标签): +推荐与全栈一并启动(`./development.sh` 在 TOML 中 `deploy.otel_enabled=true` 时会起 observability compose,并默认打开 Grafana 浏览器标签): ```bash cd api ./development.sh ``` -仅手动起 API(不自动开 Grafana): +若 API 跑在 **Docker compose** 里,应设 `otel_exporter_otlp_endpoint = "http://otel-collector:4317"`(服务名 + 容器内端口),见 `config/staging.toml` / `config/production.toml`。 -```bash -cd api -uv run uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 -``` +不需要可观测性时:在对应 `config/*.toml` 设 `otel_enabled = false`(或未启动 observability compose)。 -Celery worker 同一 `.env`;未设 `OTEL_SERVICE_NAME` 时 worker 默认为 `life-echo-celery-worker`。 +legacy:仍可在 `.env` 设 `OTEL_ENABLED=true/false` 覆盖 TOML,供 `development.sh` 决策是否启动 compose。 -若 API 跑在 **Docker compose** 里,应设 `OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317`(服务名 + 容器内端口),而不是 `localhost`。 +## 采样与其它 OTel 项 -不需要可观测性时:`.env` 中 `OTEL_ENABLED=false`(或未启动 observability compose)。 +采样策略等在 `[otel]` section 与 `OtelConfig.traces_sampler()`(按 `APP_ENV` 推导),见 [`docs/configuration.md`](configuration.md)。 -## 采集内容 +## 排查 -| 类型 | 来源 | -|------|------| -| HTTP | FastAPI 自动 instrumentation(`/health` 排除) | -| DB | SQLAlchemy | -| Redis | redis-py | -| 出站 HTTP | httpx(DeepSeek 等) | -| Celery | 任务 span + W3C trace 传播 | -| LLM | `llm_telemetry`(LangChain / DeepSeek / `llm_call`)+ `llm.call.*` / `llm.tokens.*` metrics | -| 业务 | `business_telemetry`:WS 回合、回忆录 phase、ASR/TTS、支付等子 span | -| 日志 | loguru patcher 注入 `trace_id`;Promtail 解析 `event` / `tid=`;可选 `LOG_JSON_FILE` JSON sink | +1. **Grafana 无数据**:确认 `./development.sh` 日志含「启动可观测性栈」,或手动 `docker compose -f docker-compose.dev.yml -f docker-compose.observability.yml up -d`。 +2. **应用报 UNAVAILABLE localhost:48317**:`deploy.otel_enabled=true` 但 Collector 未起 — 与 Grafana 问题同源。 +3. **Prometheus 无 LLM 指标**: `./scripts/verify_observability_metrics.sh`(需 observability compose + 有流量)。 +4. **Collector health**:`http://127.0.0.1:48333`;Prometheus target `otel-collector:8889` 应为 UP。 +5. **Celery 任务失败/延迟**:Grafana → **Life Echo Business**(`memory.compaction.*`、`memoir` 等业务 span);生产栈另可开 Flower `http://127.0.0.1:5555`(需 `FLOWER_PASSWORD`)。 +6. **关闭 telemetry**:`config/*.toml` → `[deploy] otel_enabled = false`。 -日志字段:`request_id`、`trace_id`、`span_id`。HTTP 由中间件 `contextualize`;**Celery / 后台**由 loguru **patcher** 从当前 OTel span 合并,无需经过 HTTP 中间件。 +## Checklist(本地) -## 常用排查 +- [ ] `config/development.toml` 中 `deploy.otel_enabled=true` +- [ ] `./development.sh` 或 observability compose 已启动 +- [ ] API + Celery worker 在跑并产生请求 +- [ ] Grafana http://127.0.0.1:48300 可打开 -1. **API 慢**:Grafana → Tempo,按 `service.name=life-echo-api` 查 trace;看 DB / httpx / `llm.*` / `conversation.ws.*` 子 span。 -2. **LLM 慢**:**Life Echo LLM** Dashboard,或 Loki:`{compose_service=~".+"} |= "event=llm_json_call"`。 -3. **回忆录卡阶段**:Tempo 搜 `memoir.phase1` / `memoir.phase2` / `memoir.story_pipeline.*`;**Life Echo Business** Dashboard 看 `business_operation_duration_milliseconds`。 -4. **日志 ↔ Trace**:在 Tempo 复制 `trace_id` → Loki:`{compose_service=~".+"} |= "tid=<前12位>"`(控制台短格式);Promtail 将 `trace_id` 写入 **structured metadata**(非高基数 label)。 -5. **Celery 堆积**:Tempo 过滤 `life-echo-celery-worker`;Loki `event=celery_task_failed`。 -6. **无数据**:`.env` 中 `OTEL_ENABLED=true`、`OTEL_EXPORTER_OTLP_ENDPOINT` 端口与 `OTEL_GRPC_HOST_PORT` 一致;Collector health `http://127.0.0.1:48333`;Prometheus target `otel-collector:8889` UP。 - -### LOG_JSON_FILE 与 Promtail - -- **默认**:loguru 人类可读行 → Docker stdout → Promtail **regex** 提取 `tid` / `event` / `duration_ms`;`trace_id` 进 structured metadata,**不作为 Loki label**。 -- **可选**:`LOG_JSON_FILE=/path/to/app.jsonl` 开启 JSON sink(`serialize=true`),便于与 OTLP logs 或自建采集对齐;与 Promtail 可**并存**(同一容器 stdout 仍走 regex)。 - -## 采样(staging/prod 第二阶段) - -| 环境 | 建议 | -|------|------| -| development | `OTEL_TRACES_SAMPLER=always_on` | -| staging/production | `OTEL_TRACES_SAMPLER=parentbased_traceidratio`,`OTEL_TRACES_SAMPLER_ARG=0.1` | - -关闭 telemetry:`OTEL_ENABLED=false`,无 exporter 开销。 - -## Prometheus 指标名(OTel → Prometheus) - -| OTel 仪器 | Prometheus 系列(histogram) | -|-----------|------------------------------| -| `llm.call.duration` (ms) | `llm_call_duration_milliseconds_bucket` | -| `business.operation.duration` (ms) | `business_operation_duration_milliseconds_bucket` | -| `http.server.request.duration` (s) | `http_server_request_duration_seconds_bucket` | -| `db.client.operation.duration` (s) | `db_client_operation_duration_seconds_bucket` | -| `http.client.request.duration` (s) | `http_client_request_duration_seconds_bucket` | - -Counter 示例:`llm_call_total`、`llm_tokens_input_total`。 - -校验脚本(需 observability compose + 有流量): - -```bash -chmod +x scripts/verify_observability_metrics.sh -./scripts/verify_observability_metrics.sh -``` - -## 验收清单(本地 E2E) - -- [ ] `OTEL_ENABLED=true`,启动 compose + API + Celery worker -- [ ] 跑一条 WS 对话;Tempo 可见 `conversation.ws.process_turn`、`llm.chat_invoke` -- [ ] 触发 memoir phase1;Tempo 可见 `memoir.phase1.*`、`memoir.story_pipeline.*` -- [ ] Prometheus:`call_type` label 存在;真实 LLM 后 `llm_tokens_input_total` > 0 -- [ ] Loki:`|= "tid="` 能查到同次请求日志 -- [ ] `./scripts/verify_observability_metrics.sh` 通过 -- [ ] Grafana Alerting 页无 provisioning 错误(通知渠道可空) - -## 配置目录 - -- [`deploy/observability/`](../deploy/observability/):Collector、Tempo、Loki、Prometheus、Grafana provisioning -- [`docker-compose.observability.yml`](../docker-compose.observability.yml):本地 overlay +更多配置分层说明见 [`configuration.md`](configuration.md)。 diff --git a/api/docs/本地开发环境配置.md b/api/docs/本地开发环境配置.md index 2082b89..734c173 100644 --- a/api/docs/本地开发环境配置.md +++ b/api/docs/本地开发环境配置.md @@ -66,9 +66,9 @@ DEEPSEEK_BASE_URL=https://api.deepseek.com # LLM_MODEL=gpt-4 # LLM_BASE_URL=https://api.openai.com -# Redis 配置(宿主 48307,见 docker-compose.dev.yml) +# Redis 配置(宿主 48307,见 docker-compose.dev.yml;业务 DB/0,Celery 自动 DB/1) REDIS_URL=redis://localhost:48307/0 -REDIS_SESSION_TTL=86400 # 会话过期时间(秒),默认 24 小时 +# 会话 TTL 见 config/default.toml [misc] redis_session_ttl(默认 86400 秒),非 .env 项 # 数据库配置(宿主 48291,见 docker-compose.dev.yml) DATABASE_URL=postgresql://postgres:postgres@localhost:48291/life_echo @@ -163,11 +163,11 @@ docker compose up -d --scale celery-worker=3 ### 监控(可选) -启用 Flower 监控面板: +生产 `docker-compose.yml` 已包含 Flower(仅绑定 `127.0.0.1:5555`,需设置 `FLOWER_USER` / `FLOWER_PASSWORD`): -1. 编辑 `docker-compose.yml`,取消 `flower` 服务的注释 -2. 重启服务:`docker compose up -d` -3. 访问 http://localhost:5555 查看 Celery 任务监控 +1. 在 `.env` 中设置 `FLOWER_PASSWORD`(及可选 `FLOWER_USER`) +2. 启动:`docker compose up -d flower` +3. 访问 http://127.0.0.1:5555 查看 Celery 任务监控 ## 常见问题 @@ -214,8 +214,8 @@ Redis 连接失败: Error connecting to redis://localhost:48307/0 # Celery 并发数(根据 CPU 核心数调整) # 在启动命令中配置:--concurrency=4 -# 会话 TTL(根据业务需求调整) -REDIS_SESSION_TTL=86400 +# 会话 TTL:config/default.toml [misc] redis_session_ttl(默认 86400) +# TaskTracker TTL:config/default.toml [redis] task_tracker_ttl_seconds ``` ## API 端点 diff --git a/api/docs/部署指南.md b/api/docs/部署指南.md index 75f3a1d..5a66d79 100644 --- a/api/docs/部署指南.md +++ b/api/docs/部署指南.md @@ -61,44 +61,60 @@ ### 2. 环境变量配置 +配置分两层,详见 [configuration.md](configuration.md): + +- **`.env`(secrets + bootstrap)**:数据库/Redis 连接串、`SECRET_KEY`、腾讯云 API 密钥等 +- **`config/{APP_ENV}.toml`(非密钥)**:SMS 应用 ID/签名/模板、JWT 过期时间、CORS 等 + #### 2.1 编辑生产环境配置文件 -编辑 `api/.env.production`: +编辑 `api/.env.production`(密钥与连接串): ```bash -# 数据库配置 +APP_ENV=production DATABASE_URL=postgresql://username:password@host:5432/life_echo -# Redis配置 -REDIS_URL=redis://localhost:6379/0 +# Redis(业务 DB/0;Celery broker/backend 使用 DB/1) +REDIS_URL=redis://redis:6379/0 +REDIS_PASSWORD=your_redis_password_here +# 生产 compose 已显式注入;本地或自定义部署请设: +CELERY_REDIS_URL=redis://:your_redis_password_here@redis:6379/1 +# 升级至 DB 分离后,DB/0 上未消费的 Celery 队列会被丢弃(一次性 cutover)。 +# REDIS_URL 使用 DB/15 时必须显式设置 CELERY_REDIS_URL(无法 auto DB+1)。 -# JWT配置 -JWT_SECRET_KEY=your_jwt_secret_key_here -JWT_ALGORITHM=HS256 -ACCESS_TOKEN_EXPIRE_MINUTES=30 -REFRESH_TOKEN_EXPIRE_DAYS=7 +# Auth(生产务必:openssl rand -hex 32) +SECRET_KEY=your_strong_random_secret_here -# 腾讯云短信服务配置 -TENCENT_SMS_SECRET_ID=your_secret_id_here -TENCENT_SMS_SECRET_KEY=your_secret_key_here -TENCENT_SMS_SDK_APP_ID=your_app_id_here -TENCENT_SMS_SIGN_NAME=人生回响 -# 统一使用一个短信模板ID(所有场景共用) -TENCENT_SMS_TEMPLATE_ID=your_template_id_here +# 腾讯云 API 密钥(SMS / ASR / TTS / COS 共用) +TENCENT_SECRET_ID=your_secret_id_here +TENCENT_SECRET_KEY=your_secret_key_here +``` -# 其他配置 -CORS_ORIGINS=["https://your-domain.com"] +编辑 `api/config/production.toml`(SMS 业务 ID 等非密钥项): + +```toml +[deploy] +tencent_sms_sdk_app_id = "1400xxxxxx" +tencent_sms_sign_name = "人生回响" +tencent_sms_template_id = "123456" +api_cors_origins = "https://your-domain.com" +access_token_expire_minutes = 30 +refresh_token_expire_days = 7 ``` #### 2.2 配置说明 -| 配置项 | 说明 | 示例 | -|--------|------|------| -| `TENCENT_SMS_SECRET_ID` | 腾讯云API密钥ID | `AKIDxxxxxxxxxxxxx` | -| `TENCENT_SMS_SECRET_KEY` | 腾讯云API密钥Key | `xxxxxxxxxxxxxxxx` | -| `TENCENT_SMS_SDK_APP_ID` | 短信应用ID | `1400xxxxxx` | -| `TENCENT_SMS_SIGN_NAME` | 短信签名(不含【】) | `人生回响` | -| `TENCENT_SMS_TEMPLATE_ID` | 短信模板ID(所有场景共用) | `123456` | +| 配置项 | 层级 | 说明 | 示例 | +|--------|------|------|------| +| `SECRET_KEY` | `.env` | JWT 签名密钥 | `openssl rand -hex 32` 输出 | +| `TENCENT_SECRET_ID` | `.env` | 腾讯云 API 密钥 ID | `AKIDxxxxxxxxxxxxx` | +| `TENCENT_SECRET_KEY` | `.env` | 腾讯云 API 密钥 Key | `xxxxxxxxxxxxxxxx` | +| `tencent_sms_sdk_app_id` | `config/production.toml` | 短信应用 ID | `1400xxxxxx` | +| `tencent_sms_sign_name` | `config/production.toml` | 短信签名(不含【】) | `人生回响` | +| `tencent_sms_template_id` | `config/production.toml` | 短信模板 ID(所有场景共用) | `123456` | +| `api_cors_origins` | `config/production.toml` | 浏览器跨域允许来源(逗号分隔) | `https://lifecho.worldsplats.com` | +| `access_token_expire_minutes` | `config/production.toml` | Access token 有效期(分钟) | `30` | +| `refresh_token_expire_days` | `config/production.toml` | Refresh token 有效期(天) | `7` | ### 3. 数据库迁移 @@ -171,15 +187,16 @@ services: build: . ports: - "8000:8000" + env_file: + - .env environment: + - APP_ENV=production - DATABASE_URL=${DATABASE_URL} - REDIS_URL=${REDIS_URL} - - JWT_SECRET_KEY=${JWT_SECRET_KEY} - - TENCENT_SMS_SECRET_ID=${TENCENT_SMS_SECRET_ID} - - TENCENT_SMS_SECRET_KEY=${TENCENT_SMS_SECRET_KEY} - - TENCENT_SMS_SDK_APP_ID=${TENCENT_SMS_SDK_APP_ID} - - TENCENT_SMS_SIGN_NAME=${TENCENT_SMS_SIGN_NAME} - - TENCENT_SMS_TEMPLATE_ID=${TENCENT_SMS_TEMPLATE_ID} + - REDIS_PASSWORD=${REDIS_PASSWORD} + - SECRET_KEY=${SECRET_KEY} + - TENCENT_SECRET_ID=${TENCENT_SECRET_ID} + - TENCENT_SECRET_KEY=${TENCENT_SECRET_KEY} depends_on: - postgres - redis @@ -251,11 +268,13 @@ sudo systemctl status life-echo-api | Secret名称 | 说明 | |-----------|------| -| `TENCENT_SMS_SECRET_ID` | 腾讯云API密钥ID | -| `TENCENT_SMS_SECRET_KEY` | 腾讯云API密钥Key | -| `TENCENT_SMS_SDK_APP_ID` | 短信应用ID | -| `TENCENT_SMS_SIGN_NAME` | 短信签名 | -| `TENCENT_SMS_TEMPLATE_ID` | 短信模板ID(所有场景共用) | +| `SECRET_KEY` | JWT 签名密钥 | +| `TENCENT_SECRET_ID` | 腾讯云 API 密钥 ID | +| `TENCENT_SECRET_KEY` | 腾讯云 API 密钥 Key | +| `DATABASE_URL` | 数据库连接串 | +| `REDIS_PASSWORD` | Redis 密码(若启用) | + +SMS 应用 ID、签名、模板 ID 等非密钥项在 `config/production.toml` 的 `[deploy]` 中维护,随代码发版。 参考文档:`github-actions-secrets.md` @@ -403,10 +422,10 @@ gunzip < backup_file.sql.gz | psql -U postgres -d life_echo **检查步骤:** -1. 验证环境变量配置 +1. 验证环境变量与 TOML 配置 ```bash -echo $TENCENT_SMS_SECRET_ID -echo $TENCENT_SMS_SDK_APP_ID +echo $TENCENT_SECRET_ID +# SMS 业务 ID 见 config/production.toml [deploy] tencent_sms_sdk_app_id ``` 2. 检查腾讯云账户余额 diff --git a/api/pyproject.toml b/api/pyproject.toml index 95d2598..7a12f4e 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "cos-python-sdk-v5>=1.9.41", "fastapi[standard]>=0.135.1", "faster-whisper>=1.2.1", + "flower>=2.0.1", "greenlet>=3.3.2", "httpx>=0.28.1", "langchain>=1.2.12", @@ -73,6 +74,9 @@ quote-style = "double" indent-style = "space" line-ending = "auto" +[tool.fastapi] +entrypoint = "app.main:app" + [tool.pytest.ini_options] testpaths = ["tests"] pythonpath = ["."] diff --git a/api/scripts/verify_observability_metrics.sh b/api/scripts/verify_observability_metrics.sh index 1ea22c5..0b5266e 100755 --- a/api/scripts/verify_observability_metrics.sh +++ b/api/scripts/verify_observability_metrics.sh @@ -34,7 +34,7 @@ done if [[ "${fail}" -ne 0 ]]; then echo "" - echo "Some metrics missing. Ensure OTEL_ENABLED=true, API/worker running, and traffic generated." + echo "Some metrics missing. Ensure config deploy.otel_enabled=true, observability compose running, API/worker up, and traffic generated." exit 1 fi echo "All required metrics present." diff --git a/api/skills-lock.json b/api/skills-lock.json new file mode 100644 index 0000000..c890d85 --- /dev/null +++ b/api/skills-lock.json @@ -0,0 +1,17 @@ +{ + "version": 1, + "skills": { + "celery-expert": { + "source": "martinholovsky/claude-skills-generator", + "sourceType": "github", + "skillPath": "skills/celery-expert/SKILL.md", + "computedHash": "a0bc038bb3a9c930024720a2f34eb5fa7aad2152171cc7f7943afdf8c98f7f60" + }, + "redis-development": { + "source": "redis/agent-skills", + "sourceType": "github", + "skillPath": "skills/redis-development/SKILL.md", + "computedHash": "f48a1eedb507058ddf1cc4c320ddb5e36e8a2b6be9ed9165a85d7438ed9fab48" + } + } +} diff --git a/api/static/home.html b/api/static/home/index.html similarity index 100% rename from api/static/home.html rename to api/static/home/index.html diff --git a/api/static/legal/privacy.html b/api/static/legal/privacy.html new file mode 100644 index 0000000..25cc8c2 --- /dev/null +++ b/api/static/legal/privacy.html @@ -0,0 +1,141 @@ + + + + + + 隐私政策 - 岁月留书 + + + +
+

隐私政策

+
岁月留书隐私保护政策
+ +
+ 服务提供方:上海华嘎科技有限公司
+ 产品名称:岁月留书
+ 生效日期:2026年1月27日 +
+ +

上海华嘎科技有限公司(以下简称"我们")非常重视用户的隐私保护。本隐私政策说明了我们如何收集、使用、存储和保护您的个人信息。请您仔细阅读本隐私政策,以了解我们对您个人信息的处理方式。

+ +

一、信息收集

+

为了向您提供更好的服务,我们可能会收集以下信息:

+

1. 账户信息:当您注册账户时,我们会收集您的手机号码、密码、昵称、邮箱(可选)等信息。

+

2. 设备信息:我们可能会收集您的设备型号、操作系统版本、设备标识符、IP地址等信息,用于提供更好的服务体验和保障账户安全。

+

3. 使用信息:我们会收集您使用本服务时产生的信息,包括但不限于对话记录、语音内容、文字内容、操作日志等。

+

4. 位置信息:在您授权的情况下,我们可能会收集您的位置信息,用于提供基于位置的服务。

+ +

二、信息使用

+

我们收集您的个人信息主要用于以下目的:

+

1. 提供服务:使用您的信息来提供、维护、改进我们的服务,包括处理您的对话请求、生成回忆录内容等。

+

2. 账户管理:用于账户注册、登录验证、密码重置、账户安全保护等。

+

3. 客户服务:用于响应您的咨询、处理您的反馈、解决技术问题等。

+

4. 安全保护:用于检测、预防、处理欺诈、滥用、安全风险和技术问题。

+

5. 法律合规:遵守适用的法律法规、法律程序或政府要求。

+

6. 服务改进:分析用户使用情况,改进我们的产品和服务质量。

+ +

三、信息存储

+

1. 存储地点:您的个人信息将存储在中华人民共和国境内。如需跨境传输,我们将严格按照相关法律法规执行。

+

2. 存储期限:我们仅在为实现本政策所述目的所必需的期间内保留您的个人信息。在您注销账户后,我们将删除或匿名化处理您的个人信息,法律法规另有规定的除外。

+

3. 安全措施:我们采用行业标准的安全技术和措施来保护您的个人信息,包括但不限于数据加密、访问控制、安全审计等。

+ +

四、信息共享与披露

+

我们承诺不会向第三方出售、出租或以其他方式披露您的个人信息,但以下情况除外:

+

1. 获得您的同意:在获得您明确同意的情况下,我们可能会与第三方共享您的信息。

+

2. 服务提供商:我们可能会与为我们提供服务的第三方(如云服务提供商、数据分析服务商等)共享必要的信息,但这些第三方必须遵守严格的保密义务。

+

3. 法律要求:根据法律法规、法律程序、诉讼或政府主管部门的要求,我们可能需要披露您的个人信息。

+

4. 紧急情况:为保护我们、用户或公众的权利、财产或安全,我们可能会在必要时披露相关信息。

+ +

五、您的权利

+

根据相关法律法规,您对自己的个人信息享有以下权利:

+

1. 访问权:您有权访问我们持有的您的个人信息。

+

2. 更正权:您有权要求更正不准确或不完整的个人信息。

+

3. 删除权:在特定情况下,您有权要求删除您的个人信息。

+

4. 撤回同意:您有权撤回之前给予我们的同意,但这可能影响您使用部分服务功能。

+

5. 注销账户:您有权注销您的账户。账户注销后,我们将删除或匿名化处理您的个人信息。

+

如您需要行使上述权利,请通过我们提供的联系方式与我们联系。

+ +

六、未成年人保护

+

我们非常重视对未成年人个人信息的保护。如果您是18周岁以下的未成年人,请在您的监护人同意和指导下使用本服务。如果我们发现自己在未事先获得可证实的监护人同意的情况下收集了未成年人的个人信息,我们会设法尽快删除相关数据。

+ +

七、Cookie和类似技术

+

我们可能会使用Cookie和类似技术来收集信息、改善用户体验、分析服务使用情况等。您可以通过浏览器设置管理Cookie,但请注意,禁用Cookie可能会影响部分服务功能的使用。

+ +

八、第三方服务

+

我们的服务可能包含指向第三方网站、产品和服务的链接。我们不对这些第三方的隐私做法负责,建议您仔细阅读这些第三方的隐私政策。

+ +

九、隐私政策的更新

+

我们可能会不时更新本隐私政策。更新后,我们会在相关页面公布最新版本的隐私政策,并通过适当方式通知您。如您不同意更新后的隐私政策,请停止使用本服务;如您继续使用,则视为接受更新后的隐私政策。

+ +

十、联系我们

+

如您对本隐私政策有任何疑问、意见或建议,或需要行使您的相关权利,请通过以下方式与我们联系:

+

公司名称:上海华嘎科技有限公司
+ 产品名称:岁月留书

+

我们将在收到您的请求后,尽快予以回复。

+ +
+ 最后更新时间:2026年1月27日 +
+
+ + + \ No newline at end of file diff --git a/api/static/legal/terms.html b/api/static/legal/terms.html new file mode 100644 index 0000000..50f66d0 --- /dev/null +++ b/api/static/legal/terms.html @@ -0,0 +1,139 @@ + + + + + + 用户协议 - 岁月留书 + + + +
+

用户协议

+
岁月留书用户服务协议
+ +
+ 服务提供方:上海华嘎科技有限公司
+ 产品名称:岁月留书
+ 生效日期:2026年1月27日 +
+ +

一、协议的接受

+

欢迎使用岁月留书(以下简称"本服务")。本协议是您与上海华嘎科技有限公司(以下简称"我们"或"公司")之间关于使用本服务的法律协议。

+

请您仔细阅读本协议的全部内容,特别是涉及免除或限制责任的条款、法律适用和争议解决条款。当您点击"同意"按钮或实际使用本服务时,即表示您已充分阅读、理解并同意接受本协议的全部内容。

+ +

二、服务说明

+

1. 岁月留书是一款帮助用户记录和整理人生回忆的智能应用服务。

+

2. 我们有权根据业务发展需要调整、变更或终止部分或全部服务内容。

+

3. 我们保留随时修改或中断服务而不需通知用户的权利。

+ +

三、用户账户

+

1. 您需要注册账户才能使用本服务的部分功能。注册时,您应当提供真实、准确、完整的个人信息。

+

2. 您有责任维护账户信息的安全性和准确性,并对账户下的所有活动负责。

+

3. 如发现账户被盗用或存在安全漏洞,请立即通知我们。

+ +

四、用户行为规范

+

1. 您在使用本服务时,应当遵守国家法律法规,不得利用本服务从事违法违规活动。

+

2. 您不得上传、发布、传播含有以下内容的信息:

+

+ (1)违反国家法律法规、危害国家安全、破坏社会稳定的内容;
+ (2)侵犯他人知识产权、隐私权、名誉权等合法权益的内容;
+ (3)色情、暴力、赌博、诈骗等不良信息;
+ (4)其他违反公序良俗的内容。 +

+

3. 您应当尊重他人的合法权益,不得恶意干扰、破坏本服务的正常运行。

+ +

五、知识产权

+

1. 本服务的所有知识产权,包括但不限于商标、专利、著作权等,均归我们所有。

+

2. 您在使用本服务过程中产生的内容(包括但不限于文字、图片、音频等),其知识产权归您所有,但您授予我们在提供服务所必需的范围内使用这些内容的权利。

+

3. 未经我们书面许可,您不得以任何形式复制、传播、展示、镜像、上传、下载本服务的任何内容。

+ +

六、隐私保护

+

我们非常重视您的隐私保护。关于我们如何收集、使用、存储和保护您的个人信息,请详细阅读我们的《隐私政策》。

+ +

七、免责声明

+

1. 本服务基于现有技术和条件提供,我们不对服务的及时性、准确性、完整性、可靠性作任何明示或暗示的保证。

+

2. 因不可抗力、计算机病毒、黑客攻击、系统不稳定、用户设备故障等原因导致的服务中断或数据丢失,我们不承担责任。

+

3. 您因使用本服务而产生的任何直接或间接损失,我们均不承担责任。

+ +

八、服务变更与终止

+

1. 我们有权根据业务发展需要,随时变更、中断或终止部分或全部服务。

+

2. 如您违反本协议,我们有权立即终止向您提供服务,并保留追究法律责任的权利。

+

3. 服务终止后,您账户内的数据可能被删除,请您提前备份重要数据。

+ +

九、协议修改

+

我们有权随时修改本协议。协议修改后,我们会在相关页面公布修改后的协议内容。如您不同意修改后的协议,请停止使用本服务;如您继续使用,则视为接受修改后的协议。

+ +

十、法律适用与争议解决

+

1. 本协议的订立、生效、解释、履行和争议解决均适用中华人民共和国大陆地区法律法规。

+

2. 如因本协议产生任何争议,双方应友好协商解决;协商不成的,任何一方均可向我们住所地有管辖权的人民法院提起诉讼。

+ +

十一、其他

+

1. 如本协议的任何条款被认定为无效或不可执行,不影响其他条款的效力。

+

2. 本协议的标题仅为方便阅读而设,不影响本协议任何条款的含义或解释。

+

3. 如您对本协议有任何疑问,可通过我们提供的联系方式与我们联系。

+ +
+ 最后更新时间:2026年1月27日 +
+
+ + + \ No newline at end of file diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 9a408e9..9471646 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -2,15 +2,46 @@ from __future__ import annotations +pytest_plugins = ["tests.support.auth_async_sqlite"] + +import os + +import pytest + +# TOML SSOT:OTEL_ENABLED env 已无效;在 pytest 最早阶段关闭 deploy.otel_enabled。 +os.environ.setdefault("APP_ENV", "development") + + +def pytest_configure(config: pytest.Config) -> None: + from app.core.app_config import get_app_config + + get_app_config().deploy.otel_enabled = False + import uuid from datetime import datetime, timezone from typing import Callable -import pytest +from fastapi import FastAPI from app.features.user.models import User +def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None: + from app.core.telemetry import shutdown_telemetry + + shutdown_telemetry() + + +def install_test_error_handlers(app: FastAPI) -> FastAPI: + """为最小测试 FastAPI 应用挂载与生产一致的异常处理。""" + from app.core.errors import register_exception_handlers + from app.core.middleware import RequestIdMiddleware + + app.add_middleware(RequestIdMiddleware) + register_exception_handlers(app) + return app + + @pytest.fixture def unique_phone() -> str: """避免与测试库中已存在手机号冲突(11 位)。""" diff --git a/api/tests/evaluation/test_internal_router_auth.py b/api/tests/evaluation/test_internal_router_auth.py index 8fc4a1d..9e66049 100644 --- a/api/tests/evaluation/test_internal_router_auth.py +++ b/api/tests/evaluation/test_internal_router_auth.py @@ -4,6 +4,7 @@ import pytest from httpx import ASGITransport, AsyncClient from app.features.evaluation.internal_auth import get_internal_eval_principal +from tests.conftest import install_test_error_handlers @pytest.mark.asyncio @@ -19,7 +20,7 @@ async def test_internal_eval_list_fixtures_requires_config( ) from app.features.evaluation.router import router - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(router, prefix="/internal/api/evaluation") transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://t") as client: @@ -50,7 +51,7 @@ async def test_internal_eval_with_override_lists_fixtures( from app.features.evaluation.router import router - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(router, prefix="/internal/api/evaluation") async def _override_auth(): diff --git a/api/tests/evaluation/test_memoir_pipeline_run_router.py b/api/tests/evaluation/test_memoir_pipeline_run_router.py index 4155a75..7309ec3 100644 --- a/api/tests/evaluation/test_memoir_pipeline_run_router.py +++ b/api/tests/evaluation/test_memoir_pipeline_run_router.py @@ -1,6 +1,7 @@ """GET /users/{user_id}/memoir-pipeline-run(快照读取)。""" import pytest +from tests.conftest import install_test_error_handlers from httpx import ASGITransport, AsyncClient from app.features.evaluation.internal_auth import get_internal_eval_principal @@ -42,7 +43,7 @@ async def test_memoir_pipeline_run_ok_by_phase1_task( _fake_eval, ) - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(router, prefix="/internal/api/evaluation") async def _override_auth(): @@ -78,7 +79,7 @@ async def test_memoir_pipeline_run_400_both_ids( ) from app.features.evaluation.router import router - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(router, prefix="/internal/api/evaluation") async def _override_auth(): diff --git a/api/tests/evaluation/test_memoir_readiness_router.py b/api/tests/evaluation/test_memoir_readiness_router.py index dacee38..4409ef1 100644 --- a/api/tests/evaluation/test_memoir_readiness_router.py +++ b/api/tests/evaluation/test_memoir_readiness_router.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone import pytest +from tests.conftest import install_test_error_handlers from httpx import ASGITransport, AsyncClient from app.features.evaluation.internal_auth import get_internal_eval_principal @@ -33,7 +34,7 @@ async def test_memoir_phase1_ready_returns_bundle( pending_segment_ids=[], ) - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(router, prefix="/internal/api/evaluation") async def _override_auth(): @@ -79,7 +80,7 @@ async def test_memoir_phase1_ready_404_propagates( ) -> MemoirPhase1ReadyOut: raise EvaluationNotFoundError("conversation not found") - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(router, prefix="/internal/api/evaluation") async def _override_auth(): diff --git a/api/tests/evaluation/test_replay_router.py b/api/tests/evaluation/test_replay_router.py index 23724a7..dd4c51c 100644 --- a/api/tests/evaluation/test_replay_router.py +++ b/api/tests/evaluation/test_replay_router.py @@ -1,6 +1,7 @@ """回放 / 评审路由参数校验(最小 HTTP)。""" import pytest +from tests.conftest import install_test_error_handlers from httpx import ASGITransport, AsyncClient from app.features.evaluation.internal_auth import get_internal_eval_principal @@ -19,7 +20,7 @@ async def test_replay_conversation_requires_fixture_or_utterances( ) from app.features.evaluation.router import router - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(router, prefix="/internal/api/evaluation") async def _override_auth(): @@ -51,7 +52,7 @@ async def test_replay_conversation_rejects_both_fixture_and_utterances( ) from app.features.evaluation.router import router - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(router, prefix="/internal/api/evaluation") async def _override_auth(): diff --git a/api/tests/fixtures/config/merge/default.toml b/api/tests/fixtures/config/merge/default.toml new file mode 100644 index 0000000..02e632a --- /dev/null +++ b/api/tests/fixtures/config/merge/default.toml @@ -0,0 +1,7 @@ +[deploy] +enable_tts = true +mock_sms_login_enabled = false + +[chat] +interview_persona = "default" +interview_temperature = 0.93 diff --git a/api/tests/fixtures/config/merge/staging.toml b/api/tests/fixtures/config/merge/staging.toml new file mode 100644 index 0000000..99b2544 --- /dev/null +++ b/api/tests/fixtures/config/merge/staging.toml @@ -0,0 +1,5 @@ +[deploy] +mock_sms_login_enabled = true + +[chat] +interview_persona = "staging_persona" diff --git a/api/tests/fixtures/config/minimal/default.toml b/api/tests/fixtures/config/minimal/default.toml new file mode 100644 index 0000000..d015ad7 --- /dev/null +++ b/api/tests/fixtures/config/minimal/default.toml @@ -0,0 +1,7 @@ +# minimal default for loader tests + +[deploy] +enable_tts = true + +[chat] +interview_persona = "default" diff --git a/api/tests/support/__init__.py b/api/tests/support/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/support/auth_async_sqlite.py b/api/tests/support/auth_async_sqlite.py new file mode 100644 index 0000000..fe28df0 --- /dev/null +++ b/api/tests/support/auth_async_sqlite.py @@ -0,0 +1,77 @@ +"""Async SQLite helpers for auth HTTP integration tests.""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +from app.core.db import Base, utc_now +from app.core.security import get_token_expires_at +from app.features.auth.models import RefreshToken +from app.features.user.models import User + + +def _auth_tables() -> list: + return [User.__table__, RefreshToken.__table__] + + +@pytest.fixture +async def auth_async_engine() -> AsyncGenerator[AsyncEngine, None]: + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync( + lambda sync_conn: Base.metadata.create_all( + sync_conn, tables=_auth_tables() + ) + ) + yield engine + await engine.dispose() + + +@pytest.fixture +async def auth_session_factory( + auth_async_engine: AsyncEngine, +) -> async_sessionmaker[AsyncSession]: + return async_sessionmaker( + auth_async_engine, class_=AsyncSession, expire_on_commit=False + ) + + +async def seed_user_with_refresh_token( + session: AsyncSession, + *, + refresh_token: str = "refresh-old", + user_id: str | None = None, + phone: str | None = None, +) -> tuple[User, str]: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + phone=phone or f"138{uuid.uuid4().int % 100_000_000:08d}", + password_hash="hashed", + nickname="Test", + subscription_type="free", + created_at=utc_now(), + language_preference="zh", + ) + token_row = RefreshToken( + id=str(uuid.uuid4()), + user_id=uid, + token=refresh_token, + expires_at=get_token_expires_at(), + created_at=utc_now(), + is_revoked=False, + ) + session.add(user) + session.add(token_row) + await session.commit() + return user, refresh_token diff --git a/api/tests/test_agent_logging.py b/api/tests/test_agent_logging.py index f97b235..7448071 100644 --- a/api/tests/test_agent_logging.py +++ b/api/tests/test_agent_logging.py @@ -3,6 +3,7 @@ from __future__ import annotations import app.core.agent_logging as agent_logging +from app.core.runtime_constants import agent_log_defaults class _StubLogger: @@ -27,9 +28,9 @@ def test_log_agent_payload_skips_when_not_verbose(monkeypatch: object) -> None: def test_log_agent_payload_preview_includes_sha12(monkeypatch: object) -> None: monkeypatch.setattr("app.core.config.settings.log_level", "DEBUG") - monkeypatch.setattr("app.core.config.settings.agent_log_prompt_mode", "preview") - monkeypatch.setattr("app.core.config.settings.agent_log_prompt_dedup", False) - monkeypatch.setattr("app.core.config.settings.agent_log_max_chars", 100) + monkeypatch.setattr("app.core.runtime_constants.agent_log_defaults.prompt_mode", "preview") + monkeypatch.setattr("app.core.runtime_constants.agent_log_defaults.prompt_dedup", False) + monkeypatch.setattr("app.core.runtime_constants.agent_log_defaults.max_chars", 100) _clear_dedup() log = _StubLogger() agent_logging.log_agent_payload(log, "Unit.prompt", "hello world") @@ -43,8 +44,8 @@ def test_log_agent_payload_preview_includes_sha12(monkeypatch: object) -> None: def test_log_agent_payload_hash_only_no_preview(monkeypatch: object) -> None: monkeypatch.setattr("app.core.config.settings.log_level", "DEBUG") - monkeypatch.setattr("app.core.config.settings.agent_log_prompt_mode", "hash_only") - monkeypatch.setattr("app.core.config.settings.agent_log_prompt_dedup", False) + monkeypatch.setattr("app.core.runtime_constants.agent_log_defaults.prompt_mode", "hash_only") + monkeypatch.setattr("app.core.runtime_constants.agent_log_defaults.prompt_dedup", False) _clear_dedup() log = _StubLogger() body = "x" * 500 @@ -59,9 +60,9 @@ def test_log_agent_payload_hash_only_no_preview(monkeypatch: object) -> None: def test_log_agent_payload_dedup_second_call_skipped(monkeypatch: object) -> None: monkeypatch.setattr("app.core.config.settings.log_level", "DEBUG") - monkeypatch.setattr("app.core.config.settings.agent_log_prompt_mode", "preview") - monkeypatch.setattr("app.core.config.settings.agent_log_prompt_dedup", True) - monkeypatch.setattr("app.core.config.settings.agent_log_max_chars", 200) + monkeypatch.setattr("app.core.runtime_constants.agent_log_defaults.prompt_mode", "preview") + monkeypatch.setattr("app.core.runtime_constants.agent_log_defaults.prompt_dedup", True) + monkeypatch.setattr("app.core.runtime_constants.agent_log_defaults.max_chars", 200) _clear_dedup() log = _StubLogger() agent_logging.log_agent_payload(log, "DedupLabel.prompt", "same text") diff --git a/api/tests/test_alembic_migration_policy.py b/api/tests/test_alembic_migration_policy.py index 6be073f..79e3a15 100644 --- a/api/tests/test_alembic_migration_policy.py +++ b/api/tests/test_alembic_migration_policy.py @@ -39,7 +39,7 @@ def _script_dir() -> ScriptDirectory: def test_single_alembic_head() -> None: heads = _script_dir().get_heads() - assert heads == ["0019_align_legacy_schema"], f"unexpected heads: {heads}" + assert heads == ["0021_memory_source_segment_id"], f"unexpected heads: {heads}" def test_no_withdrawn_revision_ids_in_tree() -> None: @@ -78,11 +78,11 @@ def test_all_revisions_have_unique_ids() -> None: assert len(ids) == len(set(ids)), "duplicate revision ids" -def test_revision_chain_reaches_0019_from_0018() -> None: +def test_revision_chain_reaches_0021_from_0020() -> None: script = _script_dir() - rev = script.get_revision("0019_align_legacy_schema") + rev = script.get_revision("0021_memory_source_segment_id") assert rev is not None - assert rev.down_revision == "0018_users_language_preference" + assert rev.down_revision == "0020_refresh_rt_lineage" def test_no_autogenerate_introspection_backfill_pattern() -> None: diff --git a/api/tests/test_app_config_loader.py b/api/tests/test_app_config_loader.py new file mode 100644 index 0000000..7e6e263 --- /dev/null +++ b/api/tests/test_app_config_loader.py @@ -0,0 +1,33 @@ +"""TOML configuration loader tests.""" + +from pathlib import Path + +import pytest + +from app.core.app_config_loader import load_app_config + + +FIXTURES = Path(__file__).resolve().parent / "fixtures" / "config" + + +def test_load_default_only() -> None: + cfg = load_app_config("development", config_dir=FIXTURES / "minimal") + assert cfg.chat.interview_persona == "default" + assert cfg.deploy.enable_tts is True + + +def test_staging_overlay_merges_chat_section() -> None: + cfg = load_app_config("staging", config_dir=FIXTURES / "merge") + assert cfg.chat.interview_persona == "staging_persona" + assert cfg.chat.interview_temperature == 0.93 + assert cfg.deploy.mock_sms_login_enabled is True + + +def test_unknown_top_level_key_rejected(tmp_path: Path) -> None: + bad = tmp_path / "default.toml" + bad.write_text( + '[deploy]\nenable_tts = true\n\n[typo_section]\nfoo = 1\n', + encoding="utf-8", + ) + with pytest.raises(Exception): + load_app_config("development", config_dir=tmp_path) diff --git a/api/tests/test_app_error_contract.py b/api/tests/test_app_error_contract.py new file mode 100644 index 0000000..61a6d86 --- /dev/null +++ b/api/tests/test_app_error_contract.py @@ -0,0 +1,241 @@ +"""AppError 统一错误契约 HTTP 场景测试。""" + +import pytest +from fastapi import FastAPI, HTTPException +from httpx import ASGITransport, AsyncClient + +from app.core.errors import ( + AuthenticationError, + NotFoundError, + QuotaExceededError, + RateLimitedError, + ValidationError, + register_exception_handlers, +) +from app.core.middleware import RequestIdMiddleware +from app.features.auth.service import AuthError +from app.features.payment.payment_exceptions import PaymentError + + +def _test_app() -> FastAPI: + app = FastAPI() + app.add_middleware(RequestIdMiddleware) + register_exception_handlers(app) + + @app.get("/not-found") + async def _not_found(): + raise NotFoundError("资源不存在") + + @app.get("/auth-failed") + async def _auth_failed(): + raise AuthenticationError("无法验证凭据") + + @app.get("/quota") + async def _quota(): + raise QuotaExceededError("配额已用尽") + + @app.get("/rate-limited") + async def _rate_limited(): + raise RateLimitedError("发送过于频繁,请30秒后再试") + + @app.get("/validation") + async def _validation(): + raise ValidationError("参数无效") + + @app.get("/auth-domain") + async def _auth_domain(): + raise AuthError("该邮箱已被注册", "EMAIL_EXISTS") + + @app.get("/payment-domain") + async def _payment_domain(): + raise PaymentError("支付配置错误", code="PAYMENT_CONFIG_ERROR") + + @app.get("/http-string") + async def _http_string(): + raise HTTPException(status_code=400, detail="请求无效") + + @app.get("/http-429") + async def _http_429(): + raise HTTPException(status_code=429, detail="发送过于频繁") + + @app.get("/http-list") + async def _http_list(): + raise HTTPException( + status_code=422, + detail=[{"loc": ["body", "phone"], "msg": "field required"}], + ) + + @app.get("/http-unknown-status") + async def _http_unknown_status(): + raise HTTPException(status_code=418, detail="teapot") + + @app.get("/http-unknown-5xx") + async def _http_unknown_5xx(): + raise HTTPException(status_code=599, detail="upstream glitch") + + @app.get("/boom") + async def _boom(): + raise RuntimeError("secret_internal_xyz") + + return app + + +@pytest.mark.asyncio +async def test_not_found_error_contract() -> None: + app = _test_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/not-found") + assert r.status_code == 404 + body = r.json() + assert body["error_code"] == "NOT_FOUND" + assert body["message"] == "资源不存在" + assert "request_id" in body + + +@pytest.mark.asyncio +async def test_authentication_error_contract() -> None: + app = _test_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/auth-failed") + assert r.status_code == 401 + body = r.json() + assert body["error_code"] == "AUTHENTICATION_FAILED" + assert body["message"] == "无法验证凭据" + assert r.headers.get("www-authenticate") == "Bearer" + + +@pytest.mark.asyncio +async def test_quota_exceeded_error_contract() -> None: + app = _test_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/quota") + assert r.status_code == 429 + body = r.json() + assert body["error_code"] == "QUOTA_EXCEEDED" + assert body["message"] == "配额已用尽" + + +@pytest.mark.asyncio +async def test_rate_limited_error_contract() -> None: + app = _test_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/rate-limited") + assert r.status_code == 429 + body = r.json() + assert body["error_code"] == "RATE_LIMITED" + assert body["message"] == "发送过于频繁,请30秒后再试" + + +@pytest.mark.asyncio +async def test_validation_error_contract() -> None: + app = _test_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/validation") + assert r.status_code == 422 + body = r.json() + assert body["error_code"] == "VALIDATION_ERROR" + assert body["message"] == "参数无效" + + +@pytest.mark.asyncio +async def test_auth_error_domain_contract() -> None: + app = _test_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/auth-domain") + assert r.status_code == 400 + body = r.json() + assert body["error_code"] == "EMAIL_EXISTS" + assert body["message"] == "该邮箱已被注册" + + +@pytest.mark.asyncio +async def test_payment_error_domain_contract() -> None: + app = _test_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/payment-domain") + assert r.status_code == 502 + body = r.json() + assert body["error_code"] == "PROVIDER_ERROR" + assert body["message"] == "支付配置错误" + + +@pytest.mark.asyncio +async def test_http_exception_string_detail_contract() -> None: + app = _test_app() + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/http-string") + assert r.status_code == 400 + body = r.json() + assert body["error_code"] == "BAD_REQUEST" + assert body["message"] == "请求无效" + assert "request_id" in body + + +@pytest.mark.asyncio +async def test_http_exception_429_maps_to_rate_limited() -> None: + app = _test_app() + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/http-429") + assert r.status_code == 429 + body = r.json() + assert body["error_code"] == "RATE_LIMITED" + assert body["message"] == "发送过于频繁" + + +@pytest.mark.asyncio +async def test_http_exception_list_detail_contract() -> None: + app = _test_app() + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/http-list") + assert r.status_code == 422 + body = r.json() + assert body["error_code"] == "VALIDATION_ERROR" + assert "body.phone" in body["message"] + assert "field required" in body["message"] + + +@pytest.mark.asyncio +async def test_http_exception_unknown_5xx_maps_to_internal_error() -> None: + app = _test_app() + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/http-unknown-5xx") + assert r.status_code == 599 + body = r.json() + assert body["error_code"] == "INTERNAL_ERROR" + assert body["message"] == "upstream glitch" + + +@pytest.mark.asyncio +async def test_http_exception_unknown_status_maps_to_bad_request() -> None: + app = _test_app() + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/http-unknown-status") + assert r.status_code == 418 + body = r.json() + assert body["error_code"] == "BAD_REQUEST" + assert body["message"] == "teapot" + + +@pytest.mark.asyncio +async def test_unhandled_exception_contract() -> None: + app = _test_app() + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/boom") + assert r.status_code == 500 + body = r.json() + assert body["error_code"] == "INTERNAL_ERROR" + assert body["message"] == "服务器内部错误" + assert "secret_internal_xyz" not in r.text diff --git a/api/tests/test_auth_refresh_http.py b/api/tests/test_auth_refresh_http.py new file mode 100644 index 0000000..05f33ed --- /dev/null +++ b/api/tests/test_auth_refresh_http.py @@ -0,0 +1,180 @@ +"""Refresh token rotation HTTP scenarios (sqlite + real AuthService).""" + +from __future__ import annotations + +import asyncio +import uuid +from datetime import timedelta +from unittest.mock import MagicMock + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from app.core.config import settings +from app.core.db import get_async_db, utc_now +from app.core.dependencies import get_sms_sender +from app.features.auth import repo +from app.features.auth.router import router as auth_router +from tests.conftest import install_test_error_handlers +from tests.support.auth_async_sqlite import seed_user_with_refresh_token + + +@pytest.fixture +async def refresh_http_app( + auth_session_factory: async_sessionmaker[AsyncSession], +) -> FastAPI: + async def _override_db(): + async with auth_session_factory() as session: + yield session + + app = install_test_error_handlers(FastAPI()) + app.include_router(auth_router) + app.dependency_overrides[get_async_db] = _override_db + app.dependency_overrides[get_sms_sender] = lambda: MagicMock() + return app + + +@pytest.mark.asyncio +async def test_refresh_rotates_token_http( + refresh_http_app: FastAPI, + auth_session_factory: async_sessionmaker[AsyncSession], +) -> None: + async with auth_session_factory() as session: + await seed_user_with_refresh_token(session, refresh_token="token-v1") + + transport = ASGITransport(app=refresh_http_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/auth/refresh", + json={"refresh_token": "token-v1"}, + ) + + assert response.status_code == 200 + body = response.json() + assert body["access_token"] + assert body["refresh_token"] != "token-v1" + + +@pytest.mark.asyncio +async def test_refresh_concurrent_same_token_within_grace( + refresh_http_app: FastAPI, + auth_session_factory: async_sessionmaker[AsyncSession], +) -> None: + async with auth_session_factory() as session: + await seed_user_with_refresh_token(session, refresh_token="token-concurrent") + + transport = ASGITransport(app=refresh_http_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + first, second = await asyncio.gather( + client.post( + "/api/auth/refresh", + json={"refresh_token": "token-concurrent"}, + ), + client.post( + "/api/auth/refresh", + json={"refresh_token": "token-concurrent"}, + ), + ) + + assert first.status_code in (200, 401) + assert second.status_code in (200, 401) + assert first.status_code == 200 or second.status_code == 200 + responses = [first, second] + success = [r for r in responses if r.status_code == 200] + assert len(success) >= 1 + assert all( + r.json().get("error_code") != "REFRESH_TOKEN_REUSE" + for r in responses + if r.status_code != 200 + ) + if len(success) == 2: + assert success[0].json()["refresh_token"] == success[1].json()["refresh_token"] + assert success[0].json()["refresh_token"] != "token-concurrent" + + +@pytest.mark.asyncio +async def test_refresh_sequential_reuse_within_grace( + refresh_http_app: FastAPI, + auth_session_factory: async_sessionmaker[AsyncSession], +) -> None: + """Network retry: second call with old refresh token succeeds within grace.""" + async with auth_session_factory() as session: + await seed_user_with_refresh_token(session, refresh_token="token-retry") + + transport = ASGITransport(app=refresh_http_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + first = await client.post( + "/api/auth/refresh", + json={"refresh_token": "token-retry"}, + ) + second = await client.post( + "/api/auth/refresh", + json={"refresh_token": "token-retry"}, + ) + + assert first.status_code == 200 + assert second.status_code == 200 + assert first.json()["refresh_token"] == second.json()["refresh_token"] + assert first.json()["refresh_token"] != "token-retry" + + +@pytest.mark.asyncio +async def test_refresh_reuse_after_grace_returns_401( + refresh_http_app: FastAPI, + auth_session_factory: async_sessionmaker[AsyncSession], +) -> None: + async with auth_session_factory() as session: + await seed_user_with_refresh_token(session, refresh_token="token-grace") + + transport = ASGITransport(app=refresh_http_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + first = await client.post( + "/api/auth/refresh", + json={"refresh_token": "token-grace"}, + ) + assert first.status_code == 200 + + async with auth_session_factory() as session: + old = await repo.get_refresh_token_by_token("token-grace", session) + assert old is not None + old.rotated_at = utc_now() - timedelta( + seconds=settings.refresh_token_reuse_grace_seconds + 5 + ) + await session.commit() + + async with AsyncClient(transport=transport, base_url="http://test") as client: + reuse = await client.post( + "/api/auth/refresh", + json={"refresh_token": "token-grace"}, + ) + + assert reuse.status_code == 401 + body = reuse.json() + assert body["error_code"] == "REFRESH_TOKEN_REUSE" + assert body["message"] + assert "request_id" in body + + +@pytest.mark.asyncio +async def test_refresh_unknown_token_401( + refresh_http_app: FastAPI, + auth_session_factory: async_sessionmaker[AsyncSession], +) -> None: + async with auth_session_factory() as session: + await seed_user_with_refresh_token( + session, refresh_token="unused", user_id=str(uuid.uuid4()) + ) + + transport = ASGITransport(app=refresh_http_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/auth/refresh", + json={"refresh_token": "missing-token"}, + ) + + assert response.status_code == 401 + body = response.json() + assert body["error_code"] == "AUTHENTICATION_FAILED" + assert "request_id" in body diff --git a/api/tests/test_auth_refresh_rotation.py b/api/tests/test_auth_refresh_rotation.py new file mode 100644 index 0000000..78edd57 --- /dev/null +++ b/api/tests/test_auth_refresh_rotation.py @@ -0,0 +1,234 @@ +"""Refresh token rotation and reuse detection.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.features.auth import repo +from app.features.auth.models import RefreshToken +from app.features.auth.service import AuthError, AuthService + + +def _refresh_record( + *, + token: str = "old-refresh", + user_id: str = "user-1", + is_revoked: bool = False, + expired: bool = False, + replaced_by_token_id: str | None = None, + rotated_at: datetime | None = None, +) -> MagicMock: + now = datetime.now(timezone.utc) + record = MagicMock(spec=RefreshToken) + record.id = "rt-1" + record.user_id = user_id + record.token = token + record.expires_at = ( + now - timedelta(days=1) if expired else now + timedelta(days=30) + ) + record.created_at = now + record.is_revoked = is_revoked + record.device_info = "iphone" + record.replaced_by_token_id = replaced_by_token_id + record.rotated_at = rotated_at + return record + + +def _db_mock() -> MagicMock: + db = MagicMock() + db.commit = AsyncMock() + db.rollback = AsyncMock() + db.refresh = AsyncMock() + db.flush = AsyncMock() + return db + + +@pytest.mark.asyncio +async def test_refresh_rotates_token_in_transaction(monkeypatch) -> None: + db = _db_mock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + consumed = _refresh_record() + + monkeypatch.setattr( + repo, + "try_consume_refresh_token", + AsyncMock(return_value=consumed), + ) + monkeypatch.setattr( + repo, + "get_user_by_id", + AsyncMock(return_value=MagicMock(id="user-1")), + ) + monkeypatch.setattr( + svc, + "_issue_tokens", + AsyncMock( + return_value={ + "access_token": "new-access", + "refresh_token": "new-refresh", + "refresh_token_id": "rt-2", + } + ), + ) + monkeypatch.setattr(repo, "link_refresh_rotation", AsyncMock()) + + result = await svc.refresh_tokens("old-refresh") + + assert result["access_token"] == "new-access" + assert result["refresh_token"] == "new-refresh" + assert result["refresh_token"] != "old-refresh" + db.commit.assert_awaited_once() + svc._issue_tokens.assert_awaited_once_with("user-1", "iphone") + repo.try_consume_refresh_token.assert_awaited_once_with("old-refresh", db) + repo.link_refresh_rotation.assert_awaited_once() + assert repo.link_refresh_rotation.await_args.args[0] == "rt-1" + assert repo.link_refresh_rotation.await_args.args[1] == "rt-2" + + +@pytest.mark.asyncio +async def test_refresh_grace_reuse_returns_idempotent_tokens(monkeypatch) -> None: + db = _db_mock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + now = datetime.now(timezone.utc) + revoked = _refresh_record( + is_revoked=True, + replaced_by_token_id="rt-2", + rotated_at=now, + ) + replacement = _refresh_record(token="new-refresh", is_revoked=False) + replacement.id = "rt-2" + + monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None)) + monkeypatch.setattr( + repo, + "get_refresh_token_by_token", + AsyncMock(return_value=revoked), + ) + monkeypatch.setattr( + repo, + "get_refresh_token_by_id", + AsyncMock(return_value=replacement), + ) + monkeypatch.setattr( + repo, + "get_user_by_id", + AsyncMock(return_value=MagicMock(id="user-1")), + ) + monkeypatch.setattr( + svc, + "_revoke_all_active_tokens_in_session", + AsyncMock(return_value=2), + ) + monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) + + result = await svc.refresh_tokens("old-refresh") + + assert result["refresh_token"] == "new-refresh" + assert result["access_token"] + svc._revoke_all_active_tokens_in_session.assert_not_awaited() + svc._issue_tokens.assert_not_awaited() + db.commit.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_refresh_reuse_revokes_all_sessions(monkeypatch) -> None: + db = _db_mock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + now = datetime.now(timezone.utc) + revoked = _refresh_record( + is_revoked=True, + replaced_by_token_id="rt-2", + rotated_at=now - timedelta(seconds=120), + ) + + monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None)) + monkeypatch.setattr( + repo, + "get_refresh_token_by_token", + AsyncMock(return_value=revoked), + ) + monkeypatch.setattr( + svc, + "_revoke_all_active_tokens_in_session", + AsyncMock(return_value=2), + ) + monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) + + with pytest.raises(AuthError) as exc_info: + await svc.refresh_tokens("old-refresh") + + assert exc_info.value.code == "REFRESH_TOKEN_REUSE" + assert exc_info.value.error_code == "REFRESH_TOKEN_REUSE" + svc._revoke_all_active_tokens_in_session.assert_awaited_once_with("user-1") + db.commit.assert_awaited_once() + svc._issue_tokens.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_refresh_unknown_token(monkeypatch) -> None: + db = _db_mock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + + monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None)) + monkeypatch.setattr(repo, "get_refresh_token_by_token", AsyncMock(return_value=None)) + monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) + + with pytest.raises(AuthError) as exc_info: + await svc.refresh_tokens("missing") + + assert exc_info.value.code == "INVALID_TOKEN" + db.rollback.assert_awaited_once() + svc._issue_tokens.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_refresh_expired_token(monkeypatch) -> None: + db = _db_mock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + expired = _refresh_record(expired=True) + + monkeypatch.setattr(repo, "try_consume_refresh_token", AsyncMock(return_value=None)) + monkeypatch.setattr( + repo, + "get_refresh_token_by_token", + AsyncMock(return_value=expired), + ) + monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) + + with pytest.raises(AuthError) as exc_info: + await svc.refresh_tokens("old-refresh") + + assert exc_info.value.code == "TOKEN_EXPIRED" + db.rollback.assert_awaited_once() + svc._issue_tokens.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_refresh_user_deleted_rolls_back_consume(monkeypatch) -> None: + db = _db_mock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + consumed = _refresh_record() + + monkeypatch.setattr( + repo, + "try_consume_refresh_token", + AsyncMock(return_value=consumed), + ) + monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=None)) + monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) + + with pytest.raises(AuthError) as exc_info: + await svc.refresh_tokens("old-refresh") + + assert exc_info.value.code == "USER_NOT_FOUND" + db.rollback.assert_awaited_once() + svc._issue_tokens.assert_not_awaited() diff --git a/api/tests/test_auth_sms_login_nested_transaction.py b/api/tests/test_auth_sms_login_nested_transaction.py new file mode 100644 index 0000000..9fff350 --- /dev/null +++ b/api/tests/test_auth_sms_login_nested_transaction.py @@ -0,0 +1,103 @@ +"""SMS login user-create race uses transactional_nested savepoint isolation.""" + +from __future__ import annotations + +import uuid +from datetime import timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from app.core.db import utc_now +from app.features.auth import repo +from app.features.auth.models import SmsVerificationCode +from app.features.auth.service import AuthService +from app.features.user.models import User +from tests.support.auth_async_sqlite import auth_async_engine, auth_session_factory + + +@pytest.fixture +def sms_async_engine(auth_async_engine): + return auth_async_engine + + +@pytest.fixture +def sms_session_factory(auth_session_factory): + return auth_session_factory + + +@pytest.mark.asyncio +async def test_sms_login_user_create_race_still_issues_tokens( + sms_session_factory: async_sessionmaker[AsyncSession], + monkeypatch: pytest.MonkeyPatch, +) -> None: + phone = f"138{uuid.uuid4().int % 100_000_000:08d}" + existing = User( + id=str(uuid.uuid4()), + phone=phone, + password_hash="hashed", + nickname="Existing", + subscription_type="free", + created_at=utc_now(), + language_preference="zh", + ) + async with sms_session_factory() as session: + session.add(existing) + await session.commit() + + record = SmsVerificationCode( + id=str(uuid.uuid4()), + phone=phone, + code="123456", + purpose="login", + is_used=False, + is_expired=False, + expires_at=utc_now() + timedelta(minutes=5), + created_at=utc_now(), + ) + + lookup_calls = {"count": 0} + real_get_user_by_phone = repo.get_user_by_phone + + async def get_user_by_phone_side_effect(p: str, db: AsyncSession): + lookup_calls["count"] += 1 + if lookup_calls["count"] == 1: + return None + return await real_get_user_by_phone(p, db) + + async with sms_session_factory() as session: + sms = MagicMock() + svc = AuthService(db=session, sms=sms) + monkeypatch.setattr( + svc, + "_precheck_sms_code_for_purposes", + AsyncMock(return_value=None), + ) + monkeypatch.setattr( + repo, + "get_user_by_phone", + AsyncMock(side_effect=get_user_by_phone_side_effect), + ) + monkeypatch.setattr( + repo, + "try_consume_verification_code", + AsyncMock(return_value=record), + ) + issue_calls: list[str] = [] + + async def fake_issue_tokens(user_id: str, device_info: str = ""): + issue_calls.append(user_id) + return { + "access_token": "access", + "refresh_token": "refresh", + "refresh_token_id": str(uuid.uuid4()), + } + + monkeypatch.setattr(svc, "_issue_tokens", fake_issue_tokens) + + result = await svc.login_with_sms(phone, "123456") + + assert result["access_token"] == "access" + assert issue_calls == [existing.id] + assert result["user"].id == existing.id diff --git a/api/tests/test_auth_sms_rate_limit.py b/api/tests/test_auth_sms_rate_limit.py new file mode 100644 index 0000000..82f0de2 --- /dev/null +++ b/api/tests/test_auth_sms_rate_limit.py @@ -0,0 +1,95 @@ +"""SMS 发送失败后事务回滚,不应留下验证码记录阻塞立即重试。""" + +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.core.errors import ProviderError, RateLimitedError +from app.features.auth import repo +from app.features.auth.models import SmsVerificationCode +from app.features.auth import service as auth_service_mod +from app.features.auth.service import AuthService, CODE_EXPIRE_MINUTES + + +def _make_service(*, sms_send_ok: bool) -> AuthService: + db = MagicMock() + db.commit = AsyncMock(return_value=None) + db.rollback = AsyncMock(return_value=None) + sms = MagicMock() + sms.send_verification_code = MagicMock(return_value=sms_send_ok) + return AuthService(db=db, sms=sms) + + +@pytest.mark.asyncio +async def test_send_sms_after_provider_failure_not_rate_limited(monkeypatch) -> None: + phone = "13800138000" + + async def fake_get_user_by_phone(p: str, db): + return None + + async def fake_get_recent_code_for_rate_limit(p: str, db): + return None + + async def fake_create_verification_code(record, db): + record.id = "new-record" + + expire_calls: list[str] = [] + + async def fake_mark_expired(code_id, db): + expire_calls.append(code_id) + + monkeypatch.setattr(repo, "get_user_by_phone", fake_get_user_by_phone) + monkeypatch.setattr( + repo, "get_recent_code_for_rate_limit", fake_get_recent_code_for_rate_limit + ) + monkeypatch.setattr(repo, "create_verification_code", fake_create_verification_code) + monkeypatch.setattr( + repo, "mark_verification_code_expired", fake_mark_expired + ) + monkeypatch.setattr(auth_service_mod, "_sms_is_configured", lambda: True) + + svc_fail = _make_service(sms_send_ok=False) + with pytest.raises(ProviderError) as exc_info: + await svc_fail.send_sms_code(phone, "register") + assert "失败" in exc_info.value.message + assert exc_info.value.error_code == "PROVIDER_ERROR" + assert exc_info.value.status_code == 502 + assert expire_calls == ["new-record"] + svc_fail._sms.send_verification_code.assert_called_once() + + svc_ok = _make_service(sms_send_ok=True) + success2, message2, expires_in2 = await svc_ok.send_sms_code(phone, "register") + assert success2 is True + assert message2 == "验证码已发送" + assert expires_in2 == CODE_EXPIRE_MINUTES * 60 + + +@pytest.mark.asyncio +async def test_send_sms_rate_limited_raises_rate_limited_error(monkeypatch) -> None: + phone = "13800138000" + now = datetime.now(timezone.utc) + recent = SmsVerificationCode( + id="recent-1", + phone=phone, + code="111111", + purpose="register", + expires_at=now, + created_at=now, + ) + + monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) + monkeypatch.setattr( + repo, "get_recent_code_for_rate_limit", AsyncMock(return_value=recent) + ) + monkeypatch.setattr(auth_service_mod, "_sms_is_configured", lambda: True) + + svc = _make_service(sms_send_ok=True) + with pytest.raises(RateLimitedError) as exc_info: + await svc.send_sms_code(phone, "register") + + assert "频繁" in exc_info.value.message + assert exc_info.value.error_code == "RATE_LIMITED" + assert exc_info.value.status_code == 429 diff --git a/api/tests/test_auth_sms_verify_transactional.py b/api/tests/test_auth_sms_verify_transactional.py new file mode 100644 index 0000000..9d8b3e8 --- /dev/null +++ b/api/tests/test_auth_sms_verify_transactional.py @@ -0,0 +1,371 @@ +"""SMS 验证码原子消耗与业务事务。""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest +from sqlalchemy.exc import IntegrityError + +from app.features.auth import repo +from app.features.auth.models import SmsVerificationCode +from app.features.auth.service import AuthError, AuthService + + +def _sms_record(*, phone: str = "13800138000", code: str = "123456") -> SmsVerificationCode: + now = datetime.now(timezone.utc) + return SmsVerificationCode( + id="code-1", + phone=phone, + code=code, + purpose="login", + is_used=False, + is_expired=False, + expires_at=now + timedelta(minutes=5), + created_at=now, + ) + + +def _phone_integrity_error() -> IntegrityError: + orig = MagicMock() + orig.diag.constraint_name = "ix_users_phone" + return IntegrityError("insert", {}, orig) + + +def _db_mock() -> MagicMock: + db = MagicMock() + db.commit = AsyncMock() + db.rollback = AsyncMock() + db.refresh = AsyncMock() + db.flush = AsyncMock() + return db + + +@pytest.mark.asyncio +async def test_login_with_sms_consumes_code_in_same_transaction_as_tokens(monkeypatch) -> None: + db = MagicMock() + db.commit = AsyncMock() + db.rollback = AsyncMock() + db.refresh = AsyncMock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + record = _sms_record() + + async def fake_check(phone: str, code: str, purpose: str): + if purpose == "login": + return record, "验证成功" + return None, "验证码不存在或已使用" + + issue_calls: list[str] = [] + + async def fake_issue_tokens(user_id: str, device_info: str = ""): + issue_calls.append(user_id) + return {"access_token": "a", "refresh_token": "r"} + + monkeypatch.setattr(svc, "_check_sms_code", fake_check) + monkeypatch.setattr(svc, "_issue_tokens", fake_issue_tokens) + monkeypatch.setattr( + repo, + "get_user_by_phone", + AsyncMock(return_value=MagicMock(id="user-1")), + ) + monkeypatch.setattr( + repo, + "try_consume_verification_code", + AsyncMock(return_value=record), + ) + + result = await svc.login_with_sms("13800138000", "123456") + + assert result["access_token"] == "a" + db.commit.assert_awaited_once() + assert issue_calls == ["user-1"] + repo.try_consume_verification_code.assert_awaited() + + +@pytest.mark.asyncio +async def test_login_with_sms_does_not_issue_tokens_when_consume_fails(monkeypatch) -> None: + db = MagicMock() + db.commit = AsyncMock() + db.rollback = AsyncMock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + + monkeypatch.setattr( + svc, + "_check_sms_code", + AsyncMock(return_value=(_sms_record(), "验证成功")), + ) + monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) + monkeypatch.setattr( + repo, + "get_user_by_phone", + AsyncMock(return_value=MagicMock(id="user-1")), + ) + monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None)) + + with pytest.raises(AuthError): + await svc.login_with_sms("13800138000", "123456") + + db.rollback.assert_awaited_once() + db.commit.assert_not_awaited() + svc._issue_tokens.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_login_with_sms_does_not_issue_tokens_when_token_issue_fails(monkeypatch) -> None: + db = MagicMock() + db.commit = AsyncMock() + db.rollback = AsyncMock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + + monkeypatch.setattr( + svc, + "_check_sms_code", + AsyncMock(return_value=(_sms_record(), "验证成功")), + ) + + async def failing_issue_tokens(user_id: str, device_info: str = ""): + raise RuntimeError("token store down") + + monkeypatch.setattr(svc, "_issue_tokens", failing_issue_tokens) + monkeypatch.setattr( + repo, + "get_user_by_phone", + AsyncMock(return_value=MagicMock(id="user-1")), + ) + monkeypatch.setattr( + repo, + "try_consume_verification_code", + AsyncMock(return_value=_sms_record()), + ) + + with pytest.raises(RuntimeError, match="token store down"): + await svc.login_with_sms("13800138000", "123456") + + db.rollback.assert_awaited_once() + db.commit.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_register_with_sms_uses_atomic_consume(monkeypatch) -> None: + db = _db_mock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + record = _sms_record() + + monkeypatch.setattr( + svc, + "_check_sms_code", + AsyncMock(return_value=(record, "验证成功")), + ) + monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) + monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None)) + monkeypatch.setattr(repo, "create_user", AsyncMock()) + monkeypatch.setattr( + svc, + "_issue_tokens", + AsyncMock(return_value={"access_token": "a", "refresh_token": "r"}), + ) + consume = AsyncMock(return_value=record) + monkeypatch.setattr(repo, "try_consume_verification_code", consume) + + await svc.register_with_sms( + "13800138000", + "123456", + "password1", + "nick", + ) + + consume.assert_awaited_once_with("13800138000", "123456", "register", db) + db.commit.assert_awaited_once() + db.flush.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_register_with_sms_maps_phone_integrity_error_to_phone_exists( + monkeypatch, +) -> None: + db = _db_mock() + db.flush = AsyncMock(side_effect=_phone_integrity_error()) + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + record = _sms_record() + record.purpose = "register" + + monkeypatch.setattr( + svc, + "_check_sms_code", + AsyncMock(return_value=(record, "验证成功")), + ) + monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) + monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None)) + monkeypatch.setattr(repo, "create_user", AsyncMock()) + monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record)) + monkeypatch.setattr(svc, "_issue_tokens", AsyncMock()) + + with pytest.raises(AuthError) as exc_info: + await svc.register_with_sms( + "13800138000", + "123456", + "password1", + "nick", + ) + + assert exc_info.value.code == "PHONE_EXISTS" + db.rollback.assert_awaited_once() + db.commit.assert_not_awaited() + svc._issue_tokens.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_login_with_sms_recovers_when_concurrent_registration_wins( + monkeypatch, +) -> None: + db = _db_mock() + db.flush = AsyncMock(side_effect=_phone_integrity_error()) + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + record = _sms_record() + existing_user = MagicMock(id="existing-user") + + monkeypatch.setattr( + svc, + "_check_sms_code", + AsyncMock(return_value=(record, "验证成功")), + ) + monkeypatch.setattr( + repo, + "get_user_by_phone", + AsyncMock(side_effect=[None, existing_user]), + ) + monkeypatch.setattr(repo, "create_user", AsyncMock()) + monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record)) + + issue_calls: list[str] = [] + + async def fake_issue_tokens(user_id: str, device_info: str = ""): + issue_calls.append(user_id) + return {"access_token": "a", "refresh_token": "r"} + + monkeypatch.setattr(svc, "_issue_tokens", fake_issue_tokens) + + result = await svc.login_with_sms("13800138000", "123456") + + assert result["access_token"] == "a" + assert result["is_new_user"] is False + assert issue_calls == ["existing-user"] + db.commit.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_change_phone_maps_phone_integrity_error_to_phone_taken(monkeypatch) -> None: + db = _db_mock() + db.flush = AsyncMock(side_effect=_phone_integrity_error()) + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + record = _sms_record() + record.purpose = "change_phone" + user = MagicMock(id="user-1", phone="13800138001") + + monkeypatch.setattr( + svc, + "_check_sms_code", + AsyncMock(return_value=(record, "验证成功")), + ) + monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) + monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=user)) + monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=record)) + + with pytest.raises(AuthError) as exc_info: + await svc.change_phone("user-1", "13800138000", "123456") + + assert exc_info.value.code == "PHONE_TAKEN" + db.rollback.assert_awaited_once() + db.commit.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_register_with_sms_consume_race_returns_fresh_invalid_message( + monkeypatch, +) -> None: + db = _db_mock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + + monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None)) + monkeypatch.setattr( + svc, + "_check_sms_code", + AsyncMock(return_value=(None, "验证码不存在或已使用")), + ) + monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) + monkeypatch.setattr(repo, "get_user_by_email", AsyncMock(return_value=None)) + monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None)) + + with pytest.raises(AuthError) as exc_info: + await svc.register_with_sms( + "13800138000", + "123456", + "password1", + "nick", + ) + + assert exc_info.value.code == "INVALID_SMS_CODE" + assert exc_info.value.message != "验证成功" + assert exc_info.value.message == "验证码不存在或已使用" + + +@pytest.mark.asyncio +async def test_reset_password_consume_race_returns_fresh_invalid_message( + monkeypatch, +) -> None: + db = _db_mock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + user = MagicMock(id="user-1") + + monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None)) + monkeypatch.setattr( + svc, + "_check_sms_code", + AsyncMock(return_value=(None, "验证码已过期")), + ) + monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=user)) + monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None)) + + with pytest.raises(AuthError) as exc_info: + await svc.reset_password("13800138000", "123456", "newpass1") + + assert exc_info.value.code == "INVALID_SMS_CODE" + assert exc_info.value.message != "验证成功" + assert exc_info.value.message == "验证码已过期" + + +@pytest.mark.asyncio +async def test_change_phone_consume_race_returns_fresh_invalid_message( + monkeypatch, +) -> None: + db = _db_mock() + sms = MagicMock() + svc = AuthService(db=db, sms=sms) + user = MagicMock(id="user-1", phone="13800138001") + + monkeypatch.setattr(svc, "_precheck_sms_code", AsyncMock(return_value=None)) + monkeypatch.setattr( + svc, + "_check_sms_code", + AsyncMock(return_value=(None, "验证码不存在或已使用")), + ) + monkeypatch.setattr(repo, "get_user_by_phone", AsyncMock(return_value=None)) + monkeypatch.setattr(repo, "get_user_by_id", AsyncMock(return_value=user)) + monkeypatch.setattr(repo, "try_consume_verification_code", AsyncMock(return_value=None)) + + with pytest.raises(AuthError) as exc_info: + await svc.change_phone("user-1", "13800138000", "123456") + + assert exc_info.value.code == "INVALID_SMS_CODE" + assert exc_info.value.message != "验证成功" + assert exc_info.value.message == "验证码不存在或已使用" diff --git a/api/tests/test_avatar_preset_http.py b/api/tests/test_avatar_preset_http.py index 61a1957..87fe4dd 100644 --- a/api/tests/test_avatar_preset_http.py +++ b/api/tests/test_avatar_preset_http.py @@ -18,6 +18,7 @@ from app.features.auth.deps import get_auth_service from app.features.auth.router import router as auth_router from app.features.auth.service import AuthService from app.features.user.models import User +from tests.conftest import install_test_error_handlers def _mock_current_user() -> User: @@ -35,7 +36,7 @@ def _mock_current_user() -> User: @pytest.fixture def preset_auth_app() -> FastAPI: - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(auth_router) fixed_user = _mock_current_user() @@ -197,10 +198,9 @@ async def test_upload_avatar_cos_calls_storage_and_presigns( uid = str(uuid.uuid4()) public = f"https://{bucket}.cos.{region}.myqcloud.com/avatars/{uid}.jpg" for attr, val in ( - ("tencent_cos_secret_id", "sid"), - ("tencent_cos_secret_key", "sk"), + ("tencent_secret_id", "sid"), + ("tencent_secret_key", "sk"), ("tencent_cos_bucket", bucket), - ("tencent_cos_region", region), ("tencent_cos_base_url", f"https://{bucket}.cos.{region}.myqcloud.com"), ): monkeypatch.setattr(settings, attr, val, raising=False) @@ -214,19 +214,27 @@ async def test_upload_avatar_cos_calls_storage_and_presigns( fixed_user.id = uid fixed_user.avatar_url = None - async def _fake_update_avatar(u: str, url: str): + async def _fake_upload_avatar( + user_id: str, + file_content: bytes, + content_type: str, + *, + old_avatar_url: str | None, + ): + _ = (user_id, file_content, content_type, old_avatar_url) + url = mock_storage.upload(f"avatars/{uid}.jpg", b"jpeg", "image/jpeg") fixed_user.avatar_url = url return fixed_user mock_service = MagicMock(spec=AuthService) - mock_service.update_avatar_url = AsyncMock(side_effect=_fake_update_avatar) + mock_service.upload_avatar = AsyncMock(side_effect=_fake_upload_avatar) - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(auth_router) app.dependency_overrides[get_auth_service] = lambda: mock_service app.dependency_overrides[get_current_user] = lambda: fixed_user - transport = ASGITransport(app=app) + transport = ASGITransport(app=app, raise_app_exceptions=False) async with AsyncClient(transport=transport, base_url="http://test") as ac: r = await ac.post( "/api/auth/me/avatar", @@ -235,8 +243,7 @@ async def test_upload_avatar_cos_calls_storage_and_presigns( ) assert r.status_code == 200 - assert r.json()["avatar_url"] == "https://example.com/signed-avatar" + body = r.json() + assert body["avatar_url"] == "https://example.com/signed-avatar" mock_storage.upload.assert_called_once() assert mock_storage.upload.call_args[0][0] == f"avatars/{uid}.jpg" - mock_storage.get_url.assert_called_once() - assert mock_storage.get_url.call_args[0][0] == f"avatars/{uid}.jpg" diff --git a/api/tests/test_background_runner.py b/api/tests/test_background_runner.py index 4a0eede..a0f0bfe 100644 --- a/api/tests/test_background_runner.py +++ b/api/tests/test_background_runner.py @@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, patch import pytest from app.features.memoir import background_runner as br +from app.features.memoir.constants import memoir def test_batch_ready_for_submit_min_chars_zero() -> None: @@ -56,8 +57,8 @@ def test_next_retry_sleep_seconds() -> None: async def test_flush_pending_submits_without_gate( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(br.settings, "memoir_segment_batch_min_chars", 9999) - monkeypatch.setattr(br.settings, "memoir_segment_batch_max_wait_seconds", 9999.0) + monkeypatch.setattr(memoir, "segment_batch_min_chars", 9999) + monkeypatch.setattr(memoir, "segment_batch_max_wait_seconds", 9999.0) submitted: list[tuple[str, list[str]]] = [] @@ -91,8 +92,8 @@ async def test_flush_pending_submits_without_gate( async def test_flush_pending_merges_batch_and_extra_deduped( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(br.settings, "memoir_segment_batch_min_chars", 9999) - monkeypatch.setattr(br.settings, "memoir_segment_batch_max_wait_seconds", 9999.0) + monkeypatch.setattr(memoir, "segment_batch_min_chars", 9999) + monkeypatch.setattr(memoir, "segment_batch_max_wait_seconds", 9999.0) submitted: list[tuple[str, list[str]]] = [] @@ -126,8 +127,8 @@ async def test_flush_pending_merges_batch_and_extra_deduped( async def test_queue_message_min_chars_zero_submits_after_debounce( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(br.settings, "memoir_segment_batch_min_chars", 0) - monkeypatch.setattr(br.settings, "memoir_segment_batch_max_wait_seconds", 60.0) + monkeypatch.setattr(memoir, "segment_batch_min_chars", 0) + monkeypatch.setattr(memoir, "segment_batch_max_wait_seconds", 60.0) submitted: list[tuple[str, list[str]]] = [] @@ -147,8 +148,8 @@ async def test_queue_message_min_chars_zero_submits_after_debounce( async def test_queue_message_not_ready_then_max_wait_submits( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(br.settings, "memoir_segment_batch_min_chars", 100) - monkeypatch.setattr(br.settings, "memoir_segment_batch_max_wait_seconds", 0.12) + monkeypatch.setattr(memoir, "segment_batch_min_chars", 100) + monkeypatch.setattr(memoir, "segment_batch_max_wait_seconds", 0.12) submitted: list[tuple[str, list[str]]] = [] @@ -169,8 +170,8 @@ async def test_queue_message_not_ready_then_max_wait_submits( async def test_queue_message_not_ready_before_debounce_no_submit( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(br.settings, "memoir_segment_batch_min_chars", 100) - monkeypatch.setattr(br.settings, "memoir_segment_batch_max_wait_seconds", 60.0) + monkeypatch.setattr(memoir, "segment_batch_min_chars", 100) + monkeypatch.setattr(memoir, "segment_batch_max_wait_seconds", 60.0) submitted: list[tuple[str, list[str]]] = [] @@ -190,8 +191,8 @@ async def test_queue_message_not_ready_before_debounce_no_submit( async def test_queue_message_chars_met_submits_after_debounce( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(br.settings, "memoir_segment_batch_min_chars", 10) - monkeypatch.setattr(br.settings, "memoir_segment_batch_max_wait_seconds", 60.0) + monkeypatch.setattr(memoir, "segment_batch_min_chars", 10) + monkeypatch.setattr(memoir, "segment_batch_max_wait_seconds", 60.0) submitted: list[tuple[str, list[str]]] = [] diff --git a/api/tests/test_chapter_cover_enqueue_redis.py b/api/tests/test_chapter_cover_enqueue_redis.py new file mode 100644 index 0000000..a2e6bee --- /dev/null +++ b/api/tests/test_chapter_cover_enqueue_redis.py @@ -0,0 +1,76 @@ +import pytest + +from app.tasks import chapter_cover_enqueue as enqueue_mod + + +def test_chapter_cover_enqueue_reuses_sync_redis_client( + monkeypatch: pytest.MonkeyPatch, +) -> None: + class FakeRedis: + def __init__(self) -> None: + self.set_calls = 0 + + def set(self, *args, **kwargs): + self.set_calls += 1 + return True + + def delete(self, key: str) -> None: + return None + + shared_client = FakeRedis() + + def fake_get_sync_redis(*, decode_responses: bool): + assert decode_responses is True + return shared_client + + class Chapter: + category = "childhood" + cover_asset_id = None + story_links = [object()] + images = [] + + class FakeCoverTask: + def delay(self, chapter_id: str) -> None: + return None + + monkeypatch.setattr(enqueue_mod, "get_sync_redis", fake_get_sync_redis) + monkeypatch.setattr( + enqueue_mod, + "_load_chapter_for_enqueue_sync", + lambda chapter_id: Chapter(), + ) + monkeypatch.setattr( + enqueue_mod, + "chapter_has_story_links", + lambda chapter: True, + ) + monkeypatch.setattr( + enqueue_mod, + "effective_chapter_markdown_for_cover_gates", + lambda chapter: "body", + ) + monkeypatch.setattr( + enqueue_mod, + "strip_image_placeholders", + lambda text: text, + ) + monkeypatch.setattr( + enqueue_mod, + "chapter_eligible_for_cover_by_inline_body_image_count", + lambda chapter, markdown: True, + ) + monkeypatch.setattr( + enqueue_mod, + "primary_chapter_memoir_image", + lambda chapter: None, + ) + import app.tasks.chapter_cover_tasks as cover_tasks + + monkeypatch.setattr(cover_tasks, "generate_chapter_cover", FakeCoverTask()) + + ok = enqueue_mod.try_enqueue_generate_chapter_cover("chapter-1", source="http") + assert ok is True + assert shared_client.set_calls == 1 + + enqueue_mod.try_enqueue_generate_chapter_cover("chapter-2", source="http") + assert shared_client.set_calls == 2 diff --git a/api/tests/test_chat_input_normalize.py b/api/tests/test_chat_input_normalize.py index f8d59c1..955ffd3 100644 --- a/api/tests/test_chat_input_normalize.py +++ b/api/tests/test_chat_input_normalize.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch +from app.features.conversation.constants import chat from app.features.conversation.input_normalize import ( apply_conversation_input_rules, normalize_chat_input_for_agent, @@ -15,39 +16,39 @@ def test_apply_conversation_rules_matches_memoir_mei_kanshang() -> None: def test_normalize_chat_rules_mode() -> None: raw = "美看上我" - with patch("app.features.conversation.input_normalize.settings") as m: - m.chat_input_normalize_enabled = True - m.chat_input_normalize_mode = "rules" - m.chat_input_normalize_llm_max_tokens = 512 - m.chat_input_normalize_llm_max_input_chars = 8000 + with patch("app.features.conversation.input_normalize.chat") as m: + m.input_normalize_enabled = True + m.input_normalize_mode = "rules" + m.input_normalize_llm_max_tokens = 512 + m.input_normalize_llm_max_input_chars = 8000 assert normalize_chat_input_for_agent(raw, llm=None) == "没看上我" def test_normalize_chat_disabled_returns_raw() -> None: raw = "美看上我" - with patch("app.features.conversation.input_normalize.settings") as m: - m.chat_input_normalize_enabled = False - m.chat_input_normalize_mode = "rules" + with patch("app.features.conversation.input_normalize.chat") as m: + m.input_normalize_enabled = False + m.input_normalize_mode = "rules" assert normalize_chat_input_for_agent(raw, llm=None) == raw def test_normalize_chat_off_mode() -> None: raw = "美看上我" - with patch("app.features.conversation.input_normalize.settings") as m: - m.chat_input_normalize_enabled = True - m.chat_input_normalize_mode = "off" + with patch("app.features.conversation.input_normalize.chat") as m: + m.input_normalize_enabled = True + m.input_normalize_mode = "off" assert normalize_chat_input_for_agent(raw, llm=None) == raw def test_normalize_llm_mode_voice_only_passes_no_llm_for_typing() -> None: raw = "美看上我" fake = MagicMock() - with patch("app.features.conversation.input_normalize.settings") as m: - m.chat_input_normalize_enabled = True - m.chat_input_normalize_mode = "llm" - m.chat_input_normalize_llm_voice_only = True - m.chat_input_normalize_llm_max_tokens = 512 - m.chat_input_normalize_llm_max_input_chars = 8000 + with patch("app.features.conversation.input_normalize.chat") as m: + m.input_normalize_enabled = True + m.input_normalize_mode = "llm" + m.input_normalize_llm_voice_only = True + m.input_normalize_llm_max_tokens = 512 + m.input_normalize_llm_max_input_chars = 8000 with patch( "app.features.conversation.input_normalize._llm_normalize_chat_input" ) as llm_norm: diff --git a/api/tests/test_chat_stage_detection_gates.py b/api/tests/test_chat_stage_detection_gates.py index 2ccd8af..4392b5f 100644 --- a/api/tests/test_chat_stage_detection_gates.py +++ b/api/tests/test_chat_stage_detection_gates.py @@ -6,6 +6,7 @@ import pytest from app.agents.chat.schemas import StageDetectionOutput from app.agents.chat.stage_detection import detect_primary_life_stage +from app.features.conversation.constants import chat @pytest.mark.asyncio @@ -19,7 +20,7 @@ async def test_short_message_still_calls_stage_llm( return StageDetectionOutput(detected_stage="career") monkeypatch.setattr( - "app.agents.chat.stage_detection.settings.chat_stage_detection_enabled", + "app.agents.chat.stage_detection.chat.stage_detection_enabled", True, ) monkeypatch.setattr( diff --git a/api/tests/test_content_static_http.py b/api/tests/test_content_static_http.py new file mode 100644 index 0000000..2ef4ee9 --- /dev/null +++ b/api/tests/test_content_static_http.py @@ -0,0 +1,37 @@ +"""Content 静态页 HTTP smoke:legal 与官网主页。""" + +from __future__ import annotations + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from app.features.content.router import router as content_router +from tests.conftest import install_test_error_handlers + + +@pytest.fixture +def content_app() -> FastAPI: + app = install_test_error_handlers(FastAPI()) + app.include_router(content_router) + return app + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "path,expected_fragment", + [ + ("/api/legal/terms", "用户协议"), + ("/api/legal/privacy", "隐私"), + ("/", "岁月留书"), + ], +) +async def test_content_static_pages_return_html( + content_app: FastAPI, path: str, expected_fragment: str +) -> None: + transport = ASGITransport(app=content_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get(path) + assert r.status_code == 200 + assert r.headers.get("content-type", "").startswith("text/html") + assert expected_fragment in r.text diff --git a/api/tests/test_conversation_history_list.py b/api/tests/test_conversation_history_list.py new file mode 100644 index 0000000..58ffa79 --- /dev/null +++ b/api/tests/test_conversation_history_list.py @@ -0,0 +1,70 @@ +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.core.redis import RedisService + + +@pytest.mark.asyncio +async def test_add_message_uses_rpush_pipeline() -> None: + service = RedisService() + client = AsyncMock() + pipe = MagicMock() + pipe.execute = AsyncMock(return_value=[]) + client.pipeline = MagicMock(return_value=pipe) + service._client = client + + ok = await service.add_message("conv-1", "user", "hello") + + assert ok is True + pipe.rpush.assert_called_once() + pipe.expire.assert_called_once_with( + "conversation:history:conv-1", service.session_ttl + ) + pipe.execute.assert_awaited_once() + client.get.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_conversation_history_migrates_legacy_string() -> None: + service = RedisService() + client = AsyncMock() + legacy = [{"role": "user", "content": "hi", "messageType": "text"}] + client.exists.return_value = 1 + client.type.return_value = "string" + client.get.return_value = json.dumps(legacy, ensure_ascii=False) + pipe = MagicMock() + pipe.execute = AsyncMock(return_value=[]) + client.pipeline = MagicMock(return_value=pipe) + service._client = client + + history = await service.get_conversation_history("conv-legacy") + + assert history == legacy + pipe.execute.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_append_tts_updates_list_entry() -> None: + service = RedisService() + client = AsyncMock() + item = { + "role": "ai", + "content": "reply", + "messageType": "text", + "timestamp": "2026-01-01T00:00:00+00:00", + } + client.exists.return_value = 1 + client.type.return_value = "list" + client.lrange.return_value = [json.dumps(item, ensure_ascii=False)] + service._client = client + + ok = await service.append_tts_audio_url_to_last_ai_message( + "conv-1", "https://cdn.example/audio.mp3" + ) + + assert ok is True + client.lset.assert_awaited_once() + updated = json.loads(client.lset.await_args.args[2]) + assert updated["ttsAudioUrls"] == ["https://cdn.example/audio.mp3"] diff --git a/api/tests/test_conversation_history_turn_ids.py b/api/tests/test_conversation_history_turn_ids.py index 9fd2183..8507381 100644 --- a/api/tests/test_conversation_history_turn_ids.py +++ b/api/tests/test_conversation_history_turn_ids.py @@ -2,7 +2,8 @@ from __future__ import annotations -from unittest.mock import AsyncMock, MagicMock +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch import pytest from sqlalchemy.ext.asyncio import AsyncSession @@ -10,8 +11,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.features.conversation.history_store import HumanAiTurnIds +@asynccontextmanager +async def _capture_transactional(db): + yield db + await db.commit() + + @pytest.mark.asyncio -async def test_record_human_ai_turn_returns_both_message_ids(monkeypatch) -> None: +async def test_record_human_ai_turn_returns_both_message_ids() -> None: conv_id = "conv-1" captured: list[object] = [] @@ -25,20 +32,19 @@ async def test_record_human_ai_turn_returns_both_message_ids(monkeypatch) -> Non def add_conversation_message(msg: object, db) -> None: captured.append(msg) - monkeypatch.setattr( - "app.features.conversation.history_store.ConversationMessage", - FakeMsg, - ) - db = MagicMock(spec=AsyncSession) db.commit = AsyncMock() - db.refresh = AsyncMock() - import app.features.conversation.history_store as hs_mod - - orig_repo = hs_mod.repo - hs_mod.repo = _FakeRepo # type: ignore[misc] - try: + with patch( + "app.features.conversation.history_store.transactional", + _capture_transactional, + ), patch( + "app.features.conversation.history_store.ConversationMessage", + FakeMsg, + ), patch( + "app.features.conversation.history_store.repo", + _FakeRepo, + ): from app.features.conversation import history_store as hs store = hs.ConversationHistoryStore(db) @@ -57,8 +63,6 @@ async def test_record_human_ai_turn_returns_both_message_ids(monkeypatch) -> Non segment_id="seg-1", memory_retrieval_trace=None, ) - finally: - hs_mod.repo = orig_repo # type: ignore[misc] assert isinstance(out, HumanAiTurnIds) assert len(captured) == 2 @@ -67,3 +71,4 @@ async def test_record_human_ai_turn_returns_both_message_ids(monkeypatch) -> Non assert captured[0].segment_id == "seg-1" assert out.human_message_id == captured[0].id assert out.assistant_message_id == captured[1].id + db.commit.assert_awaited_once() diff --git a/api/tests/test_cors_and_sms_http.py b/api/tests/test_cors_and_sms_http.py new file mode 100644 index 0000000..ebac2da --- /dev/null +++ b/api/tests/test_cors_and_sms_http.py @@ -0,0 +1,122 @@ +"""CORS 默认策略与 SMS 发送失败/限流 HTTP 契约。""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from app.core.config import settings +from app.core.errors import ProviderError, RateLimitedError, ServiceUnavailableError +from app.features.auth.deps import get_auth_service +from app.features.auth.router import router as auth_router +from app.features.auth.service import AuthService +from tests.conftest import install_test_error_handlers + + +def test_default_api_cors_disables_credentials() -> None: + origins = [ + o.strip() for o in (settings.api_cors_origins or "").split(",") if o.strip() + ] + assert origins == [] + assert not bool(origins) + + +def test_explicit_api_cors_enables_credentials() -> None: + origins = [ + o.strip() + for o in "https://app.example.com, https://admin.example.com".split(",") + if o.strip() + ] + assert origins == ["https://app.example.com", "https://admin.example.com"] + assert bool(origins) + + +@pytest.fixture +def sms_app() -> FastAPI: + app = install_test_error_handlers(FastAPI()) + app.include_router(auth_router) + mock_service = MagicMock(spec=AuthService) + app.dependency_overrides[get_auth_service] = lambda: mock_service + app.state._mock_service = mock_service + return app + + +@pytest.mark.asyncio +async def test_sms_send_provider_failure_returns_502_unified_error( + sms_app: FastAPI, +) -> None: + mock_service: MagicMock = sms_app.state._mock_service + mock_service.send_sms_code = AsyncMock( + side_effect=ProviderError("短信发送失败,请稍后重试") + ) + transport = ASGITransport(app=sms_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.post( + "/api/auth/sms/send", + json={"phone": "13800138000", "purpose": "register"}, + ) + assert r.status_code == 502 + body = r.json() + assert body["error_code"] == "PROVIDER_ERROR" + assert body["message"] == "短信发送失败,请稍后重试" + assert "request_id" in body + + +@pytest.mark.asyncio +async def test_sms_send_not_configured_returns_503_unified_error( + sms_app: FastAPI, +) -> None: + mock_service: MagicMock = sms_app.state._mock_service + mock_service.send_sms_code = AsyncMock( + side_effect=ServiceUnavailableError("短信服务未配置,请稍后再试") + ) + transport = ASGITransport(app=sms_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.post( + "/api/auth/sms/send", + json={"phone": "13800138000", "purpose": "register"}, + ) + assert r.status_code == 503 + body = r.json() + assert body["error_code"] == "SERVICE_UNAVAILABLE" + assert body["message"] == "短信服务未配置,请稍后再试" + assert "request_id" in body + + +@pytest.mark.asyncio +async def test_sms_send_rate_limited_returns_429(sms_app: FastAPI) -> None: + mock_service: MagicMock = sms_app.state._mock_service + mock_service.send_sms_code = AsyncMock( + side_effect=RateLimitedError("发送过于频繁,请30秒后再试") + ) + transport = ASGITransport(app=sms_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.post( + "/api/auth/sms/send", + json={"phone": "13800138000", "purpose": "register"}, + ) + assert r.status_code == 429 + body = r.json() + assert body["error_code"] == "RATE_LIMITED" + assert "频繁" in body["message"] + + +@pytest.mark.asyncio +async def test_main_app_cors_preflight_allows_origin_without_credentials() -> None: + from app.main import app + + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.options( + "/health", + headers={ + "Origin": "https://example.com", + "Access-Control-Request-Method": "GET", + }, + ) + assert r.status_code == 200 + assert r.headers.get("access-control-allow-origin") == "*" + assert r.headers.get("access-control-allow-credentials") in (None, "false") diff --git a/api/tests/test_db_transactional.py b/api/tests/test_db_transactional.py new file mode 100644 index 0000000..2829e0a --- /dev/null +++ b/api/tests/test_db_transactional.py @@ -0,0 +1,147 @@ +"""transactional / transactional_sync commit and rollback behavior.""" + +from __future__ import annotations + +from contextlib import asynccontextmanager, contextmanager +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.core.db import ( + transactional, + transactional_nested, + transactional_nested_sync, + transactional_sync, +) + + +@pytest.mark.asyncio +async def test_transactional_commits_on_success() -> None: + session = MagicMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + + async with transactional(session): + pass + + session.commit.assert_awaited_once() + session.rollback.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_transactional_rolls_back_on_error() -> None: + session = MagicMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + + with pytest.raises(RuntimeError, match="boom"): + async with transactional(session): + raise RuntimeError("boom") + + session.commit.assert_not_awaited() + session.rollback.assert_awaited_once() + + +def test_transactional_sync_commits_on_success() -> None: + session = MagicMock() + session.commit = MagicMock() + session.rollback = MagicMock() + + with transactional_sync(session): + pass + + session.commit.assert_called_once() + session.rollback.assert_not_called() + + +def test_transactional_sync_rolls_back_on_error() -> None: + session = MagicMock() + session.commit = MagicMock() + session.rollback = MagicMock() + + with pytest.raises(RuntimeError, match="boom"): + with transactional_sync(session): + raise RuntimeError("boom") + + session.commit.assert_not_called() + session.rollback.assert_called_once() + + +@pytest.mark.asyncio +async def test_transactional_nested_releases_savepoint_on_success() -> None: + session = MagicMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + + @asynccontextmanager + async def fake_begin_nested(): + yield session + + session.begin_nested = MagicMock(return_value=fake_begin_nested()) + + async with transactional_nested(session): + pass + + session.begin_nested.assert_called_once() + session.commit.assert_not_awaited() + session.rollback.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_transactional_nested_rolls_back_savepoint_on_error() -> None: + session = MagicMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + + @asynccontextmanager + async def fake_begin_nested(): + yield session + + session.begin_nested = MagicMock(return_value=fake_begin_nested()) + + with pytest.raises(RuntimeError, match="boom"): + async with transactional_nested(session): + raise RuntimeError("boom") + + session.begin_nested.assert_called_once() + session.commit.assert_not_awaited() + session.rollback.assert_not_awaited() + + +def test_transactional_nested_sync_releases_savepoint_on_success() -> None: + session = MagicMock() + session.commit = MagicMock() + session.rollback = MagicMock() + + @contextmanager + def fake_begin_nested(): + yield session + + session.begin_nested = MagicMock(return_value=fake_begin_nested()) + + with transactional_nested_sync(session): + pass + + session.begin_nested.assert_called_once() + session.commit.assert_not_called() + session.rollback.assert_not_called() + + +def test_transactional_nested_sync_rolls_back_savepoint_on_error() -> None: + session = MagicMock() + session.commit = MagicMock() + session.rollback = MagicMock() + + @contextmanager + def fake_begin_nested(): + yield session + + session.begin_nested = MagicMock(return_value=fake_begin_nested()) + + with pytest.raises(RuntimeError, match="boom"): + with transactional_nested_sync(session): + raise RuntimeError("boom") + + session.begin_nested.assert_called_once() + session.commit.assert_not_called() + session.rollback.assert_not_called() diff --git a/api/tests/test_default_toml_legacy_parity.py b/api/tests/test_default_toml_legacy_parity.py new file mode 100644 index 0000000..1ae0edd --- /dev/null +++ b/api/tests/test_default_toml_legacy_parity.py @@ -0,0 +1,23 @@ +"""Guard: config/default.toml defaults stay aligned with pre-TOML Settings.""" + +from app.core.app_config_loader import default_config_dir, load_app_config + + +def test_default_toml_matches_legacy_settings_defaults() -> None: + cfg = load_app_config("development", config_dir=default_config_dir()) + + assert cfg.chat.interview_persona == "default" + assert cfg.chat.interview_temperature == 0.93 + assert cfg.chat.memory_top_k == 8 + assert cfg.chat.memory_evidence_max_chars == 4096 + assert cfg.chat.reply_planner_llm_enabled is False + + assert cfg.memoir.oral_normalize_mode == "rules" + + assert cfg.story.image_min_body_chars == 400 + + assert cfg.asr.provider == "whisper" + assert cfg.asr.device == "auto" + assert cfg.asr.compute_type == "auto" + + assert cfg.misc.tencent_sms_template_param_count == 2 diff --git a/api/tests/test_dialogue_lineage_memory_ingest.py b/api/tests/test_dialogue_lineage_memory_ingest.py index c303c0a..3a959cb 100644 --- a/api/tests/test_dialogue_lineage_memory_ingest.py +++ b/api/tests/test_dialogue_lineage_memory_ingest.py @@ -3,10 +3,78 @@ from __future__ import annotations from types import SimpleNamespace +from unittest.mock import MagicMock import pytest from app.features.memory.ingest_service import MemoryIngestService +from app.features.memory.constants import memory + + +@pytest.mark.asyncio +async def test_ingest_batch_passes_segment_id(monkeypatch) -> None: + captured: dict = {} + + class FakeSession: + async def commit(self) -> None: + pass + + async def flush(self) -> None: + pass + + async def fake_get(*_args, **_kwargs): + return None + + async def fake_create_source(session, **kwargs): + captured.update(kwargs) + return SimpleNamespace(id="src-1") + + async def fake_create_chunk(*_args, **_kwargs): + return SimpleNamespace(id="ch-0") + + class FakeEmbeddingService: + def __init__(self, *_args, **_kwargs) -> None: + pass + + async def embed_source(self, user_id: str, source_id: str) -> dict: + return {"status": "success", "vectors_written": 1} + + monkeypatch.setattr( + "app.features.memory.ingest_service.get_transcript_source_by_segment_id", + fake_get, + ) + monkeypatch.setattr( + "app.features.memory.ingest_service.create_source", + fake_create_source, + ) + monkeypatch.setattr( + "app.features.memory.ingest_service.create_chunk", + fake_create_chunk, + ) + monkeypatch.setattr( + "app.features.memory.ingest_service.MemoryEmbeddingService", + FakeEmbeddingService, + ) + monkeypatch.setattr("app.features.memory.constants.memory.enrichment_enabled", False) + + lineage = { + "schema_version": 1, + "conversation_id": "c9", + "turns": [{"user_message_id": "um-1", "assistant_message_id": "as-1"}], + "primary_user_message_id": "um-1", + } + service = MemoryIngestService( + FakeSession(), # type: ignore[arg-type] + embedding_provider=None, + enrichment_scheduler=MagicMock(schedule_many=MagicMock(return_value=[])), + ) + ids = await service.ingest_transcripts_batch( + "u1", + [("c9", "hello there", lineage, "seg-9")], + ) + assert ids == ["src-1"] + assert captured.get("segment_id") == "seg-9" + assert captured.get("lineage_json") == lineage @pytest.mark.asyncio @@ -59,7 +127,7 @@ async def test_memory_ingest_passes_lineage(monkeypatch) -> None: "app.features.memory.ingest_service.MemoryEmbeddingService", FakeEmbeddingService, ) - monkeypatch.setattr("app.core.config.settings.memory_enrichment_enabled", False) + monkeypatch.setattr("app.features.memory.constants.memory.enrichment_enabled", False) lineage = { "schema_version": 1, diff --git a/api/tests/test_error_code_registry.py b/api/tests/test_error_code_registry.py new file mode 100644 index 0000000..fba9480 --- /dev/null +++ b/api/tests/test_error_code_registry.py @@ -0,0 +1,120 @@ +"""Ensure runtime error_code values stay within the OpenAPI registry.""" + +from __future__ import annotations + +import inspect +import re +from pathlib import Path + +import pytest + +from app.core.error_codes import ALL_ERROR_CODES, ERROR_CODE_ENUM +from app.core.errors import ( + AppError, + AuthenticationError, + AuthorizationError, + BadRequestError, + ConflictError, + GatewayTimeoutError, + NotFoundError, + ProviderError, + QuotaExceededError, + RateLimitedError, + ServiceUnavailableError, + ValidationError, + _STATUS_TO_ERROR_CODE, +) +from app.features.auth.service import _AUTH_CODE_MAP +from app.features.payment.payment_exceptions import _PAYMENT_CODE_MAP + +_APP_FEATURES_ROOT = Path(__file__).resolve().parents[1] / "app" / "features" +_LITERAL_ERROR_CODE_RE = re.compile(r"""error_code\s*=\s*["']([A-Z][A-Z0-9_]*)["']""") + + +def _app_error_subclass_codes() -> set[str]: + codes: set[str] = set() + for cls in ( + NotFoundError, + BadRequestError, + AuthenticationError, + AuthorizationError, + ValidationError, + ConflictError, + ServiceUnavailableError, + GatewayTimeoutError, + ProviderError, + QuotaExceededError, + RateLimitedError, + ): + sig = inspect.signature(cls.__init__) + default = sig.parameters.get("message") + # Instantiate with defaults to read resolved error_code from AppError base. + instance = cls() + codes.add(instance.error_code) + return codes + + +def _auth_runtime_codes() -> set[str]: + return {external for _, external in _AUTH_CODE_MAP.values()} + + +def _payment_runtime_codes() -> set[str]: + return {external for _, external in _PAYMENT_CODE_MAP.values()} + + +def _literal_feature_error_codes() -> set[str]: + codes: set[str] = set() + for path in _APP_FEATURES_ROOT.rglob("*.py"): + text = path.read_text(encoding="utf-8") + codes.update(_LITERAL_ERROR_CODE_RE.findall(text)) + return codes + + +def _runtime_error_codes() -> set[str]: + return ( + _app_error_subclass_codes() + | _auth_runtime_codes() + | _payment_runtime_codes() + | set(_STATUS_TO_ERROR_CODE.values()) + | _literal_feature_error_codes() + ) + + +def test_runtime_error_codes_are_registered_in_openapi_enum() -> None: + runtime = _runtime_error_codes() + registry = set(ERROR_CODE_ENUM) + missing = runtime - registry + assert not missing, f"Unregistered runtime error_code values: {sorted(missing)}" + + +def test_auth_and_payment_registry_http_status_matches_runtime_maps() -> None: + registry_by_code = {entry["code"]: entry for entry in ALL_ERROR_CODES} + for internal, (status_code, external) in {**_AUTH_CODE_MAP, **_PAYMENT_CODE_MAP}.items(): + if external not in registry_by_code: + continue + entry = registry_by_code[external] + assert entry["http_status"] == status_code, ( + f"{internal} maps to {external} with HTTP {status_code}, " + f"but registry lists {entry['http_status']}" + ) + + +@pytest.mark.parametrize( + "error_cls,expected_code", + [ + (NotFoundError, "NOT_FOUND"), + (BadRequestError, "BAD_REQUEST"), + (AuthenticationError, "AUTHENTICATION_FAILED"), + (AuthorizationError, "FORBIDDEN"), + (ValidationError, "VALIDATION_ERROR"), + (ConflictError, "CONFLICT"), + (ServiceUnavailableError, "SERVICE_UNAVAILABLE"), + (GatewayTimeoutError, "GATEWAY_TIMEOUT"), + (ProviderError, "PROVIDER_ERROR"), + (QuotaExceededError, "QUOTA_EXCEEDED"), + (RateLimitedError, "RATE_LIMITED"), + ], +) +def test_app_error_subclasses_use_registered_codes(error_cls: type[AppError], expected_code: str) -> None: + assert error_cls().error_code == expected_code + assert expected_code in ERROR_CODE_ENUM diff --git a/api/tests/test_eval_judge_llm_spec.py b/api/tests/test_eval_judge_llm_spec.py index a07947d..f2e4fa2 100644 --- a/api/tests/test_eval_judge_llm_spec.py +++ b/api/tests/test_eval_judge_llm_spec.py @@ -4,6 +4,7 @@ import pytest from app.core.config import settings from app.core.dependencies import build_eval_judge_llm_spec +from app.features.evaluation.constants import eval_cfg from app.features.evaluation.judge_service import ( eval_judge_compare_transcript_each_max_chars_for_context, eval_judge_conversation_transcript_max_chars_for_context, @@ -13,22 +14,21 @@ from app.features.evaluation.judge_service import ( def test_build_eval_judge_zhipu_uses_bigmodel_defaults( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "eval_judge_api_key", "") monkeypatch.setattr(settings, "zhipu_api_key", "z-test") - monkeypatch.setattr(settings, "eval_judge_model", "glm-5") + monkeypatch.setattr(eval_cfg, "judge_model", "glm-5") spec = build_eval_judge_llm_spec("zhipu", None) assert spec is not None assert spec.provider == "zhipu" assert spec.resolved_model == "glm-5" assert spec.llm is not None - assert spec.context_window_tokens == settings.eval_judge_context_window_tokens + assert spec.context_window_tokens == eval_cfg.judge_context_window_tokens def test_build_eval_judge_zhipu_request_model_override( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "eval_judge_api_key", "e-test") - monkeypatch.setattr(settings, "eval_judge_model", "glm-5") + monkeypatch.setattr(settings, "zhipu_api_key", "e-test") + monkeypatch.setattr(eval_cfg, "judge_model", "glm-5") spec = build_eval_judge_llm_spec("zhipu", "glm-4-plus") assert spec is not None assert spec.resolved_model == "glm-4-plus" @@ -38,7 +38,6 @@ def test_build_eval_judge_deepseek_requires_key( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setattr(settings, "deepseek_api_key", "") - monkeypatch.setattr(settings, "llm_api_key", "") assert build_eval_judge_llm_spec("deepseek", None) is None @@ -47,8 +46,8 @@ def test_build_eval_judge_deepseek_v4_flash_non_thinking_default_path( ) -> None: """默认 deepseek-v4-flash 且关闭 thinking 时显式传 disabled(避免 API 默认 enabled)。""" monkeypatch.setattr(settings, "deepseek_api_key", "d-test") - monkeypatch.setattr(settings, "eval_judge_deepseek_model", "deepseek-v4-flash") - monkeypatch.setattr(settings, "eval_judge_deepseek_thinking_enabled", False) + monkeypatch.setattr(eval_cfg, "judge_deepseek_model", "deepseek-v4-flash") + monkeypatch.setattr(eval_cfg, "judge_deepseek_thinking_enabled", False) spec = build_eval_judge_llm_spec("deepseek", None) assert spec is not None assert spec.resolved_model == "deepseek-v4-flash" @@ -60,8 +59,8 @@ def test_build_eval_judge_deepseek_context_budget( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setattr(settings, "deepseek_api_key", "d-test") - monkeypatch.setattr(settings, "eval_judge_deepseek_model", "deepseek-reasoner") - monkeypatch.setattr(settings, "eval_judge_deepseek_context_window_tokens", 64_000) + monkeypatch.setattr(eval_cfg, "judge_deepseek_model", "deepseek-reasoner") + monkeypatch.setattr(eval_cfg, "judge_deepseek_context_window_tokens", 64_000) spec = build_eval_judge_llm_spec("deepseek", None) assert spec is not None assert spec.provider == "deepseek" diff --git a/api/tests/test_fidelity_gate.py b/api/tests/test_fidelity_gate.py index 9c099ad..ea4f17c 100644 --- a/api/tests/test_fidelity_gate.py +++ b/api/tests/test_fidelity_gate.py @@ -9,13 +9,14 @@ import pytest from app.agents.memoir.fidelity_check_agent import FidelityCheckAgent from app.core.config import settings from app.core.llm_call import LLMCallError +from app.features.memoir.constants import memoir def test_fidelity_fail_closed_on_parse_when_not_append( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_fidelity_check_enabled", True) - monkeypatch.setattr(settings, "memoir_fidelity_fail_open_on_parse_error", False) + monkeypatch.setattr(memoir, "fidelity_check_enabled", True) + monkeypatch.setattr(memoir, "fidelity_fail_open_on_parse_error", False) agent = FidelityCheckAgent() llm = MagicMock() with patch( @@ -36,8 +37,8 @@ def test_fidelity_fail_closed_on_parse_when_not_append( def test_fidelity_fail_open_on_parse_when_append( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_fidelity_check_enabled", True) - monkeypatch.setattr(settings, "memoir_fidelity_fail_open_on_parse_error", False) + monkeypatch.setattr(memoir, "fidelity_check_enabled", True) + monkeypatch.setattr(memoir, "fidelity_fail_open_on_parse_error", False) agent = FidelityCheckAgent() llm = MagicMock() with patch( @@ -58,8 +59,8 @@ def test_fidelity_fail_open_on_parse_when_append( def test_fidelity_fail_open_global_flag_overrides_append( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_fidelity_check_enabled", True) - monkeypatch.setattr(settings, "memoir_fidelity_fail_open_on_parse_error", True) + monkeypatch.setattr(memoir, "fidelity_check_enabled", True) + monkeypatch.setattr(memoir, "fidelity_fail_open_on_parse_error", True) agent = FidelityCheckAgent() llm = MagicMock() with patch( diff --git a/api/tests/test_history_store_transactional.py b/api/tests/test_history_store_transactional.py new file mode 100644 index 0000000..ec6838f --- /dev/null +++ b/api/tests/test_history_store_transactional.py @@ -0,0 +1,141 @@ +"""ConversationHistoryStore transactional boundaries.""" + +from __future__ import annotations + +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.conversation.history_store import ConversationHistoryStore + + +@asynccontextmanager +async def _capture_transactional(db): + yield db + await db.commit() + + +@pytest.mark.asyncio +async def test_record_ai_only_turn_commits_before_redis_sync() -> None: + db = MagicMock(spec=AsyncSession) + db.commit = AsyncMock() + redis_sync = AsyncMock() + captured: list[object] = [] + + class FakeMsg: + def __init__(self, **kwargs) -> None: + self.id = "ai-1" + for k, v in kwargs.items(): + setattr(self, k, v) + + class _FakeRepo: + @staticmethod + def add_conversation_message(msg: object, _db) -> None: + captured.append(msg) + + with patch( + "app.features.conversation.history_store.transactional", + _capture_transactional, + ), patch( + "app.features.conversation.history_store.ConversationMessage", + FakeMsg, + ), patch( + "app.features.conversation.history_store.repo", + _FakeRepo, + ): + store = ConversationHistoryStore(db) + store._sync_redis_best_effort = redis_sync # type: ignore[method-assign] + store._touch_conversation = AsyncMock() # type: ignore[method-assign] + + msg_id = await store.record_ai_only_turn("conv-1", ["hello"]) + + assert msg_id is not None + assert len(captured) == 1 + assert captured[0].id == msg_id + db.commit.assert_awaited_once() + redis_sync.assert_awaited_once_with("conv-1") + + +@pytest.mark.asyncio +async def test_attach_ai_tts_commits_repo_update_before_redis_sync() -> None: + db = MagicMock(spec=AsyncSession) + db.commit = AsyncMock() + redis_sync = AsyncMock() + repo_calls: list[tuple] = [] + + async def fake_set_latest(*args, **kwargs): + repo_calls.append((args, kwargs)) + return object() + + with patch( + "app.features.conversation.history_store.transactional", + _capture_transactional, + ), patch( + "app.features.conversation.history_store.repo.set_latest_ai_message_tts_audio_urls", + fake_set_latest, + ): + store = ConversationHistoryStore(db) + store._sync_redis_best_effort = redis_sync # type: ignore[method-assign] + + await store.attach_ai_tts_audio_urls( + "conv-1", + tts_audio_urls=["https://example.com/a.mp3"], + segment_id="seg-1", + ) + + assert len(repo_calls) == 1 + db.commit.assert_awaited_once() + redis_sync.assert_awaited_once_with("conv-1") + + +@pytest.mark.asyncio +async def test_record_human_ai_turn_commits_pair_before_redis_sync() -> None: + db = MagicMock(spec=AsyncSession) + db.commit = AsyncMock() + redis_sync = AsyncMock() + captured: list[object] = [] + + class FakeMsg: + def __init__(self, **kwargs) -> None: + self.id = kwargs.get("id") or f"msg-{len(captured)}" + for k, v in kwargs.items(): + setattr(self, k, v) + + class _FakeRepo: + @staticmethod + def add_conversation_message(msg: object, _db) -> None: + captured.append(msg) + + with patch( + "app.features.conversation.history_store.transactional", + _capture_transactional, + ), patch( + "app.features.conversation.history_store.ConversationMessage", + FakeMsg, + ), patch( + "app.features.conversation.history_store.repo", + _FakeRepo, + ): + store = ConversationHistoryStore(db) + store._sync_redis_best_effort = redis_sync # type: ignore[method-assign] + store._touch_conversation = AsyncMock() # type: ignore[method-assign] + + out = await store.record_human_ai_turn( + "conv-1", + "hello", + ["reply"], + user_message_timestamp=datetime.now(timezone.utc), + is_from_voice=False, + voice_session_id=None, + audio_duration_seconds=None, + tts_audio_urls=None, + segment_id="seg-1", + ) + + assert out is not None + assert len(captured) == 2 + db.commit.assert_awaited_once() + redis_sync.assert_awaited_once_with("conv-1") diff --git a/api/tests/test_http_contract_errors.py b/api/tests/test_http_contract_errors.py index bddf8ae..bc498d6 100644 --- a/api/tests/test_http_contract_errors.py +++ b/api/tests/test_http_contract_errors.py @@ -7,6 +7,8 @@ import pytest from httpx import ASGITransport, AsyncClient from app.core.dependencies import get_current_user +from app.core.errors import register_exception_handlers +from app.core.middleware import RequestIdMiddleware from app.features.auth.deps import get_auth_service from app.features.auth.router import router as auth_router from app.features.payment.deps import get_payment_order_service @@ -22,6 +24,8 @@ async def test_wechat_notify_returns_fixed_message_on_service_error() -> None: raise RuntimeError("wechat_sdk_secret_123") app = FastAPI() + app.add_middleware(RequestIdMiddleware) + register_exception_handlers(app) app.include_router(payment_router) app.dependency_overrides[get_payment_order_service] = lambda: BoomOrderService() @@ -62,12 +66,9 @@ async def test_avatar_upload_500_detail_sanitized( return_value="https://test-bucket.cos.ap-shanghai.myqcloud.com/avatars/user-contract-test/abc.jpg" ) monkeypatch.setattr(deps, "get_object_storage", lambda: mock_storage) - monkeypatch.setattr(settings, "tencent_cos_secret_id", "sid", raising=False) - monkeypatch.setattr(settings, "tencent_cos_secret_key", "sk", raising=False) + monkeypatch.setattr(settings, "tencent_secret_id", "sid", raising=False) + monkeypatch.setattr(settings, "tencent_secret_key", "sk", raising=False) monkeypatch.setattr(settings, "tencent_cos_bucket", "test-bucket", raising=False) - monkeypatch.setattr( - settings, "tencent_cos_region", "ap-shanghai", raising=False - ) monkeypatch.setattr( settings, "tencent_cos_base_url", @@ -76,22 +77,25 @@ async def test_avatar_upload_500_detail_sanitized( ) class BoomAuth: - async def update_avatar_url(self, user_id: str, avatar_url: str): + async def upload_avatar(self, user_id, file_content, content_type, **kwargs): raise RuntimeError("db_connection_secret_xyz") app = FastAPI() + app.add_middleware(RequestIdMiddleware) + register_exception_handlers(app) app.include_router(auth_router) app.dependency_overrides[get_current_user] = lambda: fake_user app.dependency_overrides[get_auth_service] = lambda: BoomAuth() - transport = ASGITransport(app=app) + transport = ASGITransport(app=app, raise_app_exceptions=False) files = {"file": ("a.jpg", BytesIO(_minimal_jpeg_bytes()), "image/jpeg")} async with AsyncClient(transport=transport, base_url="http://test") as client: r = await client.post("/api/auth/me/avatar", files=files) assert r.status_code == 500 body = r.json() - detail = body.get("detail", "") - assert detail == "处理图片失败,请重试" - assert "secret" not in str(detail).lower() + assert body.get("error_code") == "INTERNAL_ERROR" + message = body.get("message", "") + assert message == "服务器内部错误" + assert "secret" not in str(message).lower() assert "db_connection" not in r.text diff --git a/api/tests/test_http_router_error_contract.py b/api/tests/test_http_router_error_contract.py new file mode 100644 index 0000000..d6c114b --- /dev/null +++ b/api/tests/test_http_router_error_contract.py @@ -0,0 +1,139 @@ +"""真实 feature router 的错误契约 HTTP 场景测试。""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient + +from app.core.errors import BadRequestError, NotFoundError +from app.features.auth.deps import get_auth_service +from app.features.auth.router import router as auth_router +from app.features.auth.service import AuthError +from app.features.conversation.deps import get_conversation_service +from app.features.conversation.router import router as conversation_router +from tests.conftest import install_test_error_handlers + + +@pytest.mark.asyncio +async def test_auth_register_validation_returns_unified_422() -> None: + app = install_test_error_handlers(FastAPI()) + app.include_router(auth_router) + + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.post( + "/api/auth/register", + json={"phone": "123", "password": "x", "nickname": ""}, + ) + + assert r.status_code == 422 + body = r.json() + assert body["error_code"] == "VALIDATION_ERROR" + assert "request_id" in body + assert body["message"] + + +@pytest.mark.asyncio +async def test_conversation_list_requires_auth_unified_401() -> None: + app = install_test_error_handlers(FastAPI()) + app.include_router(conversation_router) + + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/api/conversations") + + assert r.status_code == 401 + body = r.json() + assert body["error_code"] == "AUTHENTICATION_FAILED" + assert r.headers.get("www-authenticate") == "Bearer" + + +@pytest.mark.asyncio +async def test_conversation_detail_not_found_unified_404() -> None: + app = install_test_error_handlers(FastAPI()) + app.include_router(conversation_router) + + mock_service = MagicMock() + mock_service.get_one = AsyncMock( + side_effect=NotFoundError("Conversation not found") + ) + app.dependency_overrides[get_conversation_service] = lambda: mock_service + + fake_user = MagicMock() + fake_user.id = "user-1" + from app.core.dependencies import get_current_user + + app.dependency_overrides[get_current_user] = lambda: fake_user + + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/api/conversations/conv-missing") + + assert r.status_code == 404 + body = r.json() + assert body["error_code"] == "NOT_FOUND" + assert body["message"] == "Conversation not found" + + +@pytest.mark.asyncio +async def test_auth_login_invalid_credentials_unified_401() -> None: + app = install_test_error_handlers(FastAPI()) + app.include_router(auth_router) + + mock_service = MagicMock() + mock_service.login = AsyncMock( + side_effect=AuthError("手机号或密码错误", "INVALID_CREDENTIALS") + ) + app.dependency_overrides[get_auth_service] = lambda: mock_service + + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.post( + "/api/auth/login", + json={ + "phone": "13800138000", + "password": "wrong-password", + "agreed_to_terms": True, + }, + ) + + assert r.status_code == 401 + body = r.json() + assert body["error_code"] == "AUTHENTICATION_FAILED" + assert body["message"] == "手机号或密码错误" + assert "request_id" in body + assert r.headers.get("www-authenticate") == "Bearer" + + +@pytest.mark.asyncio +async def test_auth_avatar_invalid_image_unified_400() -> None: + app = install_test_error_handlers(FastAPI()) + app.include_router(auth_router) + + mock_service = MagicMock() + mock_service.upload_avatar = AsyncMock( + side_effect=BadRequestError("无效的图片文件格式。文件头: deadbeef") + ) + app.dependency_overrides[get_auth_service] = lambda: mock_service + + fake_user = MagicMock() + fake_user.id = "user-1" + fake_user.avatar_url = None + from app.core.dependencies import get_current_user + + app.dependency_overrides[get_current_user] = lambda: fake_user + + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.post( + "/api/auth/me/avatar", + files={"file": ("x.png", b"not-an-image", "image/png")}, + ) + + assert r.status_code == 400 + body = r.json() + assert body["error_code"] == "BAD_REQUEST" + assert "无效的图片" in body["message"] diff --git a/api/tests/test_image_prompt_policy.py b/api/tests/test_image_prompt_policy.py index 005e103..72f8e13 100644 --- a/api/tests/test_image_prompt_policy.py +++ b/api/tests/test_image_prompt_policy.py @@ -6,6 +6,7 @@ import pytest from app.features.memoir.memoir_images.prompting import MemoirImagePromptService from app.features.memoir.memoir_images.settings import MemoirImageSettings +from app.features.memoir.constants import memoir def _svc(llm=None) -> MemoirImagePromptService: @@ -26,7 +27,7 @@ def test_story_primary_fallback_uses_placeholder_when_llm_disabled(): def test_story_primary_fallback_disabled_requires_brief(monkeypatch): monkeypatch.setattr( - "app.features.memoir.memoir_images.prompting.settings.image_prompt_fallback_disabled", + "app.features.memoir.memoir_images.prompting.memoir.image_prompt_fallback_disabled", True, ) with pytest.raises(RuntimeError, match="prompt_brief"): @@ -46,7 +47,7 @@ def test_story_primary_style_from_chat_stage_when_no_intent_style(): def test_story_primary_fallback_disabled_requires_llm(monkeypatch): monkeypatch.setattr( - "app.features.memoir.memoir_images.prompting.settings.image_prompt_fallback_disabled", + "app.features.memoir.memoir_images.prompting.memoir.image_prompt_fallback_disabled", True, ) with pytest.raises(RuntimeError, match="requires LLM"): @@ -99,7 +100,7 @@ def test_cover_fallback_uses_service_template_when_llm_disabled(): def test_cover_fallback_when_excerpt_empty(monkeypatch): monkeypatch.setattr( - "app.features.memoir.memoir_images.prompting.settings.image_prompt_fallback_disabled", + "app.features.memoir.memoir_images.prompting.memoir.image_prompt_fallback_disabled", False, ) out = _svc(llm=None).build_cover_prompt( @@ -112,7 +113,7 @@ def test_cover_fallback_when_excerpt_empty(monkeypatch): def test_cover_fallback_disabled_requires_excerpt(monkeypatch): monkeypatch.setattr( - "app.features.memoir.memoir_images.prompting.settings.image_prompt_fallback_disabled", + "app.features.memoir.memoir_images.prompting.memoir.image_prompt_fallback_disabled", True, ) with pytest.raises(RuntimeError, match="context_excerpt"): @@ -131,7 +132,7 @@ def test_image_prompt_orchestrator_provider_failure_uses_fallback(monkeypatch): raise RuntimeError("provider missing") monkeypatch.setattr( - "app.agents.image_prompt.orchestrator.settings.image_prompt_fallback_disabled", + "app.agents.image_prompt.orchestrator.memoir.image_prompt_fallback_disabled", False, ) monkeypatch.setattr("app.core.llm_gateway.LlmGateway", lambda: BoomGateway()) @@ -155,7 +156,7 @@ def test_image_prompt_orchestrator_provider_failure_raises_when_disabled( raise RuntimeError("provider missing") monkeypatch.setattr( - "app.agents.image_prompt.orchestrator.settings.image_prompt_fallback_disabled", + "app.agents.image_prompt.orchestrator.memoir.image_prompt_fallback_disabled", True, ) monkeypatch.setattr("app.core.llm_gateway.LlmGateway", lambda: BoomGateway()) diff --git a/api/tests/test_infra_regressions.py b/api/tests/test_infra_regressions.py index 13c6337..61b7a29 100644 --- a/api/tests/test_infra_regressions.py +++ b/api/tests/test_infra_regressions.py @@ -42,26 +42,19 @@ def test_chapter_pipeline_lock_delegates_to_token_lock( def test_post_commit_reuses_singleton_redis_client( monkeypatch: pytest.MonkeyPatch, ) -> None: - created: list[object] = [] + client = object() - class FakeRedis: - pass - - def fake_from_url(url: str, *, decode_responses: bool) -> FakeRedis: - client = FakeRedis() - created.append(client) + def fake_get_sync_redis(*, decode_responses: bool): assert decode_responses is True - assert url return client - monkeypatch.setattr(post_commit, "_redis_client", None) - monkeypatch.setattr(post_commit.redis, "from_url", fake_from_url) + monkeypatch.setattr(post_commit, "get_sync_redis", fake_get_sync_redis) first = post_commit._get_redis() second = post_commit._get_redis() assert first is second - assert created == [first] + assert first is client @pytest.mark.asyncio diff --git a/api/tests/test_interview_turn_plan.py b/api/tests/test_interview_turn_plan.py index 7e03d35..05fc942 100644 --- a/api/tests/test_interview_turn_plan.py +++ b/api/tests/test_interview_turn_plan.py @@ -296,3 +296,11 @@ def test_build_topic_chips_english_uses_slot_name_map_en(): place_chip = next(c for c in chips if c["id"] == "place") assert place_chip["label"] == "where you grew up" assert place_chip["text"].startswith("I'd like to talk about") + + +def test_build_topic_chips_belief_stage_not_empty(): + from app.agents.chat.prompts_conversation import build_topic_chips + + chips = build_topic_chips("belief", [], max_chips=4, language="en") + assert len(chips) == 4 + assert chips[0]["label"] diff --git a/api/tests/test_judge_service.py b/api/tests/test_judge_service.py index 65b1b77..9e090ab 100644 --- a/api/tests/test_judge_service.py +++ b/api/tests/test_judge_service.py @@ -4,6 +4,7 @@ import pytest from app.core.config import settings from app.core.llm_call import LLMCallError +from app.features.evaluation.constants import eval_cfg from app.features.evaluation.conversation_compare_summary import ( build_conversation_compare_summary, ) @@ -91,8 +92,8 @@ async def test_judge_conversation_wrapper_keeps_legacy_shape( def test_eval_judge_transcript_budget_exceeds_legacy_8192( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "eval_judge_max_transcript_chars", 0) - monkeypatch.setattr(settings, "eval_judge_context_window_tokens", 200_000) + monkeypatch.setattr(eval_cfg, "judge_max_transcript_chars", 0) + monkeypatch.setattr(eval_cfg, "judge_context_window_tokens", 200_000) n = eval_judge_conversation_transcript_max_chars() assert n > 90_000 each = eval_judge_compare_transcript_each_max_chars() @@ -127,7 +128,7 @@ def test_trim_compare_transcript_pair_prefers_trimming_longer_side() -> None: def test_eval_judge_transcript_budget_respects_explicit_cap( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "eval_judge_max_transcript_chars", 12_000) + monkeypatch.setattr(eval_cfg, "judge_max_transcript_chars", 12_000) assert eval_judge_conversation_transcript_max_chars() == 12_000 diff --git a/api/tests/test_main_app_smoke.py b/api/tests/test_main_app_smoke.py new file mode 100644 index 0000000..1b4f341 --- /dev/null +++ b/api/tests/test_main_app_smoke.py @@ -0,0 +1,30 @@ +"""Production FastAPI app smoke tests (handlers + middleware + routers).""" + +from __future__ import annotations + +import pytest +from httpx import ASGITransport, AsyncClient + +from app.main import app + + +@pytest.mark.asyncio +async def test_main_app_health() -> None: + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/health") + assert r.status_code == 200 + assert r.json() == {"status": "ok"} + + +@pytest.mark.asyncio +async def test_main_app_unauthenticated_conversations_401() -> None: + transport = ASGITransport(app=app, raise_app_exceptions=False) + async with AsyncClient(transport=transport, base_url="http://test") as client: + r = await client.get("/api/conversations") + assert r.status_code == 401 + body = r.json() + assert body["error_code"] == "AUTHENTICATION_FAILED" + assert "request_id" in body + assert r.headers.get("x-request-id") + assert r.headers.get("www-authenticate") == "Bearer" diff --git a/api/tests/test_memoir_phase1_ingest_idempotency.py b/api/tests/test_memoir_phase1_ingest_idempotency.py new file mode 100644 index 0000000..9eb8e5b --- /dev/null +++ b/api/tests/test_memoir_phase1_ingest_idempotency.py @@ -0,0 +1,250 @@ +"""Memoir queue scheduling and phase1 memory ingest idempotency.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.conversation.models import Conversation, Segment +from app.features.user.models import User +from app.features.conversation.ws import pipeline as ws_pipeline +from app.features.memory.ingest_service import MemoryIngestService +from app.tasks import memoir_tasks +from sqlalchemy.exc import IntegrityError + + +@pytest.mark.asyncio +async def test_process_user_message_queues_memoir_after_lineage() -> None: + db = MagicMock(spec=AsyncSession) + user = User(id="u1", language_preference="zh") + conversation = Conversation(id="conv-1", user_id="u1") + segment = Segment( + id="seg-1", + conversation_id="conv-1", + user_input_text="hello", + processed=False, + ) + call_order: list[str] = [] + + async def _record_turn(*_args, **_kwargs): + call_order.append("record_turn") + segment.lineage_json = {"primary_user_message_id": "hum-1"} + return SimpleNamespace( + human_message_id="hum-1", + assistant_message_id="ai-1", + ) + + async def _queue_segment(*_args, **_kwargs): + call_order.append("queue_segment") + + turn = SimpleNamespace( + messages=["reply"], + skip_tts=True, + memory_retrieval_trace=None, + ) + + with ( + patch.object( + ws_pipeline.chat_turn_service, + "process_turn", + AsyncMock(return_value=turn), + ), + patch.object( + ws_pipeline.ConversationHistoryStore, + "record_human_ai_turn_with_segment", + AsyncMock(side_effect=_record_turn), + ), + patch.object( + ws_pipeline, + "_schedule_memoir_ingest_for_segment", + AsyncMock(side_effect=_queue_segment), + ), + patch.object(ws_pipeline.manager, "active_connections", {}), + patch.object(ws_pipeline.manager, "send_message", AsyncMock()), + patch.object( + ws_pipeline, + "get_or_create_state", + AsyncMock(return_value=SimpleNamespace(current_stage="childhood")), + ), + patch.object(ws_pipeline, "maybe_send_topic_chips_ws", AsyncMock()), + ): + await ws_pipeline.process_user_message( + "conv-1", + "hello", + conversation, + segment, + db, + user=user, + user_message_timestamp=datetime.now(timezone.utc), + ) + + assert call_order == ["record_turn", "queue_segment"] + + +@pytest.mark.asyncio +async def test_process_user_message_queues_memoir_on_ai_failure() -> None: + db = MagicMock(spec=AsyncSession) + user = User(id="u1", language_preference="zh") + conversation = Conversation(id="conv-1", user_id="u1") + segment = Segment( + id="seg-1", + conversation_id="conv-1", + user_input_text="hello", + processed=False, + ) + queued = False + + async def _queue_segment(*_args, **_kwargs): + nonlocal queued + queued = True + + turn = SimpleNamespace(messages=[], skip_tts=True, memory_retrieval_trace=None) + + with ( + patch.object( + ws_pipeline.chat_turn_service, + "process_turn", + AsyncMock(return_value=turn), + ), + patch.object( + ws_pipeline.ConversationHistoryStore, + "record_human_ai_turn_with_segment", + AsyncMock(return_value=None), + ), + patch.object( + ws_pipeline, + "_schedule_memoir_ingest_for_segment", + AsyncMock(side_effect=_queue_segment), + ), + patch.object(ws_pipeline.manager, "active_connections", {"conv-1": MagicMock()}), + patch.object(ws_pipeline.manager, "send_message", AsyncMock()), + ): + await ws_pipeline.process_user_message( + "conv-1", + "hello", + conversation, + segment, + db, + user=user, + user_message_timestamp=datetime.now(timezone.utc), + ) + + assert queued is True + + +@pytest.mark.asyncio +async def test_ingest_batch_idempotent_by_segment_id(monkeypatch) -> None: + stored: dict[str, SimpleNamespace] = {} + create_calls = 0 + + class FakeSession: + async def commit(self) -> None: + pass + + async def flush(self) -> None: + pass + + async def fake_get(db, *, user_id: str, segment_id: str): + return stored.get(segment_id) + + async def fake_create_source(session, **kwargs): + nonlocal create_calls + create_calls += 1 + sid = kwargs["segment_id"] + src = SimpleNamespace(id=f"src-{create_calls}", segment_id=sid) + stored[sid] = src + return src + + async def fake_create_chunk(*_args, **kwargs): + return SimpleNamespace(id=f"ch-{kwargs.get('chunk_index')}") + + class FakeEmbeddingService: + def __init__(self, *_args, **_kwargs) -> None: + pass + + async def embed_source(self, user_id: str, source_id: str) -> dict: + return {"status": "success", "vectors_written": 1} + + monkeypatch.setattr( + "app.features.memory.ingest_service.get_transcript_source_by_segment_id", + fake_get, + ) + monkeypatch.setattr( + "app.features.memory.ingest_service.create_source", + fake_create_source, + ) + monkeypatch.setattr( + "app.features.memory.ingest_service.create_chunk", + fake_create_chunk, + ) + monkeypatch.setattr( + "app.features.memory.ingest_service.MemoryEmbeddingService", + FakeEmbeddingService, + ) + monkeypatch.setattr("app.features.memory.constants.memory.enrichment_enabled", False) + + service = MemoryIngestService( + FakeSession(), # type: ignore[arg-type] + embedding_provider=None, + enrichment_scheduler=MagicMock(schedule_many=MagicMock(return_value=[])), + ) + items = [ + ("c1", "hello", {"primary_user_message_id": "m1"}, "seg-1"), + ] + first = await service.ingest_transcripts_batch("u1", items) + second = await service.ingest_transcripts_batch("u1", items) + + assert first == ["src-1"] + assert second == ["src-1"] + assert create_calls == 1 + + +def test_phase1_memory_ingest_batch_sync_reraises_on_failure(monkeypatch) -> None: + async def _fail(*_args, **_kwargs): + raise RuntimeError("ingest unavailable") + + monkeypatch.setattr(memoir_tasks, "_memory_ingest_transcripts_batch", _fail) + db = MagicMock() + items = [("c1", "hello", None, "seg-1")] + + with pytest.raises(RuntimeError, match="ingest unavailable"): + memoir_tasks._phase1_memory_ingest_batch_sync( + db, + "u1", + items, + memoir_correlation_id="corr-1", + ) + + +def test_phase1_memory_ingest_batch_sync_resolves_integrity_race(monkeypatch) -> None: + async def _race(*_args, **_kwargs): + raise IntegrityError("insert", {}, Exception("unique")) + + monkeypatch.setattr(memoir_tasks, "_memory_ingest_transcripts_batch", _race) + db = MagicMock() + existing = SimpleNamespace(id="src-existing") + + def _lookup(_db, *, user_id: str, segment_id: str): + assert user_id == "u1" + assert segment_id == "seg-1" + return existing + + monkeypatch.setattr( + memoir_tasks, + "get_transcript_source_by_segment_id_sync", + _lookup, + ) + items = [("c1", "hello", None, "seg-1")] + + ids = memoir_tasks._phase1_memory_ingest_batch_sync( + db, + "u1", + items, + memoir_correlation_id="corr-1", + ) + + assert ids == ["src-existing"] diff --git a/api/tests/test_memoir_pipeline_optimization.py b/api/tests/test_memoir_pipeline_optimization.py index fa1f0dd..8d69b98 100644 --- a/api/tests/test_memoir_pipeline_optimization.py +++ b/api/tests/test_memoir_pipeline_optimization.py @@ -18,6 +18,7 @@ from app.agents.memoir.classification_agent import ChapterClassifyResult from app.agents.memoir.extraction_agent import ExtractionResult from app.agents.memoir.orchestrator import MemoirOrchestrator from app.agents.state_schema import MemoirStateSchema +from app.features.memoir.constants import memoir # --------------------------------------------------------------------------- # Phase1 batch path defaults @@ -26,18 +27,12 @@ from app.agents.state_schema import MemoirStateSchema def test_phase1_batch_enabled_by_default() -> None: """memoir_phase1_batch_llm_enabled should default to True after optimization.""" - from app.core.config import Settings - - s = Settings() - assert s.memoir_phase1_batch_llm_enabled is True - assert s.memoir_phase1_batch_llm_chunk_size >= 1 + assert memoir.phase1_batch_llm_enabled is True + assert memoir.phase1_batch_llm_chunk_size >= 1 def test_quality_pass_enabled_by_default() -> None: - from app.core.config import Settings - - s = Settings() - assert s.memoir_quality_pass_enabled is True + assert memoir.quality_pass_enabled is True # --------------------------------------------------------------------------- @@ -48,7 +43,7 @@ def test_quality_pass_enabled_by_default() -> None: def test_orchestrator_tries_batch_first(monkeypatch: pytest.MonkeyPatch) -> None: """When batch LLM is enabled and LLM is available, batch path should be attempted.""" monkeypatch.setattr( - "app.agents.memoir.orchestrator.settings.memoir_phase1_batch_llm_enabled", + "app.agents.memoir.orchestrator.memoir.phase1_batch_llm_enabled", True, ) @@ -95,7 +90,7 @@ def test_orchestrator_tries_batch_first(monkeypatch: pytest.MonkeyPatch) -> None def test_orchestrator_fallback_to_sequential(monkeypatch: pytest.MonkeyPatch) -> None: """If batch path raises, should fall back to sequential extraction.""" monkeypatch.setattr( - "app.agents.memoir.orchestrator.settings.memoir_phase1_batch_llm_enabled", + "app.agents.memoir.orchestrator.memoir.phase1_batch_llm_enabled", True, ) @@ -142,7 +137,7 @@ def test_orchestrator_sequential_filters_invalid_slots( ) -> None: """Sequential fallback should match batch path slot validation.""" monkeypatch.setattr( - "app.agents.memoir.orchestrator.settings.memoir_phase1_batch_llm_enabled", + "app.agents.memoir.orchestrator.memoir.phase1_batch_llm_enabled", False, ) diff --git a/api/tests/test_memoir_pipeline_progress.py b/api/tests/test_memoir_pipeline_progress.py index 81c55e8..26f6589 100644 --- a/api/tests/test_memoir_pipeline_progress.py +++ b/api/tests/test_memoir_pipeline_progress.py @@ -21,7 +21,7 @@ class _FakeRedis: @pytest.fixture def fake_redis(monkeypatch: pytest.MonkeyPatch) -> _FakeRedis: fr = _FakeRedis() - monkeypatch.setattr(mpp, "_client", fr) + monkeypatch.setattr(mpp, "_redis", lambda: fr) return fr diff --git a/api/tests/test_memoir_route_defer.py b/api/tests/test_memoir_route_defer.py index 6fa51d4..28aa782 100644 --- a/api/tests/test_memoir_route_defer.py +++ b/api/tests/test_memoir_route_defer.py @@ -30,6 +30,7 @@ from app.features.payment import models as _payment_models # noqa: F401 from app.features.story import models as _story_models # noqa: F401 from app.features.user import models as _user_models # noqa: F401 from app.features.user.models import User +from app.features.memoir.constants import memoir from app.tasks.memoir_tasks import ( _persist_phase2_route_defer, _wake_deferred_segments_for_category, @@ -185,7 +186,7 @@ def test_pipeline_does_not_defer_when_disabled( monkeypatch: pytest.MonkeyPatch, ) -> None: """关闭开关后,旧行为:直接写 new_story(不再延迟)。""" - monkeypatch.setattr(settings, "memoir_route_defer_enabled", False) + monkeypatch.setattr(memoir, "route_defer_enabled", False) seg = SimpleNamespace(id="seg-no-defer", user_input_text="一句简短的口述") decide_return = StoryRouteDecision( @@ -298,8 +299,8 @@ def test_persist_phase2_route_defer_marks_segment_and_schedules_next( monkeypatch: pytest.MonkeyPatch, ) -> None: """首次延迟:写入 defer 元数据并安排下一次 timeout(未达上限)。""" - monkeypatch.setattr(settings, "memoir_route_defer_seconds", 30.0) - monkeypatch.setattr(settings, "memoir_route_defer_max_attempts", 3) + monkeypatch.setattr(memoir, "route_defer_seconds", 30.0) + monkeypatch.setattr(memoir, "route_defer_max_attempts", 3) db = sqlite_session_factory() seg = _seed_user_segment( @@ -345,8 +346,8 @@ def test_persist_phase2_route_defer_stops_scheduling_at_max_attempts( monkeypatch: pytest.MonkeyPatch, ) -> None: """达到 max_attempts 后不再继续派发 timeout,segment 仍保留 defer 元数据。""" - monkeypatch.setattr(settings, "memoir_route_defer_seconds", 30.0) - monkeypatch.setattr(settings, "memoir_route_defer_max_attempts", 2) + monkeypatch.setattr(memoir, "route_defer_seconds", 30.0) + monkeypatch.setattr(memoir, "route_defer_max_attempts", 2) db = sqlite_session_factory() seg = _seed_user_segment( diff --git a/api/tests/test_memoir_skip_story.py b/api/tests/test_memoir_skip_story.py index 69c80eb..1e10b13 100644 --- a/api/tests/test_memoir_skip_story.py +++ b/api/tests/test_memoir_skip_story.py @@ -14,6 +14,7 @@ from app.agents.memoir.extraction_agent import ExtractionResult from app.agents.memoir.orchestrator import MemoirOrchestrator from app.agents.stage_constants import CHAT_STAGES from app.agents.state_schema import DEFAULT_STAGE_ORDER, MemoirStateSchema +from app.features.memoir.constants import memoir def _empty_state() -> MemoirStateSchema: @@ -123,7 +124,7 @@ def test_prepare_batches_batch_llm_path_matches_per_segment_skip_logic( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setattr( - "app.agents.memoir.orchestrator.settings.memoir_phase1_batch_llm_enabled", + "app.agents.memoir.orchestrator.memoir.phase1_batch_llm_enabled", True, ) diff --git a/api/tests/test_memoir_two_phase.py b/api/tests/test_memoir_two_phase.py index e1a6ef5..d956ca0 100644 --- a/api/tests/test_memoir_two_phase.py +++ b/api/tests/test_memoir_two_phase.py @@ -10,6 +10,7 @@ from app.agents.memoir.classification_agent import ChapterClassifyResult from app.agents.state_schema import MemoirStateSchema from app.tasks.memoir_tasks import _should_trigger_phase2 +from app.features.memoir.constants import memoir def test_segment_chapter_category_populated() -> None: @@ -69,15 +70,15 @@ def test_should_trigger_phase2_matrix( expect: bool, ) -> None: monkeypatch.setattr( - "app.tasks.memoir_tasks.settings.memoir_narrative_immediate_char_threshold", + "app.tasks.memoir_tasks.memoir.narrative_immediate_char_threshold", 50, ) monkeypatch.setattr( - "app.tasks.memoir_tasks.settings.memoir_narrative_batch_min_segments", + "app.tasks.memoir_tasks.memoir.narrative_batch_min_segments", 3, ) monkeypatch.setattr( - "app.tasks.memoir_tasks.settings.memoir_narrative_batch_min_chars", + "app.tasks.memoir_tasks.memoir.narrative_batch_min_chars", 80, ) db = MagicMock() diff --git a/api/tests/test_memory_compaction.py b/api/tests/test_memory_compaction.py index 1918413..b53263e 100644 --- a/api/tests/test_memory_compaction.py +++ b/api/tests/test_memory_compaction.py @@ -17,6 +17,7 @@ from app.features.memory.compaction_service import ( run_memory_compaction, text_layer_match, ) +from app.features.memory.constants import memory from app.tasks.memory_compaction_tasks import ( memory_compaction_run, memory_compaction_sweep, @@ -114,8 +115,8 @@ def test_schedule_merges_subsequent_triggers(monkeypatch) -> None: fake_redis = FakeRedis() calls: list[tuple[str, dict, int]] = [] - monkeypatch.setattr(settings, "memory_compaction_enabled", True) - monkeypatch.setattr(settings, "memory_compaction_debounce_seconds", 30) + monkeypatch.setattr(memory, "compaction_enabled", True) + monkeypatch.setattr(memory, "compaction_debounce_seconds", 30) monkeypatch.setattr(schedule, "_get_redis", lambda: fake_redis) monkeypatch.setattr(schedule.time, "time", lambda: 100.0) monkeypatch.setattr( @@ -140,7 +141,7 @@ def test_finalize_reschedules_when_deadline_extended(monkeypatch) -> None: fake_redis = FakeRedis() calls: list[tuple[str, dict, int]] = [] - monkeypatch.setattr(settings, "memory_compaction_debounce_seconds", 30) + monkeypatch.setattr(memory, "compaction_debounce_seconds", 30) monkeypatch.setattr(schedule, "_get_redis", lambda: fake_redis) monkeypatch.setattr(schedule.time, "time", lambda: 140.0) monkeypatch.setattr( @@ -169,7 +170,7 @@ def test_finalize_reschedules_when_deadline_extended(monkeypatch) -> None: def test_finalize_clears_stale_deadline_when_not_extended(monkeypatch) -> None: fake_redis = FakeRedis() - monkeypatch.setattr(settings, "memory_compaction_debounce_seconds", 30) + monkeypatch.setattr(memory, "compaction_debounce_seconds", 30) monkeypatch.setattr(schedule, "_get_redis", lambda: fake_redis) fake_redis.set(schedule.debounce_key("u1"), "130.0") @@ -353,7 +354,7 @@ def test_memory_compaction_run_releases_gate_and_retries_on_failure( events.append(f"retry:{type(exc).__name__}") raise RetryTriggered("retried") - monkeypatch.setattr(settings, "memory_compaction_enabled", True) + monkeypatch.setattr(memory, "compaction_enabled", True) monkeypatch.setattr( "app.tasks.memory_compaction_tasks.read_debounce_deadline_ts", lambda user_id: 100.0, @@ -392,7 +393,7 @@ def test_memory_compaction_run_releases_gate_and_retries_on_failure( def test_memory_compaction_sweep_skipped_when_disabled( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memory_compaction_enabled", False) + monkeypatch.setattr(memory, "compaction_enabled", False) out = memory_compaction_sweep() assert out == {"skipped": True, "reason": "disabled"} @@ -400,8 +401,8 @@ def test_memory_compaction_sweep_skipped_when_disabled( def test_memory_compaction_sweep_schedules_recent_users( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memory_compaction_enabled", True) - monkeypatch.setattr(settings, "memory_compaction_sweep_recent_hours", 24) + monkeypatch.setattr(memory, "compaction_enabled", True) + monkeypatch.setattr(memory, "compaction_sweep_recent_hours", 24) scheduled: list[tuple[str, dict]] = [] async def fake_list(hours: int): @@ -420,8 +421,53 @@ def test_memory_compaction_sweep_schedules_recent_users( out = memory_compaction_sweep() assert out["scheduled"] == 2 - assert set(out["user_ids"]) == {"user-a", "user-b"} + assert out["failed"] == 0 + assert out["hours"] == 24 assert {u for u, _ in scheduled} == {"user-a", "user-b"} for _, ctx in scheduled: assert ctx.get("trigger_source") == "beat" assert ctx.get("sweep_hours") == 24 + + +@pytest.mark.asyncio +async def test_run_memory_compaction_async_wraps_transactional(monkeypatch) -> None: + commit_calls: list[str] = [] + compact_calls: list[tuple[str, dict | None]] = [] + + class FakeSession: + async def commit(self) -> None: + commit_calls.append("commit") + + async def rollback(self) -> None: + commit_calls.append("rollback") + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + class FakeMemoryService: + def __init__(self, db) -> None: + self._db = db + + async def compact_user(self, user_id: str, context: dict | None): + compact_calls.append((user_id, context)) + return {"chunks_excluded": 1} + + monkeypatch.setattr( + "app.tasks.memory_compaction_tasks.AsyncSessionLocal", + lambda: FakeSession(), + ) + monkeypatch.setattr( + "app.tasks.memory_compaction_tasks.MemoryService", + FakeMemoryService, + ) + + from app.tasks.memory_compaction_tasks import _run_memory_compaction_async + + out = await _run_memory_compaction_async("u1", {"trigger_source": "test"}) + + assert out == {"chunks_excluded": 1} + assert compact_calls == [("u1", {"trigger_source": "test"})] + assert commit_calls == ["commit"] diff --git a/api/tests/test_memory_compaction_sweep.py b/api/tests/test_memory_compaction_sweep.py new file mode 100644 index 0000000..ac57d4b --- /dev/null +++ b/api/tests/test_memory_compaction_sweep.py @@ -0,0 +1,30 @@ +from unittest.mock import MagicMock + +import pytest + +from app.tasks import memory_compaction_tasks as tasks + + +def test_memory_compaction_sweep_continues_after_schedule_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(tasks.memory, "compaction_enabled", True) + monkeypatch.setattr(tasks.memory, "compaction_sweep_recent_hours", 6) + + async def _fake_list(hours: int) -> list[str]: + return ["u1", "u2", "u3"] + + monkeypatch.setattr( + tasks, "_list_users_with_recent_chunks_async", _fake_list + ) + + def fake_schedule(user_id: str, ctx: dict) -> None: + if user_id == "u2": + raise RuntimeError("boom") + + monkeypatch.setattr(tasks, "schedule_memory_compaction_run", fake_schedule) + monkeypatch.setattr(tasks, "business_span", lambda *a, **k: MagicMock()) + + result = tasks.memory_compaction_sweep.run() + + assert result == {"scheduled": 2, "failed": 1, "hours": 6} diff --git a/api/tests/test_memory_enrichment_baseline.py b/api/tests/test_memory_enrichment_baseline.py index 76b1460..eaadee4 100644 --- a/api/tests/test_memory_enrichment_baseline.py +++ b/api/tests/test_memory_enrichment_baseline.py @@ -10,6 +10,7 @@ from app.features.memory.enrichment import enrich_memory_after_ingest_async from app.features.memory.llm_schemas import EnrichmentPayload from app.features.memory.models import MemorySource from app.features.user.models import User +from app.features.memory.constants import memory def test_enrichment_payload_roundtrip() -> None: @@ -31,7 +32,7 @@ async def test_enrich_memory_after_ingest_async_single_llm_call( ) -> None: from app.features.memory import enrichment as mod - monkeypatch.setattr("app.core.config.settings.memory_enrichment_enabled", True) + monkeypatch.setattr("app.features.memory.constants.memory.enrichment_enabled", True) invoke_count = {"n": 0} @@ -113,7 +114,7 @@ async def test_enrich_memory_skips_when_parse_returns_none( ) -> None: from app.features.memory import enrichment as mod - monkeypatch.setattr("app.core.config.settings.memory_enrichment_enabled", True) + monkeypatch.setattr("app.features.memory.constants.memory.enrichment_enabled", True) async def fake_run(*_args, **_kwargs): return None diff --git a/api/tests/test_mock_sms_login_http.py b/api/tests/test_mock_sms_login_http.py index b9917a7..18634e5 100644 --- a/api/tests/test_mock_sms_login_http.py +++ b/api/tests/test_mock_sms_login_http.py @@ -14,11 +14,12 @@ from app.core.security import create_access_token, verify_token from app.features.auth.deps import get_auth_service from app.features.auth.router import router as auth_router from app.features.auth.service import AuthService +from tests.conftest import install_test_error_handlers @pytest.fixture def mock_sms_app() -> FastAPI: - app = FastAPI() + app = install_test_error_handlers(FastAPI()) app.include_router(auth_router) mock_service = MagicMock(spec=AuthService) uid = str(uuid.uuid4()) diff --git a/api/tests/test_openapi_error_response.py b/api/tests/test_openapi_error_response.py new file mode 100644 index 0000000..b84a39e --- /dev/null +++ b/api/tests/test_openapi_error_response.py @@ -0,0 +1,112 @@ +"""OpenAPI 统一 ErrorResponse 组件与 router 引用。""" + +from __future__ import annotations + +from fastapi import FastAPI + +from app.core.error_codes import AUTH_ERROR_CODES, ERROR_CODE_ENUM +from app.core.openapi import COMMON_ERROR_RESPONSES, custom_openapi, error_responses +from app.features.auth.router import router as auth_router +from app.features.memoir.router import router as memoir_router +from app.features.payment.router import router as payment_router + + +def _openapi_app(*routers) -> FastAPI: + app = FastAPI() + for router in routers: + app.include_router(router) + app.openapi = lambda: custom_openapi(app) # type: ignore[assignment] + return app + + +def test_openapi_includes_error_response_schema() -> None: + app = _openapi_app(payment_router) + app.openapi_schema = None + schema = custom_openapi(app) + + error_schema = schema["components"]["schemas"]["ErrorResponse"] + assert set(error_schema["required"]) == {"error_code", "message", "request_id"} + error_code_prop = error_schema["properties"]["error_code"] + assert error_code_prop["allOf"][0]["$ref"].endswith("ErrorCode") + + error_code_enum = schema["components"]["schemas"]["ErrorCode"]["enum"] + assert "NOT_FOUND" in error_code_enum + assert "PHONE_EXISTS" in error_code_enum + assert set(error_code_enum) == set(ERROR_CODE_ENUM) + + +def test_openapi_description_includes_error_code_table() -> None: + app = _openapi_app(auth_router) + app.openapi_schema = None + schema = custom_openapi(app) + description = schema["info"]["description"] + assert "PHONE_EXISTS" in description + assert "INVALID_SMS_CODE" in description + assert AUTH_ERROR_CODES[0]["code"] in description + + +def test_openapi_domain_error_code_schema() -> None: + app = _openapi_app(auth_router) + app.openapi_schema = None + schema = custom_openapi(app) + domain_enum = schema["components"]["schemas"]["DomainErrorCode"]["enum"] + assert "PHONE_EXISTS" in domain_enum + assert "NOT_FOUND" not in domain_enum + + +def test_error_responses_reference_error_response_component() -> None: + responses = error_responses(401, 404) + assert responses[401]["content"]["application/json"]["schema"]["$ref"].endswith( + "ErrorResponse" + ) + assert 404 in responses + assert responses[404] == COMMON_ERROR_RESPONSES[404] + + +def test_memoir_router_openapi_uses_error_response_ref() -> None: + app = _openapi_app(memoir_router) + app.openapi_schema = None + schema = custom_openapi(app) + get_chapters = schema["paths"]["/api/chapters"]["get"] + assert "401" in get_chapters["responses"] + ref = get_chapters["responses"]["401"]["content"]["application/json"]["schema"]["$ref"] + assert ref.endswith("ErrorResponse") + + +def test_auth_register_openapi_uses_error_response_ref() -> None: + app = _openapi_app(auth_router) + app.openapi_schema = None + schema = custom_openapi(app) + register = schema["paths"]["/api/auth/register"]["post"] + assert "400" in register["responses"] + ref = register["responses"]["400"]["content"]["application/json"]["schema"]["$ref"] + assert ref.endswith("ErrorResponse") + + +def test_error_responses_includes_502() -> None: + responses = error_responses(502) + assert responses[502] == COMMON_ERROR_RESPONSES[502] + assert responses[502]["content"]["application/json"]["schema"]["$ref"].endswith( + "ErrorResponse" + ) + + +def test_payment_router_openapi_includes_502() -> None: + app = _openapi_app(payment_router) + app.openapi_schema = None + schema = custom_openapi(app) + create_order = schema["paths"]["/api/payment/create-order"]["post"] + assert "502" in create_order["responses"] + ref = create_order["responses"]["502"]["content"]["application/json"]["schema"]["$ref"] + assert ref.endswith("ErrorResponse") + + +def test_payment_router_openapi_includes_504_and_500() -> None: + app = _openapi_app(payment_router) + app.openapi_schema = None + schema = custom_openapi(app) + create_order = schema["paths"]["/api/payment/create-order"]["post"] + for status in ("500", "504"): + assert status in create_order["responses"] + ref = create_order["responses"][status]["content"]["application/json"]["schema"]["$ref"] + assert ref.endswith("ErrorResponse") diff --git a/api/tests/test_oral_normalize.py b/api/tests/test_oral_normalize.py index 8ddcbf1..ab8993e 100644 --- a/api/tests/test_oral_normalize.py +++ b/api/tests/test_oral_normalize.py @@ -22,27 +22,27 @@ def test_apply_rules_no_false_positive_rong() -> None: def test_normalize_respects_global_off() -> None: raw = "美看上我" - with patch("app.features.memoir.oral_normalize.settings") as m: - m.memoir_oral_normalize_enabled = False - m.memoir_oral_normalize_mode = "rules" + with patch("app.features.memoir.oral_normalize.memoir") as m: + m.oral_normalize_enabled = False + m.oral_normalize_mode = "rules" assert normalize_oral_for_memoir(raw, llm=None) == raw def test_normalize_rules_mode_no_llm() -> None: raw = "美看上我" - with patch("app.features.memoir.oral_normalize.settings") as m: - m.memoir_oral_normalize_enabled = True - m.memoir_oral_normalize_mode = "rules" - m.memoir_oral_normalize_llm_max_tokens = 512 - m.memoir_oral_normalize_llm_max_input_chars = 8000 + with patch("app.features.memoir.oral_normalize.memoir") as m: + m.oral_normalize_enabled = True + m.oral_normalize_mode = "rules" + m.oral_normalize_llm_max_tokens = 512 + m.oral_normalize_llm_max_input_chars = 8000 assert normalize_oral_for_memoir(raw, llm=None) == "没看上我" def test_normalize_mode_off_string() -> None: raw = "美看上我" - with patch("app.features.memoir.oral_normalize.settings") as m: - m.memoir_oral_normalize_enabled = True - m.memoir_oral_normalize_mode = "off" - m.memoir_oral_normalize_llm_max_tokens = 512 - m.memoir_oral_normalize_llm_max_input_chars = 8000 + with patch("app.features.memoir.oral_normalize.memoir") as m: + m.oral_normalize_enabled = True + m.oral_normalize_mode = "off" + m.oral_normalize_llm_max_tokens = 512 + m.oral_normalize_llm_max_input_chars = 8000 assert normalize_oral_for_memoir(raw, llm=None) == raw diff --git a/api/tests/test_pipeline_tts_cancel_emits_all_segments.py b/api/tests/test_pipeline_tts_cancel_emits_all_segments.py index ca480e7..5ad1ac3 100644 --- a/api/tests/test_pipeline_tts_cancel_emits_all_segments.py +++ b/api/tests/test_pipeline_tts_cancel_emits_all_segments.py @@ -73,8 +73,8 @@ def _patch_common(monkeypatch: pytest.MonkeyPatch) -> tuple[list[dict], MagicMoc monkeypatch.setattr(ws_pipeline.manager, "active_connections", {}) fake_store = MagicMock() - fake_store.record_human_ai_turn = AsyncMock() - fake_store.attach_ai_tts_audio_urls = AsyncMock(return_value=None) + fake_store.record_human_ai_turn_with_segment = AsyncMock() + fake_store.attach_ai_tts_for_turn = AsyncMock(return_value=None) monkeypatch.setattr( ws_pipeline, "ConversationHistoryStore", lambda _db: fake_store ) @@ -96,7 +96,7 @@ async def test_tts_cancel_mid_flight_still_emits_all_agent_response_segments( sent_messages, fake_store = _patch_common(monkeypatch) conversation_id = "conv-cancel-mid" ws_pipeline.manager.active_connections[conversation_id] = object() - fake_store.record_human_ai_turn.return_value = _ids_for(conversation_id) + fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id) turn_result = ChatTurnResult( messages=["第一段", "第二段", "第三段"], @@ -166,7 +166,7 @@ async def test_tts_cancel_before_any_segment_still_emits_agent_response( sent_messages, fake_store = _patch_common(monkeypatch) conversation_id = "conv-cancel-pre" ws_pipeline.manager.active_connections[conversation_id] = object() - fake_store.record_human_ai_turn.return_value = _ids_for(conversation_id) + fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id) turn_result = ChatTurnResult( messages=["唯一段"], @@ -225,7 +225,7 @@ async def test_empty_responses_emits_terminal_error( sent_messages, fake_store = _patch_common(monkeypatch) conversation_id = "conv-empty" ws_pipeline.manager.active_connections[conversation_id] = object() - fake_store.record_human_ai_turn.return_value = None + fake_store.record_human_ai_turn_with_segment.return_value = None monkeypatch.setattr( ws_pipeline.chat_turn_service, @@ -291,7 +291,7 @@ async def test_tts_disabled_emits_all_segments_without_tts_calls( sent_messages, fake_store = _patch_common(monkeypatch) conversation_id = "conv-text-only" ws_pipeline.manager.active_connections[conversation_id] = object() - fake_store.record_human_ai_turn.return_value = _ids_for(conversation_id) + fake_store.record_human_ai_turn_with_segment.return_value = _ids_for(conversation_id) monkeypatch.setattr( ws_pipeline.chat_turn_service, diff --git a/api/tests/test_recompose_retry_policy.py b/api/tests/test_recompose_retry_policy.py index a99a7d9..dd7b9e0 100644 --- a/api/tests/test_recompose_retry_policy.py +++ b/api/tests/test_recompose_retry_policy.py @@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch import pytest from celery.exceptions import Retry -from app.core.config import settings +from app.features.memoir.constants import memoir from app.tasks.chapter_compose_tasks import recompose_chapter @@ -29,7 +29,7 @@ def test_recompose_retries_when_lock_busy_and_flag_on( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setattr( - settings, "memoir_recompose_retry_on_lock_contention", True, raising=False + memoir, "recompose_retry_on_lock_contention", True, raising=False ) session = MagicMock() ch = MagicMock() @@ -62,7 +62,7 @@ def test_recompose_skips_when_lock_busy_and_flag_off( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setattr( - settings, "memoir_recompose_retry_on_lock_contention", False, raising=False + memoir, "recompose_retry_on_lock_contention", False, raising=False ) session = MagicMock() ch = MagicMock() diff --git a/api/tests/test_redis_sync_client.py b/api/tests/test_redis_sync_client.py new file mode 100644 index 0000000..11021c3 --- /dev/null +++ b/api/tests/test_redis_sync_client.py @@ -0,0 +1,27 @@ +import pytest + +from app.core import redis_sync + + +def test_sync_redis_reuses_singleton(monkeypatch: pytest.MonkeyPatch) -> None: + created: list[object] = [] + + class FakeRedis: + def close(self) -> None: + pass + + def fake_from_url(*args, **kwargs): + client = FakeRedis() + created.append(client) + return client + + redis_sync.reset_sync_redis_clients_for_tests() + monkeypatch.setattr(redis_sync.redis, "from_url", fake_from_url) + + first = redis_sync.get_sync_redis(decode_responses=True) + second = redis_sync.get_sync_redis(decode_responses=True) + + assert first is second + assert len(created) == 1 + + redis_sync.reset_sync_redis_clients_for_tests() diff --git a/api/tests/test_redis_urls.py b/api/tests/test_redis_urls.py new file mode 100644 index 0000000..9e20493 --- /dev/null +++ b/api/tests/test_redis_urls.py @@ -0,0 +1,46 @@ +from app.core.redis_urls import ( + derive_celery_redis_url, + inject_redis_password, + resolve_redis_urls, +) + + +def test_inject_redis_password_when_missing() -> None: + url = inject_redis_password("redis://localhost:6379/0", "secret") + assert url == "redis://:secret@localhost:6379/0" + + +def test_inject_redis_password_skips_when_present() -> None: + url = inject_redis_password("redis://:existing@localhost:6379/0", "secret") + assert url == "redis://:existing@localhost:6379/0" + + +def test_derive_celery_redis_url_increments_db() -> None: + url = derive_celery_redis_url("redis://localhost:6379/0") + assert url == "redis://localhost:6379/1" + + +def test_resolve_redis_urls_applies_password_to_both() -> None: + business, celery = resolve_redis_urls( + "redis://localhost:6379/0", + redis_password="secret", + ) + assert business == "redis://:secret@localhost:6379/0" + assert celery == "redis://:secret@localhost:6379/1" + + +def test_celery_redis_url_override() -> None: + business, celery = resolve_redis_urls( + "redis://localhost:6379/0", + redis_password="secret", + celery_redis_url_override="redis://broker:6380/2", + ) + assert business == "redis://:secret@localhost:6379/0" + assert celery == "redis://broker:6380/2" + + +def test_derive_celery_redis_url_rejects_db_15() -> None: + import pytest + + with pytest.raises(ValueError, match="CELERY_REDIS_URL"): + derive_celery_redis_url("redis://localhost:6379/15") diff --git a/api/tests/test_settings_allowlist.py b/api/tests/test_settings_allowlist.py new file mode 100644 index 0000000..e72ae8f --- /dev/null +++ b/api/tests/test_settings_allowlist.py @@ -0,0 +1,55 @@ +"""Settings 字段数量守卫:防止 env 反弹为巨型配置。""" + +from app.core.config import Settings + +ALLOWLIST_MAX_FIELDS = 22 + +EXPECTED_PREFIXES = ( + "database_", + "redis_", + "celery_", + "app_", + "secret_", + "deepseek_", + "zhipu_", + "tencent_secret_", + "wechat_pay_", + "alipay_", + "liblib_", + "internal_eval_", +) + + +def test_settings_field_count_within_allowlist() -> None: + assert len(Settings.model_fields) <= ALLOWLIST_MAX_FIELDS + + +def test_settings_has_only_secrets_and_bootstrap_fields() -> None: + for name in Settings.model_fields: + assert name.startswith(EXPECTED_PREFIXES), ( + f"unexpected Settings field {name!r}; " + "non-secret deploy/product config belongs in config/*.toml" + ) + + +def test_settings_has_no_product_tuning_field_names() -> None: + blocked = ( + "chat_", + "memory_", + "eval_", + "story_", + "memoir_", + "agent_log_", + "enable_", + "otel_", + "log_", + "access_", + "refresh_", + "mock_", + "tencent_sms_", + "tencent_cos_", + "api_cors_", + "alembic_", + ) + for name in Settings.model_fields: + assert not any(name.startswith(p) for p in blocked) diff --git a/api/tests/test_sms_login_new_user_persists_language.py b/api/tests/test_sms_login_new_user_persists_language.py index b6da060..4a07f74 100644 --- a/api/tests/test_sms_login_new_user_persists_language.py +++ b/api/tests/test_sms_login_new_user_persists_language.py @@ -26,6 +26,8 @@ from app.features.user.models import User def _make_service() -> AuthService: db = MagicMock() db.commit = AsyncMock(return_value=None) + db.rollback = AsyncMock(return_value=None) + db.flush = AsyncMock(return_value=None) db.refresh = AsyncMock(return_value=None) sms = MagicMock() return AuthService(db=db, sms=sms) diff --git a/api/tests/test_state_service_batch_stage_policy.py b/api/tests/test_state_service_batch_stage_policy.py index eb6f4ef..dd10921 100644 --- a/api/tests/test_state_service_batch_stage_policy.py +++ b/api/tests/test_state_service_batch_stage_policy.py @@ -26,6 +26,7 @@ from app.features.memoir.state_service import ( update_slot_sync, ) from app.features.user.models import User +from app.features.memoir.constants import memoir @pytest.fixture @@ -76,7 +77,7 @@ def _add_user_and_state( def test_apply_current_stage_policy_live_path_always_writes( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", False) + monkeypatch.setattr(memoir, "extraction_updates_current_stage", False) state = SimpleNamespace(current_stage="childhood") _apply_current_stage_policy(state, "career", memoir_batch=False) assert state.current_stage == "career" @@ -85,7 +86,7 @@ def test_apply_current_stage_policy_live_path_always_writes( def test_apply_current_stage_policy_batch_flag_off_short_circuit( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", False) + monkeypatch.setattr(memoir, "extraction_updates_current_stage", False) state = SimpleNamespace(current_stage="childhood") _apply_current_stage_policy(state, "career", memoir_batch=True) assert state.current_stage == "childhood" @@ -94,7 +95,7 @@ def test_apply_current_stage_policy_batch_flag_off_short_circuit( def test_apply_current_stage_policy_batch_same_bucket_updates( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + monkeypatch.setattr(memoir, "extraction_updates_current_stage", True) state = SimpleNamespace(current_stage="career") _apply_current_stage_policy(state, "career", memoir_batch=True) assert state.current_stage == "career" @@ -103,7 +104,7 @@ def test_apply_current_stage_policy_batch_same_bucket_updates( def test_apply_current_stage_policy_batch_same_bucket_repairs_chapter_key( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + monkeypatch.setattr(memoir, "extraction_updates_current_stage", True) state = SimpleNamespace(current_stage="career_early") _apply_current_stage_policy(state, "career", memoir_batch=True) assert state.current_stage == "career" @@ -112,7 +113,7 @@ def test_apply_current_stage_policy_batch_same_bucket_repairs_chapter_key( def test_apply_current_stage_policy_batch_cross_bucket_blocked( monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + monkeypatch.setattr(memoir, "extraction_updates_current_stage", True) state = SimpleNamespace(current_stage="childhood") _apply_current_stage_policy(state, "career", memoir_batch=True) assert state.current_stage == "childhood" @@ -122,7 +123,7 @@ def test_update_slot_sync_batch_respects_flag_false( sqlite_session_factory, monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", False) + monkeypatch.setattr(memoir, "extraction_updates_current_stage", False) uid = "u-batch-off" db = sqlite_session_factory() _add_user_and_state(db, user_id=uid, current_stage="childhood") @@ -147,7 +148,7 @@ def test_update_slot_sync_batch_flag_true_same_bucket_updates_row( sqlite_session_factory, monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + monkeypatch.setattr(memoir, "extraction_updates_current_stage", True) uid = "u-batch-on" db = sqlite_session_factory() _add_user_and_state(db, user_id=uid, current_stage="career") @@ -172,7 +173,7 @@ def test_update_slot_sync_ignores_invalid_slot_name( sqlite_session_factory, monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + monkeypatch.setattr(memoir, "extraction_updates_current_stage", True) uid = "u-invalid-slot" db = sqlite_session_factory() _add_user_and_state(db, user_id=uid, current_stage="childhood") @@ -197,7 +198,7 @@ def test_update_slot_sync_batch_flag_true_cross_bucket_unchanged( sqlite_session_factory, monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.setattr(settings, "memoir_extraction_updates_current_stage", True) + monkeypatch.setattr(memoir, "extraction_updates_current_stage", True) uid = "u-cross" db = sqlite_session_factory() _add_user_and_state(db, user_id=uid, current_stage="childhood") diff --git a/api/tests/test_story_route_payload.py b/api/tests/test_story_route_payload.py index a6f5c1b..55c8cc1 100644 --- a/api/tests/test_story_route_payload.py +++ b/api/tests/test_story_route_payload.py @@ -13,6 +13,7 @@ from app.agents.memoir.story_route_payload import ( _truncate_body_for_route, ) from app.core.config import Settings +from app.features.story.constants import story def _story(**kwargs): @@ -102,8 +103,8 @@ def test_long_body_uses_head_tail(): def test_total_budget_downgrades_tail_rows(monkeypatch): settings = Settings() - monkeypatch.setattr(settings, "story_route_candidate_total_max_chars", 800) - monkeypatch.setattr(settings, "story_route_index_preview_chars", 40) + monkeypatch.setattr(story, "route_candidate_total_max_chars", 800) + monkeypatch.setattr(story, "route_index_preview_chars", 40) stories = [ _story( id="1", diff --git a/api/tests/test_task_tracker_ttl.py b/api/tests/test_task_tracker_ttl.py new file mode 100644 index 0000000..0ef5d27 --- /dev/null +++ b/api/tests/test_task_tracker_ttl.py @@ -0,0 +1,51 @@ +import json +from unittest.mock import AsyncMock + +import pytest + +from app.core.task_tracker import TaskTracker + + +@pytest.mark.asyncio +async def test_update_task_status_refreshes_ttl(monkeypatch: pytest.MonkeyPatch) -> None: + client = AsyncMock() + client.hget.return_value = json.dumps({"task_id": "t1", "status": "pending"}) + client.exists.return_value = 1 + + async def fake_get_client(): + return client + + monkeypatch.setattr( + "app.core.task_tracker.redis_service.get_client", + fake_get_client, + ) + + tracker = TaskTracker() + ok = await tracker.update_task_status("user-1", "t1", "running") + + assert ok is True + client.hset.assert_awaited_once() + client.expire.assert_awaited_once_with("task:user:user-1:tasks", tracker.task_ttl) + + +@pytest.mark.asyncio +async def test_remove_task_refreshes_ttl_when_hash_remains( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = AsyncMock() + client.exists.return_value = 1 + + async def fake_get_client(): + return client + + monkeypatch.setattr( + "app.core.task_tracker.redis_service.get_client", + fake_get_client, + ) + + tracker = TaskTracker() + ok = await tracker.remove_task("user-1", "t1") + + assert ok is True + client.hdel.assert_awaited_once() + client.expire.assert_awaited_once_with("task:user:user-1:tasks", tracker.task_ttl) diff --git a/api/tests/test_ws_pipeline_transactional.py b/api/tests/test_ws_pipeline_transactional.py new file mode 100644 index 0000000..d7c3619 --- /dev/null +++ b/api/tests/test_ws_pipeline_transactional.py @@ -0,0 +1,237 @@ +"""WS pipeline / history_store 原子持久化边界;memoir 调度顺序。""" + +from __future__ import annotations + +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch +import uuid + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from app.features.auth import models as _auth_models # noqa: F401 +from app.features.conversation import models as _conv_models # noqa: F401 +from app.features.memory import models as _memory_models # noqa: F401 +from app.features.memoir import models as _memoir_models # noqa: F401 +from app.features.payment import models as _payment_models # noqa: F401 +from app.features.story import models as _story_models # noqa: F401 +from app.features.user import models as _user_models # noqa: F401 +from app.features.conversation.history_store import ConversationHistoryStore +from app.features.conversation.models import ConversationMessage, Segment +from app.features.conversation.ws import persist + + +@asynccontextmanager +async def _capture_transactional(db): + yield db + await db.commit() + + +@pytest.mark.asyncio +async def test_persist_message_tts_url_segment_commits_once() -> None: + db = MagicMock(spec=AsyncSession) + db.commit = AsyncMock() + msg = ConversationMessage( + id="msg-1", + conversation_id="conv-1", + role="ai", + content="hi", + tts_audio_urls=[], + ) + + with patch("app.features.conversation.ws.persist.transactional", _capture_transactional): + await persist.persist_message_tts_url_segment(db, msg, 0, "https://cos/0.mp3") + + assert msg.tts_audio_urls == ["https://cos/0.mp3"] + db.commit.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_persist_voice_segment_row_commits_segment_and_activity() -> None: + from app.features.conversation.models import Conversation + + db = MagicMock(spec=AsyncSession) + db.commit = AsyncMock() + db.add = MagicMock() + conv = Conversation(id="conv-1", user_id="u1") + segment = Segment( + id="seg-1", + conversation_id="conv-1", + user_input_text="hello", + processed=False, + ) + + with patch("app.features.conversation.ws.persist.transactional", _capture_transactional): + await persist.persist_voice_segment_row(db, segment, conv) + + db.add.assert_called_once_with(segment) + assert conv.last_message_at is not None + db.commit.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_record_human_ai_turn_with_segment_single_commit() -> None: + db = MagicMock(spec=AsyncSession) + db.commit = AsyncMock() + db.rollback = AsyncMock() + db.flush = AsyncMock() + commit_count = 0 + + @asynccontextmanager + async def counting_transactional(session): + nonlocal commit_count + yield session + await session.commit() + commit_count += 1 + + store = ConversationHistoryStore(db) + segment = Segment( + id="seg-1", + conversation_id="conv-1", + user_input_text="hi", + processed=False, + ) + + with ( + patch( + "app.features.conversation.history_store.transactional", + counting_transactional, + ), + patch.object(store, "_touch_conversation", AsyncMock()), + patch.object(store, "_sync_redis_best_effort", AsyncMock()), + patch("app.features.conversation.history_store.repo.add_conversation_message"), + ): + turn_ids = await store.record_human_ai_turn_with_segment( + "conv-1", + "hello", + ["reply"], + segment, + user_message_timestamp=datetime.now(timezone.utc), + is_from_voice=True, + voice_session_id="vs-1", + audio_duration_seconds=3, + agent_response="reply", + ) + + assert turn_ids is not None + assert commit_count == 1 + db.flush.assert_awaited_once() + assert segment.agent_response == "reply" + assert segment.user_message_id == turn_ids.human_message_id + assert segment.lineage_json is not None + assert segment.lineage_json["primary_user_message_id"] == turn_ids.human_message_id + assert segment.lineage_json["turns"][0]["user_message_id"] == turn_ids.human_message_id + + +@pytest.mark.asyncio +async def test_record_human_ai_turn_with_segment_postgres_flush_order() -> None: + """Regression: Postgres FK on segments.user_message_id requires message INSERT first.""" + from app.core.config import settings + + if not settings.database_url.startswith("postgresql"): + pytest.skip("requires PostgreSQL") + + from app.core.db import AsyncSessionLocal, transactional + from app.features.conversation.models import Conversation + from app.features.user.models import User + + uid = str(uuid.uuid4()) + cid = str(uuid.uuid4()) + sid = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + + async with AsyncSessionLocal() as db: + db.add( + User( + id=uid, + phone=f"138{uuid.uuid4().int % 100_000_000:08d}", + password_hash="x", + nickname="t", + subscription_type="free", + created_at=now, + ) + ) + conv = Conversation(id=cid, user_id=uid, last_message_at=now) + db.add(conv) + segment = Segment( + id=sid, + conversation_id=cid, + user_input_text="hi", + processed=False, + created_at=now, + ) + async with transactional(db): + db.add(segment) + await db.refresh(segment) + + async with AsyncSessionLocal() as db: + segment = await db.get(Segment, sid) + store = ConversationHistoryStore(db) + turn_ids = await store.record_human_ai_turn_with_segment( + cid, + "hello", + ["reply"], + segment, + user_message_timestamp=now, + is_from_voice=False, + voice_session_id=None, + audio_duration_seconds=None, + agent_response="reply", + ) + + assert turn_ids is not None + assert segment.user_message_id == turn_ids.human_message_id + + +@pytest.mark.asyncio +async def test_attach_ai_tts_for_turn_single_commit() -> None: + db = MagicMock(spec=AsyncSession) + commit_count = 0 + + @asynccontextmanager + async def counting_transactional(session): + nonlocal commit_count + yield session + await session.commit() + commit_count += 1 + + store = ConversationHistoryStore(db) + segment = Segment( + id="seg-1", + conversation_id="conv-1", + user_input_text="hi", + processed=False, + ) + ai_row = ConversationMessage( + id="ai-1", + conversation_id="conv-1", + role="ai", + content="reply", + ) + + async def _set_tts(*_args, **_kwargs): + ai_row.tts_audio_urls = ["https://cos/a.mp3"] + return ai_row + + with ( + patch( + "app.features.conversation.history_store.transactional", + counting_transactional, + ), + patch.object(store, "_sync_redis_best_effort", AsyncMock()), + patch( + "app.features.conversation.history_store.repo.set_latest_ai_message_tts_audio_urls", + _set_tts, + ), + ): + await store.attach_ai_tts_for_turn( + "conv-1", + tts_audio_urls=["https://cos/a.mp3"], + segment=segment, + ) + + assert commit_count == 1 + assert segment.tts_audio_urls == ["https://cos/a.mp3"] + assert ai_row.tts_audio_urls == ["https://cos/a.mp3"] diff --git a/api/uv.lock b/api/uv.lock index 1645a62..f50ecd1 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -82,6 +82,7 @@ dependencies = [ { name = "cos-python-sdk-v5" }, { name = "fastapi", extra = ["standard"] }, { name = "faster-whisper" }, + { name = "flower" }, { name = "greenlet" }, { name = "httpx" }, { name = "langchain" }, @@ -132,6 +133,7 @@ requires-dist = [ { name = "cos-python-sdk-v5", specifier = ">=1.9.41" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.135.1" }, { name = "faster-whisper", specifier = ">=1.2.1" }, + { name = "flower", specifier = ">=2.0.1" }, { name = "greenlet", specifier = ">=3.3.2" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "langchain", specifier = ">=1.2.12" }, @@ -898,6 +900,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl", hash = "sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4", size = 26661, upload-time = "2025-12-19T23:16:13.622Z" }, ] +[[package]] +name = "flower" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "celery" }, + { name = "humanize" }, + { name = "prometheus-client" }, + { name = "pytz" }, + { name = "tornado" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/09/a1/357f1b5d8946deafdcfdd604f51baae9de10aafa2908d0b7322597155f92/flower-2.0.1.tar.gz", hash = "sha256:5ab717b979530770c16afb48b50d2a98d23c3e9fe39851dcf6bc4d01845a02a0", size = 3220408, upload-time = "2023-08-13T14:37:46.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/ff/ee2f67c0ff146ec98b5df1df637b2bc2d17beeb05df9f427a67bd7a7d79c/flower-2.0.1-py2.py3-none-any.whl", hash = "sha256:9db2c621eeefbc844c8dd88be64aef61e84e2deb29b271e02ab2b5b9f01068e2", size = 383553, upload-time = "2023-08-13T14:37:41.552Z" }, +] + [[package]] name = "fonttools" version = "4.62.0" @@ -1135,6 +1153,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/92/e3/e3a44f54c8e2f28983fcf07f13d4260b37bd6a0d3a081041bc60b91d230e/huggingface_hub-1.6.0-py3-none-any.whl", hash = "sha256:ef40e2d5cb85e48b2c067020fa5142168342d5108a1b267478ed384ecbf18961", size = 612874, upload-time = "2026-03-06T14:19:16.844Z" }, ] +[[package]] +name = "humanize" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/66/a3921783d54be8a6870ac4ccffcd15c4dc0dd7fcce51c6d63b8c63935276/humanize-4.15.0.tar.gz", hash = "sha256:1dd098483eb1c7ee8e32eb2e99ad1910baefa4b75c3aff3a82f4d78688993b10", size = 83599, upload-time = "2025-12-20T20:16:13.19Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/7b/bca5613a0c3b542420cf92bd5e5fb8ebd5435ce1011a091f66bb7693285e/humanize-4.15.0-py3-none-any.whl", hash = "sha256:b1186eb9f5a9749cd9cb8565aee77919dd7c8d076161cf44d70e59e3301e1769", size = 132203, upload-time = "2025-12-20T20:16:11.67Z" }, +] + [[package]] name = "idna" version = "3.11" @@ -1961,6 +1988,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "prometheus-client" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/fb/d9aa83ffe43ce1f19e557c0971d04b90561b0cfd50762aafb01968285553/prometheus_client-0.25.0.tar.gz", hash = "sha256:5e373b75c31afb3c86f1a52fa1ad470c9aace18082d39ec0d2f918d11cc9ba28", size = 86035, upload-time = "2026-04-09T19:53:42.359Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl", hash = "sha256:d5aec89e349a6ec230805d0df882f3807f74fd6c1a2fa86864e3c2279059fed1", size = 64154, upload-time = "2026-04-09T19:53:41.324Z" }, +] + [[package]] name = "prompt-toolkit" version = "3.0.52" @@ -2355,6 +2391,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/d0/397f9626e711ff749a95d96b7af99b9c566a9bb5129b8e4c10fc4d100304/python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155", size = 24579, upload-time = "2026-01-25T10:15:54.811Z" }, ] +[[package]] +name = "pytz" +version = "2026.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/46/dd499ec9038423421951e4fad73051febaa13d2df82b4064f87af8b8c0c3/pytz-2026.2.tar.gz", hash = "sha256:0e60b47b29f21574376f218fe21abc009894a2321ea16c6754f3cad6eb7cdd6a", size = 320861, upload-time = "2026-05-04T01:35:29.667Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/dd/96da98f892250475bdf2328112d7468abdd4acc7b902b6af23f4ed958ea0/pytz-2026.2-py2.py3-none-any.whl", hash = "sha256:04156e608bee23d3792fd45c94ae47fae1036688e75032eea2e3bf0323d1f126", size = 510141, upload-time = "2026-05-04T01:35:27.408Z" }, +] + [[package]] name = "pyyaml" version = "6.0.3" @@ -2840,6 +2885,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/72/f4/0de46cfa12cdcbcd464cc59fde36912af405696f687e53a091fb432f694c/tokenizers-0.22.2-cp39-abi3-win_arm64.whl", hash = "sha256:9ce725d22864a1e965217204946f830c37876eee3b2ba6fc6255e8e903d5fcbc", size = 2612133, upload-time = "2026-01-05T10:45:17.232Z" }, ] +[[package]] +name = "tornado" +version = "6.5.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/f1/3173dfa4a18db4a9b03e5d55325559dab51ee653763bb8745a75af491286/tornado-6.5.5.tar.gz", hash = "sha256:192b8f3ea91bd7f1f50c06955416ed76c6b72f96779b962f07f911b91e8d30e9", size = 516006, upload-time = "2026-03-10T21:31:02.067Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/8c/77f5097695f4dd8255ecbd08b2a1ed8ba8b953d337804dd7080f199e12bf/tornado-6.5.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:487dc9cc380e29f58c7ab88f9e27cdeef04b2140862e5076a66fb6bb68bb1bfa", size = 445983, upload-time = "2026-03-10T21:30:44.28Z" }, + { url = "https://files.pythonhosted.org/packages/ab/5e/7625b76cd10f98f1516c36ce0346de62061156352353ef2da44e5c21523c/tornado-6.5.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:65a7f1d46d4bb41df1ac99f5fcb685fb25c7e61613742d5108b010975a9a6521", size = 444246, upload-time = "2026-03-10T21:30:46.571Z" }, + { url = "https://files.pythonhosted.org/packages/b2/04/7b5705d5b3c0fab088f434f9c83edac1573830ca49ccf29fb83bf7178eec/tornado-6.5.5-cp39-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e74c92e8e65086b338fd56333fb9a68b9f6f2fe7ad532645a290a464bcf46be5", size = 447229, upload-time = "2026-03-10T21:30:48.273Z" }, + { url = "https://files.pythonhosted.org/packages/34/01/74e034a30ef59afb4097ef8659515e96a39d910b712a89af76f5e4e1f93c/tornado-6.5.5-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:435319e9e340276428bbdb4e7fa732c2d399386d1de5686cb331ec8eee754f07", size = 448192, upload-time = "2026-03-10T21:30:51.22Z" }, + { url = "https://files.pythonhosted.org/packages/be/00/fe9e02c5a96429fce1a1d15a517f5d8444f9c412e0bb9eadfbe3b0fc55bf/tornado-6.5.5-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:3f54aa540bdbfee7b9eb268ead60e7d199de5021facd276819c193c0fb28ea4e", size = 448039, upload-time = "2026-03-10T21:30:53.52Z" }, + { url = "https://files.pythonhosted.org/packages/82/9e/656ee4cec0398b1d18d0f1eb6372c41c6b889722641d84948351ae19556d/tornado-6.5.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:36abed1754faeb80fbd6e64db2758091e1320f6bba74a4cf8c09cd18ccce8aca", size = 447445, upload-time = "2026-03-10T21:30:55.541Z" }, + { url = "https://files.pythonhosted.org/packages/5a/76/4921c00511f88af86a33de770d64141170f1cfd9c00311aea689949e274e/tornado-6.5.5-cp39-abi3-win32.whl", hash = "sha256:dd3eafaaeec1c7f2f8fdcd5f964e8907ad788fe8a5a32c4426fbbdda621223b7", size = 448582, upload-time = "2026-03-10T21:30:57.142Z" }, + { url = "https://files.pythonhosted.org/packages/2c/23/f6c6112a04d28eed765e374435fb1a9198f73e1ec4b4024184f21faeb1ad/tornado-6.5.5-cp39-abi3-win_amd64.whl", hash = "sha256:6443a794ba961a9f619b1ae926a2e900ac20c34483eea67be4ed8f1e58d3ef7b", size = 448990, upload-time = "2026-03-10T21:30:58.857Z" }, + { url = "https://files.pythonhosted.org/packages/b7/c8/876602cbc96469911f0939f703453c1157b0c826ecb05bdd32e023397d4e/tornado-6.5.5-cp39-abi3-win_arm64.whl", hash = "sha256:2c9a876e094109333f888539ddb2de4361743e5d21eece20688e3e351e4990a6", size = 448016, upload-time = "2026-03-10T21:31:00.43Z" }, +] + [[package]] name = "tqdm" version = "4.67.3" diff --git a/app-eval-web/src/api.ts b/app-eval-web/src/api.ts index 1700ccc..b374936 100644 --- a/app-eval-web/src/api.ts +++ b/app-eval-web/src/api.ts @@ -1,11 +1,19 @@ import { apiBase, apiKey } from "./config"; +import { parseApiError } from "./parseApiError"; export { apiBase, apiKey }; export async function api( path: string, init?: RequestInit, -): Promise<{ ok: boolean; data?: T; error?: string; status: number }> { +): Promise<{ + ok: boolean; + data?: T; + error?: string; + errorCode?: string; + requestId?: string; + status: number; +}> { const url = `${apiBase}${path.startsWith("/") ? path : `/${path}`}`; try { const r = await fetch(url, { @@ -25,16 +33,13 @@ export async function api( /* ignore */ } if (!r.ok) { + const parsed = parseApiError(data, text || r.statusText); return { ok: false, status: r.status, - error: - typeof data === "object" && - data && - "detail" in (data as object) && - data !== null - ? String((data as unknown as { detail: unknown }).detail) - : text || r.statusText, + error: parsed.message, + errorCode: parsed.errorCode, + requestId: parsed.requestId, }; } return { ok: true, data, status: r.status }; diff --git a/app-eval-web/src/mainApi.ts b/app-eval-web/src/mainApi.ts index aeb130d..1bdc7a7 100644 --- a/app-eval-web/src/mainApi.ts +++ b/app-eval-web/src/mainApi.ts @@ -1,4 +1,5 @@ import { mainApiBase } from "./config"; +import { parseApiError } from "./parseApiError"; export const LIVE_ACCESS_TOKEN_KEY = "life-echo-eval-live-access-token"; export const LIVE_REFRESH_TOKEN_KEY = "life-echo-eval-live-refresh-token"; @@ -11,7 +12,14 @@ export function mainApiUrl(path: string): string { export async function mainApiFetch( path: string, init: RequestInit & { accessToken: string | null; jsonBody?: unknown }, -): Promise<{ ok: boolean; status: number; data?: T; error?: string }> { +): Promise<{ + ok: boolean; + status: number; + data?: T; + error?: string; + errorCode?: string; + requestId?: string; +}> { const { accessToken, jsonBody, ...rest } = init; const headers: Record = { ...(rest.headers as Record | undefined), @@ -38,14 +46,14 @@ export async function mainApiFetch( /* ignore */ } if (!r.ok) { - const detail = - typeof data === "object" && - data && - "detail" in (data as object) && - (data as { detail: unknown }).detail !== undefined - ? String((data as { detail: unknown }).detail) - : text || r.statusText; - return { ok: false, status: r.status, error: detail }; + const parsed = parseApiError(data, text || r.statusText); + return { + ok: false, + status: r.status, + error: parsed.message, + errorCode: parsed.errorCode, + requestId: parsed.requestId, + }; } return { ok: true, data, status: r.status }; } catch (e: unknown) { diff --git a/app-eval-web/src/pages/LiveTesterPage.tsx b/app-eval-web/src/pages/LiveTesterPage.tsx index ff59abf..4c0a623 100644 --- a/app-eval-web/src/pages/LiveTesterPage.tsx +++ b/app-eval-web/src/pages/LiveTesterPage.tsx @@ -366,8 +366,15 @@ export default function LiveTesterPage() { return; } if (parsed.type === "error") { - const text = String(d.message ?? d.detail ?? "error"); - setLiveEvents((prev) => [...prev, { id, kind: "error", text, at }]); + const text = String(d.message ?? "error"); + const code = + typeof d.error_code === "string" + ? d.error_code + : typeof d.code === "string" + ? d.code + : undefined; + const display = code ? `[${code}] ${text}` : text; + setLiveEvents((prev) => [...prev, { id, kind: "error", text: display, at }]); return; } if (parsed.type === "memoir_update") { diff --git a/app-eval-web/src/parseApiError.test.ts b/app-eval-web/src/parseApiError.test.ts new file mode 100644 index 0000000..9bc3955 --- /dev/null +++ b/app-eval-web/src/parseApiError.test.ts @@ -0,0 +1,49 @@ +import { describe, expect, it } from "vitest"; + +import { parseApiError, parseApiErrorMessage } from "./parseApiError"; + +describe("parseApiError", () => { + it("prefers unified message field", () => { + expect( + parseApiError( + { error_code: "NOT_FOUND", message: "资源不存在", request_id: "r1" }, + "fallback", + ), + ).toEqual({ + message: "资源不存在", + errorCode: "NOT_FOUND", + requestId: "r1", + }); + }); + + it("uses fallback when message is missing", () => { + expect(parseApiError({ error_code: "INTERNAL_ERROR" }, "HTTP 500")).toEqual({ + message: "HTTP 500", + errorCode: "INTERNAL_ERROR", + requestId: undefined, + }); + }); + + it("uses fallback when body is empty", () => { + expect(parseApiError(undefined, "HTTP 500")).toEqual({ message: "HTTP 500" }); + }); + + it("falls back to legacy detail string", () => { + expect(parseApiError({ detail: "请求无效" }, "HTTP 400")).toEqual({ + message: "请求无效", + errorCode: undefined, + requestId: undefined, + }); + }); +}); + +describe("parseApiErrorMessage", () => { + it("returns message from parseApiError", () => { + expect( + parseApiErrorMessage( + { error_code: "NOT_FOUND", message: "资源不存在", request_id: "r1" }, + "fallback", + ), + ).toBe("资源不存在"); + }); +}); diff --git a/app-eval-web/src/parseApiError.ts b/app-eval-web/src/parseApiError.ts new file mode 100644 index 0000000..57ee217 --- /dev/null +++ b/app-eval-web/src/parseApiError.ts @@ -0,0 +1,53 @@ +/** Parse unified API error body `{ message, error_code, request_id }` or legacy `{ detail }`. */ +export interface ParsedApiError { + message: string; + errorCode?: string; + requestId?: string; +} + +function messageFromDetail(detail: unknown): string | null { + if (typeof detail === "string" && detail.trim()) { + return detail.trim(); + } + if (Array.isArray(detail)) { + const parts = detail + .map((item) => { + if (typeof item === "string") return item; + if (item && typeof item === "object" && "msg" in item) { + const msg = (item as { msg?: string }).msg; + return typeof msg === "string" ? msg : ""; + } + return ""; + }) + .filter(Boolean); + if (parts.length) return parts.join("; "); + } + return null; +} + +export function parseApiError( + data: unknown, + fallback: string, +): ParsedApiError { + if (typeof data === "object" && data !== null) { + const body = data as Record; + const unified = + typeof body.message === "string" && body.message.trim() + ? body.message.trim() + : null; + const message = unified ?? messageFromDetail(body.detail) ?? fallback; + return { + message, + errorCode: + typeof body.error_code === "string" ? body.error_code : undefined, + requestId: + typeof body.request_id === "string" ? body.request_id : undefined, + }; + } + return { message: fallback }; +} + +/** Backward-compatible helper for call sites that only need the message string. */ +export function parseApiErrorMessage(data: unknown, fallback: string): string { + return parseApiError(data, fallback).message; +} diff --git a/app-expo/.env.example b/app-expo/.env.example index f952c6c..7eea498 100644 --- a/app-expo/.env.example +++ b/app-expo/.env.example @@ -12,7 +12,7 @@ # 变量在构建时注入;修改后需重新 prebuild/打包客户端。 # # 助手朗读:无独立 EXPO_PUBLIC_* TTS 开关。会话页顶栏在每轮 WebSocket 中带 `tts_this_turn`; -# 服务端是否具备合成能力见 api/.env 中 ENABLE_TTS 等(模板见 api/.env.example)。 +# 服务端是否具备合成能力见 api/config/*.toml 中 [deploy] enable_tts(密钥见 api/.env.example)。 # --- development(本地,关于页显示版本 + API)--- # APP_VARIANT=development diff --git a/app-expo/src/app/(auth)/login.tsx b/app-expo/src/app/(auth)/login.tsx index 42af415..d2eb9d8 100644 --- a/app-expo/src/app/(auth)/login.tsx +++ b/app-expo/src/app/(auth)/login.tsx @@ -54,12 +54,8 @@ export default function LoginScreen() { return () => clearInterval(id); }, [countdown]); - const handleGetCode = useCallback(() => { - if (!termsAccepted) { - setShowTermsDialog(true); - return; - } - if (!canGetCode) return; + const requestSmsCode = useCallback(() => { + if (phone.length !== PHONE_LENGTH || countdown > 0) return; Keyboard.dismiss(); sendCode.mutate( { phone, purpose: 'login' }, @@ -68,7 +64,21 @@ export default function LoginScreen() { onError: () => {}, }, ); - }, [canGetCode, phone, sendCode, termsAccepted]); + }, [phone, countdown, sendCode]); + + const handleGetCode = useCallback(() => { + if (!termsAccepted) { + setShowTermsDialog(true); + return; + } + if (!canGetCode) return; + requestSmsCode(); + }, [canGetCode, requestSmsCode, termsAccepted]); + + const handleTermsAgree = useCallback(() => { + setTermsAccepted(true); + requestSmsCode(); + }, [requestSmsCode]); const handleLogin = useCallback(() => { if (!canLogin) return; @@ -294,7 +304,9 @@ export default function LoginScreen() { onOpenChange={setShowTermsDialog} title={t('login.termsRequiredTitle')} description={t('login.termsRequired')} - confirmLabel={t('login.termsRequiredConfirm')} + cancelLabel={t('login.termsRequiredCancel')} + confirmLabel={t('login.termsRequiredAgree')} + onConfirm={handleTermsAgree} /> diff --git a/app-expo/src/components/info-dialog.tsx b/app-expo/src/components/info-dialog.tsx index 81026b5..eb229ef 100644 --- a/app-expo/src/components/info-dialog.tsx +++ b/app-expo/src/components/info-dialog.tsx @@ -3,6 +3,7 @@ import React from 'react'; import { AlertDialog, AlertDialogAction, + AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, @@ -18,10 +19,12 @@ export interface InfoDialogProps { title: string; description: string; confirmLabel?: string; + cancelLabel?: string; + onConfirm?: () => void; } /** - * Reusable info dialog with a single confirm button. + * Reusable info dialog with a single confirm button, or cancel + confirm pair. * Use for prompts like "please agree to terms", "please fill in X", etc. */ export function InfoDialog({ @@ -30,6 +33,8 @@ export function InfoDialog({ title, description, confirmLabel = 'OK', + cancelLabel, + onConfirm, }: InfoDialogProps) { const colors = useThemeColors(); @@ -41,11 +46,24 @@ export function InfoDialog({ {description} - - - {confirmLabel} - - + {cancelLabel ? ( + <> + + {cancelLabel} + + + + {confirmLabel} + + + + ) : ( + + + {confirmLabel} + + + )} diff --git a/app-expo/src/core/api/client.ts b/app-expo/src/core/api/client.ts index 12d9c74..693918d 100644 --- a/app-expo/src/core/api/client.ts +++ b/app-expo/src/core/api/client.ts @@ -1,5 +1,6 @@ import { config } from '@/core/config'; +import { parseApiError } from './parseApiError'; import { ApiError, AuthError, NetworkError, type ApiErrorBody } from './types'; type HttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH'; @@ -32,49 +33,15 @@ export function initApiClient(d: ClientDeps) { deps = d; } -let isRefreshing = false; -let refreshQueue: { - resolve: (token: string | null) => void; - reject: (err: unknown) => void; -}[] = []; - -function drainQueue(token: string | null, error?: unknown) { - const queue = refreshQueue; - refreshQueue = []; - for (const { resolve, reject } of queue) { - if (error) reject(error); - else resolve(token); - } -} - -async function waitForRefresh(): Promise { - return new Promise((resolve, reject) => { - refreshQueue.push({ resolve, reject }); - }); -} - async function handleTokenRefresh(): Promise { if (!deps) throw new AuthError('API client not initialized'); - if (isRefreshing) return waitForRefresh(); - - isRefreshing = true; - try { - const ok = await deps.refreshTokens(); - if (!ok) { - deps.onAuthFailure(); - drainQueue(null, new AuthError()); - throw new AuthError(); - } - const newToken = await deps.getAccessToken(); - drainQueue(newToken); - return newToken; - } catch (err) { - drainQueue(null, err); - throw err; - } finally { - isRefreshing = false; + const ok = await deps.refreshTokens(); + if (!ok) { + deps.onAuthFailure(); + throw new AuthError(); } + return deps.getAccessToken(); } /** FormData detection without relying on global FormData (RN-safe). */ @@ -204,24 +171,12 @@ async function request( if (!response.ok) { const body = await parseErrorBody(response); - let message = body?.message ?? `HTTP ${response.status}`; - if (body?.detail != null) { - if (typeof body.detail === 'string') { - message = body.detail; - } else if (Array.isArray(body.detail)) { - const parts = body.detail - .map((d: unknown) => - typeof d === 'string' ? d : ((d as { msg?: string })?.msg ?? ''), - ) - .filter(Boolean); - if (parts.length) message = parts.join('; '); - } - } + const parsed = parseApiError(body, `HTTP ${response.status}`); throw new ApiError( - message, + parsed.message, response.status, - body?.error_code, - body?.request_id, + parsed.errorCode, + parsed.requestId, ); } diff --git a/app-expo/src/core/api/parseApiError.ts b/app-expo/src/core/api/parseApiError.ts new file mode 100644 index 0000000..2e2343e --- /dev/null +++ b/app-expo/src/core/api/parseApiError.ts @@ -0,0 +1,49 @@ +import type { ApiErrorBody } from './types'; + +export interface ParsedApiError { + message: string; + errorCode?: string; + requestId?: string; +} + +function messageFromDetail(detail: ApiErrorBody['detail']): string | null { + if (typeof detail === 'string' && detail.trim()) { + return detail.trim(); + } + if (Array.isArray(detail)) { + const parts = detail + .map((item) => { + if (typeof item === 'string') return item; + if (item && typeof item === 'object' && 'msg' in item) { + const msg = (item as { msg?: string }).msg; + return typeof msg === 'string' ? msg : ''; + } + return ''; + }) + .filter(Boolean); + if (parts.length) return parts.join('; '); + } + return null; +} + +/** Parse unified `{ message, error_code, request_id }` or legacy FastAPI `{ detail }`. */ +export function parseApiError( + body: ApiErrorBody | null | undefined, + fallback: string, +): ParsedApiError { + if (body == null) { + return { message: fallback }; + } + const unified = + typeof body.message === 'string' && body.message.trim() + ? body.message.trim() + : null; + const message = unified ?? messageFromDetail(body.detail) ?? fallback; + return { + message, + errorCode: + typeof body.error_code === 'string' ? body.error_code : undefined, + requestId: + typeof body.request_id === 'string' ? body.request_id : undefined, + }; +} diff --git a/app-expo/src/core/api/types.ts b/app-expo/src/core/api/types.ts index 806c0d8..7bdda83 100644 --- a/app-expo/src/core/api/types.ts +++ b/app-expo/src/core/api/types.ts @@ -42,7 +42,7 @@ export class NetworkError extends Error { export interface ApiErrorBody { error_code?: string; message?: string; - /** FastAPI HTTPException uses "detail" for error message */ - detail?: string | string[]; + /** Legacy FastAPI HTTPException body */ + detail?: string | Array; request_id?: string; } diff --git a/app-expo/src/core/auth/refresh-lock.ts b/app-expo/src/core/auth/refresh-lock.ts new file mode 100644 index 0000000..1abddf9 --- /dev/null +++ b/app-expo/src/core/auth/refresh-lock.ts @@ -0,0 +1,47 @@ +/** Serialize refresh-token HTTP calls across api client and providers. */ + +let isRefreshing = false; +const waitQueue: { + resolve: (value: unknown) => void; + reject: (reason?: unknown) => void; +}[] = []; + +function drainQueue(error?: unknown, value?: unknown) { + const queue = waitQueue.splice(0); + for (const { resolve, reject } of queue) { + if (error !== undefined) reject(error); + else resolve(value); + } +} + +/** + * Ensures only one refresh runs at a time; concurrent callers await the same result. + */ +export async function withRefreshLock(fn: () => Promise): Promise { + if (isRefreshing) { + return new Promise((resolve, reject) => { + waitQueue.push({ + resolve: (value) => resolve(value as T), + reject, + }); + }); + } + + isRefreshing = true; + try { + const result = await fn(); + drainQueue(undefined, result); + return result; + } catch (error) { + drainQueue(error); + throw error; + } finally { + isRefreshing = false; + } +} + +/** Test-only reset for unit tests. */ +export function resetRefreshLockForTests(): void { + isRefreshing = false; + waitQueue.length = 0; +} diff --git a/app-expo/src/core/providers.tsx b/app-expo/src/core/providers.tsx index 674c7e6..c7ad30a 100644 --- a/app-expo/src/core/providers.tsx +++ b/app-expo/src/core/providers.tsx @@ -5,6 +5,7 @@ import { AppSettingsProvider } from '@/core/app-settings-context'; import { MemoirReadingSettingsProvider } from '@/core/memoir-reading-settings-context'; import { NetworkError } from '@/core/api/types'; import { tokenManager } from '@/core/auth/token-manager'; +import { withRefreshLock } from '@/core/auth/refresh-lock'; import { config } from '@/core/config'; import { authKeys } from '@/features/auth/auth-query-keys'; import { AppQueryProvider, queryClient } from '@/core/query'; @@ -15,7 +16,7 @@ import { AppQueryProvider, queryClient } from '@/core/query'; * Throws NetworkError on transport-level failures so the caller * can distinguish "session dead" from "network down". */ -async function refreshTokens(): Promise { +async function performRefreshFetch(): Promise { const refreshToken = await tokenManager.getRefreshToken(); if (!refreshToken) return false; @@ -43,6 +44,10 @@ async function refreshTokens(): Promise { return true; } +async function refreshTokens(): Promise { + return withRefreshLock(performRefreshFetch); +} + /** * Called by the API client when token refresh is explicitly rejected. * Must synchronously flip query caches so useSession() immediately diff --git a/app-expo/src/features/conversation/realtime-session.ts b/app-expo/src/features/conversation/realtime-session.ts index bafc7e2..6b95e67 100644 --- a/app-expo/src/features/conversation/realtime-session.ts +++ b/app-expo/src/features/conversation/realtime-session.ts @@ -166,11 +166,7 @@ export class RealtimeSession { if (!this.assistantTurnTtsSync && this.streamingBuffer.trim().length > 0) { this.onStreamingText?.(this.streamingBuffer, false); } - if ( - this.uiOwner && - this.pendingTopicSuggestionsPayload && - this.onTopicSuggestions - ) { + if (this.pendingTopicSuggestionsPayload && this.onTopicSuggestions) { const p = this.pendingTopicSuggestionsPayload; this.pendingTopicSuggestionsPayload = null; this.onTopicSuggestions(p); @@ -413,7 +409,7 @@ export class RealtimeSession { stage: event.stage, suggestions: event.suggestions, }; - if (this.uiOwner && this.onTopicSuggestions) { + if (this.onTopicSuggestions && this.uiOwner) { this.pendingTopicSuggestionsPayload = null; this.onTopicSuggestions(payload); } else { diff --git a/app-expo/src/i18n/generated/resources.ts b/app-expo/src/i18n/generated/resources.ts index a1481f9..9fbfc65 100644 --- a/app-expo/src/i18n/generated/resources.ts +++ b/app-expo/src/i18n/generated/resources.ts @@ -31,7 +31,8 @@ interface Resources { "termsAnd": "and", "termsIntro": "I agree to the", "termsRequired": "Please agree to the User Agreement and Privacy Policy first", - "termsRequiredConfirm": "OK", + "termsRequiredAgree": "Agree", + "termsRequiredCancel": "Cancel", "termsRequiredTitle": "Agreement Required", "userAgreement": "User Agreement", "welcomeSubtitle": "Some lives grow richer the more you savor them.", @@ -177,7 +178,12 @@ interface Resources { "profile": { "about": { "aboutUs": "About Us", - "title": "About" + "appName": "Life Echo", + "appSubtitle": "岁月留书", + "backend": "API endpoint", + "tagline": "Capture your life story and turn memories into a book.", + "title": "About", + "version": "Version {{version}}" }, "appExperience": { "language": "Language", diff --git a/app-expo/src/i18n/locales/en/auth.json b/app-expo/src/i18n/locales/en/auth.json index fe95423..1d011e3 100644 --- a/app-expo/src/i18n/locales/en/auth.json +++ b/app-expo/src/i18n/locales/en/auth.json @@ -11,7 +11,8 @@ "termsAnd": "and", "termsIntro": "I agree to the", "termsRequired": "Please agree to the User Agreement and Privacy Policy first", - "termsRequiredConfirm": "OK", + "termsRequiredAgree": "Agree", + "termsRequiredCancel": "Cancel", "termsRequiredTitle": "Agreement Required", "userAgreement": "User Agreement", "welcomeSubtitle": "Some lives grow richer the more you savor them.", diff --git a/app-expo/src/i18n/locales/en/profile.json b/app-expo/src/i18n/locales/en/profile.json index d042542..54ef913 100644 --- a/app-expo/src/i18n/locales/en/profile.json +++ b/app-expo/src/i18n/locales/en/profile.json @@ -2,7 +2,7 @@ "about": { "aboutUs": "About Us", "appName": "Life Echo", - "appSubtitle": "岁月时书", + "appSubtitle": "岁月留书", "backend": "API endpoint", "tagline": "Capture your life story and turn memories into a book.", "title": "About", diff --git a/app-expo/src/i18n/locales/zh/app.json b/app-expo/src/i18n/locales/zh/app.json index ac7464e..39a6186 100644 --- a/app-expo/src/i18n/locales/zh/app.json +++ b/app-expo/src/i18n/locales/zh/app.json @@ -4,7 +4,7 @@ "system": "跟随系统", "zh": "中文" }, - "name": "岁月时书", + "name": "岁月留书", "tabs": { "conversations": "对话", "explore": "探索", diff --git a/app-expo/src/i18n/locales/zh/auth.json b/app-expo/src/i18n/locales/zh/auth.json index d74f2cb..9e509c3 100644 --- a/app-expo/src/i18n/locales/zh/auth.json +++ b/app-expo/src/i18n/locales/zh/auth.json @@ -11,7 +11,8 @@ "termsAnd": "和", "termsIntro": "我已阅读并同意", "termsRequired": "请先同意用户协议和隐私政策", - "termsRequiredConfirm": "知道了", + "termsRequiredAgree": "同意", + "termsRequiredCancel": "取消", "termsRequiredTitle": "需要同意协议", "userAgreement": "《用户协议》", "welcomeSubtitle": "有些人生,越嚼越有味道。", diff --git a/app-expo/src/i18n/locales/zh/profile.json b/app-expo/src/i18n/locales/zh/profile.json index e9e05fd..36996f1 100644 --- a/app-expo/src/i18n/locales/zh/profile.json +++ b/app-expo/src/i18n/locales/zh/profile.json @@ -1,7 +1,7 @@ { "about": { "aboutUs": "关于我们", - "appName": "岁月时书", + "appName": "岁月留书", "appSubtitle": "Life Echo", "backend": "连接的后端", "tagline": "记录你的人生故事,让回忆成书。", diff --git a/app-expo/tests/core/api/parseApiError.test.ts b/app-expo/tests/core/api/parseApiError.test.ts new file mode 100644 index 0000000..c8504e8 --- /dev/null +++ b/app-expo/tests/core/api/parseApiError.test.ts @@ -0,0 +1,55 @@ +import { parseApiError } from '@/core/api/parseApiError'; + +describe('parseApiError', () => { + it('prefers unified message field', () => { + expect( + parseApiError( + { + error_code: 'NOT_FOUND', + message: '资源不存在', + request_id: 'r1', + }, + 'fallback', + ), + ).toEqual({ + message: '资源不存在', + errorCode: 'NOT_FOUND', + requestId: 'r1', + }); + }); + + it('falls back to legacy detail string', () => { + expect(parseApiError({ detail: '请求无效' }, 'HTTP 400')).toEqual({ + message: '请求无效', + errorCode: undefined, + requestId: undefined, + }); + }); + + it('parses legacy detail validation list', () => { + expect( + parseApiError( + { + detail: [{ loc: ['body', 'phone'], msg: 'field required' }], + }, + 'HTTP 422', + ), + ).toEqual({ + message: 'field required', + errorCode: undefined, + requestId: undefined, + }); + }); + + it('uses fallback when message is missing', () => { + expect(parseApiError({ error_code: 'INTERNAL_ERROR' }, 'HTTP 500')).toEqual({ + message: 'HTTP 500', + errorCode: 'INTERNAL_ERROR', + requestId: undefined, + }); + }); + + it('uses fallback when body is null', () => { + expect(parseApiError(null, 'HTTP 502')).toEqual({ message: 'HTTP 502' }); + }); +}); diff --git a/app-expo/tests/core/auth/refresh-lock.test.ts b/app-expo/tests/core/auth/refresh-lock.test.ts new file mode 100644 index 0000000..7cd3e0f --- /dev/null +++ b/app-expo/tests/core/auth/refresh-lock.test.ts @@ -0,0 +1,46 @@ +import { + resetRefreshLockForTests, + withRefreshLock, +} from '@/core/auth/refresh-lock'; + +describe('withRefreshLock', () => { + beforeEach(() => { + resetRefreshLockForTests(); + }); + + it('runs concurrent callers through a single in-flight refresh', async () => { + let runs = 0; + const fn = jest.fn(async () => { + runs += 1; + await new Promise((resolve) => setTimeout(resolve, 20)); + return 'ok'; + }); + + const [a, b, c] = await Promise.all([ + withRefreshLock(fn), + withRefreshLock(fn), + withRefreshLock(fn), + ]); + + expect(a).toBe('ok'); + expect(b).toBe('ok'); + expect(c).toBe('ok'); + expect(fn).toHaveBeenCalledTimes(1); + expect(runs).toBe(1); + }); + + it('propagates errors to queued waiters', async () => { + const fn = jest.fn(async () => { + throw new Error('refresh failed'); + }); + + const results = await Promise.allSettled([ + withRefreshLock(fn), + withRefreshLock(fn), + ]); + + expect(fn).toHaveBeenCalledTimes(1); + expect(results[0].status).toBe('rejected'); + expect(results[1].status).toBe('rejected'); + }); +}); diff --git a/assets/demo.html b/assets/demo.html index 8a22243..9ffa28f 100644 --- a/assets/demo.html +++ b/assets/demo.html @@ -3,7 +3,7 @@ - 岁月时书 - Demo + 岁月留书 - Demo