Files
life-echo/api/app/core/dependencies.py
Sully 92b7848c48 feat/tts (#15)
Co-authored-by: Kevin <kevin@brighteng.org>
2026-03-19 09:11:25 +08:00

178 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
全局共享依赖:
- 认证get_current_user / get_optional_user
- Port DI factoryget_sms_sender / get_llm_provider / get_tts_provider / ...
"""
from functools import lru_cache
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.db import get_async_db
from app.core.security import verify_token
from app.ports.asr import ASRProvider
from app.ports.embedding import EmbeddingProvider
from app.ports.image_gen import ImageGenerator
from app.ports.llm import LLMProvider
from app.ports.sms import SmsSender
from app.ports.storage import ObjectStorage
from app.ports.tts import TTSProvider
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
# ── Port DI factories ───────────────────────────────────────
@lru_cache
def get_sms_sender() -> SmsSender:
from app.adapters.sms.tencent import TencentSmsSender
return TencentSmsSender(
secret_id=settings.tencent_sms_secret_id,
secret_key=settings.tencent_sms_secret_key,
sdk_app_id=settings.tencent_sms_sdk_app_id,
sign_name=settings.tencent_sms_sign_name,
template_id=settings.tencent_sms_template_id,
template_param_count=settings.tencent_sms_template_param_count,
)
@lru_cache
def get_llm_provider() -> LLMProvider:
from app.adapters.llm.deepseek import DeepSeekLLMProvider
api_key = settings.deepseek_api_key or settings.llm_api_key
base_url = settings.deepseek_base_url or settings.llm_base_url
model = settings.deepseek_model or settings.llm_model or "deepseek-chat"
return DeepSeekLLMProvider(
api_key=api_key,
base_url=base_url,
model=model,
temperature=settings.llm_temperature,
)
@lru_cache
def get_tts_provider() -> TTSProvider:
if settings.tts_provider == "tencent":
from app.adapters.tts.tencent_tts import TencentTTSProvider
return TencentTTSProvider(
secret_id=settings.tencent_secret_id,
secret_key=settings.tencent_secret_key,
voice_type=settings.tts_voice_type,
codec=settings.tts_codec,
)
from app.adapters.tts.openai_tts import OpenAITTSProvider
return OpenAITTSProvider(api_key=settings.openai_api_key)
@lru_cache
def get_asr_provider() -> ASRProvider:
if settings.asr_provider == "tencent":
from app.adapters.asr.tencent_asr import TencentASRProvider
return TencentASRProvider(
secret_id=settings.tencent_secret_id,
secret_key=settings.tencent_secret_key,
)
from app.adapters.asr.whisper_local import WhisperASRProvider
return WhisperASRProvider(
model_size=settings.asr_model_size,
device=settings.asr_device,
compute_type=settings.asr_compute_type,
cache_dir=settings.asr_model_cache_dir,
)
@lru_cache
def get_image_generator() -> ImageGenerator:
from app.adapters.image_gen.liblib import LiblibImageGenerator
return LiblibImageGenerator(
access_key=settings.liblib_access_key,
secret_key=settings.liblib_secret_key,
base_url=settings.liblib_base_url,
template_uuid=settings.liblib_template_uuid,
poll_interval=settings.memoir_image_poll_interval,
max_attempts=settings.memoir_image_max_attempts,
)
@lru_cache
def get_object_storage() -> ObjectStorage:
from app.adapters.storage.tencent_cos import TencentCosStorage
return TencentCosStorage(
secret_id=settings.tencent_cos_secret_id,
secret_key=settings.tencent_cos_secret_key,
region=settings.tencent_cos_region,
bucket=settings.tencent_cos_bucket,
base_url=settings.tencent_cos_base_url,
token=settings.tencent_cos_token,
)
@lru_cache
def get_embedding_provider() -> EmbeddingProvider:
from app.adapters.embedding.openai import OpenAIEmbeddingProvider
api_key = settings.openai_api_key or settings.deepseek_api_key
return OpenAIEmbeddingProvider(api_key=api_key)
# ── Auth dependencies ────────────────────────────────────────
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_async_db),
):
"""Resolve authenticated user from JWT access token."""
from app.features.user.models import User # deferred to avoid circular import
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
payload = verify_token(token)
if payload is None:
raise credentials_exception
user_id: str | None = payload.get("sub")
if user_id is None:
raise credentials_exception
if payload.get("type") != "access":
raise credentials_exception
user = await db.get(User, user_id)
if user is None:
raise credentials_exception
return user
async def get_optional_user(
token: Optional[str] = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_async_db),
):
"""Return user if a valid token is provided, else None."""
if token is None:
return None
try:
return await get_current_user(token, db)
except HTTPException:
return None