114 lines
3.2 KiB
Python
114 lines
3.2 KiB
Python
"""
|
||
数据库引擎、会话工厂、Base 声明基类。
|
||
|
||
- 异步(FastAPI)与同步(Celery、Alembic、脚本)均使用 psycopg(psycopg3)。
|
||
- database_url 由 settings 指定;若为 postgresql://... 则在此拼接为 postgresql+psycopg:// 以使用 psycopg3。
|
||
|
||
事务规则:
|
||
- get_async_db() 只负责创建和关闭 session,不自动 commit/rollback。
|
||
- 事务提交由 service 层显式调用 await db.commit()。
|
||
- repo 层禁止调用 commit() / rollback()。
|
||
"""
|
||
|
||
from contextlib import contextmanager
|
||
from typing import AsyncGenerator
|
||
|
||
from sqlalchemy import create_engine, text
|
||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||
|
||
from app.core.config import settings
|
||
|
||
|
||
class Base(DeclarativeBase):
|
||
pass
|
||
|
||
|
||
def utc_now():
|
||
"""返回当前 UTC 时间(带时区信息)"""
|
||
from datetime import datetime, timezone
|
||
|
||
return datetime.now(timezone.utc)
|
||
|
||
|
||
# ── Database URL(纯 postgresql:// 时拼接为 postgresql+psycopg://)────────────────
|
||
|
||
|
||
def ensure_psycopg_url(url: str) -> str:
|
||
"""若为 postgresql://... 则改为 postgresql+psycopg://...,否则原样返回。"""
|
||
if url.startswith("postgresql://") and not url.startswith("postgresql+psycopg://"):
|
||
return "postgresql+psycopg://" + url[len("postgresql://") :]
|
||
return url
|
||
|
||
|
||
def _database_url() -> str:
|
||
return ensure_psycopg_url(settings.database_url)
|
||
|
||
|
||
# ── Async engine & session (FastAPI) ───────────────────────────
|
||
|
||
async_engine = create_async_engine(
|
||
_database_url(),
|
||
echo=False,
|
||
pool_size=5,
|
||
max_overflow=10,
|
||
pool_pre_ping=True,
|
||
)
|
||
|
||
AsyncSessionLocal = async_sessionmaker(
|
||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||
)
|
||
|
||
|
||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||
"""Yield a clean async session — callers are responsible for commit/rollback."""
|
||
async with AsyncSessionLocal() as session:
|
||
try:
|
||
yield session
|
||
finally:
|
||
await session.close()
|
||
|
||
|
||
# ── Sync engine & session (Celery, Alembic, scripts) ─────────
|
||
|
||
sync_engine = create_engine(
|
||
_database_url(),
|
||
echo=False,
|
||
pool_size=5,
|
||
max_overflow=10,
|
||
pool_pre_ping=True,
|
||
)
|
||
|
||
SyncSessionLocal = sessionmaker(
|
||
bind=sync_engine,
|
||
autocommit=False,
|
||
autoflush=False,
|
||
expire_on_commit=False,
|
||
)
|
||
|
||
SessionLocal = SyncSessionLocal
|
||
|
||
|
||
def init_db_schema() -> None:
|
||
"""根据当前已注册的 Model 创建缺失的表(开发/空库时自动建表,与 api-bak 行为一致)。
|
||
不执行 DROP/ALTER,仅 create_all;生产环境仍建议用 Alembic 做可控迁移。
|
||
memory 模块依赖 pgvector,建表前先确保扩展已启用。
|
||
"""
|
||
with sync_engine.connect() as conn:
|
||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||
conn.commit()
|
||
Base.metadata.create_all(bind=sync_engine)
|
||
|
||
|
||
@contextmanager
|
||
def get_sync_db():
|
||
"""Context-managed synchronous session for Celery tasks."""
|
||
db = SessionLocal()
|
||
try:
|
||
yield db
|
||
except Exception:
|
||
db.rollback()
|
||
raise
|
||
finally:
|
||
db.close()
|