52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
|
|
"""
|
||
|
|
数据库连接和初始化
|
||
|
|
"""
|
||
|
|
import os
|
||
|
|
from sqlalchemy import create_engine
|
||
|
|
from sqlalchemy.orm import sessionmaker, Session
|
||
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||
|
|
|
||
|
|
from .models import Base
|
||
|
|
|
||
|
|
# 数据库文件路径
|
||
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./life_echo.db")
|
||
|
|
ASYNC_DATABASE_URL = DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://")
|
||
|
|
|
||
|
|
# 创建同步引擎(用于迁移等)
|
||
|
|
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||
|
|
|
||
|
|
# 创建异步引擎(用于实际应用)
|
||
|
|
async_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False)
|
||
|
|
|
||
|
|
# 创建会话工厂
|
||
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
|
|
AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||
|
|
|
||
|
|
|
||
|
|
def init_db():
|
||
|
|
"""初始化数据库,创建所有表"""
|
||
|
|
Base.metadata.create_all(bind=engine)
|
||
|
|
|
||
|
|
|
||
|
|
def get_db():
|
||
|
|
"""获取同步数据库会话(用于迁移等)"""
|
||
|
|
db = SessionLocal()
|
||
|
|
try:
|
||
|
|
yield db
|
||
|
|
finally:
|
||
|
|
db.close()
|
||
|
|
|
||
|
|
|
||
|
|
async def get_async_db():
|
||
|
|
"""获取异步数据库会话(用于实际应用)"""
|
||
|
|
async with AsyncSessionLocal() as session:
|
||
|
|
try:
|
||
|
|
yield session
|
||
|
|
await session.commit()
|
||
|
|
except Exception:
|
||
|
|
await session.rollback()
|
||
|
|
raise
|
||
|
|
finally:
|
||
|
|
await session.close()
|
||
|
|
|