""" 数据库连接和初始化 """ import os from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from .models import Base # 数据库文件路径 # 从环境变量获取,如果格式不正确则使用默认值 raw_database_url = os.getenv("DATABASE_URL", "sqlite:///./life_echo.db") # 处理数据库 URL # 如果已经是异步格式,需要提取同步格式用于同步引擎 if raw_database_url.startswith("sqlite+aiosqlite://"): # 提取文件路径(移除协议部分) file_path = raw_database_url.replace("sqlite+aiosqlite://", "") DATABASE_URL = f"sqlite://{file_path}" ASYNC_DATABASE_URL = raw_database_url elif raw_database_url.startswith("sqlite://"): DATABASE_URL = raw_database_url ASYNC_DATABASE_URL = raw_database_url.replace("sqlite://", "sqlite+aiosqlite://") else: # 如果格式不正确,使用默认值并打印警告 print(f"警告: DATABASE_URL 格式不正确 ({raw_database_url}),使用默认值") DATABASE_URL = "sqlite:///./life_echo.db" ASYNC_DATABASE_URL = "sqlite+aiosqlite:///./life_echo.db" # 创建同步引擎(用于迁移等) engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) # 创建异步引擎(用于实际应用) async_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False) # 创建会话工厂 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) def init_db(): """初始化数据库,创建所有表""" Base.metadata.create_all(bind=engine) def get_db(): """获取同步数据库会话(用于迁移等)""" db = SessionLocal() try: yield db finally: db.close() async def get_async_db(): """获取异步数据库会话(用于实际应用)""" async with AsyncSessionLocal() as session: try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close()