Reuse memoir JSON payload parsing
This commit is contained in:
23
api/services/memoir_images/json_payload.py
Normal file
23
api/services/memoir_images/json_payload.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import re
|
||||
|
||||
_MARKDOWN_JSON_FENCE_RE = re.compile(
|
||||
r"^\s*```(?:json)?\s*(.*?)\s*```\s*$",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def extract_json_payload(raw_response: str | None) -> str:
|
||||
cleaned = (raw_response or "").strip()
|
||||
fenced_match = _MARKDOWN_JSON_FENCE_RE.match(cleaned)
|
||||
if fenced_match:
|
||||
cleaned = fenced_match.group(1).strip()
|
||||
|
||||
if cleaned.startswith("{") and cleaned.endswith("}"):
|
||||
return cleaned
|
||||
|
||||
start = cleaned.find("{")
|
||||
end = cleaned.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return cleaned[start : end + 1].strip()
|
||||
|
||||
return cleaned
|
||||
@@ -3,11 +3,11 @@ import logging
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
|
||||
from .json_payload import extract_json_payload
|
||||
from .settings import MemoirImageSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_CJK_RE = re.compile(r"[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]")
|
||||
_MARKDOWN_JSON_FENCE_RE = re.compile(r"^\s*```(?:json)?\s*(.*?)\s*```\s*$", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
class MemoirImagePromptService:
|
||||
@@ -63,7 +63,7 @@ class MemoirImagePromptService:
|
||||
+ json.dumps(llm_input, ensure_ascii=False)
|
||||
)
|
||||
raw_response = response.content
|
||||
parsed = json.loads(_extract_json_payload(raw_response))
|
||||
parsed = json.loads(extract_json_payload(raw_response))
|
||||
return {
|
||||
"prompt": _ensure_style_in_prompt(parsed["prompt"], parsed.get("style", style)),
|
||||
"style": parsed.get("style", style),
|
||||
@@ -121,23 +121,6 @@ def _contains_cjk(value: str) -> bool:
|
||||
return bool(_CJK_RE.search(value or ""))
|
||||
|
||||
|
||||
def _extract_json_payload(raw_response: str | None) -> str:
|
||||
cleaned = (raw_response or "").strip()
|
||||
fenced_match = _MARKDOWN_JSON_FENCE_RE.match(cleaned)
|
||||
if fenced_match:
|
||||
cleaned = fenced_match.group(1).strip()
|
||||
|
||||
if cleaned.startswith("{") and cleaned.endswith("}"):
|
||||
return cleaned
|
||||
|
||||
start = cleaned.find("{")
|
||||
end = cleaned.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return cleaned[start : end + 1].strip()
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _ensure_style_in_prompt(prompt: str, style: str) -> str:
|
||||
cleaned_prompt = (prompt or "").strip()
|
||||
cleaned_style = (style or "").strip()
|
||||
|
||||
@@ -31,6 +31,7 @@ from agents.prompts.profile_prompts import format_user_profile_context
|
||||
import hashlib
|
||||
|
||||
from services.memoir_images.parser import build_initial_image_assets, parse_image_placeholders
|
||||
from services.memoir_images.json_payload import extract_json_payload
|
||||
from services.memoir_images.prompting import MemoirImagePromptService
|
||||
from services.memoir_images.provider import LiblibImageProvider
|
||||
from services.memoir_images.schema import (
|
||||
@@ -389,8 +390,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
stage_slots=state.slots.get(detected_stage, {}),
|
||||
)
|
||||
response = llm.invoke(prompt)
|
||||
content = response.content.strip()
|
||||
parsed = json.loads(content)
|
||||
parsed = json.loads(extract_json_payload(response.content))
|
||||
detected_stage = parsed.get("detected_stage", detected_stage)
|
||||
extracted_slots = parsed.get("slots", {}) or {}
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
|
||||
@@ -9,20 +9,22 @@ class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase):
|
||||
@patch("api.tasks.memoir_tasks._update_task_status_sync")
|
||||
@patch("api.tasks.memoir_tasks._release_chapter_lock")
|
||||
@patch("api.tasks.memoir_tasks._acquire_chapter_lock", return_value=True)
|
||||
@patch("api.tasks.memoir_tasks._update_slot_sync")
|
||||
@patch("api.tasks.memoir_tasks._classify_chapter_category", return_value="childhood")
|
||||
@patch("api.tasks.memoir_tasks._get_or_create_state_sync")
|
||||
@patch("api.tasks.memoir_tasks.llm_service.get_llm", return_value=None)
|
||||
@patch("api.tasks.memoir_tasks.llm_service.get_llm")
|
||||
@patch("api.tasks.memoir_tasks.generate_chapter_images.delay")
|
||||
@patch("api.tasks.memoir_tasks.SessionLocal")
|
||||
@patch("api.tasks.memoir_tasks.MemoirImageSettings.from_env")
|
||||
def test_process_memoir_segments_enqueues_image_jobs_after_commit(
|
||||
def test_process_memoir_segments_parses_markdown_wrapped_state_extraction_json(
|
||||
self,
|
||||
settings_from_env,
|
||||
session_local_cls,
|
||||
delay_mock,
|
||||
_get_llm,
|
||||
get_llm_mock,
|
||||
get_state_mock,
|
||||
_classify_mock,
|
||||
update_slot_mock,
|
||||
_acquire_lock_mock,
|
||||
_release_lock_mock,
|
||||
_update_status_mock,
|
||||
@@ -38,6 +40,23 @@ class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase):
|
||||
liblib_template_uuid="tpl-uuid",
|
||||
)
|
||||
get_state_mock.return_value = SimpleNamespace(current_stage="childhood", slots={})
|
||||
update_slot_mock.return_value = SimpleNamespace(current_stage="childhood", slots={})
|
||||
llm = Mock()
|
||||
llm.invoke.side_effect = [
|
||||
SimpleNamespace(
|
||||
content="""```json
|
||||
{
|
||||
"detected_stage": "childhood",
|
||||
"slots": {
|
||||
"family_memory": "外婆总在门口等我"
|
||||
}
|
||||
}
|
||||
```"""
|
||||
),
|
||||
SimpleNamespace(content="童年的门前"),
|
||||
SimpleNamespace(content="新的章节正文\n\n{{IMAGE:南方小镇的青石板路}}"),
|
||||
]
|
||||
get_llm_mock.return_value = llm
|
||||
|
||||
segment = SimpleNamespace(
|
||||
id="segment-1",
|
||||
@@ -69,6 +88,14 @@ class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase):
|
||||
)
|
||||
process_memoir_segments.run.__func__(task_self, "user-1", ["segment-1"])
|
||||
|
||||
update_slot_mock.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
stage="childhood",
|
||||
slot_name="family_memory",
|
||||
snippet="外婆总在门口等我",
|
||||
segment_ids=["segment-1"],
|
||||
db=db,
|
||||
)
|
||||
self.assertIn("commit", events)
|
||||
delay_events = [event for event in events if event.startswith("delay:")]
|
||||
self.assertEqual(len(delay_events), 1)
|
||||
|
||||
Reference in New Issue
Block a user