Files
life-echo/api/app/core/db.py
2026-03-20 15:15:35 +08:00

114 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
数据库引擎、会话工厂、Base 声明基类。
- 异步FastAPI与同步Celery、Alembic、脚本均使用 psycopgpsycopg3
- database_url 由 settings 指定;若为 postgresql://... 则在此拼接为 postgresql+psycopg:// 以使用 psycopg3。
事务规则:
- get_async_db() 只负责创建和关闭 session不自动 commit/rollback。
- 事务提交由 service 层显式调用 await db.commit()。
- repo 层禁止调用 commit() / rollback()。
"""
from contextlib import contextmanager
from typing import AsyncGenerator
from sqlalchemy import create_engine, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from app.core.config import settings
class Base(DeclarativeBase):
pass
def utc_now():
"""返回当前 UTC 时间(带时区信息)"""
from datetime import datetime, timezone
return datetime.now(timezone.utc)
# ── Database URL纯 postgresql:// 时拼接为 postgresql+psycopg://)────────────────
def ensure_psycopg_url(url: str) -> str:
"""若为 postgresql://... 则改为 postgresql+psycopg://...,否则原样返回。"""
if url.startswith("postgresql://") and not url.startswith("postgresql+psycopg://"):
return "postgresql+psycopg://" + url[len("postgresql://") :]
return url
def _database_url() -> str:
return ensure_psycopg_url(settings.database_url)
# ── Async engine & session (FastAPI) ───────────────────────────
async_engine = create_async_engine(
_database_url(),
echo=False,
pool_size=5,
max_overflow=10,
pool_pre_ping=True,
)
AsyncSessionLocal = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""Yield a clean async session — callers are responsible for commit/rollback."""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
# ── Sync engine & session (Celery, Alembic, scripts) ─────────
sync_engine = create_engine(
_database_url(),
echo=False,
pool_size=5,
max_overflow=10,
pool_pre_ping=True,
)
SyncSessionLocal = sessionmaker(
bind=sync_engine,
autocommit=False,
autoflush=False,
expire_on_commit=False,
)
SessionLocal = SyncSessionLocal
def init_db_schema() -> None:
"""根据当前已注册的 Model 创建缺失的表(开发/空库时自动建表,与 api-bak 行为一致)。
不执行 DROP/ALTER仅 create_all生产环境仍建议用 Alembic 做可控迁移。
memory 模块依赖 pgvector建表前先确保扩展已启用。
"""
with sync_engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
Base.metadata.create_all(bind=sync_engine)
@contextmanager
def get_sync_db():
"""Context-managed synchronous session for Celery tasks."""
db = SessionLocal()
try:
yield db
except Exception:
db.rollback()
raise
finally:
db.close()