2026-01-07 11:56:33 +08:00
|
|
|
|
"""
|
|
|
|
|
|
数据库连接和初始化
|
2026-02-12 13:33:14 +08:00
|
|
|
|
仅支持 PostgreSQL
|
2026-01-07 11:56:33 +08:00
|
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
|
|
|
from sqlalchemy import create_engine
|
|
|
|
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
|
|
|
|
|
|
|
|
|
|
from .models import Base
|
|
|
|
|
|
|
2026-01-21 23:21:36 +01:00
|
|
|
|
# 从环境变量获取数据库 URL
|
|
|
|
|
|
raw_database_url = os.getenv("DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/life_echo")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_database_url(url: str) -> tuple[str, str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
解析数据库 URL,返回同步和异步版本
|
2026-02-12 13:33:14 +08:00
|
|
|
|
|
2026-01-21 23:21:36 +01:00
|
|
|
|
支持格式:
|
|
|
|
|
|
- PostgreSQL: postgresql://user:pass@host:port/db
|
|
|
|
|
|
- PostgreSQL async: postgresql+asyncpg://user:pass@host:port/db
|
|
|
|
|
|
"""
|
|
|
|
|
|
if url.startswith("postgresql+asyncpg://"):
|
|
|
|
|
|
async_url = url
|
|
|
|
|
|
sync_url = url.replace("postgresql+asyncpg://", "postgresql://")
|
|
|
|
|
|
elif url.startswith("postgresql://"):
|
|
|
|
|
|
sync_url = url
|
|
|
|
|
|
async_url = url.replace("postgresql://", "postgresql+asyncpg://")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"警告: DATABASE_URL 格式不正确 ({url}),使用默认 PostgreSQL")
|
|
|
|
|
|
sync_url = "postgresql://postgres:postgres@localhost:5432/life_echo"
|
|
|
|
|
|
async_url = "postgresql+asyncpg://postgres:postgres@localhost:5432/life_echo"
|
2026-02-12 13:33:14 +08:00
|
|
|
|
|
2026-01-21 23:21:36 +01:00
|
|
|
|
return sync_url, async_url
|
|
|
|
|
|
|
2026-01-07 11:56:33 +08:00
|
|
|
|
|
2026-01-21 23:21:36 +01:00
|
|
|
|
DATABASE_URL, ASYNC_DATABASE_URL = parse_database_url(raw_database_url)
|
2026-01-07 11:56:33 +08:00
|
|
|
|
|
2026-02-12 13:33:14 +08:00
|
|
|
|
# 同步引擎(用于迁移、Celery 任务等),使用 psycopg
|
|
|
|
|
|
sync_url = DATABASE_URL.replace("postgresql://", "postgresql+psycopg://")
|
|
|
|
|
|
engine = create_engine(sync_url, pool_size=5, max_overflow=10)
|
|
|
|
|
|
|
|
|
|
|
|
# 异步引擎(用于 FastAPI)
|
|
|
|
|
|
async_engine = create_async_engine(
|
|
|
|
|
|
ASYNC_DATABASE_URL,
|
|
|
|
|
|
echo=False,
|
|
|
|
|
|
pool_size=5,
|
|
|
|
|
|
max_overflow=10,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 会话工厂
|
2026-01-07 11:56:33 +08:00
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_db():
|
|
|
|
|
|
"""初始化数据库,创建所有表"""
|
|
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_db():
|
|
|
|
|
|
"""获取同步数据库会话(用于迁移等)"""
|
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield db
|
|
|
|
|
|
finally:
|
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_async_db():
|
|
|
|
|
|
"""获取异步数据库会话(用于实际应用)"""
|
|
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield session
|
|
|
|
|
|
await session.commit()
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
await session.rollback()
|
|
|
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
|
|
|
await session.close()
|