refactor: 更新数据库和模型
- 更新数据库连接和会话管理 - 更新数据模型以支持用户认证 - 添加RefreshToken模型用于刷新令牌管理
This commit is contained in:
@@ -2,13 +2,14 @@
|
|||||||
数据库模块
|
数据库模块
|
||||||
"""
|
"""
|
||||||
from .database import get_db, get_async_db, init_db
|
from .database import get_db, get_async_db, init_db
|
||||||
from .models import User, Conversation, Segment, Chapter, Book
|
from .models import User, Conversation, Segment, Chapter, Book, RefreshToken
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_db",
|
"get_db",
|
||||||
"get_async_db",
|
"get_async_db",
|
||||||
"init_db",
|
"init_db",
|
||||||
"User",
|
"User",
|
||||||
|
"RefreshToken",
|
||||||
"Conversation",
|
"Conversation",
|
||||||
"Segment",
|
"Segment",
|
||||||
"Chapter",
|
"Chapter",
|
||||||
|
|||||||
@@ -9,8 +9,24 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sess
|
|||||||
from .models import Base
|
from .models import Base
|
||||||
|
|
||||||
# 数据库文件路径
|
# 数据库文件路径
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./life_echo.db")
|
# 从环境变量获取,如果格式不正确则使用默认值
|
||||||
ASYNC_DATABASE_URL = DATABASE_URL.replace("sqlite://", "sqlite+aiosqlite://")
|
raw_database_url = os.getenv("DATABASE_URL", "sqlite:///./life_echo.db")
|
||||||
|
|
||||||
|
# 处理数据库 URL
|
||||||
|
# 如果已经是异步格式,需要提取同步格式用于同步引擎
|
||||||
|
if raw_database_url.startswith("sqlite+aiosqlite://"):
|
||||||
|
# 提取文件路径(移除协议部分)
|
||||||
|
file_path = raw_database_url.replace("sqlite+aiosqlite://", "")
|
||||||
|
DATABASE_URL = f"sqlite://{file_path}"
|
||||||
|
ASYNC_DATABASE_URL = raw_database_url
|
||||||
|
elif raw_database_url.startswith("sqlite://"):
|
||||||
|
DATABASE_URL = raw_database_url
|
||||||
|
ASYNC_DATABASE_URL = raw_database_url.replace("sqlite://", "sqlite+aiosqlite://")
|
||||||
|
else:
|
||||||
|
# 如果格式不正确,使用默认值并打印警告
|
||||||
|
print(f"警告: DATABASE_URL 格式不正确 ({raw_database_url}),使用默认值")
|
||||||
|
DATABASE_URL = "sqlite:///./life_echo.db"
|
||||||
|
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///./life_echo.db"
|
||||||
|
|
||||||
# 创建同步引擎(用于迁移等)
|
# 创建同步引擎(用于迁移等)
|
||||||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
|
||||||
|
|||||||
@@ -15,7 +15,10 @@ class User(Base):
|
|||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True)
|
||||||
openid = Column(String, unique=True, nullable=True) # 微信 OpenID
|
phone = Column(String, unique=True, nullable=False, index=True) # 手机号(唯一,必填)
|
||||||
|
password_hash = Column(String, nullable=False) # 密码哈希
|
||||||
|
email = Column(String, unique=True, nullable=True) # 邮箱(可选)
|
||||||
|
openid = Column(String, unique=True, nullable=True) # 微信 OpenID(可选)
|
||||||
nickname = Column(String, nullable=False)
|
nickname = Column(String, nullable=False)
|
||||||
avatar_url = Column(String, nullable=True)
|
avatar_url = Column(String, nullable=True)
|
||||||
subscription_type = Column(String, default="free") # free, premium
|
subscription_type = Column(String, default="free") # free, premium
|
||||||
@@ -25,6 +28,7 @@ class User(Base):
|
|||||||
conversations = relationship("Conversation", back_populates="user")
|
conversations = relationship("Conversation", back_populates="user")
|
||||||
chapters = relationship("Chapter", back_populates="user")
|
chapters = relationship("Chapter", back_populates="user")
|
||||||
books = relationship("Book", back_populates="user")
|
books = relationship("Book", back_populates="user")
|
||||||
|
refresh_tokens = relationship("RefreshToken", back_populates="user", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
|
||||||
class Conversation(Base):
|
class Conversation(Base):
|
||||||
@@ -96,3 +100,18 @@ class Book(Base):
|
|||||||
# Relationships
|
# Relationships
|
||||||
user = relationship("User", back_populates="books")
|
user = relationship("User", back_populates="books")
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Base):
|
||||||
|
"""刷新令牌表"""
|
||||||
|
__tablename__ = "refresh_tokens"
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
user_id = Column(String, ForeignKey("users.id"), nullable=False, index=True)
|
||||||
|
token = Column(String, unique=True, nullable=False, index=True) # 刷新令牌(唯一)
|
||||||
|
expires_at = Column(DateTime, nullable=False) # 过期时间(30天后)
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||||||
|
is_revoked = Column(Boolean, default=False) # 是否已撤销
|
||||||
|
|
||||||
|
# Relationships
|
||||||
|
user = relationship("User", back_populates="refresh_tokens")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user