agent init
This commit is contained in:
@@ -8,7 +8,8 @@ from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from services.llm_service import llm_service
|
||||
from .prompts import ConversationStage, get_conversation_prompt
|
||||
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt
|
||||
from .state_schema import MemoirStateSchema
|
||||
|
||||
|
||||
class ConversationAgent:
|
||||
@@ -82,6 +83,60 @@ class ConversationAgent:
|
||||
response = chain.predict(input=user_message)
|
||||
|
||||
return response
|
||||
|
||||
def generate_response_with_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema
|
||||
) -> List[str]:
|
||||
"""
|
||||
基于共享状态生成引导式回复
|
||||
|
||||
Args:
|
||||
conversation_id: 对话 ID
|
||||
user_message: 用户消息
|
||||
memoir_state: 共享状态
|
||||
|
||||
Returns:
|
||||
Agent 回应文本列表(支持多条消息)
|
||||
"""
|
||||
if not self.llm:
|
||||
return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
|
||||
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
filled_slots = {
|
||||
key: value.snippet
|
||||
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots=empty_slots,
|
||||
filled_slots=filled_slots,
|
||||
user_message=user_message,
|
||||
)
|
||||
|
||||
memory = self._get_memory(conversation_id)
|
||||
prompt_template = PromptTemplate(
|
||||
input_variables=["history", "input"],
|
||||
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
|
||||
)
|
||||
|
||||
chain = ConversationChain(
|
||||
llm=self.llm,
|
||||
prompt=prompt_template,
|
||||
memory=memory,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
response = chain.predict(input=user_message)
|
||||
|
||||
# 支持多条消息,用 [SPLIT] 分隔
|
||||
messages = [msg.strip() for msg in response.split("[SPLIT]") if msg.strip()]
|
||||
# 最多返回 3 条
|
||||
return messages[:3] if messages else [response]
|
||||
|
||||
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user