""" 数据库连接和初始化 支持 PostgreSQL(推荐)和 SQLite(本地开发) """ import os from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from .models import Base # 从环境变量获取数据库 URL raw_database_url = os.getenv("DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/life_echo") def parse_database_url(url: str) -> tuple[str, str]: """ 解析数据库 URL,返回同步和异步版本 支持格式: - PostgreSQL: postgresql://user:pass@host:port/db - PostgreSQL async: postgresql+asyncpg://user:pass@host:port/db - SQLite: sqlite:///./path/to/db.db - SQLite async: sqlite+aiosqlite:///./path/to/db.db """ # PostgreSQL if url.startswith("postgresql+asyncpg://"): async_url = url sync_url = url.replace("postgresql+asyncpg://", "postgresql://") elif url.startswith("postgresql://"): sync_url = url async_url = url.replace("postgresql://", "postgresql+asyncpg://") # SQLite elif url.startswith("sqlite+aiosqlite://"): async_url = url sync_url = url.replace("sqlite+aiosqlite://", "sqlite://") elif url.startswith("sqlite://"): sync_url = url async_url = url.replace("sqlite://", "sqlite+aiosqlite://") else: # 默认使用 PostgreSQL print(f"警告: DATABASE_URL 格式不正确 ({url}),使用默认 PostgreSQL") sync_url = "postgresql://postgres:postgres@localhost:5432/life_echo" async_url = "postgresql+asyncpg://postgres:postgres@localhost:5432/life_echo" return sync_url, async_url DATABASE_URL, ASYNC_DATABASE_URL = parse_database_url(raw_database_url) # 创建同步引擎(用于迁移、Celery 任务等) # SQLite 需要特殊的 connect_args if DATABASE_URL.startswith("sqlite"): engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) else: # 使用 psycopg (v3) 驱动 sync_url = DATABASE_URL.replace("postgresql://", "postgresql+psycopg://") engine = create_engine(sync_url, pool_size=5, max_overflow=10) # 创建异步引擎(用于 FastAPI) if ASYNC_DATABASE_URL.startswith("sqlite"): async_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False) else: async_engine = create_async_engine( ASYNC_DATABASE_URL, echo=False, pool_size=5, max_overflow=10 ) # 创建会话工厂 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) def init_db(): """初始化数据库,创建所有表""" Base.metadata.create_all(bind=engine) def get_db(): """获取同步数据库会话(用于迁移等)""" db = SessionLocal() try: yield db finally: db.close() async def get_async_db(): """获取异步数据库会话(用于实际应用)""" async with AsyncSessionLocal() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close()