Files
life-echo/api/app/core/db.py

178 lines
5.8 KiB
Python
Raw Normal View History

"""
数据库引擎会话工厂Base 声明基类
- 异步FastAPI与同步CeleryAlembic脚本均使用 psycopgpsycopg3
- database_url settings 指定若为 postgresql://... 则在此拼接为 postgresql+psycopg:// 以使用 psycopg3
事务规则
- get_async_db() 只负责创建和关闭 session不自动 commit/rollback
- service / Celery task 层优先使用 transactional() / transactional_sync() 管理多步写操作
- repo 层禁止调用 commit() / rollback()
transactional 语义
- transactional() / transactional_sync() 是顶层事务边界成功 exit commit 整个 session异常时 rollback 整个 session
- 不支持嵌套自身同一 session 上连续两次 transactional() = 两次独立 commitWS pipeline 分段持久化属于此模式
- 需要嵌套回滚时在已开启的事务内使用 transactional_nested() / transactional_nested_sync()基于 SQLAlchemy begin_nested() savepoint
- 选择指南单步/整段业务原子提交 transactional长生命周期 session 内局部试错可独立回滚的子步骤 transactional_nested必须在外层事务 active 期间
transactional_nested 示例外层提交内层失败仅回滚 savepoint::
async with transactional(session):
session.add(parent_row)
try:
async with transactional_nested(session):
await attempt_optional_side_effect(session)
except RecoverableError:
pass # savepoint rolled back; parent_row still commits
"""
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncGenerator
from sqlalchemy import create_engine, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from app.core.config import settings
class Base(DeclarativeBase):
pass
def utc_now():
"""返回当前 UTC 时间(带时区信息)"""
from datetime import datetime, timezone
return datetime.now(timezone.utc)
# ── Database URL纯 postgresql:// 时拼接为 postgresql+psycopg://)────────────────
2026-03-19 14:36:14 +08:00
def ensure_psycopg_url(url: str) -> str:
"""若为 postgresql://... 则改为 postgresql+psycopg://...,否则原样返回。"""
if url.startswith("postgresql://") and not url.startswith("postgresql+psycopg://"):
return "postgresql+psycopg://" + url[len("postgresql://") :]
return url
def _database_url() -> str:
return ensure_psycopg_url(settings.database_url)
# ── Async engine & session (FastAPI) ───────────────────────────
async_engine = create_async_engine(
_database_url(),
echo=False,
pool_size=5,
max_overflow=10,
2026-03-20 15:15:35 +08:00
pool_pre_ping=True,
)
AsyncSessionLocal = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""Yield a clean async session — callers are responsible for commit/rollback."""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
@asynccontextmanager
async def transactional(session: AsyncSession):
"""Top-level async transaction: commit on success, rollback on any exception.
Do not nest transactional() on the same session; each call commits independently.
For partial rollback within an active transaction, use transactional_nested().
"""
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
@asynccontextmanager
async def transactional_nested(session: AsyncSession):
"""Savepoint boundary; roll back only this block on error.
Must be used while the session already has an active transaction (e.g. inside
transactional() before it commits, or after autobegin from a prior write).
"""
async with session.begin_nested():
yield session
# ── Sync engine & session (Celery, Alembic, scripts) ─────────
sync_engine = create_engine(
_database_url(),
echo=False,
pool_size=5,
max_overflow=10,
pool_pre_ping=True,
)
SyncSessionLocal = sessionmaker(
bind=sync_engine,
autocommit=False,
autoflush=False,
expire_on_commit=False,
)
SessionLocal = SyncSessionLocal
def init_db_schema() -> None:
"""根据当前已注册的 Model 创建缺失的表(开发/空库时自动建表,与 api-bak 行为一致)。
不执行 DROP/ALTER create_all生产环境仍建议用 Alembic 做可控迁移
memory 模块依赖 pgvector建表前先确保扩展已启用
"""
with sync_engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
Base.metadata.create_all(bind=sync_engine)
@contextmanager
def transactional_sync(session: Session):
"""Top-level sync transaction: commit on success, rollback on any exception.
Do not nest transactional_sync() on the same session; each call commits independently.
For partial rollback within an active transaction, use transactional_nested_sync().
"""
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
@contextmanager
def transactional_nested_sync(session: Session):
"""Savepoint boundary for sync Celery / scripts; roll back only this block on error."""
with session.begin_nested():
yield session
@contextmanager
def get_sync_db():
"""Context-managed synchronous session for Celery tasks."""
db = SessionLocal()
try:
yield db
except Exception:
db.rollback()
raise
finally:
db.close()