feat: 添加PostgreSQL支持并更新数据库配置
- 新增PostgreSQL服务支持,使用最新版17 - 更新Docker Compose配置以支持PostgreSQL和Redis - 修改数据库连接逻辑,支持PostgreSQL和SQLite - 更新文档以反映新的数据库配置和使用方法 - 优化数据模型,确保时间戳字段支持时区
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""
|
||||
数据库连接和初始化
|
||||
支持 PostgreSQL(推荐)和 SQLite(本地开发)
|
||||
"""
|
||||
import os
|
||||
from sqlalchemy import create_engine
|
||||
@@ -8,31 +9,64 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sess
|
||||
|
||||
from .models import Base
|
||||
|
||||
# 数据库文件路径
|
||||
# 从环境变量获取,如果格式不正确则使用默认值
|
||||
raw_database_url = os.getenv("DATABASE_URL", "sqlite:///./life_echo.db")
|
||||
# 从环境变量获取数据库 URL
|
||||
raw_database_url = os.getenv("DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/life_echo")
|
||||
|
||||
# 处理数据库 URL
|
||||
# 如果已经是异步格式,需要提取同步格式用于同步引擎
|
||||
if raw_database_url.startswith("sqlite+aiosqlite://"):
|
||||
# 提取文件路径(移除协议部分)
|
||||
file_path = raw_database_url.replace("sqlite+aiosqlite://", "")
|
||||
DATABASE_URL = f"sqlite://{file_path}"
|
||||
ASYNC_DATABASE_URL = raw_database_url
|
||||
elif raw_database_url.startswith("sqlite://"):
|
||||
DATABASE_URL = raw_database_url
|
||||
ASYNC_DATABASE_URL = raw_database_url.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
|
||||
def parse_database_url(url: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析数据库 URL,返回同步和异步版本
|
||||
|
||||
支持格式:
|
||||
- PostgreSQL: postgresql://user:pass@host:port/db
|
||||
- PostgreSQL async: postgresql+asyncpg://user:pass@host:port/db
|
||||
- SQLite: sqlite:///./path/to/db.db
|
||||
- SQLite async: sqlite+aiosqlite:///./path/to/db.db
|
||||
"""
|
||||
# PostgreSQL
|
||||
if url.startswith("postgresql+asyncpg://"):
|
||||
async_url = url
|
||||
sync_url = url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
elif url.startswith("postgresql://"):
|
||||
sync_url = url
|
||||
async_url = url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
# SQLite
|
||||
elif url.startswith("sqlite+aiosqlite://"):
|
||||
async_url = url
|
||||
sync_url = url.replace("sqlite+aiosqlite://", "sqlite://")
|
||||
elif url.startswith("sqlite://"):
|
||||
sync_url = url
|
||||
async_url = url.replace("sqlite://", "sqlite+aiosqlite://")
|
||||
else:
|
||||
# 默认使用 PostgreSQL
|
||||
print(f"警告: DATABASE_URL 格式不正确 ({url}),使用默认 PostgreSQL")
|
||||
sync_url = "postgresql://postgres:postgres@localhost:5432/life_echo"
|
||||
async_url = "postgresql+asyncpg://postgres:postgres@localhost:5432/life_echo"
|
||||
|
||||
return sync_url, async_url
|
||||
|
||||
|
||||
DATABASE_URL, ASYNC_DATABASE_URL = parse_database_url(raw_database_url)
|
||||
|
||||
# 创建同步引擎(用于迁移、Celery 任务等)
|
||||
# SQLite 需要特殊的 connect_args
|
||||
if DATABASE_URL.startswith("sqlite"):
|
||||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
else:
|
||||
# 如果格式不正确,使用默认值并打印警告
|
||||
print(f"警告: DATABASE_URL 格式不正确 ({raw_database_url}),使用默认值")
|
||||
DATABASE_URL = "sqlite:///./life_echo.db"
|
||||
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///./life_echo.db"
|
||||
# 使用 psycopg (v3) 驱动
|
||||
sync_url = DATABASE_URL.replace("postgresql://", "postgresql+psycopg://")
|
||||
engine = create_engine(sync_url, pool_size=5, max_overflow=10)
|
||||
|
||||
# 创建同步引擎(用于迁移等)
|
||||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
|
||||
# 创建异步引擎(用于实际应用)
|
||||
async_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False)
|
||||
# 创建异步引擎(用于 FastAPI)
|
||||
if ASYNC_DATABASE_URL.startswith("sqlite"):
|
||||
async_engine = create_async_engine(ASYNC_DATABASE_URL, echo=False)
|
||||
else:
|
||||
async_engine = create_async_engine(
|
||||
ASYNC_DATABASE_URL,
|
||||
echo=False,
|
||||
pool_size=5,
|
||||
max_overflow=10
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
数据库模型定义
|
||||
"""
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import Column, String, Integer, DateTime, Boolean, Text, ForeignKey, JSON
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
@@ -10,6 +10,11 @@ from sqlalchemy.orm import relationship
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def utc_now():
|
||||
"""返回当前 UTC 时间(带时区信息)"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""用户表"""
|
||||
__tablename__ = "users"
|
||||
@@ -22,7 +27,7 @@ class User(Base):
|
||||
nickname = Column(String, nullable=False)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
subscription_type = Column(String, default="free") # free, premium
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
created_at = Column(DateTime(timezone=True), default=utc_now)
|
||||
|
||||
# Relationships
|
||||
conversations = relationship("Conversation", back_populates="user")
|
||||
@@ -38,8 +43,8 @@ class Conversation(Base):
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, ForeignKey("users.id"), nullable=False)
|
||||
started_at = Column(DateTime, default=datetime.utcnow)
|
||||
ended_at = Column(DateTime, nullable=True)
|
||||
started_at = Column(DateTime(timezone=True), default=utc_now)
|
||||
ended_at = Column(DateTime(timezone=True), nullable=True)
|
||||
duration_seconds = Column(Integer, default=0)
|
||||
summary = Column(Text, nullable=True)
|
||||
status = Column(String, default="active") # active, ended, processing
|
||||
@@ -59,7 +64,7 @@ class Segment(Base):
|
||||
conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False)
|
||||
audio_url = Column(String, nullable=True)
|
||||
transcript_text = Column(Text, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
created_at = Column(DateTime(timezone=True), default=utc_now)
|
||||
processed = Column(Boolean, default=False)
|
||||
topic_category = Column(String, nullable=True)
|
||||
agent_response = Column(Text, nullable=True)
|
||||
@@ -79,7 +84,7 @@ class Chapter(Base):
|
||||
order_index = Column(Integer, nullable=False)
|
||||
status = Column(String, default="draft") # draft, completed
|
||||
images = Column(JSON, nullable=True) # 图片 URL 列表
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
|
||||
category = Column(String, nullable=True) # 章节分类
|
||||
is_new = Column(Boolean, default=True) # 是否为新内容(未读)
|
||||
source_segments = Column(JSON, nullable=True) # 来源 segment IDs 列表
|
||||
@@ -98,7 +103,7 @@ class Book(Base):
|
||||
total_pages = Column(Integer, default=0)
|
||||
total_words = Column(Integer, default=0)
|
||||
cover_image_url = Column(String, nullable=True)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
|
||||
has_update = Column(Boolean, default=False) # 是否有新内容
|
||||
last_update_chapter_id = Column(String, nullable=True) # 最近更新的章节 ID
|
||||
|
||||
@@ -116,7 +121,7 @@ class MemoirState(Base):
|
||||
current_stage = Column(String, default="childhood") # 当前阶段
|
||||
covered_stages = Column(JSON, default=list) # 已完成阶段列表
|
||||
slots = Column(JSON, nullable=False) # 各阶段 slot 信息
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
updated_at = Column(DateTime(timezone=True), default=utc_now, onupdate=utc_now)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="memoir_state")
|
||||
@@ -129,8 +134,8 @@ class RefreshToken(Base):
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True)
|
||||
token = Column(String, unique=True, nullable=False, index=True) # 刷新令牌(唯一)
|
||||
expires_at = Column(DateTime, nullable=False) # 过期时间(30天后)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
expires_at = Column(DateTime(timezone=True), nullable=False) # 过期时间(30天后)
|
||||
created_at = Column(DateTime(timezone=True), default=utc_now)
|
||||
is_revoked = Column(Boolean, default=False) # 是否已撤销
|
||||
|
||||
# Relationships
|
||||
|
||||
Reference in New Issue
Block a user