diff --git a/api/database/__init__.py b/api/database/__init__.py index 3067e8a..92376aa 100644 --- a/api/database/__init__.py +++ b/api/database/__init__.py @@ -2,13 +2,14 @@ 数据库模块 """ from .database import get_db, get_async_db, init_db -from .models import User, Conversation, Segment, Chapter, Book +from .models import User, Conversation, Segment, Chapter, Book, RefreshToken __all__ = [ "get_db", "get_async_db", "init_db", "User", + "RefreshToken", "Conversation", "Segment", "Chapter", diff --git a/api/database/database.py b/api/database/database.py index 0ea70ec..5a40a0d 100644 --- a/api/database/database.py +++ b/api/database/database.py @@ -9,8 +9,24 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sess from .models import Base # 数据库文件路径 -DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./life_echo.db") -ASYNC_DATABASE_URL = DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://") +# 从环境变量获取,如果格式不正确则使用默认值 +raw_database_url = os.getenv("DATABASE_URL", "sqlite:///./life_echo.db") + +# 处理数据库 URL +# 如果已经是异步格式,需要提取同步格式用于同步引擎 +if raw_database_url.startswith("sqlite+aiosqlite://"): + # 提取文件路径(移除协议部分) + file_path = raw_database_url.replace("sqlite+aiosqlite://", "") + DATABASE_URL = f"sqlite://{file_path}" + ASYNC_DATABASE_URL = raw_database_url +elif raw_database_url.startswith("sqlite://"): + DATABASE_URL = raw_database_url + ASYNC_DATABASE_URL = raw_database_url.replace("sqlite://", "sqlite+aiosqlite://") +else: + # 如果格式不正确,使用默认值并打印警告 + print(f"警告: DATABASE_URL 格式不正确 ({raw_database_url}),使用默认值") + DATABASE_URL = "sqlite:///./life_echo.db" + ASYNC_DATABASE_URL = "sqlite+aiosqlite:///./life_echo.db" # 创建同步引擎(用于迁移等) engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) diff --git a/api/database/models.py b/api/database/models.py index 57570e5..56a5f0e 100644 --- a/api/database/models.py +++ b/api/database/models.py @@ -15,7 +15,10 @@ class User(Base): __tablename__ = "users" id = Column(String, primary_key=True) - openid = Column(String, unique=True, nullable=True) # 微信 OpenID + phone = Column(String, unique=True, nullable=False, index=True) # 手机号(唯一,必填) + password_hash = Column(String, nullable=False) # 密码哈希 + email = Column(String, unique=True, nullable=True) # 邮箱(可选) + openid = Column(String, unique=True, nullable=True) # 微信 OpenID(可选) nickname = Column(String, nullable=False) avatar_url = Column(String, nullable=True) subscription_type = Column(String, default="free") # free, premium @@ -25,6 +28,7 @@ class User(Base): conversations = relationship("Conversation", back_populates="user") chapters = relationship("Chapter", back_populates="user") books = relationship("Book", back_populates="user") + refresh_tokens = relationship("RefreshToken", back_populates="user", cascade="all, delete-orphan") class Conversation(Base): @@ -96,3 +100,18 @@ class Book(Base): # Relationships user = relationship("User", back_populates="books") + +class RefreshToken(Base): + """刷新令牌表""" + __tablename__ = "refresh_tokens" + + id = Column(String, primary_key=True) + user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True) + token = Column(String, unique=True, nullable=False, index=True) # 刷新令牌(唯一) + expires_at = Column(DateTime, nullable=False) # 过期时间(30天后) + created_at = Column(DateTime, default=datetime.utcnow) + is_revoked = Column(Boolean, default=False) # 是否已撤销 + + # Relationships + user = relationship("User", back_populates="refresh_tokens") +