""" 数据库引擎、会话工厂、Base 声明基类。 - 异步(FastAPI)与同步(Celery、Alembic、脚本)均使用 psycopg(psycopg3)。 - database_url 由 settings 指定;若为 postgresql://... 则在此拼接为 postgresql+psycopg:// 以使用 psycopg3。 事务规则: - get_async_db() 只负责创建和关闭 session,不自动 commit/rollback。 - 事务提交由 service 层显式调用 await db.commit()。 - repo 层禁止调用 commit() / rollback()。 """ from contextlib import contextmanager from typing import AsyncGenerator from sqlalchemy import create_engine, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, sessionmaker from app.core.config import settings class Base(DeclarativeBase): pass def utc_now(): """返回当前 UTC 时间(带时区信息)""" from datetime import datetime, timezone return datetime.now(timezone.utc) # ── Database URL(纯 postgresql:// 时拼接为 postgresql+psycopg://)──────────────── def ensure_psycopg_url(url: str) -> str: """若为 postgresql://... 则改为 postgresql+psycopg://...,否则原样返回。""" if url.startswith("postgresql://") and not url.startswith("postgresql+psycopg://"): return "postgresql+psycopg://" + url[len("postgresql://") :] return url def _database_url() -> str: return ensure_psycopg_url(settings.database_url) # ── Async engine & session (FastAPI) ─────────────────────────── async_engine = create_async_engine( _database_url(), echo=False, pool_size=5, max_overflow=10, pool_pre_ping=True, ) AsyncSessionLocal = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False ) async def get_async_db() -> AsyncGenerator[AsyncSession, None]: """Yield a clean async session — callers are responsible for commit/rollback.""" async with AsyncSessionLocal() as session: try: yield session finally: await session.close() # ── Sync engine & session (Celery, Alembic, scripts) ───────── sync_engine = create_engine( _database_url(), echo=False, pool_size=5, max_overflow=10, pool_pre_ping=True, ) SyncSessionLocal = sessionmaker( bind=sync_engine, autocommit=False, autoflush=False, expire_on_commit=False, ) SessionLocal = SyncSessionLocal def init_db_schema() -> None: """根据当前已注册的 Model 创建缺失的表(开发/空库时自动建表,与 api-bak 行为一致)。 不执行 DROP/ALTER,仅 create_all;生产环境仍建议用 Alembic 做可控迁移。 memory 模块依赖 pgvector,建表前先确保扩展已启用。 """ with sync_engine.connect() as conn: conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) conn.commit() Base.metadata.create_all(bind=sync_engine) @contextmanager def get_sync_db(): """Context-managed synchronous session for Celery tasks.""" db = SessionLocal() try: yield db except Exception: db.rollback() raise finally: db.close()