feat: 添加PostgreSQL支持并更新数据库配置

- 新增PostgreSQL服务支持,使用最新版17
- 更新Docker Compose配置以支持PostgreSQL和Redis
- 修改数据库连接逻辑,支持PostgreSQL和SQLite
- 更新文档以反映新的数据库配置和使用方法
- 优化数据模型,确保时间戳字段支持时区
This commit is contained in:
penghanyuan
2026-01-21 23:21:36 +01:00
parent dbbb924625
commit 0591e9d7c1
7 changed files with 170 additions and 48 deletions

View File

@@ -1,5 +1,6 @@
"""
数据库连接和初始化
支持 PostgreSQL推荐和 SQLite本地开发
"""
import os
from sqlalchemy import create_engine
@@ -8,31 +9,64 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sess
from .models import Base
# 数据库文件路径
# 从环境变量获取,如果格式不正确则使用默认值
raw_database_url = os.getenv("DATABASE_URL", "sqlite:///./life_echo.db")
# 从环境变量获取数据库 URL
raw_database_url = os.getenv("DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/life_echo")
# 处理数据库 URL
# 如果已经是异步格式,需要提取同步格式用于同步引擎
if raw_database_url.startswith("sqlite+aiosqlite://"):
# 提取文件路径(移除协议部分)
file_path = raw_database_url.replace("sqlite+aiosqlite://", "")
DATABASE_URL = f"sqlite://{file_path}"
ASYNC_DATABASE_URL = raw_database_url
elif raw_database_url.startswith("sqlite://"):
DATABASE_URL = raw_database_url
ASYNC_DATABASE_URL = raw_database_url.replace("sqlite://", "sqlite+aiosqlite://")
def parse_database_url(url: str) -> tuple[str, str]:
"""
解析数据库 URL返回同步和异步版本
支持格式:
- PostgreSQL: postgresql://user:pass@host:port/db
- PostgreSQL async: postgresql+asyncpg://user:pass@host:port/db
- SQLite: sqlite:///./path/to/db.db
- SQLite async: sqlite+aiosqlite:///./path/to/db.db
"""
# PostgreSQL
if url.startswith("postgresql+asyncpg://"):
async_url = url
sync_url = url.replace("postgresql+asyncpg://", "postgresql://")
elif url.startswith("postgresql://"):
sync_url = url
async_url = url.replace("postgresql://", "postgresql+asyncpg://")
# SQLite
elif url.startswith("sqlite+aiosqlite://"):
async_url = url
sync_url = url.replace("sqlite+aiosqlite://", "sqlite://")
elif url.startswith("sqlite://"):
sync_url = url
async_url = url.replace("sqlite://", "sqlite+aiosqlite://")
else:
# 默认使用 PostgreSQL
print(f"警告: DATABASE_URL 格式不正确 ({url}),使用默认 PostgreSQL")
sync_url = "postgresql://postgres:postgres@localhost:5432/life_echo"
async_url = "postgresql+asyncpg://postgres:postgres@localhost:5432/life_echo"
return sync_url, async_url
DATABASE_URL, ASYNC_DATABASE_URL = parse_database_url(raw_database_url)
# 创建同步引擎用于迁移、Celery 任务等)
# SQLite 需要特殊的 connect_args
if DATABASE_URL.startswith("sqlite"):
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
else:
# 如果格式不正确,使用默认值并打印警告
print(f"警告: DATABASE_URL 格式不正确 ({raw_database_url}),使用默认值")
DATABASE_URL = "sqlite:///./life_echo.db"
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///./life_echo.db"
# 使用 psycopg (v3) 驱动
sync_url = DATABASE_URL.replace("postgresql://", "postgresql+psycopg://")
engine = create_engine(sync_url, pool_size=5, max_overflow=10)
# 创建步引擎(用于迁移等
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
# 创建异步引擎(用于实际应用)
async_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False)
# 创建步引擎(用于 FastAPI
if ASYNC_DATABASE_URL.startswith("sqlite"):
async_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False)
else:
async_engine = create_async_engine(
ASYNC_DATABASE_URL,
echo=False,
pool_size=5,
max_overflow=10
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View File

@@ -1,7 +1,7 @@
"""
数据库模型定义
"""
from datetime import datetime
from datetime import datetime, timezone
from typing import Optional, List
from sqlalchemy import Column, String, Integer, DateTime, Boolean, Text, ForeignKey, JSON
from sqlalchemy.ext.declarative import declarative_base
@@ -10,6 +10,11 @@ from sqlalchemy.orm import relationship
Base = declarative_base()
def utc_now():
"""返回当前 UTC 时间(带时区信息)"""
return datetime.now(timezone.utc)
class User(Base):
"""用户表"""
__tablename__ = "users"
@@ -22,7 +27,7 @@ class User(Base):
nickname = Column(String, nullable=False)
avatar_url = Column(String, nullable=True)
subscription_type = Column(String, default="free") # free, premium
created_at = Column(DateTime, default=datetime.utcnow)
created_at = Column(DateTime(timezone=True), default=utc_now)
# Relationships
conversations = relationship("Conversation", back_populates="user")
@@ -38,8 +43,8 @@ class Conversation(Base):
id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey("users.id"), nullable=False)
started_at = Column(DateTime, default=datetime.utcnow)
ended_at = Column(DateTime, nullable=True)
started_at = Column(DateTime(timezone=True), default=utc_now)
ended_at = Column(DateTime(timezone=True), nullable=True)
duration_seconds = Column(Integer, default=0)
summary = Column(Text, nullable=True)
status = Column(String, default="active") # active, ended, processing
@@ -59,7 +64,7 @@ class Segment(Base):
conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False)
audio_url = Column(String, nullable=True)
transcript_text = Column(Text, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
created_at = Column(DateTime(timezone=True), default=utc_now)
processed = Column(Boolean, default=False)
topic_category = Column(String, nullable=True)
agent_response = Column(Text, nullable=True)
@@ -79,7 +84,7 @@ class Chapter(Base):
order_index = Column(Integer, nullable=False)
status = Column(String, default="draft") # draft, completed
images = Column(JSON, nullable=True) # 图片 URL 列表
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
category = Column(String, nullable=True) # 章节分类
is_new = Column(Boolean, default=True) # 是否为新内容(未读)
source_segments = Column(JSON, nullable=True) # 来源 segment IDs 列表
@@ -98,7 +103,7 @@ class Book(Base):
total_pages = Column(Integer, default=0)
total_words = Column(Integer, default=0)
cover_image_url = Column(String, nullable=True)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
has_update = Column(Boolean, default=False) # 是否有新内容
last_update_chapter_id = Column(String, nullable=True) # 最近更新的章节 ID
@@ -116,7 +121,7 @@ class MemoirState(Base):
current_stage = Column(String, default="childhood") # 当前阶段
covered_stages = Column(JSON, default=list) # 已完成阶段列表
slots = Column(JSON, nullable=False) # 各阶段 slot 信息
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
# Relationships
user = relationship("User", back_populates="memoir_state")
@@ -129,8 +134,8 @@ class RefreshToken(Base):
id = Column(String, primary_key=True)
user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True)
token = Column(String, unique=True, nullable=False, index=True) # 刷新令牌(唯一)
expires_at = Column(DateTime, nullable=False) # 过期时间30天后
created_at = Column(DateTime, default=datetime.utcnow)
expires_at = Column(DateTime(timezone=True), nullable=False) # 过期时间30天后
created_at = Column(DateTime(timezone=True), default=utc_now)
is_revoked = Column(Boolean, default=False) # 是否已撤销
# Relationships