2026-03-18 17:18:23 +08:00
|
|
|
|
"""
|
|
|
|
|
|
数据库引擎、会话工厂、Base 声明基类。
|
|
|
|
|
|
|
|
|
|
|
|
- 异步(FastAPI)与同步(Celery、Alembic、脚本)均使用 psycopg(psycopg3)。
|
|
|
|
|
|
- database_url 由 settings 指定;若为 postgresql://... 则在此拼接为 postgresql+psycopg:// 以使用 psycopg3。
|
|
|
|
|
|
|
|
|
|
|
|
事务规则:
|
|
|
|
|
|
- get_async_db() 只负责创建和关闭 session,不自动 commit/rollback。
|
2026-05-22 13:44:50 +08:00
|
|
|
|
- service / Celery task 层优先使用 transactional() / transactional_sync() 管理多步写操作。
|
2026-03-18 17:18:23 +08:00
|
|
|
|
- repo 层禁止调用 commit() / rollback()。
|
2026-05-22 13:44:50 +08:00
|
|
|
|
|
|
|
|
|
|
transactional 语义:
|
|
|
|
|
|
- transactional() / transactional_sync() 是顶层事务边界;成功 exit 时 commit 整个 session,异常时 rollback 整个 session。
|
|
|
|
|
|
- 不支持嵌套自身:同一 session 上连续两次 transactional() = 两次独立 commit(WS pipeline 分段持久化属于此模式)。
|
|
|
|
|
|
- 需要嵌套回滚时:在已开启的事务内使用 transactional_nested() / transactional_nested_sync()(基于 SQLAlchemy begin_nested() savepoint)。
|
|
|
|
|
|
- 选择指南:单步/整段业务原子提交 → transactional;长生命周期 session 内局部试错、可独立回滚的子步骤 → transactional_nested(必须在外层事务 active 期间)。
|
|
|
|
|
|
|
|
|
|
|
|
transactional_nested 示例(外层提交、内层失败仅回滚 savepoint)::
|
|
|
|
|
|
|
|
|
|
|
|
async with transactional(session):
|
|
|
|
|
|
session.add(parent_row)
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with transactional_nested(session):
|
|
|
|
|
|
await attempt_optional_side_effect(session)
|
|
|
|
|
|
except RecoverableError:
|
|
|
|
|
|
pass # savepoint rolled back; parent_row still commits
|
2026-03-18 17:18:23 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
2026-05-22 13:44:50 +08:00
|
|
|
|
from contextlib import asynccontextmanager, contextmanager
|
2026-03-18 17:18:23 +08:00
|
|
|
|
from typing import AsyncGenerator
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import create_engine, text
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
2026-05-22 13:44:50 +08:00
|
|
|
|
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
2026-03-18 17:18:23 +08:00
|
|
|
|
|
|
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def utc_now():
|
|
|
|
|
|
"""返回当前 UTC 时间(带时区信息)"""
|
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
|
|
|
|
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── Database URL(纯 postgresql:// 时拼接为 postgresql+psycopg://)────────────────
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
def ensure_psycopg_url(url: str) -> str:
|
|
|
|
|
|
"""若为 postgresql://... 则改为 postgresql+psycopg://...,否则原样返回。"""
|
|
|
|
|
|
if url.startswith("postgresql://") and not url.startswith("postgresql+psycopg://"):
|
|
|
|
|
|
return "postgresql+psycopg://" + url[len("postgresql://") :]
|
|
|
|
|
|
return url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _database_url() -> str:
|
|
|
|
|
|
return ensure_psycopg_url(settings.database_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── Async engine & session (FastAPI) ───────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
async_engine = create_async_engine(
|
|
|
|
|
|
_database_url(),
|
|
|
|
|
|
echo=False,
|
|
|
|
|
|
pool_size=5,
|
|
|
|
|
|
max_overflow=10,
|
2026-03-20 15:15:35 +08:00
|
|
|
|
pool_pre_ping=True,
|
2026-03-18 17:18:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
AsyncSessionLocal = async_sessionmaker(
|
|
|
|
|
|
async_engine, class_=AsyncSession, expire_on_commit=False
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
|
|
|
|
|
"""Yield a clean async session — callers are responsible for commit/rollback."""
|
|
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield session
|
|
|
|
|
|
finally:
|
|
|
|
|
|
await session.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-05-22 13:44:50 +08:00
|
|
|
|
@asynccontextmanager
|
|
|
|
|
|
async def transactional(session: AsyncSession):
|
|
|
|
|
|
"""Top-level async transaction: commit on success, rollback on any exception.
|
|
|
|
|
|
|
|
|
|
|
|
Do not nest transactional() on the same session; each call commits independently.
|
|
|
|
|
|
For partial rollback within an active transaction, use transactional_nested().
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield session
|
|
|
|
|
|
await session.commit()
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
await session.rollback()
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
|
async def transactional_nested(session: AsyncSession):
|
|
|
|
|
|
"""Savepoint boundary; roll back only this block on error.
|
|
|
|
|
|
|
|
|
|
|
|
Must be used while the session already has an active transaction (e.g. inside
|
|
|
|
|
|
transactional() before it commits, or after autobegin from a prior write).
|
|
|
|
|
|
"""
|
|
|
|
|
|
async with session.begin_nested():
|
|
|
|
|
|
yield session
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
# ── Sync engine & session (Celery, Alembic, scripts) ─────────
|
|
|
|
|
|
|
|
|
|
|
|
sync_engine = create_engine(
|
|
|
|
|
|
_database_url(),
|
|
|
|
|
|
echo=False,
|
|
|
|
|
|
pool_size=5,
|
|
|
|
|
|
max_overflow=10,
|
|
|
|
|
|
pool_pre_ping=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
SyncSessionLocal = sessionmaker(
|
|
|
|
|
|
bind=sync_engine,
|
|
|
|
|
|
autocommit=False,
|
|
|
|
|
|
autoflush=False,
|
|
|
|
|
|
expire_on_commit=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
SessionLocal = SyncSessionLocal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_db_schema() -> None:
|
|
|
|
|
|
"""根据当前已注册的 Model 创建缺失的表(开发/空库时自动建表,与 api-bak 行为一致)。
|
|
|
|
|
|
不执行 DROP/ALTER,仅 create_all;生产环境仍建议用 Alembic 做可控迁移。
|
|
|
|
|
|
memory 模块依赖 pgvector,建表前先确保扩展已启用。
|
|
|
|
|
|
"""
|
|
|
|
|
|
with sync_engine.connect() as conn:
|
|
|
|
|
|
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
Base.metadata.create_all(bind=sync_engine)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-05-22 13:44:50 +08:00
|
|
|
|
@contextmanager
|
|
|
|
|
|
def transactional_sync(session: Session):
|
|
|
|
|
|
"""Top-level sync transaction: commit on success, rollback on any exception.
|
|
|
|
|
|
|
|
|
|
|
|
Do not nest transactional_sync() on the same session; each call commits independently.
|
|
|
|
|
|
For partial rollback within an active transaction, use transactional_nested_sync().
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield session
|
|
|
|
|
|
session.commit()
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
session.rollback()
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
|
def transactional_nested_sync(session: Session):
|
|
|
|
|
|
"""Savepoint boundary for sync Celery / scripts; roll back only this block on error."""
|
|
|
|
|
|
with session.begin_nested():
|
|
|
|
|
|
yield session
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
@contextmanager
|
|
|
|
|
|
def get_sync_db():
|
|
|
|
|
|
"""Context-managed synchronous session for Celery tasks."""
|
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield db
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
db.rollback()
|
|
|
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
|
|
|
db.close()
|