Files
life-echo/api/database/database.py
iammm0 a261d9da27 refactor: 优化后端数据库与依赖
- 优化 api/database/database.py
- 更新 api/requirements.txt

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-12 13:33:14 +08:00

82 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
数据库连接和初始化
仅支持 PostgreSQL
"""
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from .models import Base
# 从环境变量获取数据库 URL
raw_database_url = os.getenv("DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/life_echo")
def parse_database_url(url: str) -> tuple[str, str]:
"""
解析数据库 URL返回同步和异步版本
支持格式:
- PostgreSQL: postgresql://user:pass@host:port/db
- PostgreSQL async: postgresql+asyncpg://user:pass@host:port/db
"""
if url.startswith("postgresql+asyncpg://"):
async_url = url
sync_url = url.replace("postgresql+asyncpg://", "postgresql://")
elif url.startswith("postgresql://"):
sync_url = url
async_url = url.replace("postgresql://", "postgresql+asyncpg://")
else:
print(f"警告: DATABASE_URL 格式不正确 ({url}),使用默认 PostgreSQL")
sync_url = "postgresql://postgres:postgres@localhost:5432/life_echo"
async_url = "postgresql+asyncpg://postgres:postgres@localhost:5432/life_echo"
return sync_url, async_url
DATABASE_URL, ASYNC_DATABASE_URL = parse_database_url(raw_database_url)
# 同步引擎用于迁移、Celery 任务等),使用 psycopg
sync_url = DATABASE_URL.replace("postgresql://", "postgresql+psycopg://")
engine = create_engine(sync_url, pool_size=5, max_overflow=10)
# 异步引擎(用于 FastAPI
async_engine = create_async_engine(
ASYNC_DATABASE_URL,
echo=False,
pool_size=5,
max_overflow=10,
)
# 会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
def init_db():
"""初始化数据库,创建所有表"""
Base.metadata.create_all(bind=engine)
def get_db():
"""获取同步数据库会话(用于迁移等)"""
db = SessionLocal()
try:
yield db
finally:
db.close()
async def get_async_db():
"""获取异步数据库会话(用于实际应用)"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()