diff --git a/api/services/memoir_images/json_payload.py b/api/services/memoir_images/json_payload.py new file mode 100644 index 0000000..b91f7d2 --- /dev/null +++ b/api/services/memoir_images/json_payload.py @@ -0,0 +1,23 @@ +import re + +_MARKDOWN_JSON_FENCE_RE = re.compile( + r"^\s*```(?:json)?\s*(.*?)\s*```\s*$", + re.IGNORECASE | re.DOTALL, +) + + +def extract_json_payload(raw_response: str | None) -> str: + cleaned = (raw_response or "").strip() + fenced_match = _MARKDOWN_JSON_FENCE_RE.match(cleaned) + if fenced_match: + cleaned = fenced_match.group(1).strip() + + if cleaned.startswith("{") and cleaned.endswith("}"): + return cleaned + + start = cleaned.find("{") + end = cleaned.rfind("}") + if start != -1 and end != -1 and end > start: + return cleaned[start : end + 1].strip() + + return cleaned diff --git a/api/services/memoir_images/prompting.py b/api/services/memoir_images/prompting.py index c5aa34b..c9fab30 100644 --- a/api/services/memoir_images/prompting.py +++ b/api/services/memoir_images/prompting.py @@ -3,11 +3,11 @@ import logging import re from typing import Any, Optional +from .json_payload import extract_json_payload from .settings import MemoirImageSettings logger = logging.getLogger(__name__) _CJK_RE = re.compile(r"[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]") -_MARKDOWN_JSON_FENCE_RE = re.compile(r"^\s*```(?:json)?\s*(.*?)\s*```\s*$", re.IGNORECASE | re.DOTALL) class MemoirImagePromptService: @@ -63,7 +63,7 @@ class MemoirImagePromptService: + json.dumps(llm_input, ensure_ascii=False) ) raw_response = response.content - parsed = json.loads(_extract_json_payload(raw_response)) + parsed = json.loads(extract_json_payload(raw_response)) return { "prompt": _ensure_style_in_prompt(parsed["prompt"], parsed.get("style", style)), "style": parsed.get("style", style), @@ -121,23 +121,6 @@ def _contains_cjk(value: str) -> bool: return bool(_CJK_RE.search(value or "")) -def _extract_json_payload(raw_response: str | None) -> str: - cleaned = (raw_response or "").strip() - fenced_match = _MARKDOWN_JSON_FENCE_RE.match(cleaned) - if fenced_match: - cleaned = fenced_match.group(1).strip() - - if cleaned.startswith("{") and cleaned.endswith("}"): - return cleaned - - start = cleaned.find("{") - end = cleaned.rfind("}") - if start != -1 and end != -1 and end > start: - return cleaned[start : end + 1].strip() - - return cleaned - - def _ensure_style_in_prompt(prompt: str, style: str) -> str: cleaned_prompt = (prompt or "").strip() cleaned_style = (style or "").strip() diff --git a/api/tasks/memoir_tasks.py b/api/tasks/memoir_tasks.py index 181fb23..dc0f171 100644 --- a/api/tasks/memoir_tasks.py +++ b/api/tasks/memoir_tasks.py @@ -31,6 +31,7 @@ from agents.prompts.profile_prompts import format_user_profile_context import hashlib from services.memoir_images.parser import build_initial_image_assets, parse_image_placeholders +from services.memoir_images.json_payload import extract_json_payload from services.memoir_images.prompting import MemoirImagePromptService from services.memoir_images.provider import LiblibImageProvider from services.memoir_images.schema import ( @@ -389,8 +390,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): stage_slots=state.slots.get(detected_stage, {}), ) response = llm.invoke(prompt) - content = response.content.strip() - parsed = json.loads(content) + parsed = json.loads(extract_json_payload(response.content)) detected_stage = parsed.get("detected_stage", detected_stage) extracted_slots = parsed.get("slots", {}) or {} except (json.JSONDecodeError, Exception) as e: diff --git a/api/tests/test_process_memoir_segments_image_enqueue.py b/api/tests/test_process_memoir_segments_image_enqueue.py index 7bbf1cd..7ea4b3a 100644 --- a/api/tests/test_process_memoir_segments_image_enqueue.py +++ b/api/tests/test_process_memoir_segments_image_enqueue.py @@ -9,20 +9,22 @@ class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase): @patch("api.tasks.memoir_tasks._update_task_status_sync") @patch("api.tasks.memoir_tasks._release_chapter_lock") @patch("api.tasks.memoir_tasks._acquire_chapter_lock", return_value=True) + @patch("api.tasks.memoir_tasks._update_slot_sync") @patch("api.tasks.memoir_tasks._classify_chapter_category", return_value="childhood") @patch("api.tasks.memoir_tasks._get_or_create_state_sync") - @patch("api.tasks.memoir_tasks.llm_service.get_llm", return_value=None) + @patch("api.tasks.memoir_tasks.llm_service.get_llm") @patch("api.tasks.memoir_tasks.generate_chapter_images.delay") @patch("api.tasks.memoir_tasks.SessionLocal") @patch("api.tasks.memoir_tasks.MemoirImageSettings.from_env") - def test_process_memoir_segments_enqueues_image_jobs_after_commit( + def test_process_memoir_segments_parses_markdown_wrapped_state_extraction_json( self, settings_from_env, session_local_cls, delay_mock, - _get_llm, + get_llm_mock, get_state_mock, _classify_mock, + update_slot_mock, _acquire_lock_mock, _release_lock_mock, _update_status_mock, @@ -38,6 +40,23 @@ class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase): liblib_template_uuid="tpl-uuid", ) get_state_mock.return_value = SimpleNamespace(current_stage="childhood", slots={}) + update_slot_mock.return_value = SimpleNamespace(current_stage="childhood", slots={}) + llm = Mock() + llm.invoke.side_effect = [ + SimpleNamespace( + content="""```json +{ + "detected_stage": "childhood", + "slots": { + "family_memory": "外婆总在门口等我" + } +} +```""" + ), + SimpleNamespace(content="童年的门前"), + SimpleNamespace(content="新的章节正文\n\n{{IMAGE:南方小镇的青石板路}}"), + ] + get_llm_mock.return_value = llm segment = SimpleNamespace( id="segment-1", @@ -69,6 +88,14 @@ class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase): ) process_memoir_segments.run.__func__(task_self, "user-1", ["segment-1"]) + update_slot_mock.assert_called_once_with( + user_id="user-1", + stage="childhood", + slot_name="family_memory", + snippet="外婆总在门口等我", + segment_ids=["segment-1"], + db=db, + ) self.assertIn("commit", events) delay_events = [event for event in events if event.startswith("delay:")] self.assertEqual(len(delay_events), 1)