86 lines
2.2 KiB
Python
86 lines
2.2 KiB
Python
|
|
"""ProfileAgent LLM gateway injection regression tests."""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
from types import SimpleNamespace
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
from app.agents.chat.profile_agent import ProfileAgent
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _Response:
|
|||
|
|
def __init__(self, content: str) -> None:
|
|||
|
|
self.content = content
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _BoundJsonLlm:
|
|||
|
|
async def ainvoke(self, _prompt: str) -> _Response:
|
|||
|
|
return _Response(
|
|||
|
|
json.dumps(
|
|||
|
|
{
|
|||
|
|
"birth_year": 1988,
|
|||
|
|
"birth_place": "杭州",
|
|||
|
|
"grew_up_place": "杭州",
|
|||
|
|
"occupation": "工程师",
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _JsonLlm:
|
|||
|
|
def bind(self, **_kwargs) -> _BoundJsonLlm: # noqa: ANN003
|
|||
|
|
return _BoundJsonLlm()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class _Provider:
|
|||
|
|
langchain_llm = _JsonLlm()
|
|||
|
|
|
|||
|
|
def __init__(self) -> None:
|
|||
|
|
self.messages: list[dict] = []
|
|||
|
|
|
|||
|
|
async def complete(self, messages: list[dict], **_kwargs) -> str: # noqa: ANN003
|
|||
|
|
self.messages = messages
|
|||
|
|
return "谢谢分享!还能再说说吗?"
|
|||
|
|
|
|||
|
|
async def stream(self, *_args, **_kwargs): # noqa: ANN003
|
|||
|
|
if False:
|
|||
|
|
yield ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_profile_agent_llm_provider_injection_covers_chat_and_json(
|
|||
|
|
monkeypatch: pytest.MonkeyPatch,
|
|||
|
|
) -> None:
|
|||
|
|
async def fake_history(*_args, **_kwargs):
|
|||
|
|
return SimpleNamespace(window=[], turn_total=0)
|
|||
|
|
|
|||
|
|
monkeypatch.setattr(
|
|||
|
|
"app.agents.chat.profile_agent.get_history_with_window",
|
|||
|
|
fake_history,
|
|||
|
|
)
|
|||
|
|
provider = _Provider()
|
|||
|
|
agent = ProfileAgent(llm_provider=provider)
|
|||
|
|
|
|||
|
|
extracted = await agent.extract_profile_from_message(
|
|||
|
|
"我是一名工程师,1988 年出生在杭州。",
|
|||
|
|
["birth_year", "birth_place", "occupation"],
|
|||
|
|
)
|
|||
|
|
followup = await agent.generate_profile_followup(
|
|||
|
|
conversation_id="c1",
|
|||
|
|
user_message="我在杭州长大。",
|
|||
|
|
missing_fields=["grew_up_place"],
|
|||
|
|
filled_fields={"birth_year": "1988"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert extracted == {
|
|||
|
|
"birth_year": 1988,
|
|||
|
|
"birth_place": "杭州",
|
|||
|
|
"grew_up_place": "杭州",
|
|||
|
|
"occupation": "工程师",
|
|||
|
|
}
|
|||
|
|
assert followup
|
|||
|
|
assert provider.messages
|
|||
|
|
assert provider.messages[0]["role"] == "system"
|