Files
life-echo/api/app/core/dependencies.py
Kevin bb16d3a5c9 refactor(agents): 抽取阶段常量与对话上下文;快档 LLM;图片 prompt 可禁止回退
访谈与阶段
- 新增 app/agents/stage_constants.py:集中 CHAT_STAGES、章节分类/顺序、阶段到默认 memoir 类别等,与 MemoirState 默认槽位顺序对齐;减少散落在 prompts 内的重复常量。
- 新增 app/agents/chat/prompt_context.py:以 ChatPromptContext 汇总 guided 系统提示所需字段(阶段、槽位、轮次、人设、记忆证据、回复长度模式、背景声线、职业等),统一走 get_guided_conversation_prompt。
- 大幅收敛 app/agents/chat/prompts_conversation.py;调整 prompts.py、stage_prompts.py、stage_detection.py;同步 interview_agent、profile_agent、helpers 与 state_schema,使对话侧构造提示的方式一致、可测。

回忆录流水线
- memoir/prompts.py 删除已迁至 stage_constants / 独立模板的大段常量与图片占位相关逻辑;classification / extraction / fidelity / narrative agents 与 orchest(全量历史仍可用于计数,注入模型时按轮次与字符上限截断)、image_prompt_fallback_disabled。
- dependencies 增加 get_llm_provider_fast(LRU 缓存,可与默认共用密钥与 base_url)。

任务与编排
- memoir_tasks:prepare_batches 注入 llm_fast;开启独立快档模型时打结构化日志。
- chapter_cover_tasks、story_image_tasks:与图片 prompt / JSON 工具路径或策略变更对齐(import 与行为一致)。
- story_pipeline_sync 等小处同步。

其它核心
- langchain_llm、text_normalize 随上述调用链微调。

开发者体验
- .cursor/settings.json:启用 redis-development、postman 插件。

测试
- 新增 test_image_prompt_policy:覆盖「禁止回退」等图片 prompt 策略。
- 更新 test_interview_prompts、test_interview_reply_length、test_experience_regressions、test_json_and_memory_utils,匹配新常量位置、json_utils 与对话/长度行为。
2026-04-02 12:00:00 +08:00

199 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
全局共享依赖:
- 认证get_current_user / get_optional_user
- Port DI factoryget_sms_sender / get_llm_provider / get_tts_provider / ...
"""
from functools import lru_cache
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.db import get_async_db
from app.core.security import verify_token
from app.ports.asr import ASRProvider
from app.ports.embedding import EmbeddingProvider
from app.ports.image_gen import ImageGenerator
from app.ports.llm import LLMProvider
from app.ports.sms import SmsSender
from app.ports.storage import ObjectStorage
from app.ports.tts import TTSProvider
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
# ── Port DI factories ───────────────────────────────────────
@lru_cache
def get_sms_sender() -> SmsSender:
from app.adapters.sms.tencent import TencentSmsSender
return TencentSmsSender(
secret_id=settings.tencent_sms_secret_id,
secret_key=settings.tencent_sms_secret_key,
sdk_app_id=settings.tencent_sms_sdk_app_id,
sign_name=settings.tencent_sms_sign_name,
template_id=settings.tencent_sms_template_id,
template_param_count=settings.tencent_sms_template_param_count,
)
@lru_cache
def get_llm_provider() -> LLMProvider:
from app.adapters.llm.deepseek import DeepSeekLLMProvider
api_key = settings.deepseek_api_key or settings.llm_api_key
base_url = settings.deepseek_base_url or settings.llm_base_url
model = settings.deepseek_model or settings.llm_model or "deepseek-chat"
return DeepSeekLLMProvider(
api_key=api_key,
base_url=base_url,
model=model,
temperature=settings.llm_temperature,
)
@lru_cache
def get_llm_provider_fast() -> LLMProvider:
"""快档位:与默认共用密钥与 base_url仅模型名可单独配置。"""
fast = (settings.llm_fast_model or "").strip()
if not fast:
return get_llm_provider()
from app.adapters.llm.deepseek import DeepSeekLLMProvider
api_key = settings.deepseek_api_key or settings.llm_api_key
base_url = settings.deepseek_base_url or settings.llm_base_url
return DeepSeekLLMProvider(
api_key=api_key,
base_url=base_url,
model=fast,
temperature=settings.llm_temperature,
)
@lru_cache
def get_tts_provider() -> TTSProvider:
if settings.tts_provider == "tencent":
from app.adapters.tts.tencent_tts import TencentTTSProvider
return TencentTTSProvider(
secret_id=settings.tencent_secret_id,
secret_key=settings.tencent_secret_key,
voice_type=settings.tts_voice_type,
codec=settings.tts_codec,
)
from app.adapters.tts.openai_tts import OpenAITTSProvider
return OpenAITTSProvider(api_key=settings.openai_api_key)
@lru_cache
def get_asr_provider() -> ASRProvider:
if settings.asr_provider == "tencent":
from app.adapters.asr.tencent_asr import TencentASRProvider
return TencentASRProvider(
secret_id=settings.tencent_secret_id,
secret_key=settings.tencent_secret_key,
)
from app.adapters.asr.whisper_local import WhisperASRProvider
return WhisperASRProvider(
model_size=settings.asr_model_size,
device=settings.asr_device,
compute_type=settings.asr_compute_type,
cache_dir=settings.asr_model_cache_dir,
)
@lru_cache
def get_image_generator() -> ImageGenerator:
from app.adapters.image_gen.liblib import LiblibImageGenerator
return LiblibImageGenerator(
access_key=settings.liblib_access_key,
secret_key=settings.liblib_secret_key,
base_url=settings.liblib_base_url,
template_uuid=settings.liblib_template_uuid,
poll_interval=settings.memoir_image_poll_interval,
max_attempts=settings.memoir_image_max_attempts,
)
@lru_cache
def get_object_storage() -> ObjectStorage:
from app.adapters.storage.tencent_cos import TencentCosStorage
return TencentCosStorage(
secret_id=settings.tencent_cos_secret_id,
secret_key=settings.tencent_cos_secret_key,
region=settings.tencent_cos_region,
bucket=settings.tencent_cos_bucket,
base_url=settings.tencent_cos_base_url,
token=settings.tencent_cos_token,
)
@lru_cache
def get_embedding_provider() -> EmbeddingProvider:
from app.adapters.embedding.zhipu import ZhipuEmbeddingProvider
return ZhipuEmbeddingProvider(
api_key=settings.zhipu_api_key,
base_url=settings.embedding_base_url or None,
model=settings.embedding_model,
)
# ── Auth dependencies ────────────────────────────────────────
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_async_db),
):
"""Resolve authenticated user from JWT access token."""
from app.features.user.models import User # deferred to avoid circular import
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
payload = verify_token(token)
if payload is None:
raise credentials_exception
user_id: str | None = payload.get("sub")
if user_id is None:
raise credentials_exception
if payload.get("type") != "access":
raise credentials_exception
user = await db.get(User, user_id)
if user is None:
raise credentials_exception
return user
async def get_optional_user(
token: Optional[str] = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_async_db),
):
"""Return user if a valid token is provided, else None."""
if token is None:
return None
try:
return await get_current_user(token, db)
except HTTPException:
return None