feat & refactor: 重构agents目录结构;AI回复模块agent结构封装

This commit is contained in:
yangshilin
2026-03-19 10:36:55 +08:00
parent b56fc859cc
commit b16bb2b96c
14 changed files with 754 additions and 781 deletions

View File

@@ -14,8 +14,8 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents import ConversationAgent, MemoryAgent
from app.agents.memoir_processor import BackgroundTaskRunner
from app.agents.prompts.profile_prompts import format_user_profile_context
from app.agents.chat import ChatOrchestrator
from app.agents.memoir import BackgroundTaskRunner
from app.core.db import AsyncSessionLocal
from app.features.conversation.models import Conversation, Segment
from app.features.conversation.ws.connection_manager import manager
@@ -28,7 +28,6 @@ from app.features.conversation.ws.profile_collector import (
from app.features.user.models import User
from app.core.config import settings
from app.core.dependencies import get_asr_provider, get_tts_provider
from app.features.memoir.state_service import get_or_create_state
logger = get_logger(__name__)
@@ -53,6 +52,7 @@ async def _send_tts_audio(conversation_id: str, text: str) -> None:
# ── Agent 实例(从 ConnectionManager 移出) ─────────────────────
conversation_agent = ConversationAgent()
chat_orchestrator = ChatOrchestrator()
memory_agent = MemoryAgent()
background_runner = BackgroundTaskRunner()
@@ -429,82 +429,21 @@ async def process_user_message(
user: User = None,
user_message_timestamp: Optional[datetime] = None,
) -> None:
"""处理用户消息,生成 Agent 回应。支持资料收集模式和正式访谈模式"""
agent = conversation_agent
if user:
missing = get_missing_profile_fields(user)
if missing:
try:
extracted = await agent.extract_profile_from_message(
user_message, missing, conversation_id=conversation_id
)
if extracted:
await apply_extracted_profile(user, extracted, db)
remaining = get_missing_profile_fields(user)
filled = get_filled_profile_fields(user)
is_from_voice = bool(segment.audio_url)
responses = await agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=remaining,
filled_fields=filled,
nickname=user.nickname or "",
is_from_voice=is_from_voice,
voice_session_id=_voice_session_id_from_audio_url(segment.audio_url),
user_message_timestamp=user_message_timestamp,
)
segment.agent_response = "\n\n".join(responses)
_mark_conversation_active(conversation)
await db.commit()
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat(),
})
await _send_tts_audio(conversation_id, response_text)
if i < len(responses) - 1:
await asyncio.sleep(0.5)
return
except Exception as e:
logger.error(f"资料收集处理失败: {e}", exc_info=True)
state = await get_or_create_state(conversation.user_id, db)
if conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
user_profile_context = ""
if user:
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
"""处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent"""
try:
is_from_voice = bool(segment.audio_url)
responses = await agent.generate_response_with_state(
voice_session_id = _voice_session_id_from_audio_url(segment.audio_url)
responses = await chat_orchestrator.process_user_message(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state,
user_profile_context=user_profile_context,
user=user,
conversation=conversation,
is_from_voice=is_from_voice,
voice_session_id=_voice_session_id_from_audio_url(segment.audio_url),
voice_session_id=voice_session_id,
db=db,
apply_extracted_profile_fn=apply_extracted_profile,
get_missing_profile_fields_fn=get_missing_profile_fields,
get_filled_profile_fields_fn=get_filled_profile_fields,
user_message_timestamp=user_message_timestamp,
)