""" 数据库连接和初始化 仅支持 PostgreSQL """ import os from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from .models import Base # 从环境变量获取数据库 URL raw_database_url = os.getenv("DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/life_echo") def parse_database_url(url: str) -> tuple[str, str]: """ 解析数据库 URL,返回同步和异步版本 支持格式: - PostgreSQL: postgresql://user:pass@host:port/db - PostgreSQL async: postgresql+asyncpg://user:pass@host:port/db """ if url.startswith("postgresql+asyncpg://"): async_url = url sync_url = url.replace("postgresql+asyncpg://", "postgresql://") elif url.startswith("postgresql://"): sync_url = url async_url = url.replace("postgresql://", "postgresql+asyncpg://") else: print(f"警告: DATABASE_URL 格式不正确 ({url}),使用默认 PostgreSQL") sync_url = "postgresql://postgres:postgres@localhost:5432/life_echo" async_url = "postgresql+asyncpg://postgres:postgres@localhost:5432/life_echo" return sync_url, async_url DATABASE_URL, ASYNC_DATABASE_URL = parse_database_url(raw_database_url) # 同步引擎(用于迁移、Celery 任务等),使用 psycopg sync_url = DATABASE_URL.replace("postgresql://", "postgresql+psycopg://") engine = create_engine(sync_url, pool_size=5, max_overflow=10) # 异步引擎(用于 FastAPI) async_engine = create_async_engine( ASYNC_DATABASE_URL, echo=False, pool_size=5, max_overflow=10, ) # 会话工厂 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) def init_db(): """初始化数据库,创建所有表""" Base.metadata.create_all(bind=engine) def get_db(): """获取同步数据库会话(用于迁移等)""" db = SessionLocal() try: yield db finally: db.close() async def get_async_db(): """获取异步数据库会话(用于实际应用)""" async with AsyncSessionLocal() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close()