chore/ 删除无用文件

This commit is contained in:
Kevin
2026-03-19 14:36:14 +08:00
parent 2f60858c9c
commit c6e07ce5ca
135 changed files with 2111 additions and 4510 deletions

View File

@@ -1,4 +1,5 @@
"""核心消息处理管道Agent 调用、ASR 转写、分段有序聚合"""
import asyncio
import base64
from app.core.logging import get_logger
@@ -19,7 +20,10 @@ from app.agents.memoir import BackgroundTaskRunner
from app.core.db import AsyncSessionLocal
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.ws.connection_manager import manager
from app.features.conversation.ws.message_types import LEGACY_VOICE_SESSION_ID, MessageType
from app.features.conversation.ws.message_types import (
LEGACY_VOICE_SESSION_ID,
MessageType,
)
from app.features.conversation.ws.profile_collector import (
apply_extracted_profile,
get_filled_profile_fields,
@@ -42,15 +46,18 @@ async def _send_tts_audio(conversation_id: str, text: str) -> None:
"TTS skipped: synthesize returned empty. Check TTS config in .env"
)
return
await manager.send_message(conversation_id, {
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": {
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"format": settings.tts_codec,
await manager.send_message(
conversation_id,
{
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": {
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
"format": settings.tts_codec,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
)
except Exception as e:
err_str = str(e)
if "PkgExhausted" in err_str:
@@ -61,6 +68,7 @@ async def _send_tts_audio(conversation_id: str, text: str) -> None:
else:
logger.error("TTS synthesize failed: %s", e)
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
conversation_agent = ConversationAgent()
chat_orchestrator = ChatOrchestrator()
@@ -70,6 +78,7 @@ background_runner = BackgroundTaskRunner()
# ── 分段流状态 ──────────────────────────────────────────────────
@dataclass
class SegmentStreamState:
"""会话内分段处理状态(用于并行 ASR + 有序聚合)"""
@@ -136,11 +145,14 @@ def cleanup_segment_states(conversation_id: str) -> None:
# ── 工具函数 ────────────────────────────────────────────────────
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _mark_conversation_active(conversation: Conversation, at: Optional[datetime] = None) -> datetime:
def _mark_conversation_active(
conversation: Conversation, at: Optional[datetime] = None
) -> datetime:
activity_time = at or _utc_now()
conversation.last_message_at = activity_time
return activity_time
@@ -152,7 +164,9 @@ def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str:
return LEGACY_VOICE_SESSION_ID
def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]:
def _voice_session_id_from_client_segment_id(
client_segment_id: Optional[str],
) -> Optional[str]:
if not client_segment_id:
return None
session_id, separator, _ = client_segment_id.rpartition("-")
@@ -171,11 +185,14 @@ def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]
prefix = "audio-segment:"
if not audio_url or not audio_url.startswith(prefix):
return None
payload = audio_url[len(prefix):]
payload = audio_url[len(prefix) :]
voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":")
try:
if separator:
return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw))
return (
_normalize_voice_session_id(voice_session_id_raw),
int(segment_index_raw),
)
return (LEGACY_VOICE_SESSION_ID, int(payload))
except ValueError:
return None
@@ -201,14 +218,21 @@ async def _find_existing_segment_by_index(
segment_index: int,
) -> Optional[Segment]:
segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index)
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.audio_url == segment_audio_url,
).order_by(Segment.created_at.desc())
stmt = (
select(Segment)
.where(
Segment.conversation_id == conversation_id,
Segment.audio_url == segment_audio_url,
)
.order_by(Segment.created_at.desc())
)
result = await db.execute(stmt)
candidates = result.scalars().all()
for item in candidates:
if item.conversation_id == conversation_id and item.audio_url == segment_audio_url:
if (
item.conversation_id == conversation_id
and item.audio_url == segment_audio_url
):
return item
return None
@@ -252,16 +276,19 @@ async def _send_segment_transition_feedback(
segment_index: int,
) -> None:
"""发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。"""
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": LISTENING_FEEDBACK_TEXT,
"transition": True,
"segment_index": segment_index,
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": LISTENING_FEEDBACK_TEXT,
"transition": True,
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
)
async def _delayed_listening_feedback(
@@ -281,6 +308,7 @@ async def _delayed_listening_feedback(
# ── 分段语音异步处理 ────────────────────────────────────────────
async def process_audio_segment(
conversation_id: str,
user_id: str,
@@ -298,18 +326,24 @@ async def process_audio_segment(
conversation = await db.get(Conversation, conversation_id)
user = await db.get(User, user_id)
if not conversation:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "对话不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "对话不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
if not user:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "用户不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": "用户不存在,分段处理已取消"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
async with state.lock:
@@ -320,14 +354,18 @@ async def process_audio_segment(
)
if should_prime_state:
persisted_contiguous_index = await _get_persisted_contiguous_segment_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
persisted_contiguous_index = (
await _get_persisted_contiguous_segment_index(
db=db,
conversation_id=conversation_id,
voice_session_id=voice_session_id,
)
)
if persisted_contiguous_index >= 0:
async with state.lock:
state.consumed_index = max(state.consumed_index, persisted_contiguous_index)
state.consumed_index = max(
state.consumed_index, persisted_contiguous_index
)
try:
audio_bytes = base64.b64decode(audio_base64)
@@ -336,28 +374,34 @@ async def process_audio_segment(
transcript_text = await get_asr_provider().transcribe(
audio_bytes, format="m4a"
)
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": transcript_text or "",
"audio_duration": audio_duration,
"voice_session_id": voice_session_id,
"segment_index": segment_index,
"is_last": is_last,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
if _is_transcribe_failure(transcript_text):
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
await manager.send_message(
conversation_id,
{
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"message": f"分段 {segment_index} 转写失败,请重试该片段",
"text": transcript_text or "",
"audio_duration": audio_duration,
"voice_session_id": voice_session_id,
"segment_index": segment_index,
"is_last": is_last,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
},
)
if _is_transcribe_failure(transcript_text):
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": f"分段 {segment_index} 转写失败,请重试该片段",
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
return
existing_segment = await _find_existing_segment_by_index(
@@ -391,7 +435,10 @@ async def process_audio_segment(
ready_segments: List[Tuple[int, str, Segment]] = []
async with state.lock:
state.processed_indices.add(segment_index)
state.buffered_transcripts[segment_index] = (transcript_text or "", segment)
state.buffered_transcripts[segment_index] = (
transcript_text or "",
segment,
)
next_index = state.consumed_index + 1
while next_index in state.buffered_transcripts:
@@ -408,7 +455,8 @@ async def process_audio_segment(
segment=ordered_segment,
db=db,
user=user,
user_message_timestamp=ordered_segment.created_at or user_message_timestamp,
user_message_timestamp=ordered_segment.created_at
or user_message_timestamp,
)
except Exception as e:
@@ -416,14 +464,17 @@ async def process_audio_segment(
f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}",
exc_info=True,
)
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {
"message": f"分段处理失败: {str(e)}",
"segment_index": segment_index,
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {
"message": f"分段处理失败: {str(e)}",
"segment_index": segment_index,
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
)
finally:
async with state.lock:
state.pending_indices.discard(segment_index)
@@ -431,6 +482,7 @@ async def process_audio_segment(
# ── 用户消息处理 ────────────────────────────────────────────────
async def process_user_message(
conversation_id: str,
user_message: str,
@@ -463,12 +515,19 @@ async def process_user_message(
await db.commit()
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {
"text": response_text,
"index": i,
"total": len(responses),
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
await _send_tts_audio(conversation_id, response_text)
if i < len(responses) - 1:
await asyncio.sleep(0.5)
@@ -477,17 +536,21 @@ async def process_user_message(
logger.error(f"处理用户消息失败: {e}", exc_info=True)
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"生成回应失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await manager.send_message(
conversation_id,
{
"type": MessageType.ERROR,
"data": {"message": f"生成回应失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
# ── 对话结束处理 ────────────────────────────────────────────────
async def process_conversation_segments(
conversation_id: str, db: AsyncSession, quota_service: "QuotaService"
):
@@ -528,8 +591,11 @@ async def process_conversation_segments(
segment_ids = [seg.id for seg in segments]
try:
from app.tasks.memoir_tasks import process_memoir_segments
process_memoir_segments.delay(conversation.user_id, segment_ids)
logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}")
logger.info(
f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}"
)
except Exception as e:
logger.error(f"提交 Celery 任务失败: {e}")