Files
life-echo/api/database/database.py
徐在坤 3a4c9f0838 refactor: 更新数据库和模型
- 更新数据库连接和会话管理
- 更新数据模型以支持用户认证
- 添加RefreshToken模型用于刷新令牌管理
2026-01-18 15:57:47 +08:00

68 lines
2.2 KiB
Python

"""
数据库连接和初始化
"""
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from .models import Base
# 数据库文件路径
# 从环境变量获取,如果格式不正确则使用默认值
raw_database_url = os.getenv("DATABASE_URL", "sqlite:///./life_echo.db")
# 处理数据库 URL
# 如果已经是异步格式,需要提取同步格式用于同步引擎
if raw_database_url.startswith("sqlite+aiosqlite://"):
# 提取文件路径(移除协议部分)
file_path = raw_database_url.replace("sqlite+aiosqlite://", "")
DATABASE_URL = f"sqlite://{file_path}"
ASYNC_DATABASE_URL = raw_database_url
elif raw_database_url.startswith("sqlite://"):
DATABASE_URL = raw_database_url
ASYNC_DATABASE_URL = raw_database_url.replace("sqlite://", "sqlite+aiosqlite://")
else:
# 如果格式不正确,使用默认值并打印警告
print(f"警告: DATABASE_URL 格式不正确 ({raw_database_url}),使用默认值")
DATABASE_URL = "sqlite:///./life_echo.db"
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///./life_echo.db"
# 创建同步引擎(用于迁移等)
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
# 创建异步引擎(用于实际应用)
async_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
def init_db():
"""初始化数据库,创建所有表"""
Base.metadata.create_all(bind=engine)
def get_db():
"""获取同步数据库会话(用于迁移等)"""
db = SessionLocal()
try:
yield db
finally:
db.close()
async def get_async_db():
"""获取异步数据库会话(用于实际应用)"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()