""" 全局共享依赖: - 认证:get_current_user / get_optional_user - Port DI factory:get_sms_sender / get_llm_provider / get_tts_provider / ... """ from functools import lru_cache from typing import Optional from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.db import get_async_db from app.core.security import verify_token from app.ports.asr import ASRProvider from app.ports.embedding import EmbeddingProvider from app.ports.image_gen import ImageGenerator from app.ports.llm import LLMProvider from app.ports.sms import SmsSender from app.ports.storage import ObjectStorage from app.ports.tts import TTSProvider oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login") # ── Port DI factories ─────────────────────────────────────── @lru_cache def get_sms_sender() -> SmsSender: from app.adapters.sms.tencent import TencentSmsSender return TencentSmsSender( secret_id=settings.tencent_sms_secret_id, secret_key=settings.tencent_sms_secret_key, sdk_app_id=settings.tencent_sms_sdk_app_id, sign_name=settings.tencent_sms_sign_name, template_id=settings.tencent_sms_template_id, template_param_count=settings.tencent_sms_template_param_count, ) @lru_cache def get_llm_provider() -> LLMProvider: from app.adapters.llm.deepseek import DeepSeekLLMProvider api_key = settings.deepseek_api_key or settings.llm_api_key base_url = settings.deepseek_base_url or settings.llm_base_url model = settings.deepseek_model or settings.llm_model or "deepseek-chat" return DeepSeekLLMProvider( api_key=api_key, base_url=base_url, model=model, temperature=settings.llm_temperature, ) @lru_cache def get_llm_provider_fast() -> LLMProvider: """快档位:与默认共用密钥与 base_url,仅模型名可单独配置。""" fast = (settings.llm_fast_model or "").strip() if not fast: return get_llm_provider() from app.adapters.llm.deepseek import DeepSeekLLMProvider api_key = settings.deepseek_api_key or settings.llm_api_key base_url = settings.deepseek_base_url or settings.llm_base_url return DeepSeekLLMProvider( api_key=api_key, base_url=base_url, model=fast, temperature=settings.llm_temperature, ) @lru_cache def get_tts_provider() -> TTSProvider: if settings.tts_provider == "tencent": from app.adapters.tts.tencent_tts import TencentTTSProvider return TencentTTSProvider( secret_id=settings.tencent_secret_id, secret_key=settings.tencent_secret_key, voice_type=settings.tts_voice_type, codec=settings.tts_codec, ) from app.adapters.tts.openai_tts import OpenAITTSProvider return OpenAITTSProvider(api_key=settings.openai_api_key) @lru_cache def get_asr_provider() -> ASRProvider: if settings.asr_provider == "tencent": from app.adapters.asr.tencent_asr import TencentASRProvider return TencentASRProvider( secret_id=settings.tencent_secret_id, secret_key=settings.tencent_secret_key, ) from app.adapters.asr.whisper_local import WhisperASRProvider return WhisperASRProvider( model_size=settings.asr_model_size, device=settings.asr_device, compute_type=settings.asr_compute_type, cache_dir=settings.asr_model_cache_dir, ) @lru_cache def get_image_generator() -> ImageGenerator: from app.adapters.image_gen.liblib import LiblibImageGenerator return LiblibImageGenerator( access_key=settings.liblib_access_key, secret_key=settings.liblib_secret_key, base_url=settings.liblib_base_url, template_uuid=settings.liblib_template_uuid, poll_interval=settings.memoir_image_poll_interval, max_attempts=settings.memoir_image_max_attempts, ) @lru_cache def get_object_storage() -> ObjectStorage: from app.adapters.storage.tencent_cos import TencentCosStorage return TencentCosStorage( secret_id=settings.tencent_cos_secret_id, secret_key=settings.tencent_cos_secret_key, region=settings.tencent_cos_region, bucket=settings.tencent_cos_bucket, base_url=settings.tencent_cos_base_url, token=settings.tencent_cos_token, ) @lru_cache def get_embedding_provider() -> EmbeddingProvider: from app.adapters.embedding.zhipu import ZhipuEmbeddingProvider return ZhipuEmbeddingProvider( api_key=settings.zhipu_api_key, base_url=settings.embedding_base_url or None, model=settings.embedding_model, ) # ── Auth dependencies ──────────────────────────────────────── async def get_current_user( token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_async_db), ): """Resolve authenticated user from JWT access token.""" from app.features.user.models import User # deferred to avoid circular import credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无法验证凭据", headers={"WWW-Authenticate": "Bearer"}, ) payload = verify_token(token) if payload is None: raise credentials_exception user_id: str | None = payload.get("sub") if user_id is None: raise credentials_exception if payload.get("type") != "access": raise credentials_exception user = await db.get(User, user_id) if user is None: raise credentials_exception return user async def get_optional_user( token: Optional[str] = Depends(oauth2_scheme), db: AsyncSession = Depends(get_async_db), ): """Return user if a valid token is provided, else None.""" if token is None: return None try: return await get_current_user(token, db) except HTTPException: return None