2026-03-18 17:18:23 +08:00
|
|
|
|
"""
|
|
|
|
|
|
数据库引擎、会话工厂、Base 声明基类。
|
|
|
|
|
|
|
|
|
|
|
|
- 异步(FastAPI)与同步(Celery、Alembic、脚本)均使用 psycopg(psycopg3)。
|
|
|
|
|
|
- database_url 由 settings 指定;若为 postgresql://... 则在此拼接为 postgresql+psycopg:// 以使用 psycopg3。
|
|
|
|
|
|
|
|
|
|
|
|
事务规则:
|
|
|
|
|
|
- get_async_db() 只负责创建和关闭 session,不自动 commit/rollback。
|
|
|
|
|
|
- 事务提交由 service 层显式调用 await db.commit()。
|
|
|
|
|
|
- repo 层禁止调用 commit() / rollback()。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
from typing import AsyncGenerator
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import create_engine, text
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
|
|
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
|
|
|
|
|
|
|
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def utc_now():
|
|
|
|
|
|
"""返回当前 UTC 时间(带时区信息)"""
|
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
|
|
|
|
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── Database URL(纯 postgresql:// 时拼接为 postgresql+psycopg://)────────────────
|
|
|
|
|
|
|
2026-03-19 14:36:14 +08:00
|
|
|
|
|
2026-03-18 17:18:23 +08:00
|
|
|
|
def ensure_psycopg_url(url: str) -> str:
|
|
|
|
|
|
"""若为 postgresql://... 则改为 postgresql+psycopg://...,否则原样返回。"""
|
|
|
|
|
|
if url.startswith("postgresql://") and not url.startswith("postgresql+psycopg://"):
|
|
|
|
|
|
return "postgresql+psycopg://" + url[len("postgresql://") :]
|
|
|
|
|
|
return url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _database_url() -> str:
|
|
|
|
|
|
return ensure_psycopg_url(settings.database_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── Async engine & session (FastAPI) ───────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
async_engine = create_async_engine(
|
|
|
|
|
|
_database_url(),
|
|
|
|
|
|
echo=False,
|
|
|
|
|
|
pool_size=5,
|
|
|
|
|
|
max_overflow=10,
|
2026-03-20 15:15:35 +08:00
|
|
|
|
pool_pre_ping=True,
|
2026-03-18 17:18:23 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
AsyncSessionLocal = async_sessionmaker(
|
|
|
|
|
|
async_engine, class_=AsyncSession, expire_on_commit=False
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
|
|
|
|
|
"""Yield a clean async session — callers are responsible for commit/rollback."""
|
|
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield session
|
|
|
|
|
|
finally:
|
|
|
|
|
|
await session.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ── Sync engine & session (Celery, Alembic, scripts) ─────────
|
|
|
|
|
|
|
|
|
|
|
|
sync_engine = create_engine(
|
|
|
|
|
|
_database_url(),
|
|
|
|
|
|
echo=False,
|
|
|
|
|
|
pool_size=5,
|
|
|
|
|
|
max_overflow=10,
|
|
|
|
|
|
pool_pre_ping=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
SyncSessionLocal = sessionmaker(
|
|
|
|
|
|
bind=sync_engine,
|
|
|
|
|
|
autocommit=False,
|
|
|
|
|
|
autoflush=False,
|
|
|
|
|
|
expire_on_commit=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
SessionLocal = SyncSessionLocal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_db_schema() -> None:
|
|
|
|
|
|
"""根据当前已注册的 Model 创建缺失的表(开发/空库时自动建表,与 api-bak 行为一致)。
|
|
|
|
|
|
不执行 DROP/ALTER,仅 create_all;生产环境仍建议用 Alembic 做可控迁移。
|
|
|
|
|
|
memory 模块依赖 pgvector,建表前先确保扩展已启用。
|
|
|
|
|
|
"""
|
|
|
|
|
|
with sync_engine.connect() as conn:
|
|
|
|
|
|
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
|
Base.metadata.create_all(bind=sync_engine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
|
def get_sync_db():
|
|
|
|
|
|
"""Context-managed synchronous session for Celery tasks."""
|
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield db
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
db.rollback()
|
|
|
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
|
|
|
db.close()
|