""" 数据库引擎、会话工厂、Base 声明基类。 - 异步(FastAPI)与同步(Celery、Alembic、脚本)均使用 psycopg(psycopg3)。 - database_url 由 settings 指定;若为 postgresql://... 则在此拼接为 postgresql+psycopg:// 以使用 psycopg3。 事务规则: - get_async_db() 只负责创建和关闭 session,不自动 commit/rollback。 - service / Celery task 层优先使用 transactional() / transactional_sync() 管理多步写操作。 - repo 层禁止调用 commit() / rollback()。 transactional 语义: - transactional() / transactional_sync() 是顶层事务边界;成功 exit 时 commit 整个 session,异常时 rollback 整个 session。 - 不支持嵌套自身:同一 session 上连续两次 transactional() = 两次独立 commit(WS pipeline 分段持久化属于此模式)。 - 需要嵌套回滚时:在已开启的事务内使用 transactional_nested() / transactional_nested_sync()(基于 SQLAlchemy begin_nested() savepoint)。 - 选择指南:单步/整段业务原子提交 → transactional;长生命周期 session 内局部试错、可独立回滚的子步骤 → transactional_nested(必须在外层事务 active 期间)。 transactional_nested 示例(外层提交、内层失败仅回滚 savepoint):: async with transactional(session): session.add(parent_row) try: async with transactional_nested(session): await attempt_optional_side_effect(session) except RecoverableError: pass # savepoint rolled back; parent_row still commits """ from contextlib import asynccontextmanager, contextmanager from typing import AsyncGenerator from sqlalchemy import create_engine, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from app.core.config import settings class Base(DeclarativeBase): pass def utc_now(): """返回当前 UTC 时间(带时区信息)""" from datetime import datetime, timezone return datetime.now(timezone.utc) # ── Database URL(纯 postgresql:// 时拼接为 postgresql+psycopg://)──────────────── def ensure_psycopg_url(url: str) -> str: """若为 postgresql://... 则改为 postgresql+psycopg://...,否则原样返回。""" if url.startswith("postgresql://") and not url.startswith("postgresql+psycopg://"): return "postgresql+psycopg://" + url[len("postgresql://") :] return url def _database_url() -> str: return ensure_psycopg_url(settings.database_url) # ── Async engine & session (FastAPI) ─────────────────────────── async_engine = create_async_engine( _database_url(), echo=False, pool_size=5, max_overflow=10, pool_pre_ping=True, ) AsyncSessionLocal = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False ) async def get_async_db() -> AsyncGenerator[AsyncSession, None]: """Yield a clean async session — callers are responsible for commit/rollback.""" async with AsyncSessionLocal() as session: try: yield session finally: await session.close() @asynccontextmanager async def transactional(session: AsyncSession): """Top-level async transaction: commit on success, rollback on any exception. Do not nest transactional() on the same session; each call commits independently. For partial rollback within an active transaction, use transactional_nested(). """ try: yield session await session.commit() except Exception: await session.rollback() raise @asynccontextmanager async def transactional_nested(session: AsyncSession): """Savepoint boundary; roll back only this block on error. Must be used while the session already has an active transaction (e.g. inside transactional() before it commits, or after autobegin from a prior write). """ async with session.begin_nested(): yield session # ── Sync engine & session (Celery, Alembic, scripts) ───────── sync_engine = create_engine( _database_url(), echo=False, pool_size=5, max_overflow=10, pool_pre_ping=True, ) SyncSessionLocal = sessionmaker( bind=sync_engine, autocommit=False, autoflush=False, expire_on_commit=False, ) SessionLocal = SyncSessionLocal def init_db_schema() -> None: """根据当前已注册的 Model 创建缺失的表(开发/空库时自动建表,与 api-bak 行为一致)。 不执行 DROP/ALTER,仅 create_all;生产环境仍建议用 Alembic 做可控迁移。 memory 模块依赖 pgvector,建表前先确保扩展已启用。 """ with sync_engine.connect() as conn: conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) conn.commit() Base.metadata.create_all(bind=sync_engine) @contextmanager def transactional_sync(session: Session): """Top-level sync transaction: commit on success, rollback on any exception. Do not nest transactional_sync() on the same session; each call commits independently. For partial rollback within an active transaction, use transactional_nested_sync(). """ try: yield session session.commit() except Exception: session.rollback() raise @contextmanager def transactional_nested_sync(session: Session): """Savepoint boundary for sync Celery / scripts; roll back only this block on error.""" with session.begin_nested(): yield session @contextmanager def get_sync_db(): """Context-managed synchronous session for Celery tasks.""" db = SessionLocal() try: yield db except Exception: db.rollback() raise finally: db.close()