Files
life-echo/api/app/core/db.py

114 lines
3.2 KiB
Python
Raw Normal View History

"""
数据库引擎会话工厂Base 声明基类
- 异步FastAPI与同步CeleryAlembic脚本均使用 psycopgpsycopg3
- database_url settings 指定若为 postgresql://... 则在此拼接为 postgresql+psycopg:// 以使用 psycopg3
事务规则
- get_async_db() 只负责创建和关闭 session不自动 commit/rollback
- 事务提交由 service 层显式调用 await db.commit()
- repo 层禁止调用 commit() / rollback()
"""
from contextlib import contextmanager
from typing import AsyncGenerator
from sqlalchemy import create_engine, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from app.core.config import settings
class Base(DeclarativeBase):
pass
def utc_now():
"""返回当前 UTC 时间(带时区信息)"""
from datetime import datetime, timezone
return datetime.now(timezone.utc)
# ── Database URL纯 postgresql:// 时拼接为 postgresql+psycopg://)────────────────
2026-03-19 14:36:14 +08:00
def ensure_psycopg_url(url: str) -> str:
"""若为 postgresql://... 则改为 postgresql+psycopg://...,否则原样返回。"""
if url.startswith("postgresql://") and not url.startswith("postgresql+psycopg://"):
return "postgresql+psycopg://" + url[len("postgresql://") :]
return url
def _database_url() -> str:
return ensure_psycopg_url(settings.database_url)
# ── Async engine & session (FastAPI) ───────────────────────────
async_engine = create_async_engine(
_database_url(),
echo=False,
pool_size=5,
max_overflow=10,
2026-03-20 15:15:35 +08:00
pool_pre_ping=True,
)
AsyncSessionLocal = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""Yield a clean async session — callers are responsible for commit/rollback."""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
# ── Sync engine & session (Celery, Alembic, scripts) ─────────
sync_engine = create_engine(
_database_url(),
echo=False,
pool_size=5,
max_overflow=10,
pool_pre_ping=True,
)
SyncSessionLocal = sessionmaker(
bind=sync_engine,
autocommit=False,
autoflush=False,
expire_on_commit=False,
)
SessionLocal = SyncSessionLocal
def init_db_schema() -> None:
"""根据当前已注册的 Model 创建缺失的表(开发/空库时自动建表,与 api-bak 行为一致)。
不执行 DROP/ALTER create_all生产环境仍建议用 Alembic 做可控迁移
memory 模块依赖 pgvector建表前先确保扩展已启用
"""
with sync_engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
Base.metadata.create_all(bind=sync_engine)
@contextmanager
def get_sync_db():
"""Context-managed synchronous session for Celery tasks."""
db = SessionLocal()
try:
yield db
except Exception:
db.rollback()
raise
finally:
db.close()