chore: 更新Docker配置,优化路由

- 更新docker-compose.yml
- 优化conversations.py、plans.py、quota.py、user.py、websocket.py

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
iammm0
2026-02-10 14:23:40 +08:00
parent e39fd97e06
commit 498277aac3
6 changed files with 317 additions and 145 deletions

View File

@@ -60,7 +60,7 @@ async def create_conversation(
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
):
"""创建新对话(需要认证)"""
"""创建新对话(需要认证)。对话轮数在每次发送消息时校验。"""
conversation = ConversationModel(
id=str(uuid.uuid4()),
user_id=current_user.id,
@@ -243,7 +243,16 @@ async def organize_conversation(
if not segments:
raise HTTPException(status_code=400, detail="该对话没有可整理的内容")
# 免费版仅允许 1 个章节整理Pro/Pro+ 无限制
from routers.quota import get_chapter_count, check_can_submit_organize
chapter_count = await get_chapter_count(current_user.id, db)
can_submit, quota_message = check_can_submit_organize(
current_user.subscription_type, chapter_count
)
if not can_submit:
raise HTTPException(status_code=403, detail=quota_message)
# 提交到Celery任务处理
try:
from routers.websocket import manager

View File

@@ -39,47 +39,66 @@ class CurrentPlanResponse(BaseModel):
# 预定义的订阅计划
# 免费版50 轮对话 + 1 个章节整理
# Pro88 元2000 轮对话,无章节限制
# Pro+288 元10000 轮对话,无章节限制
AVAILABLE_PLANS = [
PlanResponse(
id="free",
name="free",
display_name="免费版",
display_name="免费体验",
price=0.0,
currency="CNY",
features=[
"基础对话功能",
"生成回忆录章节",
"最多3次对话",
"最多10个章节"
"50 轮对话",
"1 个章节整理(所有对话整理到一个章节",
"体验回忆录生成流程"
],
max_conversations=3,
max_chapters=10,
max_words=50000,
max_conversations=50,
max_chapters=1,
max_words=None,
is_popular=False
),
PlanResponse(
id="premium",
name="premium",
display_name="高级",
price=99.0,
id="pro",
name="pro",
display_name="Pro ",
price=88.0,
currency="CNY",
features=[
"无限对话",
"章节",
"无限字数",
"优先处理",
"专属客服支持"
"2000 轮对话",
"无章节限制",
"完整回忆录生成"
],
max_conversations=None,
max_conversations=2000,
max_chapters=None,
max_words=None,
is_popular=True
),
PlanResponse(
id="pro_plus",
name="pro_plus",
display_name="Pro+ 版",
price=288.0,
currency="CNY",
features=[
"10000 轮对话",
"无章节限制",
"完整回忆录生成",
"长期创作无忧"
],
max_conversations=10000,
max_chapters=None,
max_words=None,
is_popular=False
)
]
def get_plan_by_type(subscription_type: str) -> Optional[PlanResponse]:
"""根据订阅类型获取计划信息"""
"""根据订阅类型获取计划信息。旧字段 premium 按 pro 展示。"""
if subscription_type == "premium":
subscription_type = "pro"
for plan in AVAILABLE_PLANS:
if plan.id == subscription_type:
return plan
@@ -104,46 +123,28 @@ async def get_current_plan(
"""
plan = get_plan_by_type(current_user.subscription_type)
# 计算使用情况
from sqlalchemy import select, func
# 统计对话数量
from database.models import Conversation
stmt = select(func.count(Conversation.id)).where(
Conversation.user_id == current_user.id
)
result = await db.execute(stmt)
conversation_count = result.scalar() or 0
# 统计章节数量
from database.models import Chapter
stmt = select(func.count(Chapter.id)).where(
Chapter.user_id == current_user.id
)
result = await db.execute(stmt)
chapter_count = result.scalar() or 0
# 统计总字数
stmt = select(func.sum(func.length(Chapter.content))).where(
Chapter.user_id == current_user.id
)
result = await db.execute(stmt)
total_words = result.scalar() or 0
# 计算使用情况(对话轮数 = Segment 数量)
from routers.quota import get_segment_count, get_chapter_count
segment_count = await get_segment_count(current_user.id, db)
chapter_count = await get_chapter_count(current_user.id, db)
usage = {
"conversations": conversation_count,
"conversations": segment_count, # 已用对话轮数
"chapters": chapter_count,
"words": total_words,
"max_conversations": plan.max_conversations,
"max_chapters": plan.max_chapters,
"max_words": plan.max_words
}
expires_at = None
if current_user.subscription_expires_at:
expires_at = current_user.subscription_expires_at.isoformat()
return CurrentPlanResponse(
plan_id=plan.id,
plan_name=plan.display_name,
subscription_type=current_user.subscription_type,
expires_at=None, # 目前没有过期时间概念
expires_at=expires_at,
features=plan.features,
usage=usage
)

View File

@@ -1,7 +1,9 @@
"""
配额检查 API 路由
「对话轮数」的定义每条用户发出的消息Segment 表的记录数)计为 1 轮。
"""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from typing import Optional
@@ -16,111 +18,164 @@ router = APIRouter(prefix="/api/quota", tags=["quota"])
class QuotaCheckResponse(BaseModel):
"""配额检查响应"""
has_quota: bool # 是否有配额
remaining_conversations: Optional[int] = None # 剩余对话
remaining_chapters: Optional[int] = None # 剩余章节数
remaining_words: Optional[int] = None # 剩余字数
message: str # 提示信息
has_quota: bool
remaining_conversations: Optional[int] = None # 剩余对话
remaining_chapters: Optional[int] = None
remaining_words: Optional[int] = None
# 已用量(前端展示 "已用 X / 共 Y"
used_conversations: Optional[int] = None
used_chapters: Optional[int] = None
max_conversations: Optional[int] = None
max_chapters: Optional[int] = None
message: str
# 计划配额限制
# 计划配额限制(与 plans.py 中 AVAILABLE_PLANS 一致)
# 免费50 轮对话 + 1 个章节整理Pro2000 轮无章节限制Pro+10000 轮无章节限制
PLAN_QUOTAS = {
"free": {
"max_conversations": 3,
"max_chapters": 10,
"max_words": 50000
"max_conversations": 50,
"max_chapters": 1,
"max_words": None
},
"pro": {
"max_conversations": 2000,
"max_chapters": None,
"max_words": None
},
"pro_plus": {
"max_conversations": 10000,
"max_chapters": None,
"max_words": None
},
# 兼容旧字段
"premium": {
"max_conversations": None, # 无限制
"max_conversations": None,
"max_chapters": None,
"max_words": None
}
}
async def get_segment_count(user_id: str, db: AsyncSession) -> int:
"""
获取用户已消耗的对话轮数(= 该用户所有 Segment 记录数)。
每条 Segment 对应一次用户发送的消息(文本/语音)。
"""
from database.models import Segment, Conversation
stmt = (
select(func.count(Segment.id))
.join(Conversation, Segment.conversation_id == Conversation.id)
.where(Conversation.user_id == user_id)
)
result = await db.execute(stmt)
return result.scalar() or 0
async def get_chapter_count(user_id: str, db: AsyncSession) -> int:
"""获取用户当前章节数量"""
from database.models import Chapter
stmt = select(func.count(Chapter.id)).where(Chapter.user_id == user_id)
result = await db.execute(stmt)
return result.scalar() or 0
# 保留旧名称别名,避免已有引用报错
async def get_conversation_count(user_id: str, db: AsyncSession) -> int:
"""别名:实际按 Segment 计数"""
return await get_segment_count(user_id, db)
def check_can_send_message(
subscription_type: str,
segment_count: int
) -> tuple[bool, str]:
"""
检查用户是否还能发送消息(对话轮数)。
返回 (是否允许, 提示信息)。
"""
quotas = PLAN_QUOTAS.get(subscription_type, PLAN_QUOTAS["free"])
max_conv = quotas.get("max_conversations")
if max_conv is None:
return True, ""
if segment_count >= max_conv:
return False, f"对话轮数已用完({segment_count}/{max_conv}),请升级 Pro 或 Pro+ 继续使用"
return True, ""
# 兼容旧调用
def check_can_create_conversation(
subscription_type: str,
conversation_count: int
) -> tuple[bool, str]:
return check_can_send_message(subscription_type, conversation_count)
def check_can_submit_organize(
subscription_type: str,
chapter_count: int
) -> tuple[bool, str]:
"""
检查是否可以提交整理任务(生成新章节)。
免费版仅允许 1 个章节。
返回 (是否允许, 提示信息)。
"""
quotas = PLAN_QUOTAS.get(subscription_type, PLAN_QUOTAS["free"])
max_ch = quotas.get("max_chapters")
if max_ch is None:
return True, ""
if chapter_count >= max_ch:
return False, "章节数量已达上限(免费版仅支持 1 个章节整理),请升级后继续"
return True, ""
@router.get("/check", response_model=QuotaCheckResponse)
async def check_quota(
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
):
"""
检查用户配额使用情况
根据用户的订阅计划检查是否还有配额可以使用
"""
plan_type = current_user.subscription_type
quotas = PLAN_QUOTAS.get(plan_type, PLAN_QUOTAS["free"])
# 如果是高级版,无限制
if plan_type == "premium":
return QuotaCheckResponse(
has_quota=True,
remaining_conversations=None,
remaining_chapters=None,
remaining_words=None,
message="高级版用户,无使用限制"
)
# 统计使用情况
async for db in get_async_db():
# 统计对话数量
from database.models import Conversation
stmt = select(func.count(Conversation.id)).where(
Conversation.user_id == current_user.id
)
result = await db.execute(stmt)
conversation_count = result.scalar() or 0
# 统计章节数量
from database.models import Chapter
stmt = select(func.count(Chapter.id)).where(
Chapter.user_id == current_user.id
)
result = await db.execute(stmt)
chapter_count = result.scalar() or 0
# 统计总字数
stmt = select(func.sum(func.length(Chapter.content))).where(
Chapter.user_id == current_user.id
)
result = await db.execute(stmt)
total_words = result.scalar() or 0
# 计算剩余配额
max_conversations = quotas.get("max_conversations")
max_chapters = quotas.get("max_chapters")
max_words = quotas.get("max_words")
remaining_conversations = None
remaining_chapters = None
remaining_words = None
if max_conversations is not None:
remaining_conversations = max(0, max_conversations - conversation_count)
if max_chapters is not None:
remaining_chapters = max(0, max_chapters - chapter_count)
if max_words is not None:
remaining_words = max(0, max_words - total_words)
# 检查是否有配额
has_quota = True
message = "配额充足"
if max_conversations is not None and conversation_count >= max_conversations:
has_quota = False
message = "对话次数已用完,请升级到高级版"
elif max_chapters is not None and chapter_count >= max_chapters:
has_quota = False
message = "章节数量已达上限,请升级到高级版"
elif max_words is not None and total_words >= max_words:
has_quota = False
message = "字数已达上限,请升级到高级版"
return QuotaCheckResponse(
has_quota=has_quota,
remaining_conversations=remaining_conversations,
remaining_chapters=remaining_chapters,
remaining_words=remaining_words,
message=message
)
# 统计已用量
segment_count = await get_segment_count(current_user.id, db)
chapter_count = await get_chapter_count(current_user.id, db)
max_conversations = quotas.get("max_conversations")
max_chapters = quotas.get("max_chapters")
max_words = quotas.get("max_words")
remaining_conversations = None
remaining_chapters = None
remaining_words = None
if max_conversations is not None:
remaining_conversations = max(0, max_conversations - segment_count)
if max_chapters is not None:
remaining_chapters = max(0, max_chapters - chapter_count)
# 检查是否有配额
has_quota = True
message = "配额充足"
if max_conversations is not None and segment_count >= max_conversations:
has_quota = False
message = f"对话轮数已用完({segment_count}/{max_conversations}),请升级 Pro 或 Pro+ 继续使用"
elif max_chapters is not None and chapter_count >= max_chapters:
has_quota = False
message = "章节数量已达上限(免费版仅支持 1 个章节整理),请升级后继续"
return QuotaCheckResponse(
has_quota=has_quota,
remaining_conversations=remaining_conversations,
remaining_chapters=remaining_chapters,
remaining_words=remaining_words,
used_conversations=segment_count,
used_chapters=chapter_count,
max_conversations=max_conversations,
max_chapters=max_chapters,
message=message
)

View File

@@ -1,15 +1,23 @@
"""
用户相关 API 路由
"""
from fastapi import APIRouter, Depends
import os
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from typing import Optional
from typing import Optional, Literal
from middleware.auth import get_current_user
from database.models import User
from database.models import User, utc_now
from database import get_async_db
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter(prefix="/api/user", tags=["user"])
# 是否开启测试订阅(仅用于微信支付审核未通过前的测试)
ENABLE_TEST_SUBSCRIPTION = os.getenv("ENABLE_TEST_SUBSCRIPTION", "").lower() in ("1", "true", "yes")
class UserProfileResponse(BaseModel):
"""用户资料响应"""
@@ -22,13 +30,26 @@ class UserProfileResponse(BaseModel):
created_at: str
class TestSubscriptionRequest(BaseModel):
"""测试订阅请求"""
action: Literal["activate", "deactivate"]
plan_id: Optional[str] = "pro" # activate 时生效,仅支持 pro / pro_plus
class TestSubscriptionResponse(BaseModel):
"""测试订阅响应"""
success: bool
message: str
subscription_type: str
@router.get("/profile", response_model=UserProfileResponse)
async def get_user_profile(
current_user: User = Depends(get_current_user)
):
"""
获取当前用户资料
与 /api/auth/me 功能相同,但路径不同以满足前端需求
"""
return UserProfileResponse(
@@ -40,3 +61,43 @@ async def get_user_profile(
subscription_type=current_user.subscription_type,
created_at=current_user.created_at.isoformat()
)
@router.post("/test-subscription", response_model=TestSubscriptionResponse)
async def test_subscription(
body: TestSubscriptionRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db),
):
"""
测试订阅开关(仅当 ENABLE_TEST_SUBSCRIPTION=1 时可用)。
- activate将当前用户设为付费套餐pro 或 pro_plus用于在微信支付审核通过前测试付费后额度。
- deactivate恢复为免费体验版。
"""
if not ENABLE_TEST_SUBSCRIPTION:
raise HTTPException(status_code=404, detail="测试订阅功能未开放")
now = utc_now()
if body.action == "activate":
if body.plan_id not in ("pro", "pro_plus"):
raise HTTPException(status_code=400, detail="plan_id 仅支持 pro 或 pro_plus")
current_user.subscription_type = body.plan_id
current_user.subscription_expires_at = now + timedelta(days=365)
await db.flush()
return TestSubscriptionResponse(
success=True,
message=f"已开启测试订阅:{body.plan_id}",
subscription_type=body.plan_id,
)
# deactivate
current_user.subscription_type = "free"
current_user.subscription_expires_at = None
await db.flush()
return TestSubscriptionResponse(
success=True,
message="已关闭测试订阅,恢复免费体验版",
subscription_type="free",
)

View File

@@ -170,6 +170,18 @@ async def websocket_endpoint(
text_message = message.get("data", {}).get("text", "")
if text_message:
# 校验对话轮数配额
from routers.quota import get_segment_count, check_can_send_message
seg_count = await get_segment_count(user_id, db)
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
# 保存段落到数据库
segment = Segment(
id=str(uuid.uuid4()),
@@ -199,6 +211,18 @@ async def websocket_endpoint(
audio_duration = data.get("duration", 0)
if audio_base64:
# 校验对话轮数配额
from routers.quota import get_segment_count, check_can_send_message
seg_count = await get_segment_count(user_id, db)
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
logger.info(f"收到音频消息,时长: {audio_duration}s")
try:
@@ -427,7 +451,21 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession):
# 没有未处理的段落,直接 flush 待处理任务
await manager.background_runner.flush_pending(conversation.user_id)
return
# 免费版仅允许 1 个章节整理,提交前校验
from database.models import User as UserModel
from routers.quota import get_chapter_count, check_can_submit_organize
user = await db.get(UserModel, conversation.user_id)
if user:
chapter_count = await get_chapter_count(user.id, db)
can_submit, _ = check_can_submit_organize(user.subscription_type, chapter_count)
if not can_submit:
logger.info(
f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}"
)
await manager.background_runner.flush_pending(conversation.user_id)
return
# 将未处理的段落直接提交到 Celery不通过去抖
segment_ids = [seg.id for seg in segments]
try: