Merge branch 'development' into claude/agent-proactive-chat-UYHu9

This commit is contained in:
Sully
2026-05-11 13:07:18 +08:00
committed by GitHub
128 changed files with 4827 additions and 1144 deletions

View File

@@ -22,7 +22,6 @@ from app.agents.chat.reply_limits import (
from app.agents.chat.schemas import ProfileExtractionOutput
from app.core.agent_logging import agent_span, log_agent_payload, log_agent_summary
from app.core.config import settings
from app.core.dependencies import get_llm_provider
from app.core.llm_call import allm_json_call
from app.core.llm_gateway import LlmGateway, LlmUseCase
from app.core.logging import get_logger
@@ -31,11 +30,53 @@ from app.ports.llm import LLMProvider
logger = get_logger(__name__)
def _get_langchain_llm():
try:
return LlmGateway().langchain_llm_for(LlmUseCase("chat.profile"))
except Exception:
return None
class _ProviderBackedProfileGateway:
def __init__(self, provider: LLMProvider) -> None:
self._provider = provider
async def chat_text(
self,
messages: list[dict],
*,
use_case: LlmUseCase | None = None,
temperature: float | None = None,
model: str | None = None,
max_tokens: int | None = None,
) -> str:
resolved_temperature = temperature
if resolved_temperature is None:
resolved_temperature = (
use_case.temperature
if use_case and use_case.temperature is not None
else 0.7
)
return await self._provider.complete(
messages,
temperature=resolved_temperature,
model=model if model is not None else (use_case.model if use_case else None),
max_tokens=(
max_tokens
if max_tokens is not None
else (use_case.max_tokens if use_case else None)
),
)
async def json_object(
self,
prompt: str,
schema: type[ProfileExtractionOutput],
*,
use_case: LlmUseCase,
fallback_factory: Any = None,
) -> ProfileExtractionOutput:
return await allm_json_call(
getattr(self._provider, "langchain_llm", None),
prompt,
schema,
max_tokens=use_case.max_tokens or 1024,
agent=use_case.name,
fallback_factory=fallback_factory,
)
def _langchain_messages_to_port(messages: List[Any]) -> list[dict]:
@@ -66,14 +107,17 @@ def _message_contents_char_count(messages: List[Any]) -> int:
class ProfileAgent:
"""用户资料收集 Specialist Agent"""
def __init__(self, llm_provider: LLMProvider | None = None):
self._llm_provider = llm_provider
self.llm = _get_langchain_llm()
def _provider(self) -> LLMProvider:
if self._llm_provider is not None:
return self._llm_provider
return get_llm_provider()
def __init__(
self,
llm_provider: LLMProvider | None = None,
llm_gateway: Any | None = None,
) -> None:
if llm_gateway is not None:
self._llm_gateway = llm_gateway
elif llm_provider is not None:
self._llm_gateway = _ProviderBackedProfileGateway(llm_provider)
else:
self._llm_gateway = LlmGateway()
async def _invoke_chat(
self,
@@ -88,8 +132,9 @@ class ProfileAgent:
with agent_span(
logger, f"{agent_name}.llm", conversation_id=conversation_id or ""
):
response_text = await self._provider().complete(
response_text = await self._llm_gateway.chat_text(
port_messages,
use_case=LlmUseCase("chat.profile", max_tokens=max_tokens),
max_tokens=max_tokens,
)
logger.info(
@@ -130,7 +175,7 @@ class ProfileAgent:
conversation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""从用户消息中提取资料字段,不持久化"""
if not self.llm or not missing_fields:
if not missing_fields:
return {}
recent_dialogue = ""
if conversation_id:
@@ -151,12 +196,13 @@ class ProfileAgent:
prompt = get_profile_extraction_prompt(
user_message, missing_fields, recent_dialogue=recent_dialogue or None
)
parsed = await allm_json_call(
self.llm,
parsed = await self._llm_gateway.json_object(
prompt,
ProfileExtractionOutput,
max_tokens=settings.chat_profile_extract_max_tokens,
agent="ProfileAgent.extract_profile_from_message",
use_case=LlmUseCase(
"ProfileAgent.extract_profile_from_message",
max_tokens=settings.chat_profile_extract_max_tokens,
),
fallback_factory=lambda: ProfileExtractionOutput(),
)
result = {}
@@ -197,8 +243,6 @@ class ProfileAgent:
interview_stage_hint: str = "",
) -> List[str]:
"""生成资料追问回复,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return ["谢谢!还能告诉我更多吗?"]
try:
prompt = get_profile_followup_prompt(
missing_fields,
@@ -260,8 +304,6 @@ class ProfileAgent:
nickname: str = "",
) -> List[str]:
"""生成资料收集开场白,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
try:
prompt = get_profile_greeting_prompt(missing_fields, nickname)
hw = await get_history_with_window(