Files
life-echo/api/app/features/conversation/ws/pipeline.py
Kevin 69a673e6c6 feat(api): 访谈人格/回复长度策略、口述归一、背景语气与输入净稿全链路
Chat 访谈
- 新增 persona 系统(default / warm_listener / curious_guide)与 background_voice 语气层
- 回复长度由 compute_reply_plan 统一决策(brief / standard / expanded),融合信息密度启发式
- 输入净稿(input_normalize):编排层可选 rules/llm 归一用户口语后再喂模型与记忆检索
- 记忆证据注入:按用户话检索 memory evidence 并注入 prompt

Memoir 回忆录
- 口述归一(oral_normalize):segment 原文保留,story 管线取派生净稿作叙事输入
- segment 入队批次门闸:累计字数 + 最长等待秒数,减少零碎提交
- fidelity_check / prompts / narrative_agent 微调
- Alembic 0005:清理跨章节 story 外键

Infra
- Dockerfile 加入 ffmpeg
- pyproject.toml 新增依赖并同步 uv.lock
- .env.example / .env.production 补全新配置项

Tests
- 新增 test_background_voice、test_chat_input_normalize、test_experience_regressions
- 扩展 test_interview_prompts、test_interview_reply_length、test_story_route_oral_invariant

Made-with: Cursor
2026-03-31 23:55:26 +08:00

840 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""核心消息处理管道Agent 调用、ASR 转写、分段有序聚合"""
import asyncio
import base64
import io
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from app.core.logging import get_logger
if TYPE_CHECKING:
from app.features.quota.service import QuotaService
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat import ChatOrchestrator
from app.core.agent_logging import agent_summary_enabled
from app.core.config import settings
from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC
from app.core.db import AsyncSessionLocal
from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider
from app.features.conversation.history_store import (
AI_RESPONSE_SEGMENT_JOIN,
ConversationHistoryStore,
)
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import MessageType
from app.features.conversation.ws.profile_collector import (
apply_extracted_profile,
get_filled_profile_fields,
get_missing_profile_fields,
)
from app.features.memoir.background_runner import BackgroundTaskRunner
from app.features.user.models import User
logger = get_logger(__name__)
# 客户端发送 tts_cancel 时递增process_user_message 内 TTS 循环与合成前后对照,用于短路剩余片段
_tts_cancel_epoch: dict[str, int] = {}
def bump_tts_cancel_epoch(conversation_id: str) -> None:
_tts_cancel_epoch[conversation_id] = _tts_cancel_epoch.get(conversation_id, 0) + 1
def _tts_epoch_value(conversation_id: str) -> int:
return _tts_cancel_epoch.get(conversation_id, 0)
def _tts_object_ext(codec: str) -> str:
c = (codec or "mp3").lower().lstrip(".")
if c in ("wave",):
return "wav"
return c if c else "mp3"
def _tts_codec_to_content_type(codec: str) -> str:
c = (codec or "mp3").lower().lstrip(".")
if c == "mp3":
return "audio/mpeg"
if c in ("wav", "wave"):
return "audio/wav"
return "application/octet-stream"
async def _send_tts_audio(
conversation_id: str,
text: str,
*,
chunk_index: int,
chunk_total: int,
assistant_message_id: str | None,
tts_epoch_start: int,
) -> str | None:
"""Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None."""
if not settings.enable_tts:
return None
if _tts_epoch_value(conversation_id) != tts_epoch_start:
return None
try:
tts = get_tts_provider()
audio_bytes = await tts.synthesize(text)
if not audio_bytes:
logger.warning(
"TTS skipped: synthesize returned empty. Check TTS config in .env"
)
return None
if _tts_epoch_value(conversation_id) != tts_epoch_start:
return None
ext = _tts_object_ext(settings.tts_codec)
content_type = _tts_codec_to_content_type(settings.tts_codec)
storage = get_object_storage()
key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}"
public_url = storage.upload(key, audio_bytes, content_type)
# 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL
playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC)
payload_data: Dict[str, Any] = {
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"format": settings.tts_codec,
"audio_url": playback_url,
"index": chunk_index,
"total": chunk_total,
}
if assistant_message_id:
payload_data["assistant_message_id"] = assistant_message_id
await manager.send_message(
conversation_id,
{
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": payload_data,
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return public_url
except Exception as e:
err_str = str(e)
if "PkgExhausted" in err_str:
logger.warning(
"TTS skipped: 腾讯云语音合成资源包已用尽,请在控制台购买或开通后付费: {}",
err_str[:100],
)
else:
logger.error("TTS synthesize failed: {}", e)
return None
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
chat_orchestrator = ChatOrchestrator()
background_runner = BackgroundTaskRunner()
# ── 分段流状态 ──────────────────────────────────────────────────
@dataclass
class SegmentStreamState:
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
pending_indices: Set[int] = field(default_factory=set)
processed_indices: Set[int] = field(default_factory=set)
buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict)
consumed_index: int = -1
active_tasks: Set[asyncio.Task] = field(default_factory=set)
listening_feedback_sent: bool = False
listening_feedback_task: Optional[asyncio.Task] = None
_segment_states: Dict[Tuple[str, str], SegmentStreamState] = {}
def get_or_create_segment_state(
conversation_id: str,
voice_session_id: str,
) -> SegmentStreamState:
state_key = (conversation_id, voice_session_id)
if state_key not in _segment_states:
_segment_states[state_key] = SegmentStreamState()
return _segment_states[state_key]
def register_segment_task(
conversation_id: str,
voice_session_id: str,
task: asyncio.Task,
) -> None:
state_key = (conversation_id, voice_session_id)
state = get_or_create_segment_state(conversation_id, voice_session_id)
state.active_tasks.add(task)
def _cleanup(done_task: asyncio.Task) -> None:
state.active_tasks.discard(done_task)
if not state.active_tasks and conversation_id not in manager.active_connections:
_segment_states.pop(state_key, None)
if done_task.cancelled():
return
exc = done_task.exception()
if exc:
logger.error(
"分段处理任务异常 "
f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}",
exc_info=True,
)
task.add_done_callback(_cleanup)
def cleanup_segment_states(conversation_id: str) -> None:
"""断开连接后清理无活跃任务的分段状态"""
stale_keys = [
key
for key, state in _segment_states.items()
if key[0] == conversation_id and not state.active_tasks
]
for key in stale_keys:
_segment_states.pop(key, None)
# ── 工具函数 ────────────────────────────────────────────────────
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _mark_conversation_active(
conversation: Conversation, at: Optional[datetime] = None
) -> datetime:
activity_time = at or _utc_now()
conversation.last_message_at = activity_time
return activity_time
def _voice_session_id_from_client_segment_id(
client_segment_id: Optional[str],
) -> Optional[str]:
if not client_segment_id:
return None
session_id, separator, _ = client_segment_id.rpartition("-")
if separator and session_id:
return session_id
return None
def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str:
"""构建分段语音的幂等标识conversation_id + voice_session_id + segment_index"""
return f"audio-segment:{voice_session_id}:{segment_index}"
def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]:
"""从 audio_url 解析 voice_session_id 与 segment_indexaudio-segment:{session_id}:{index})。"""
prefix = "audio-segment:"
if not audio_url or not audio_url.startswith(prefix):
return None
payload = audio_url[len(prefix) :]
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
if not separator:
return None
try:
sid = str(voice_session_id_raw).strip()
if not sid:
return None
return (sid, int(segment_index_raw))
except ValueError:
return None
def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]:
scope = _extract_segment_scope(audio_url)
if scope:
return scope[0]
return None
def _is_transcribe_failure(transcript_text: Optional[str]) -> bool:
if not transcript_text:
return True
return transcript_text.startswith("转写失败")
async def _find_existing_segment_by_index(
db: AsyncSession,
conversation_id: str,
voice_session_id: str,
segment_index: int,
) -> Optional[Segment]:
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
stmt = (
select(Segment)
.where(
Segment.conversation_id == conversation_id,
Segment.audio_url == segment_audio_url,
)
.order_by(Segment.created_at.desc())
)
result = await db.execute(stmt)
candidates = result.scalars().all()
for item in candidates:
if (
item.conversation_id == conversation_id
and item.audio_url == segment_audio_url
):
return item
return None
async def _get_persisted_contiguous_segment_index(
db: AsyncSession,
conversation_id: str,
voice_session_id: str,
) -> int:
"""读取数据库中当前 voice session 已连续落库的最大 segment_index用于重连恢复。"""
stmt = select(Segment).where(Segment.conversation_id == conversation_id)
result = await db.execute(stmt)
candidates = result.scalars().all()
persisted_indices: Set[int] = set()
for item in candidates:
if item.conversation_id != conversation_id:
continue
segment_scope = _extract_segment_scope(item.audio_url)
if not segment_scope:
continue
item_voice_session_id, item_index = segment_scope
if item_voice_session_id != voice_session_id:
continue
persisted_indices.add(item_index)
contiguous_index = -1
while contiguous_index + 1 in persisted_indices:
contiguous_index += 1
return contiguous_index
# ── 过渡反馈 ────────────────────────────────────────────────────
LISTENING_FEEDBACK_DELAY_SEC = 5.0
LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。"
async def _send_segment_transition_feedback(
conversation_id: str,
segment_index: int,
) -> None:
"""发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。"""
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": LISTENING_FEEDBACK_TEXT,
"transition": True,
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
async def _delayed_listening_feedback(
conversation_id: str,
voice_session_id: str,
) -> None:
"""录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。"""
await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC)
state = get_or_create_segment_state(conversation_id, voice_session_id)
async with state.lock:
if state.listening_feedback_sent:
return
state.listening_feedback_sent = True
state.listening_feedback_task = None
await _send_segment_transition_feedback(conversation_id, 0)
# ── 长音频切片转写 ────────────────────────────────────────────
MAX_ASR_CHUNK_MS = 55_000
def _split_audio_bytes(audio_bytes: bytes, fmt: str) -> list[bytes]:
"""用 pydub 将长音频按 ≤55 s 切片,每片导出为 16 kHz mono WAV腾讯 ASR 3 MB 限制内)。"""
from pydub import AudioSegment as PydubSegment
audio = PydubSegment.from_file(io.BytesIO(audio_bytes), format=fmt)
duration_ms = len(audio)
if duration_ms <= MAX_ASR_CHUNK_MS:
return [audio_bytes]
mono_16k = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2)
chunks: list[bytes] = []
for start in range(0, duration_ms, MAX_ASR_CHUNK_MS):
chunk = mono_16k[start : start + MAX_ASR_CHUNK_MS]
buf = io.BytesIO()
chunk.export(buf, format="wav")
chunks.append(buf.getvalue())
return chunks
async def _transcribe_long_audio(audio_bytes: bytes, fmt: str = "m4a") -> str:
"""超过 55 s 的音频自动切片后并行 ASR短音频直接转写。"""
asr = get_asr_provider()
try:
chunks = await asyncio.to_thread(_split_audio_bytes, audio_bytes, fmt)
except Exception as exc:
logger.warning("pydub 切片失败 ({}), 回退到直接转写", exc)
return await asr.transcribe(audio_bytes, format=fmt)
if len(chunks) <= 1:
return await asr.transcribe(audio_bytes, format=fmt)
logger.info("长音频切片: {}", len(chunks))
results = await asyncio.gather(
*[asr.transcribe(c, format="wav") for c in chunks],
return_exceptions=True,
)
texts: list[str] = []
for i, r in enumerate(results):
if isinstance(r, BaseException):
logger.warning("切片 {} 转写异常: {}", i, r)
continue
if r and not _is_transcribe_failure(r):
texts.append(r)
return "".join(texts)
# ── 分段语音异步处理 ────────────────────────────────────────────
async def process_audio_segment(
conversation_id: str,
user_id: str,
voice_session_id: str,
segment_index: int,
audio_base64: str,
audio_duration: int,
is_last: bool,
) -> None:
"""分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。"""
state = get_or_create_segment_state(conversation_id, voice_session_id)
logger.info(
"process_audio_segment 开始: conversation_id={} voice_session_id={} "
"segment_index={} is_last={} duration_s={} audio_b64_len={}",
conversation_id,
voice_session_id,
segment_index,
is_last,
audio_duration,
len(audio_base64 or ""),
)
try:
async with AsyncSessionLocal() as db:
conversation = await db.get(Conversation, conversation_id)
user = await db.get(User, user_id)
if not conversation or conversation.deleted_at is not None:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "对话不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
if not user:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "用户不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
async with state.lock:
should_prime_state = (
state.consumed_index < 0
and not state.processed_indices
and not state.buffered_transcripts
)
if should_prime_state:
persisted_contiguous_index = (
await _get_persisted_contiguous_segment_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
)
)
if persisted_contiguous_index >= 0:
async with state.lock:
state.consumed_index = max(
state.consumed_index, persisted_contiguous_index
)
try:
audio_bytes = base64.b64decode(audio_base64)
except Exception:
audio_bytes = b""
if not audio_bytes:
logger.warning(
"process_audio_segment: 解码后音频为空 conversation_id={} segment_index={}",
conversation_id,
segment_index,
)
transcript_text = await _transcribe_long_audio(audio_bytes, fmt="m4a")
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": transcript_text or "",
"audio_duration": audio_duration,
"voice_session_id": voice_session_id,
"segment_index": segment_index,
"is_last": is_last,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
if _is_transcribe_failure(transcript_text):
detail = (transcript_text or "").strip()
if detail.startswith("转写失败"):
user_msg = f"分段 {segment_index} {detail}"
elif not detail:
user_msg = f"分段 {segment_index} 转写失败:未识别到内容(请检查后端 ASR 配置)"
else:
user_msg = f"分段 {segment_index} 转写失败:{detail[:400]}"
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": user_msg,
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
existing_segment = await _find_existing_segment_by_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
segment_index=segment_index,
)
if existing_segment:
async with state.lock:
state.processed_indices.add(segment_index)
logger.debug(
"分段已存在,按幂等跳过: conversation_id={} voice_session_id={} "
"segment_index={} segment_id={} transcript={}",
conversation_id,
voice_session_id,
segment_index,
existing_segment.id,
existing_segment.user_input_text or "",
)
return
else:
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
user_input_text=transcript_text or "",
audio_url=_build_segment_audio_url(voice_session_id, segment_index),
audio_duration_seconds=audio_duration
if audio_duration > 0
else None,
processed=False,
)
db.add(segment)
user_message_timestamp = _mark_conversation_active(conversation)
await db.commit()
await db.refresh(segment)
await background_runner.queue_message(
conversation.user_id,
segment.id,
text_char_count=len((transcript_text or "").strip()),
)
ready_segments: List[Tuple[int, str, Segment]] = []
async with state.lock:
state.processed_indices.add(segment_index)
state.buffered_transcripts[segment_index] = (
transcript_text or "",
segment,
)
next_index = state.consumed_index + 1
while next_index in state.buffered_transcripts:
text, seg = state.buffered_transcripts.pop(next_index)
ready_segments.append((next_index, text, seg))
state.consumed_index = next_index
next_index += 1
for _, ordered_text, ordered_segment in ready_segments:
await process_user_message(
conversation_id=conversation_id,
user_message=ordered_text,
conversation=conversation,
segment=ordered_segment,
db=db,
user=user,
user_message_timestamp=ordered_segment.created_at
or user_message_timestamp,
)
except Exception as e:
logger.error(
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
exc_info=True,
)
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": f"分段处理失败: {str(e)}",
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
finally:
async with state.lock:
state.pending_indices.discard(segment_index)
# ── 用户消息处理 ────────────────────────────────────────────────
async def process_user_message(
conversation_id: str,
user_message: str,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
user: User = None,
user_message_timestamp: Optional[datetime] = None,
) -> None:
"""处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。"""
store = ConversationHistoryStore(db)
tts_urls: list[str] = []
try:
logger.info(
"process_user_message 开始: conversation_id={} segment_id={} user_chars={}",
conversation_id,
segment.id,
len(user_message or ""),
)
is_from_voice = bool(segment.audio_url)
voice_session_id = _voice_session_id_from_audio_url(segment.audio_url)
audio_dur = getattr(segment, "audio_duration_seconds", None)
t_pipeline = time.perf_counter()
turn = await chat_orchestrator.process_user_message(
conversation_id=conversation_id,
user_message=user_message,
user=user,
conversation=conversation,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
db=db,
apply_extracted_profile_fn=apply_extracted_profile,
get_missing_profile_fields_fn=get_missing_profile_fields,
get_filled_profile_fields_fn=get_filled_profile_fields,
user_message_timestamp=user_message_timestamp,
audio_duration_seconds=audio_dur,
)
if agent_summary_enabled():
logger.info(
"pipeline.process_user_message duration_ms={:.2f} "
"conversation_id={} segment_id={} user_msg_len={} "
"response_segments={} skip_tts={}",
(time.perf_counter() - t_pipeline) * 1000,
conversation_id,
segment.id,
len(user_message or ""),
len(turn.messages),
turn.skip_tts,
)
responses = turn.messages
skip_tts = turn.skip_tts
segment.agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses)
_mark_conversation_active(conversation)
ai_msg_id = await store.record_human_ai_turn(
conversation_id=conversation_id,
user_message=user_message,
responses=responses,
user_message_timestamp=user_message_timestamp,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
audio_duration_seconds=audio_dur,
tts_audio_urls=None,
segment_id=segment.id,
)
if not ai_msg_id:
logger.warning(
"process_user_message: 无有效助手段落responses 为空conversation_id={} segment_id={}",
conversation_id,
segment.id,
)
if conversation_id in manager.active_connections:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": "未生成回复,请重试或稍后再试",
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
tts_epoch_start = _tts_epoch_value(conversation_id)
n = len(responses)
for i, response_text in enumerate(responses):
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": response_text,
"index": i,
"total": n,
"assistant_message_id": ai_msg_id,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
url = None
if not skip_tts:
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
url = await _send_tts_audio(
conversation_id,
response_text,
chunk_index=i,
chunk_total=n,
assistant_message_id=ai_msg_id,
tts_epoch_start=tts_epoch_start,
)
if url:
tts_urls.append(url)
if _tts_epoch_value(conversation_id) != tts_epoch_start:
break
if i < n - 1:
await asyncio.sleep(0.5)
if tts_urls:
await store.attach_ai_tts_audio_urls(
conversation_id,
tts_audio_urls=tts_urls,
segment_id=segment.id,
)
await db.execute(
update(Segment)
.where(Segment.id == segment.id)
.values(tts_audio_urls=tts_urls)
)
await db.commit()
except Exception as e:
if tts_urls:
try:
await store.attach_ai_tts_audio_urls(
conversation_id,
tts_audio_urls=tts_urls,
segment_id=segment.id,
)
await db.execute(
update(Segment)
.where(Segment.id == segment.id)
.values(tts_audio_urls=tts_urls)
)
await db.commit()
except Exception as persist_error:
logger.warning("补写 TTS 元数据失败: {}", persist_error)
logger.error(f"处理用户消息失败: {e}", exc_info=True)
if conversation_id in manager.active_connections:
try:
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": f"生成回应失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
# ── 对话结束处理 ────────────────────────────────────────────────
async def process_conversation_segments(
conversation_id: str, db: AsyncSession, quota_service: "QuotaService"
):
"""
处理对话段落,生成章节(对话结束时调用)
注意:大部分处理已通过 Celery 任务增量完成
这里立即提交所有待处理的段落到 Celery
配额检查通过注入的 quota_service 完成,不直接 import quota 内部函数。
"""
conversation = await db.get(Conversation, conversation_id)
if not conversation or conversation.deleted_at is not None:
return
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.processed == False,
)
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
await background_runner.flush_pending(conversation.user_id)
return
user = await db.get(User, conversation.user_id)
if user:
can_submit, _ = await quota_service.check_can_submit_organize(
user.id, user.subscription_type
)
if not can_submit:
logger.info(
f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}"
)
await background_runner.flush_pending(conversation.user_id)
return
segment_ids = [seg.id for seg in segments]
try:
from app.tasks.memoir_tasks import process_memoir_segments
process_memoir_segments.delay(conversation.user_id, segment_ids)
logger.info(
f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}"
)
except Exception as e:
logger.error(f"提交 Celery 任务失败: {e}")
await background_runner.flush_pending(conversation.user_id)