Merge branch 'refactor/backend-architecture' into development

This commit is contained in:
yangshilin
2026-03-18 17:18:23 +08:00
parent 2070a03d35
commit 48b70e1350
266 changed files with 12386 additions and 9690 deletions

View File

@@ -1,19 +0,0 @@
"""
数据库模块
"""
from .database import get_db, get_async_db, init_db
from .models import User, Conversation, Segment, Chapter, Book, RefreshToken, MemoirState
__all__ = [
"get_db",
"get_async_db",
"init_db",
"User",
"RefreshToken",
"Conversation",
"Segment",
"Chapter",
"Book",
"MemoirState",
]

View File

@@ -1,81 +0,0 @@
"""
数据库连接和初始化
仅支持 PostgreSQL
"""
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from .models import Base
# 从环境变量获取数据库 URL
raw_database_url = os.getenv("DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/life_echo")
def parse_database_url(url: str) -> tuple[str, str]:
"""
解析数据库 URL返回同步和异步版本
支持格式:
- PostgreSQL: postgresql://user:pass@host:port/db
- PostgreSQL async: postgresql+asyncpg://user:pass@host:port/db
"""
if url.startswith("postgresql+asyncpg://"):
async_url = url
sync_url = url.replace("postgresql+asyncpg://", "postgresql://")
elif url.startswith("postgresql://"):
sync_url = url
async_url = url.replace("postgresql://", "postgresql+asyncpg://")
else:
print(f"警告: DATABASE_URL 格式不正确 ({url}),使用默认 PostgreSQL")
sync_url = "postgresql://postgres:postgres@localhost:5432/life_echo"
async_url = "postgresql+asyncpg://postgres:postgres@localhost:5432/life_echo"
return sync_url, async_url
DATABASE_URL, ASYNC_DATABASE_URL = parse_database_url(raw_database_url)
# 同步引擎用于迁移、Celery 任务等),使用 psycopg
sync_url = DATABASE_URL.replace("postgresql://", "postgresql+psycopg://")
engine = create_engine(sync_url, pool_size=5, max_overflow=10)
# 异步引擎(用于 FastAPI
async_engine = create_async_engine(
ASYNC_DATABASE_URL,
echo=False,
pool_size=5,
max_overflow=10,
)
# 会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
def init_db():
"""初始化数据库,创建所有表"""
Base.metadata.create_all(bind=engine)
def get_db():
"""获取同步数据库会话(用于迁移等)"""
db = SessionLocal()
try:
yield db
finally:
db.close()
async def get_async_db():
"""获取异步数据库会话(用于实际应用)"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()